diff --git a/src/net/dial.go b/src/net/dial.go index 193776fe41..0661c3ecdf 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -6,6 +6,7 @@ package net import ( "errors" + "runtime" "time" ) @@ -225,8 +226,10 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { finalDeadline: finalDeadline, } + // DualStack mode requires that dialTCP support cancelation. This is + // not available on plan9 (golang.org/issue/11225), so we ignore it. var primaries, fallbacks addrList - if d.DualStack && network == "tcp" { + if d.DualStack && network == "tcp" && runtime.GOOS != "plan9" { primaries, fallbacks = addrs.partition(isIPv4) } else { primaries = addrs @@ -236,9 +239,9 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { if len(fallbacks) == 0 { // dialParallel can accept an empty fallbacks list, // but this shortcut avoids the goroutine/channel overhead. - c, err = dialSerial(ctx, primaries, nil) + c, err = dialSerial(ctx, primaries, ctx.Cancel) } else { - c, err = dialParallel(ctx, primaries, fallbacks) + c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel) } if d.KeepAlive > 0 && err == nil { @@ -255,10 +258,9 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { // head start. It returns the first established connection and // closes the others. Otherwise it returns an error from the first // primary address. -func dialParallel(ctx *dialContext, primaries, fallbacks addrList) (Conn, error) { - results := make(chan dialResult) // unbuffered, so dialSerialAsync can detect race loss & cleanup +func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) { + results := make(chan dialResult, 2) cancel := make(chan struct{}) - defer close(cancel) // Spawn the primary racer. go dialSerialAsync(ctx, primaries, nil, cancel, results) @@ -267,28 +269,59 @@ func dialParallel(ctx *dialContext, primaries, fallbacks addrList) (Conn, error) fallbackTimer := time.NewTimer(ctx.fallbackDelay()) go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results) - var primaryErr error - for nracers := 2; nracers > 0; nracers-- { - res := <-results - // If we're still waiting for a connection, then hasten the delay. - // Otherwise, disable the Timer and let cancel take over. - if fallbackTimer.Stop() && res.error != nil { - fallbackTimer.Reset(0) - } - if res.error == nil { - return res.Conn, nil - } - if res.primary { - primaryErr = res.error + // Wait for both racers to succeed or fail. + var primaryResult, fallbackResult dialResult + for !primaryResult.done || !fallbackResult.done { + select { + case <-userCancel: + // Forward an external cancelation request. + if cancel != nil { + close(cancel) + cancel = nil + } + userCancel = nil + case res := <-results: + // Drop the result into its assigned bucket. + if res.primary { + primaryResult = res + } else { + fallbackResult = res + } + // On success, cancel the other racer (if one exists.) + if res.error == nil && cancel != nil { + close(cancel) + cancel = nil + } + // If the fallbackTimer was pending, then either we've canceled the + // fallback because we no longer want it, or we haven't canceled yet + // and therefore want it to wake up immediately. + if fallbackTimer.Stop() && cancel != nil { + fallbackTimer.Reset(0) + } } } - return nil, primaryErr + + // Return, in order of preference: + // 1. The primary connection (but close the other if we got both.) + // 2. The fallback connection. + // 3. The primary error. + if primaryResult.error == nil { + if fallbackResult.error == nil { + fallbackResult.Conn.Close() + } + return primaryResult.Conn, nil + } else if fallbackResult.error == nil { + return fallbackResult.Conn, nil + } else { + return nil, primaryResult.error + } } type dialResult struct { Conn error primary bool + done bool } // dialSerialAsync runs dialSerial after some delay, and returns the @@ -300,19 +333,11 @@ func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel < select { case <-timer.C: case <-cancel: - return + // dialSerial will immediately return errCanceled in this case. } } c, err := dialSerial(ctx, ras, cancel) - select { - case results <- dialResult{c, err, timer == nil}: - // We won the race. - case <-cancel: - // The other goroutine won the race. - if c != nil { - c.Close() - } - } + results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true} } // dialSerial connects to a list of addresses in sequence, returning @@ -336,11 +361,11 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e break } - // dialTCP does not support cancelation (see golang.org/issue/11225), - // so if cancel fires, we'll continue trying to connect until the next - // timeout, or return a spurious connection for the caller to close. + // If this dial is canceled, the implementation is expected to complete + // quickly, but it's still possible that we could return a spurious Conn, + // which the caller must Close. dialer := func(d time.Time) (Conn, error) { - return dialSingle(ctx, ra, d) + return dialSingle(ctx, ra, d, cancel) } c, err := dial(ctx.network, ra, dialer, partialDeadline) if err == nil { @@ -360,7 +385,7 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e // dialSingle attempts to establish and returns a single connection to // the destination address. This must be called through the OS-specific // dial function, because some OSes don't implement the deadline feature. -func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err error) { +func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) { la := ctx.LocalAddr if la != nil && la.Network() != ra.Network() { return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())} @@ -368,7 +393,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro switch ra := ra.(type) { case *TCPAddr: la, _ := la.(*TCPAddr) - c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel) + c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel) case *UDPAddr: la, _ := la.(*UDPAddr) c, err = dialUDP(ctx.network, la, ra, deadline) diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 1a9dfb26d3..1df923f14b 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -228,9 +228,8 @@ func TestDialerDualStackFDLeak(t *testing.T) { // expected to hang until the timeout elapses. These addresses are reserved // for benchmarking by RFC 6890. const ( - slowDst4 = "192.18.0.254" - slowDst6 = "2001:2::254" - slowTimeout = 1 * time.Second + slowDst4 = "198.18.0.254" + slowDst6 = "2001:2::254" ) // In some environments, the slow IPs may be explicitly unreachable, and fail @@ -239,7 +238,10 @@ const ( func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { c, err := dialTCP(net, laddr, raddr, deadline, cancel) if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) { - time.Sleep(deadline.Sub(time.Now())) + select { + case <-cancel: + case <-time.After(deadline.Sub(time.Now())): + } } return c, err } @@ -283,6 +285,9 @@ func TestDialParallel(t *testing.T) { if !supportsIPv4 || !supportsIPv6 { t.Skip("both IPv4 and IPv6 are required") } + if runtime.GOOS == "plan9" { + t.Skip("skipping on plan9; cannot cancel dialTCP, golang.org/issue/11225") + } closedPortDelay, expectClosedPortDelay := dialClosedPort() if closedPortDelay > expectClosedPortDelay { @@ -388,7 +393,6 @@ func TestDialParallel(t *testing.T) { fallbacks := makeAddrs(tt.fallbacks, dss.port) d := Dialer{ FallbackDelay: fallbackDelay, - Timeout: slowTimeout, } ctx := &dialContext{ Dialer: d, @@ -397,7 +401,7 @@ func TestDialParallel(t *testing.T) { finalDeadline: d.deadline(time.Now()), } startTime := time.Now() - c, err := dialParallel(ctx, primaries, fallbacks) + c, err := dialParallel(ctx, primaries, fallbacks, nil) elapsed := time.Now().Sub(startTime) if c != nil { @@ -417,9 +421,27 @@ func TestDialParallel(t *testing.T) { } else if !(elapsed <= expectElapsedMax) { t.Errorf("#%d: got %v; want <= %v", i, elapsed, expectElapsedMax) } + + // Repeat each case, ensuring that it can be canceled quickly. + cancel := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + time.Sleep(5 * time.Millisecond) + close(cancel) + wg.Done() + }() + startTime = time.Now() + c, err = dialParallel(ctx, primaries, fallbacks, cancel) + if c != nil { + c.Close() + } + elapsed = time.Now().Sub(startTime) + if elapsed > 100*time.Millisecond { + t.Errorf("#%d (cancel): got %v; want <= 100ms", i, elapsed) + } + wg.Wait() } - // Wait for any slowDst4/slowDst6 connections to timeout. - time.Sleep(slowTimeout * 3 / 2) } func lookupSlowFast(fn func(string) ([]IPAddr, error), host string) ([]IPAddr, error) { @@ -462,8 +484,6 @@ func TestDialerFallbackDelay(t *testing.T) { {true, 200 * time.Millisecond, 200 * time.Millisecond}, // The default is 300ms. {true, 0, 300 * time.Millisecond}, - // This case is last, in order to wait for hanging slowDst6 connections. - {false, 0, slowTimeout}, } handler := func(dss *dualStackServer, ln Listener) { @@ -487,7 +507,7 @@ func TestDialerFallbackDelay(t *testing.T) { } for i, tt := range testCases { - d := &Dialer{DualStack: tt.dualstack, FallbackDelay: tt.delay, Timeout: slowTimeout} + d := &Dialer{DualStack: tt.dualstack, FallbackDelay: tt.delay} startTime := time.Now() c, err := d.Dial("tcp", JoinHostPort("slow6loopback4", dss.port)) @@ -508,17 +528,58 @@ func TestDialerFallbackDelay(t *testing.T) { } } -func TestDialSerialAsyncSpuriousConnection(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping on plan9; no deadline support, golang.org/issue/11932") +func TestDialParallelSpuriousConnection(t *testing.T) { + if !supportsIPv4 || !supportsIPv6 { + t.Skip("both IPv4 and IPv6 are required") } - ln, err := newLocalListener("tcp") + if runtime.GOOS == "plan9" { + t.Skip("skipping on plan9; cannot cancel dialTCP, golang.org/issue/11225") + } + + var wg sync.WaitGroup + wg.Add(2) + handler := func(dss *dualStackServer, ln Listener) { + // Accept one connection per address. + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + // The client should close itself, without sending data. + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + var b [1]byte + if _, err := c.Read(b[:]); err != io.EOF { + t.Errorf("got %v; want %v", err, io.EOF) + } + c.Close() + wg.Done() + } + dss, err := newDualStackServer([]streamListener{ + {network: "tcp4", address: "127.0.0.1"}, + {network: "tcp6", address: "::1"}, + }) if err != nil { t.Fatal(err) } - defer ln.Close() + defer dss.teardown() + if err := dss.buildup(handler); err != nil { + t.Fatal(err) + } - d := Dialer{} + const fallbackDelay = 100 * time.Millisecond + + origTestHookDialTCP := testHookDialTCP + defer func() { testHookDialTCP = origTestHookDialTCP }() + testHookDialTCP = func(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) { + // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancelation. + // This forces dialParallel to juggle two successful connections. + time.Sleep(fallbackDelay * 2) + cancel = nil + return dialTCP(net, laddr, raddr, deadline, cancel) + } + + d := Dialer{ + FallbackDelay: fallbackDelay, + } ctx := &dialContext{ Dialer: d, network: "tcp", @@ -526,28 +587,23 @@ func TestDialSerialAsyncSpuriousConnection(t *testing.T) { finalDeadline: d.deadline(time.Now()), } - results := make(chan dialResult) - cancel := make(chan struct{}) + makeAddr := func(ip string) addrList { + addr, err := ResolveTCPAddr("tcp", JoinHostPort(ip, dss.port)) + if err != nil { + t.Fatal(err) + } + return addrList{addr} + } - // Spawn a connection in the background. - go dialSerialAsync(ctx, addrList{ln.Addr()}, nil, cancel, results) - - // Receive it at the server. - c, err := ln.Accept() + // dialParallel returns one connection (and closes the other.) + c, err := dialParallel(ctx, makeAddr("127.0.0.1"), makeAddr("::1"), nil) if err != nil { t.Fatal(err) } - defer c.Close() + c.Close() - // Tell dialSerialAsync that someone else won the race. - close(cancel) - - // The connection should close itself, without sending data. - c.SetReadDeadline(time.Now().Add(1 * time.Second)) - var b [1]byte - if _, err := c.Read(b[:]); err != io.EOF { - t.Errorf("got %v; want %v", err, io.EOF) - } + // The server should've seen both connections. + wg.Wait() } func TestDialerPartialDeadline(t *testing.T) { @@ -676,7 +732,6 @@ func TestDialerDualStack(t *testing.T) { c.Close() } } - time.Sleep(timeout * 3 / 2) // wait for the dial racers to stop } func TestDialerKeepAlive(t *testing.T) {