mirror of https://github.com/golang/go.git
use fakeDNSServer
Change-Id: I56dfdd7457782419801b208c0a6894f9e88a8a7f
This commit is contained in:
parent
cf52e70997
commit
92aeaf5b42
|
|
@ -867,96 +867,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
type fakeDNSServer struct {
|
||||
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
|
||||
alwaysTCP bool
|
||||
}
|
||||
|
||||
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
|
||||
if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
|
||||
return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
|
||||
}
|
||||
return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
|
||||
}
|
||||
|
||||
type fakeDNSConn struct {
|
||||
Conn
|
||||
tcp bool
|
||||
server *fakeDNSServer
|
||||
n string
|
||||
s string
|
||||
q dnsmessage.Message
|
||||
t time.Time
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Read(b []byte) (int, error) {
|
||||
if len(f.buf) > 0 {
|
||||
n := copy(b, f.buf)
|
||||
f.buf = f.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bb := make([]byte, 2, 514)
|
||||
bb, err = resp.AppendPack(bb)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
|
||||
}
|
||||
|
||||
if f.tcp {
|
||||
l := len(bb) - 2
|
||||
bb[0] = byte(l >> 8)
|
||||
bb[1] = byte(l)
|
||||
f.buf = bb
|
||||
return f.Read(b)
|
||||
}
|
||||
|
||||
bb = bb[2:]
|
||||
if len(b) < len(bb) {
|
||||
return 0, errors.New("read would fragment DNS message")
|
||||
}
|
||||
|
||||
copy(b, bb)
|
||||
return len(bb), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Write(b []byte) (int, error) {
|
||||
if f.tcp && len(b) >= 2 {
|
||||
b = b[2:]
|
||||
}
|
||||
if f.q.Unpack(b) != nil {
|
||||
return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
|
||||
f.t = t
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDNSPacketConn struct {
|
||||
PacketConn
|
||||
fakeDNSConn
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
|
||||
return f.fakeDNSConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) Close() error {
|
||||
return f.fakeDNSConn.Close()
|
||||
}
|
||||
|
||||
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
|
||||
func TestIgnoreDNSForgeries(t *testing.T) {
|
||||
c, s := Pipe()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -21,36 +20,62 @@ import (
|
|||
)
|
||||
|
||||
func TestResolverDialFunc(t *testing.T) {
|
||||
r := &Resolver{
|
||||
PreferGo: true,
|
||||
Dial: newResolverDialFunc(&resolverDialHandler{
|
||||
StartDial: func(network, address string) error {
|
||||
t.Logf("StartDial(%q, %q) ...", network, address)
|
||||
return nil
|
||||
},
|
||||
Question: func(h dnsmessage.Header, q dnsmessage.Question) {
|
||||
t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
|
||||
q.Name.String(), q.Type, q.Class)
|
||||
},
|
||||
// TODO: add test without HandleA* hooks specified at all, that Go
|
||||
// doesn't issue retries; map to something terminal.
|
||||
HandleA: func(w AWriter, name string) error {
|
||||
w.AddIP([4]byte{1, 2, 3, 4})
|
||||
w.AddIP([4]byte{5, 6, 7, 8})
|
||||
return nil
|
||||
},
|
||||
HandleAAAA: func(w AAAAWriter, name string) error {
|
||||
w.AddIP([16]byte{1: 1, 15: 15})
|
||||
w.AddIP([16]byte{2: 2, 14: 14})
|
||||
return nil
|
||||
},
|
||||
HandleSRV: func(w SRVWriter, name string) error {
|
||||
w.AddSRV(1, 2, 80, "foo.bar.")
|
||||
w.AddSRV(2, 3, 81, "bar.baz.")
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
fake := fakeDNSServer{
|
||||
rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
|
||||
r := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{
|
||||
ID: q.Header.ID,
|
||||
Response: true,
|
||||
RCode: dnsmessage.RCodeSuccess,
|
||||
},
|
||||
Questions: q.Questions,
|
||||
Answers: []dnsmessage.Resource{
|
||||
{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: q.Questions[0].Name,
|
||||
Type: q.Questions[0].Type,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
},
|
||||
{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: q.Questions[0].Name,
|
||||
Type: q.Questions[0].Type,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
switch q.Questions[0].Type {
|
||||
case dnsmessage.TypeA:
|
||||
r.Answers[0].Body = &dnsmessage.AResource{A: [4]byte{1, 2, 3, 4}}
|
||||
r.Answers[1].Body = &dnsmessage.AResource{A: [4]byte{5, 6, 7, 8}}
|
||||
case dnsmessage.TypeAAAA:
|
||||
r.Answers[0].Body = &dnsmessage.AAAAResource{AAAA: [16]byte{1: 1, 15: 15}}
|
||||
r.Answers[1].Body = &dnsmessage.AAAAResource{AAAA: [16]byte{2: 2, 14: 14}}
|
||||
case dnsmessage.TypeSRV:
|
||||
r.Answers[0].Body = &dnsmessage.SRVResource{
|
||||
Priority: 1,
|
||||
Weight: 2,
|
||||
Port: 80,
|
||||
Target: dnsmessage.MustNewName("foo.bar."),
|
||||
}
|
||||
r.Answers[1].Body = &dnsmessage.SRVResource{
|
||||
Priority: 2,
|
||||
Weight: 3,
|
||||
Port: 81,
|
||||
Target: dnsmessage.MustNewName("bar.baz."),
|
||||
}
|
||||
default:
|
||||
panic("unexpected DNS type")
|
||||
}
|
||||
return r, nil
|
||||
},
|
||||
}
|
||||
|
||||
r := &Resolver{PreferGo: true, Dial: fake.DialContext}
|
||||
|
||||
ctx := context.Background()
|
||||
const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
|
||||
|
||||
|
|
@ -101,225 +126,92 @@ func sortedIPStrings(ips []IP) []string {
|
|||
return ret
|
||||
}
|
||||
|
||||
func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (Conn, error) {
|
||||
a := &resolverFuncConn{
|
||||
h: h,
|
||||
network: network,
|
||||
address: address,
|
||||
ttl: 10, // 10 second default if unset
|
||||
}
|
||||
if h.StartDial != nil {
|
||||
if err := h.StartDial(network, address); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
type fakeDNSServer struct {
|
||||
rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
|
||||
alwaysTCP bool
|
||||
}
|
||||
|
||||
type resolverDialHandler struct {
|
||||
// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
|
||||
// Any error returned aborts the dial and is returned unwrapped.
|
||||
StartDial func(network, address string) error
|
||||
|
||||
Question func(dnsmessage.Header, dnsmessage.Question)
|
||||
|
||||
// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
|
||||
// A nil error means success.
|
||||
HandleA func(w AWriter, name string) error
|
||||
HandleAAAA func(w AAAAWriter, name string) error
|
||||
HandleSRV func(w SRVWriter, name string) error
|
||||
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
|
||||
if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
|
||||
return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
|
||||
}
|
||||
return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
|
||||
}
|
||||
|
||||
type ResponseWriter struct{ a *resolverFuncConn }
|
||||
|
||||
func (w ResponseWriter) header() dnsmessage.ResourceHeader {
|
||||
q := w.a.q
|
||||
return dnsmessage.ResourceHeader{
|
||||
Name: q.Name,
|
||||
Type: q.Type,
|
||||
Class: q.Class,
|
||||
TTL: w.a.ttl,
|
||||
}
|
||||
type fakeDNSConn struct {
|
||||
Conn
|
||||
tcp bool
|
||||
server *fakeDNSServer
|
||||
n string
|
||||
s string
|
||||
q dnsmessage.Message
|
||||
t time.Time
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// SetTTL sets the TTL for subsequent written resources.
|
||||
// Once a resource has been written, SetTTL calls are no-ops.
|
||||
// That is, it can only be called at most once, before anything
|
||||
// else is written.
|
||||
func (w ResponseWriter) SetTTL(seconds uint32) {
|
||||
// ... intention is last one wins and mutates all previously
|
||||
// written records too, but that's a little annoying.
|
||||
// But it's also annoying if the requirement is it needs to be set
|
||||
// last.
|
||||
// And it's also annoying if it's possible for users to set
|
||||
// different TTLs per Answer.
|
||||
if w.a.wrote {
|
||||
return
|
||||
}
|
||||
w.a.ttl = seconds
|
||||
|
||||
}
|
||||
|
||||
type AWriter struct{ ResponseWriter }
|
||||
|
||||
func (w AWriter) AddIP(v4 [4]byte) {
|
||||
w.a.wrote = true
|
||||
err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type AAAAWriter struct{ ResponseWriter }
|
||||
|
||||
func (w AAAAWriter) AddIP(v6 [16]byte) {
|
||||
w.a.wrote = true
|
||||
err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type SRVWriter struct{ ResponseWriter }
|
||||
|
||||
// AddSRV adds a SRV record. The target name must end in a period and
|
||||
// be 63 bytes or fewer.
|
||||
func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
|
||||
targetName, err := dnsmessage.NewName(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.a.wrote = true
|
||||
err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
|
||||
Priority: priority,
|
||||
Weight: weight,
|
||||
Port: port,
|
||||
Target: targetName,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err) // internal fault, not user
|
||||
}
|
||||
func (f *fakeDNSConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
|
||||
ErrRefused = errors.New("refused") // maps to RCode5, REFUSED
|
||||
)
|
||||
|
||||
type resolverFuncConn struct {
|
||||
h *resolverDialHandler
|
||||
network string
|
||||
address string
|
||||
builder *dnsmessage.Builder
|
||||
q dnsmessage.Question
|
||||
ttl uint32
|
||||
wrote bool
|
||||
|
||||
rbuf bytes.Buffer
|
||||
}
|
||||
|
||||
func (*resolverFuncConn) Close() error { return nil }
|
||||
func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
|
||||
func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
|
||||
func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
|
||||
return a.rbuf.Read(p)
|
||||
}
|
||||
|
||||
func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
|
||||
if len(packet) < 2 {
|
||||
return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
|
||||
}
|
||||
reqLen := int(packet[0])<<8 | int(packet[1])
|
||||
req := packet[2:]
|
||||
if len(req) != reqLen {
|
||||
return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
|
||||
func (f *fakeDNSConn) Read(b []byte) (int, error) {
|
||||
if len(f.buf) > 0 {
|
||||
n := copy(b, f.buf)
|
||||
f.buf = f.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var parser dnsmessage.Parser
|
||||
h, err := parser.Start(req)
|
||||
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
|
||||
if err != nil {
|
||||
// TODO: hook
|
||||
return 0, err
|
||||
}
|
||||
q, err := parser.Question()
|
||||
hadQ := (err == nil)
|
||||
if err == nil && a.h.Question != nil {
|
||||
a.h.Question(h, q)
|
||||
}
|
||||
if err != nil && err != dnsmessage.ErrSectionDone {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
resh := h
|
||||
resh.Response = true
|
||||
resh.Authoritative = true
|
||||
if hadQ {
|
||||
resh.RCode = dnsmessage.RCodeSuccess
|
||||
} else {
|
||||
resh.RCode = dnsmessage.RCodeNotImplemented
|
||||
}
|
||||
a.rbuf.Grow(514)
|
||||
a.rbuf.WriteByte('X') // reserved header for beu16 length
|
||||
a.rbuf.WriteByte('Y') // reserved header for beu16 length
|
||||
builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
|
||||
a.builder = &builder
|
||||
if hadQ {
|
||||
a.q = q
|
||||
a.builder.StartQuestions()
|
||||
err := a.builder.Question(q)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Question: %w", err)
|
||||
}
|
||||
a.builder.StartAnswers()
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA:
|
||||
if a.h.HandleA != nil {
|
||||
resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
if a.h.HandleAAAA != nil {
|
||||
resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
|
||||
}
|
||||
case dnsmessage.TypeSRV:
|
||||
if a.h.HandleSRV != nil {
|
||||
resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
tcpRes, err := builder.Finish()
|
||||
bb := make([]byte, 2, 514)
|
||||
bb, err = resp.AppendPack(bb)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Finish: %w", err)
|
||||
return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
|
||||
}
|
||||
|
||||
n = len(tcpRes) - 2
|
||||
tcpRes[0] = byte(n >> 8)
|
||||
tcpRes[1] = byte(n)
|
||||
a.rbuf.Write(tcpRes[2:])
|
||||
|
||||
return len(packet), nil
|
||||
}
|
||||
|
||||
type someaddr struct{}
|
||||
|
||||
func (someaddr) Network() string { return "unused" }
|
||||
func (someaddr) String() string { return "unused-someaddr" }
|
||||
|
||||
func mapRCode(err error) dnsmessage.RCode {
|
||||
switch err {
|
||||
case nil:
|
||||
return dnsmessage.RCodeSuccess
|
||||
case ErrNotExist:
|
||||
return dnsmessage.RCodeNameError
|
||||
case ErrRefused:
|
||||
return dnsmessage.RCodeRefused
|
||||
default:
|
||||
return dnsmessage.RCodeServerFailure
|
||||
if f.tcp {
|
||||
l := len(bb) - 2
|
||||
bb[0] = byte(l >> 8)
|
||||
bb[1] = byte(l)
|
||||
f.buf = bb
|
||||
return f.Read(b)
|
||||
}
|
||||
|
||||
bb = bb[2:]
|
||||
if len(b) < len(bb) {
|
||||
return 0, errors.New("read would fragment DNS message")
|
||||
}
|
||||
|
||||
copy(b, bb)
|
||||
return len(bb), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Write(b []byte) (int, error) {
|
||||
if f.tcp && len(b) >= 2 {
|
||||
b = b[2:]
|
||||
}
|
||||
if f.q.Unpack(b) != nil {
|
||||
return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
|
||||
f.t = t
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDNSPacketConn struct {
|
||||
PacketConn
|
||||
fakeDNSConn
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
|
||||
return f.fakeDNSConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (f *fakeDNSPacketConn) Close() error {
|
||||
return f.fakeDNSConn.Close()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue