diff --git a/src/pkg/net/dial.go b/src/pkg/net/dial.go index f0f47b2155..fb47795d79 100644 --- a/src/pkg/net/dial.go +++ b/src/pkg/net/dial.go @@ -37,6 +37,13 @@ type Dialer struct { // network being dialed. // If nil, a local address is automatically chosen. LocalAddr Addr + + // DualStack allows a single dial to attempt to establish + // multiple IPv4 and IPv6 connections and to return the first + // established connection when the network is "tcp" and the + // destination is a host name that has multiple address family + // DNS records. + DualStack bool } // Return either now+Timeout or Deadline, whichever comes first. @@ -143,10 +150,72 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { // See func Dial for a description of the network and address // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { - return resolveAndDial(network, address, d.LocalAddr, d.deadline()) + ra, err := resolveAddr("dial", network, address, d.deadline()) + if err != nil { + return nil, &OpError{Op: "dial", Net: network, Addr: nil, Err: err} + } + dialer := func(deadline time.Time) (Conn, error) { + return dialSingle(network, address, d.LocalAddr, ra.toAddr(), deadline) + } + if ras, ok := ra.(addrList); ok && d.DualStack && network == "tcp" { + dialer = func(deadline time.Time) (Conn, error) { + return dialMulti(network, address, d.LocalAddr, ras, deadline) + } + } + return dial(network, ra.toAddr(), dialer, d.deadline()) } -func dial(net, addr string, la, ra Addr, deadline time.Time) (Conn, error) { +// dialMulti attempts to establish connections to each destination of +// the list of addresses. It will return the first established +// connection and close the other connections. Otherwise it returns +// error on the last attempt. +func dialMulti(net, addr string, la Addr, ras addrList, deadline time.Time) (Conn, error) { + type racer struct { + Conn + Addr + error + } + // Sig controls the flow of dial results on lane. It passes a + // token to the next racer and also indicates the end of flow + // by using closed channel. + sig := make(chan bool, 1) + lane := make(chan racer, 1) + for _, ra := range ras { + go func(ra Addr) { + c, err := dialSingle(net, addr, la, ra, deadline) + if _, ok := <-sig; ok { + lane <- racer{c, ra, err} + } else if err == nil { + // We have to return the resources + // that belong to the other + // connections here for avoiding + // unnecessary resource starvation. + c.Close() + } + }(ra.toAddr()) + } + defer close(sig) + var failAddr Addr + lastErr := errTimeout + nracers := len(ras) + for nracers > 0 { + sig <- true + select { + case racer := <-lane: + if racer.error == nil { + return racer.Conn, nil + } + failAddr = racer.Addr + lastErr = racer.error + nracers-- + } + } + return nil, &OpError{Op: "dial", Net: net, Addr: failAddr, Err: lastErr} +} + +// dialSingle attempts to establish and returns a single connection to +// the destination address. +func dialSingle(net, addr string, la, ra Addr, deadline time.Time) (Conn, error) { if la != nil && la.Network() != ra.Network() { return nil, &OpError{Op: "dial", Net: net, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())} } @@ -168,13 +237,6 @@ func dial(net, addr string, la, ra Addr, deadline time.Time) (Conn, error) { } } -type stringAddr struct { - net, addr string -} - -func (a stringAddr) Network() string { return a.net } -func (a stringAddr) String() string { return a.addr } - // Listen announces on the local network address laddr. // The network net must be a stream-oriented network: "tcp", "tcp4", // "tcp6", "unix" or "unixpacket". diff --git a/src/pkg/net/dial_gen.go b/src/pkg/net/dial_gen.go index f051cdaa84..ada6233003 100644 --- a/src/pkg/net/dial_gen.go +++ b/src/pkg/net/dial_gen.go @@ -12,62 +12,35 @@ import ( var testingIssue5349 bool // used during tests -// resolveAndDialChannel is the simple pure-Go implementation of -// resolveAndDial, still used on operating systems where the deadline -// hasn't been pushed down into the pollserver. (Plan 9 and some old -// versions of Windows) -func resolveAndDialChannel(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +// dialChannel is the simple pure-Go implementation of dial, still +// used on operating systems where the deadline hasn't been pushed +// down into the pollserver. (Plan 9 and some old versions of Windows) +func dialChannel(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { var timeout time.Duration if !deadline.IsZero() { timeout = deadline.Sub(time.Now()) } if timeout <= 0 { - ra, err := resolveAddr("dial", net, addr, noDeadline) - if err != nil { - return nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: err} - } - return dial(net, addr, localAddr, ra.toAddr(), noDeadline) + return dialer(noDeadline) } t := time.NewTimer(timeout) defer t.Stop() - type pair struct { + type racer struct { Conn error } - ch := make(chan pair, 1) - resolvedAddr := make(chan Addr, 1) + ch := make(chan racer, 1) go func() { if testingIssue5349 { time.Sleep(time.Millisecond) } - ra, err := resolveAddr("dial", net, addr, noDeadline) - if err != nil { - ch <- pair{nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: err}} - return - } - resolvedAddr <- ra.toAddr() // in case we need it for OpError - c, err := dial(net, addr, localAddr, ra.toAddr(), noDeadline) - ch <- pair{c, err} + c, err := dialer(noDeadline) + ch <- racer{c, err} }() select { case <-t.C: - // Try to use the real Addr in our OpError, if we resolved it - // before the timeout. Otherwise we just use stringAddr. - var ra Addr - select { - case a := <-resolvedAddr: - ra = a - default: - ra = &stringAddr{net, addr} - } - err := &OpError{ - Op: "dial", - Net: net, - Addr: ra, - Err: errTimeout, - } - return nil, err - case p := <-ch: - return p.Conn, p.error + return nil, &OpError{Op: "dial", Net: net, Addr: ra, Err: errTimeout} + case racer := <-ch: + return racer.Conn, racer.error } } diff --git a/src/pkg/net/dial_test.go b/src/pkg/net/dial_test.go index d79c8a536f..74391bbde7 100644 --- a/src/pkg/net/dial_test.go +++ b/src/pkg/net/dial_test.go @@ -5,13 +5,17 @@ package net import ( + "bytes" "flag" "fmt" "io" "os" + "os/exec" "reflect" "regexp" "runtime" + "strconv" + "sync" "testing" "time" ) @@ -314,6 +318,96 @@ func TestDialTimeoutFDLeak(t *testing.T) { } } +func numTCP() (ntcp, nopen, nclose int, err error) { + lsof, err := exec.Command("lsof", "-n", "-p", strconv.Itoa(os.Getpid())).Output() + if err != nil { + return 0, 0, 0, err + } + ntcp += bytes.Count(lsof, []byte("TCP")) + for _, state := range []string{"LISTEN", "SYN_SENT", "SYN_RECEIVED", "ESTABLISHED"} { + nopen += bytes.Count(lsof, []byte(state)) + } + for _, state := range []string{"CLOSED", "CLOSE_WAIT", "LAST_ACK", "FIN_WAIT_1", "FIN_WAIT_2", "CLOSING", "TIME_WAIT"} { + nclose += bytes.Count(lsof, []byte(state)) + } + return ntcp, nopen, nclose, nil +} + +func TestDialMultiFDLeak(t *testing.T) { + if !supportsIPv4 || !supportsIPv6 { + t.Skip("neither ipv4 nor ipv6 is supported") + } + + halfDeadServer := func(dss *dualStackServer, ln Listener) { + for { + if c, err := ln.Accept(); err != nil { + return + } else { + // It just keeps established + // connections like a half-dead server + // does. + dss.putConn(c) + } + } + } + dss, err := newDualStackServer([]streamListener{ + {net: "tcp4", addr: "127.0.0.1"}, + {net: "tcp6", addr: "[::1]"}, + }) + if err != nil { + t.Fatalf("newDualStackServer failed: %v", err) + } + defer dss.teardown() + if err := dss.buildup(halfDeadServer); err != nil { + t.Fatalf("dualStackServer.buildup failed: %v", err) + } + + _, before, _, err := numTCP() + if err != nil { + t.Skipf("skipping test; error finding or running lsof: %v", err) + } + + var wg sync.WaitGroup + portnum, _, _ := dtoi(dss.port, 0) + ras := addrList{ + // Losers that will fail to connect, see RFC 6890. + &TCPAddr{IP: IPv4(198, 18, 0, 254), Port: portnum}, + &TCPAddr{IP: ParseIP("2001:2::254"), Port: portnum}, + + // Winner candidates of this race. + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: portnum}, + &TCPAddr{IP: IPv6loopback, Port: portnum}, + + // Losers that will have established connections. + &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: portnum}, + &TCPAddr{IP: IPv6loopback, Port: portnum}, + } + const T1 = 10 * time.Millisecond + const T2 = 2 * T1 + const N = 10 + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if c, err := dialMulti("tcp", "fast failover test", nil, ras, time.Now().Add(T1)); err == nil { + c.Close() + } + }() + } + wg.Wait() + time.Sleep(T2) + + ntcp, after, nclose, err := numTCP() + if err != nil { + t.Skipf("skipping test; error finding or running lsof: %v", err) + } + t.Logf("tcp sessions: %v, open sessions: %v, closing sessions: %v", ntcp, after, nclose) + + if after != before { + t.Fatalf("got %v open sessions; expected %v", after, before) + } +} + func numFD() int { if runtime.GOOS == "linux" { f, err := os.Open("/proc/self/fd") @@ -404,3 +498,46 @@ func TestDialer(t *testing.T) { t.Error(err) } } + +func TestDialDualStackLocalhost(t *testing.T) { + if ips, err := LookupIP("localhost"); err != nil { + t.Fatalf("LookupIP failed: %v", err) + } else if len(ips) < 2 || !supportsIPv4 || !supportsIPv6 { + t.Skip("localhost doesn't have a pair of different address family IP addresses") + } + + touchAndByeServer := func(dss *dualStackServer, ln Listener) { + for { + if c, err := ln.Accept(); err != nil { + return + } else { + c.Close() + } + } + } + dss, err := newDualStackServer([]streamListener{ + {net: "tcp4", addr: "127.0.0.1"}, + {net: "tcp6", addr: "[::1]"}, + }) + if err != nil { + t.Fatalf("newDualStackServer failed: %v", err) + } + defer dss.teardown() + if err := dss.buildup(touchAndByeServer); err != nil { + t.Fatalf("dualStackServer.buildup failed: %v", err) + } + + d := &Dialer{DualStack: true} + for _ = range dss.lns { + if c, err := d.Dial("tcp", "localhost:"+dss.port); err != nil { + t.Errorf("Dial failed: %v", err) + } else { + if addr := c.LocalAddr().(*TCPAddr); addr.IP.To4() != nil { + dss.teardownNetwork("tcp4") + } else if addr.IP.To16() != nil && addr.IP.To4() == nil { + dss.teardownNetwork("tcp6") + } + c.Close() + } + } +} diff --git a/src/pkg/net/dialgoogle_test.go b/src/pkg/net/dialgoogle_test.go index 000e1c323a..b4ebad0e0d 100644 --- a/src/pkg/net/dialgoogle_test.go +++ b/src/pkg/net/dialgoogle_test.go @@ -41,6 +41,34 @@ func TestResolveGoogle(t *testing.T) { } } +func TestDialGoogle(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + + d := &Dialer{DualStack: true} + for _, network := range []string{"tcp", "tcp4", "tcp6"} { + if network == "tcp" && !supportsIPv4 && !supportsIPv6 { + t.Logf("skipping test; both ipv4 and ipv6 are not supported") + continue + } else if network == "tcp4" && !supportsIPv4 { + t.Logf("skipping test; ipv4 is not supported") + continue + } else if network == "tcp6" && !supportsIPv6 { + t.Logf("skipping test; ipv6 is not supported") + continue + } else if network == "tcp6" && !*testIPv6 { + t.Logf("test disabled; use -ipv6 to enable") + continue + } + if c, err := d.Dial(network, "www.google.com:http"); err != nil { + t.Errorf("Dial failed: %v", err) + } else { + c.Close() + } + } +} + // fd is already connected to the destination, port 80. // Run an HTTP request to fetch the appropriate page. func fetchGoogle(t *testing.T, fd Conn, network, addr string) { diff --git a/src/pkg/net/fd_plan9.go b/src/pkg/net/fd_plan9.go index 0d9dc54408..38515f20e3 100644 --- a/src/pkg/net/fd_plan9.go +++ b/src/pkg/net/fd_plan9.go @@ -21,10 +21,10 @@ type netFD struct { func sysInit() { } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { // On plan9, use the relatively inefficient // goroutine-racing implementation. - return resolveAndDialChannel(net, addr, localAddr, deadline) + return dialChannel(net, ra, dialer, deadline) } func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { diff --git a/src/pkg/net/fd_unix.go b/src/pkg/net/fd_unix.go index 457c1d18e2..2e62ba0ec4 100644 --- a/src/pkg/net/fd_unix.go +++ b/src/pkg/net/fd_unix.go @@ -36,12 +36,8 @@ type netFD struct { func sysInit() { } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { - ra, err := resolveAddr("dial", net, addr, deadline) - if err != nil { - return nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: err} - } - return dial(net, addr, localAddr, ra.toAddr(), deadline) +func dial(network string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { + return dialer(deadline) } func newFD(sysfd, family, sotype int, net string) (*netFD, error) { diff --git a/src/pkg/net/fd_windows.go b/src/pkg/net/fd_windows.go index 6f344057c7..d480fb4057 100644 --- a/src/pkg/net/fd_windows.go +++ b/src/pkg/net/fd_windows.go @@ -83,17 +83,13 @@ func canUseConnectEx(net string) bool { return syscall.LoadConnectEx() == nil } -func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { +func dial(net string, ra Addr, dialer func(time.Time) (Conn, error), deadline time.Time) (Conn, error) { if !canUseConnectEx(net) { // Use the relatively inefficient goroutine-racing // implementation of DialTimeout. - return resolveAndDialChannel(net, addr, localAddr, deadline) + return dialChannel(net, ra, dialer, deadline) } - ra, err := resolveAddr("dial", net, addr, deadline) - if err != nil { - return nil, &OpError{Op: "dial", Net: net, Addr: nil, Err: err} - } - return dial(net, addr, localAddr, ra.toAddr(), deadline) + return dialer(deadline) } // operation contains superset of data necessary to perform all async IO. diff --git a/src/pkg/net/mockserver_test.go b/src/pkg/net/mockserver_test.go new file mode 100644 index 0000000000..68ded5d757 --- /dev/null +++ b/src/pkg/net/mockserver_test.go @@ -0,0 +1,82 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "sync" + +type streamListener struct { + net, addr string + ln Listener +} + +type dualStackServer struct { + lnmu sync.RWMutex + lns []streamListener + port string + + cmu sync.RWMutex + cs []Conn // established connections at the passive open side +} + +func (dss *dualStackServer) buildup(server func(*dualStackServer, Listener)) error { + for i := range dss.lns { + go server(dss, dss.lns[i].ln) + } + return nil +} + +func (dss *dualStackServer) putConn(c Conn) error { + dss.cmu.Lock() + dss.cs = append(dss.cs, c) + dss.cmu.Unlock() + return nil +} + +func (dss *dualStackServer) teardownNetwork(net string) error { + dss.lnmu.Lock() + for i := range dss.lns { + if net == dss.lns[i].net && dss.lns[i].ln != nil { + dss.lns[i].ln.Close() + dss.lns[i].ln = nil + } + } + dss.lnmu.Unlock() + return nil +} + +func (dss *dualStackServer) teardown() error { + dss.lnmu.Lock() + for i := range dss.lns { + if dss.lns[i].ln != nil { + dss.lns[i].ln.Close() + } + } + dss.lnmu.Unlock() + dss.cmu.Lock() + for _, c := range dss.cs { + c.Close() + } + dss.cmu.Unlock() + return nil +} + +func newDualStackServer(lns []streamListener) (*dualStackServer, error) { + dss := &dualStackServer{lns: lns, port: "0"} + for i := range dss.lns { + ln, err := Listen(dss.lns[i].net, dss.lns[i].addr+":"+dss.port) + if err != nil { + dss.teardown() + return nil, err + } + dss.lns[i].ln = ln + if dss.port == "0" { + if _, dss.port, err = SplitHostPort(ln.Addr().String()); err != nil { + dss.teardown() + return nil, err + } + } + } + return dss, nil +}