diff --git a/src/net/dial.go b/src/net/dial.go index b24bd2f5f4..c538342566 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -341,6 +341,7 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { type sysDialer struct { Dialer network, address string + testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) } // Dial connects to the address on the named network. diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 0550acb01d..e49b4a61d6 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -234,9 +234,7 @@ func TestDialParallel(t *testing.T) { for i, tt := range testCases { i, tt := i, tt t.Run(fmt.Sprint(i), func(t *testing.T) { - origTestHookDialTCP := testHookDialTCP - defer func() { testHookDialTCP = origTestHookDialTCP }() - testHookDialTCP = func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) { + dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) { n := "tcp6" if raddr.IP.To4() != nil { n = "tcp4" @@ -262,9 +260,10 @@ func TestDialParallel(t *testing.T) { } startTime := time.Now() sd := &sysDialer{ - Dialer: d, - network: "tcp", - address: "?", + Dialer: d, + network: "tcp", + address: "?", + testHookDialTCP: dialTCP, } c, err := sd.dialParallel(context.Background(), primaries, fallbacks) elapsed := time.Since(startTime) diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 768d03b06c..435335e92e 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -15,8 +15,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { } func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - if testHookDialTCP != nil { - return testHookDialTCP(ctx, sd.network, laddr, raddr) + if h := sd.testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) + } + if h := testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) } return sd.doDialTCP(ctx, laddr, raddr) } diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index bc3d324e6b..1c91170c50 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -55,8 +55,11 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { } func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - if testHookDialTCP != nil { - return testHookDialTCP(ctx, sd.network, laddr, raddr) + if h := sd.testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) + } + if h := testHookDialTCP; h != nil { + return h(ctx, sd.network, laddr, raddr) } return sd.doDialTCP(ctx, laddr, raddr) }