From 92aeaf5b42e1122f25a1ee93db3e9426eee05d61 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Fri, 20 Oct 2023 10:58:16 +0200 Subject: [PATCH] use fakeDNSServer Change-Id: I56dfdd7457782419801b208c0a6894f9e88a8a7f --- src/net/dnsclient_unix_test.go | 90 -------- src/net/resolverdialfunc_test.go | 356 +++++++++++-------------------- 2 files changed, 124 insertions(+), 322 deletions(-) diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index dfc9773a66..a358e8b6e5 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -867,96 +867,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { } } -type fakeDNSServer struct { - rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) - alwaysTCP bool -} - -func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { - if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" { - return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil - } - return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil -} - -type fakeDNSConn struct { - Conn - tcp bool - server *fakeDNSServer - n string - s string - q dnsmessage.Message - t time.Time - buf []byte -} - -func (f *fakeDNSConn) Close() error { - return nil -} - -func (f *fakeDNSConn) Read(b []byte) (int, error) { - if len(f.buf) > 0 { - n := copy(b, f.buf) - f.buf = f.buf[n:] - return n, nil - } - - resp, err := f.server.rh(f.n, f.s, f.q, f.t) - if err != nil { - return 0, err - } - - bb := make([]byte, 2, 514) - bb, err = resp.AppendPack(bb) - if err != nil { - return 0, fmt.Errorf("cannot marshal DNS message: %v", err) - } - - if f.tcp { - l := len(bb) - 2 - bb[0] = byte(l >> 8) - bb[1] = byte(l) - f.buf = bb - return f.Read(b) - } - - bb = bb[2:] - if len(b) < len(bb) { - return 0, errors.New("read would fragment DNS message") - } - - copy(b, bb) - return len(bb), nil -} - -func (f *fakeDNSConn) Write(b []byte) (int, error) { - if f.tcp && len(b) >= 2 { - b = b[2:] - } - if f.q.Unpack(b) != nil { - return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b)) - } - return len(b), nil -} - -func (f *fakeDNSConn) SetDeadline(t time.Time) error { - f.t = t - return nil -} - -type fakeDNSPacketConn struct { - PacketConn - fakeDNSConn -} - -func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error { - return f.fakeDNSConn.SetDeadline(t) -} - -func (f *fakeDNSPacketConn) Close() error { - return f.fakeDNSConn.Close() -} - // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go index 1af4199269..17af27e9ad 100644 --- a/src/net/resolverdialfunc_test.go +++ b/src/net/resolverdialfunc_test.go @@ -8,7 +8,6 @@ package net import ( - "bytes" "context" "errors" "fmt" @@ -21,36 +20,62 @@ import ( ) func TestResolverDialFunc(t *testing.T) { - r := &Resolver{ - PreferGo: true, - Dial: newResolverDialFunc(&resolverDialHandler{ - StartDial: func(network, address string) error { - t.Logf("StartDial(%q, %q) ...", network, address) - return nil - }, - Question: func(h dnsmessage.Header, q dnsmessage.Question) { - t.Logf("Header: %+v for %q (type=%v, class=%v)", h, - q.Name.String(), q.Type, q.Class) - }, - // TODO: add test without HandleA* hooks specified at all, that Go - // doesn't issue retries; map to something terminal. - HandleA: func(w AWriter, name string) error { - w.AddIP([4]byte{1, 2, 3, 4}) - w.AddIP([4]byte{5, 6, 7, 8}) - return nil - }, - HandleAAAA: func(w AAAAWriter, name string) error { - w.AddIP([16]byte{1: 1, 15: 15}) - w.AddIP([16]byte{2: 2, 14: 14}) - return nil - }, - HandleSRV: func(w SRVWriter, name string) error { - w.AddSRV(1, 2, 80, "foo.bar.") - w.AddSRV(2, 3, 81, "bar.baz.") - return nil - }, - }), + fake := fakeDNSServer{ + rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: q.Questions, + Answers: []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: q.Questions[0].Type, + Class: dnsmessage.ClassINET, + }, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: q.Questions[0].Type, + Class: dnsmessage.ClassINET, + }, + }, + }, + } + + switch q.Questions[0].Type { + case dnsmessage.TypeA: + r.Answers[0].Body = &dnsmessage.AResource{A: [4]byte{1, 2, 3, 4}} + r.Answers[1].Body = &dnsmessage.AResource{A: [4]byte{5, 6, 7, 8}} + case dnsmessage.TypeAAAA: + r.Answers[0].Body = &dnsmessage.AAAAResource{AAAA: [16]byte{1: 1, 15: 15}} + r.Answers[1].Body = &dnsmessage.AAAAResource{AAAA: [16]byte{2: 2, 14: 14}} + case dnsmessage.TypeSRV: + r.Answers[0].Body = &dnsmessage.SRVResource{ + Priority: 1, + Weight: 2, + Port: 80, + Target: dnsmessage.MustNewName("foo.bar."), + } + r.Answers[1].Body = &dnsmessage.SRVResource{ + Priority: 2, + Weight: 3, + Port: 81, + Target: dnsmessage.MustNewName("bar.baz."), + } + default: + panic("unexpected DNS type") + } + return r, nil + }, } + + r := &Resolver{PreferGo: true, Dial: fake.DialContext} + ctx := context.Background() const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." @@ -101,225 +126,92 @@ func sortedIPStrings(ips []IP) []string { return ret } -func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { - return func(ctx context.Context, network, address string) (Conn, error) { - a := &resolverFuncConn{ - h: h, - network: network, - address: address, - ttl: 10, // 10 second default if unset - } - if h.StartDial != nil { - if err := h.StartDial(network, address); err != nil { - return nil, err - } - } - return a, nil - } +type fakeDNSServer struct { + rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) + alwaysTCP bool } -type resolverDialHandler struct { - // StartDial, if non-nil, is called when Go first calls Resolver.Dial. - // Any error returned aborts the dial and is returned unwrapped. - StartDial func(network, address string) error - - Question func(dnsmessage.Header, dnsmessage.Question) - - // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). - // A nil error means success. - HandleA func(w AWriter, name string) error - HandleAAAA func(w AAAAWriter, name string) error - HandleSRV func(w SRVWriter, name string) error +func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { + if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" { + return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil + } + return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil } -type ResponseWriter struct{ a *resolverFuncConn } - -func (w ResponseWriter) header() dnsmessage.ResourceHeader { - q := w.a.q - return dnsmessage.ResourceHeader{ - Name: q.Name, - Type: q.Type, - Class: q.Class, - TTL: w.a.ttl, - } +type fakeDNSConn struct { + Conn + tcp bool + server *fakeDNSServer + n string + s string + q dnsmessage.Message + t time.Time + buf []byte } -// SetTTL sets the TTL for subsequent written resources. -// Once a resource has been written, SetTTL calls are no-ops. -// That is, it can only be called at most once, before anything -// else is written. -func (w ResponseWriter) SetTTL(seconds uint32) { - // ... intention is last one wins and mutates all previously - // written records too, but that's a little annoying. - // But it's also annoying if the requirement is it needs to be set - // last. - // And it's also annoying if it's possible for users to set - // different TTLs per Answer. - if w.a.wrote { - return - } - w.a.ttl = seconds - -} - -type AWriter struct{ ResponseWriter } - -func (w AWriter) AddIP(v4 [4]byte) { - w.a.wrote = true - err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) - if err != nil { - panic(err) - } -} - -type AAAAWriter struct{ ResponseWriter } - -func (w AAAAWriter) AddIP(v6 [16]byte) { - w.a.wrote = true - err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) - if err != nil { - panic(err) - } -} - -type SRVWriter struct{ ResponseWriter } - -// AddSRV adds a SRV record. The target name must end in a period and -// be 63 bytes or fewer. -func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { - targetName, err := dnsmessage.NewName(target) - if err != nil { - return err - } - w.a.wrote = true - err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ - Priority: priority, - Weight: weight, - Port: port, - Target: targetName, - }) - if err != nil { - panic(err) // internal fault, not user - } +func (f *fakeDNSConn) Close() error { return nil } -var ( - ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN - ErrRefused = errors.New("refused") // maps to RCode5, REFUSED -) - -type resolverFuncConn struct { - h *resolverDialHandler - network string - address string - builder *dnsmessage.Builder - q dnsmessage.Question - ttl uint32 - wrote bool - - rbuf bytes.Buffer -} - -func (*resolverFuncConn) Close() error { return nil } -func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } -func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } -func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } -func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } -func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } - -func (a *resolverFuncConn) Read(p []byte) (n int, err error) { - return a.rbuf.Read(p) -} - -func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { - if len(packet) < 2 { - return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) - } - reqLen := int(packet[0])<<8 | int(packet[1]) - req := packet[2:] - if len(req) != reqLen { - return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) +func (f *fakeDNSConn) Read(b []byte) (int, error) { + if len(f.buf) > 0 { + n := copy(b, f.buf) + f.buf = f.buf[n:] + return n, nil } - var parser dnsmessage.Parser - h, err := parser.Start(req) + resp, err := f.server.rh(f.n, f.s, f.q, f.t) if err != nil { - // TODO: hook - return 0, err - } - q, err := parser.Question() - hadQ := (err == nil) - if err == nil && a.h.Question != nil { - a.h.Question(h, q) - } - if err != nil && err != dnsmessage.ErrSectionDone { return 0, err } - resh := h - resh.Response = true - resh.Authoritative = true - if hadQ { - resh.RCode = dnsmessage.RCodeSuccess - } else { - resh.RCode = dnsmessage.RCodeNotImplemented - } - a.rbuf.Grow(514) - a.rbuf.WriteByte('X') // reserved header for beu16 length - a.rbuf.WriteByte('Y') // reserved header for beu16 length - builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) - a.builder = &builder - if hadQ { - a.q = q - a.builder.StartQuestions() - err := a.builder.Question(q) - if err != nil { - return 0, fmt.Errorf("Question: %w", err) - } - a.builder.StartAnswers() - switch q.Type { - case dnsmessage.TypeA: - if a.h.HandleA != nil { - resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) - } - case dnsmessage.TypeAAAA: - if a.h.HandleAAAA != nil { - resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) - } - case dnsmessage.TypeSRV: - if a.h.HandleSRV != nil { - resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) - } - } - } - tcpRes, err := builder.Finish() + bb := make([]byte, 2, 514) + bb, err = resp.AppendPack(bb) if err != nil { - return 0, fmt.Errorf("Finish: %w", err) + return 0, fmt.Errorf("cannot marshal DNS message: %v", err) } - n = len(tcpRes) - 2 - tcpRes[0] = byte(n >> 8) - tcpRes[1] = byte(n) - a.rbuf.Write(tcpRes[2:]) - - return len(packet), nil -} - -type someaddr struct{} - -func (someaddr) Network() string { return "unused" } -func (someaddr) String() string { return "unused-someaddr" } - -func mapRCode(err error) dnsmessage.RCode { - switch err { - case nil: - return dnsmessage.RCodeSuccess - case ErrNotExist: - return dnsmessage.RCodeNameError - case ErrRefused: - return dnsmessage.RCodeRefused - default: - return dnsmessage.RCodeServerFailure + if f.tcp { + l := len(bb) - 2 + bb[0] = byte(l >> 8) + bb[1] = byte(l) + f.buf = bb + return f.Read(b) } + + bb = bb[2:] + if len(b) < len(bb) { + return 0, errors.New("read would fragment DNS message") + } + + copy(b, bb) + return len(bb), nil +} + +func (f *fakeDNSConn) Write(b []byte) (int, error) { + if f.tcp && len(b) >= 2 { + b = b[2:] + } + if f.q.Unpack(b) != nil { + return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b)) + } + return len(b), nil +} + +func (f *fakeDNSConn) SetDeadline(t time.Time) error { + f.t = t + return nil +} + +type fakeDNSPacketConn struct { + PacketConn + fakeDNSConn +} + +func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error { + return f.fakeDNSConn.SetDeadline(t) +} + +func (f *fakeDNSPacketConn) Close() error { + return f.fakeDNSConn.Close() }