mirror of https://github.com/golang/go.git
net: support all PacketConn and Conn returned by Resolver.Dial
Allow the Resolver.Dial func to return instances of Conn other than *TCPConn and *UDPConn. If the Conn is also a PacketConn, assume DNS messages transmitted over the Conn adhere to section 4.2.1. "UDP usage". Otherwise, follow section 4.2.2. "TCP usage". Provides a hook mechanism so that DNS queries generated by the net package may be answered or modified before being sent to over the network. Updates #19910 Change-Id: Ib089a28ad4a1848bbeaf624ae889f1e82d56655b Reviewed-on: https://go-review.googlesource.com/45153 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
d55d7b9397
commit
d8a7990ffa
|
|
@ -36,14 +36,14 @@ type dnsConn interface {
|
|||
dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
|
||||
}
|
||||
|
||||
func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
|
||||
return dnsRoundTripUDP(c, query)
|
||||
// dnsPacketConn implements the dnsConn interface for RFC 1035's
|
||||
// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
|
||||
// such as a *UDPConn.
|
||||
type dnsPacketConn struct {
|
||||
Conn
|
||||
}
|
||||
|
||||
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
|
||||
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
|
||||
// such as a *UDPConn.
|
||||
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
|
||||
func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
|
||||
b, ok := query.Pack()
|
||||
if !ok {
|
||||
return nil, errors.New("cannot marshal DNS message")
|
||||
|
|
@ -69,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
|
||||
return dnsRoundTripTCP(c, out)
|
||||
// dnsStreamConn implements the dnsConn interface for RFC 1035's
|
||||
// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
|
||||
// such as a *TCPConn.
|
||||
type dnsStreamConn struct {
|
||||
Conn
|
||||
}
|
||||
|
||||
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
|
||||
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
|
||||
// such as a *TCPConn.
|
||||
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
|
||||
func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
|
||||
b, ok := query.Pack()
|
||||
if !ok {
|
||||
return nil, errors.New("cannot marshal DNS message")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ package net
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"internal/poll"
|
||||
"io/ioutil"
|
||||
|
|
@ -43,11 +44,14 @@ var dnsTransportFallbackTests = []struct {
|
|||
|
||||
func TestDNSTransportFallback(t *testing.T) {
|
||||
fake := fakeDNSServer{
|
||||
rh: func(n, _ string, _ *dnsMsg, _ time.Time) (*dnsMsg, error) {
|
||||
rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
|
||||
r := &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
rcode: dnsRcodeSuccess,
|
||||
id: q.id,
|
||||
response: true,
|
||||
rcode: dnsRcodeSuccess,
|
||||
},
|
||||
question: q.question,
|
||||
}
|
||||
if n == "udp" {
|
||||
r.truncated = true
|
||||
|
|
@ -98,8 +102,10 @@ func TestSpecialDomainName(t *testing.T) {
|
|||
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
|
||||
r := &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: q.id,
|
||||
id: q.id,
|
||||
response: true,
|
||||
},
|
||||
question: q.question,
|
||||
}
|
||||
|
||||
switch q.question[0].Name {
|
||||
|
|
@ -612,8 +618,10 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
|
|||
fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
|
||||
r := &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: q.id,
|
||||
id: q.id,
|
||||
response: true,
|
||||
},
|
||||
question: q.question,
|
||||
}
|
||||
|
||||
switch q.question[0].Name {
|
||||
|
|
@ -751,7 +759,7 @@ type fakeDNSServer struct {
|
|||
}
|
||||
|
||||
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
|
||||
return &fakeDNSConn{nil, server, n, s, time.Time{}}, nil
|
||||
return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
|
||||
}
|
||||
|
||||
type fakeDNSConn struct {
|
||||
|
|
@ -759,6 +767,7 @@ type fakeDNSConn struct {
|
|||
server *fakeDNSServer
|
||||
n string
|
||||
s string
|
||||
q *dnsMsg
|
||||
t time.Time
|
||||
}
|
||||
|
||||
|
|
@ -766,15 +775,45 @@ func (f *fakeDNSConn) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Read(b []byte) (int, error) {
|
||||
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bb, ok := resp.Pack()
|
||||
if !ok {
|
||||
return 0, errors.New("cannot marshal DNS message")
|
||||
}
|
||||
if len(b) < len(bb) {
|
||||
return 0, errors.New("read would fragment DNS message")
|
||||
}
|
||||
|
||||
copy(b, bb)
|
||||
return len(bb), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) Write(b []byte) (int, error) {
|
||||
f.q = new(dnsMsg)
|
||||
if !f.q.Unpack(b) {
|
||||
return 0, errors.New("cannot unmarshal DNS message")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
|
||||
f.t = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
|
||||
return f.server.rh(f.n, f.s, q, f.t)
|
||||
}
|
||||
|
||||
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
|
||||
func TestIgnoreDNSForgeries(t *testing.T) {
|
||||
c, s := Pipe()
|
||||
|
|
@ -837,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
resp, err := dnsRoundTripUDP(c, msg)
|
||||
dc := &dnsPacketConn{c}
|
||||
resp, err := dc.dnsRoundTrip(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("dnsRoundTripUDP failed: %v", err)
|
||||
}
|
||||
|
|
@ -1113,7 +1153,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
|
|||
case resolveOpError:
|
||||
return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
|
||||
case resolveServfail:
|
||||
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeServerFailure}}, nil
|
||||
return &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: q.id,
|
||||
response: true,
|
||||
rcode: dnsRcodeServerFailure,
|
||||
},
|
||||
question: q.question,
|
||||
}, nil
|
||||
case resolveTimeout:
|
||||
return nil, poll.ErrTimeout
|
||||
default:
|
||||
|
|
@ -1123,7 +1170,14 @@ func TestStrictErrorsLookupIP(t *testing.T) {
|
|||
switch q.question[0].Name {
|
||||
case searchX, name + ".":
|
||||
// Return NXDOMAIN to utilize the search list.
|
||||
return &dnsMsg{dnsMsgHdr: dnsMsgHdr{id: q.id, rcode: dnsRcodeNameError}}, nil
|
||||
return &dnsMsg{
|
||||
dnsMsgHdr: dnsMsgHdr{
|
||||
id: q.id,
|
||||
response: true,
|
||||
rcode: dnsRcodeNameError,
|
||||
},
|
||||
question: q.question,
|
||||
}, nil
|
||||
case searchY:
|
||||
// Return records below.
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -111,9 +111,11 @@ type Resolver struct {
|
|||
// Go's built-in DNS resolver to make TCP and UDP connections
|
||||
// to DNS services. The provided addr will always be an IP
|
||||
// address and not a hostname.
|
||||
// The Conn returned must be a *TCPConn or *UDPConn as
|
||||
// requested by the network parameter. If nil, the default
|
||||
// dialer is used.
|
||||
// If the Conn returned is also a PacketConn, sent and received DNS
|
||||
// messages must adhere to section 4.2.1. "UDP usage" of RFC 1035.
|
||||
// Otherwise, DNS messages transmitted over Conn must adhere to section
|
||||
// 4.2.2. "TCP usage".
|
||||
// If nil, the default dialer is used.
|
||||
Dial func(ctx context.Context, network, addr string) (Conn, error)
|
||||
|
||||
// TODO(bradfitz): optional interface impl override hook
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ package net
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
|
@ -70,12 +68,10 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
|
|||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
dc, ok := c.(dnsConn)
|
||||
if !ok {
|
||||
c.Close()
|
||||
return nil, errors.New("net: Resolver.Dial returned unsupported connection type " + reflect.TypeOf(c).String())
|
||||
if _, ok := c.(PacketConn); ok {
|
||||
return &dnsPacketConn{c}, nil
|
||||
}
|
||||
return dc, nil
|
||||
return &dnsStreamConn{c}, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue