diff --git a/src/net/dial.go b/src/net/dial.go index 0461ab12ca..e243f45ba3 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -437,21 +437,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn primaries = addrs } - c, err := sd.dialParallel(ctx, primaries, fallbacks) - if err != nil { - return nil, err - } - - if tc, ok := c.(*TCPConn); ok && d.KeepAlive >= 0 { - setKeepAlive(tc.fd, true) - ka := d.KeepAlive - if d.KeepAlive == 0 { - ka = defaultTCPKeepAlive - } - setKeepAlivePeriod(tc.fd, ka) - testHookSetKeepAlive(ka) - } - return c, nil + return sd.dialParallel(ctx, primaries, fallbacks) } // dialParallel races two copies of dialSerial, giving the first a diff --git a/src/net/file_plan9.go b/src/net/file_plan9.go index dfb23d2e84..64aabf93ee 100644 --- a/src/net/file_plan9.go +++ b/src/net/file_plan9.go @@ -100,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) { switch fd.laddr.(type) { case *TCPAddr: - return newTCPConn(fd), nil + return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil case *UDPAddr: return newUDPConn(fd), nil } diff --git a/src/net/file_unix.go b/src/net/file_unix.go index 0df67db501..8b9fc38916 100644 --- a/src/net/file_unix.go +++ b/src/net/file_unix.go @@ -74,7 +74,7 @@ func fileConn(f *os.File) (Conn, error) { } switch fd.laddr.(type) { case *TCPAddr: - return newTCPConn(fd), nil + return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil case *UDPAddr: return newUDPConn(fd), nil case *IPAddr: diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 6bad0e8f8b..672170e681 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -217,10 +217,19 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error { return nil } -func newTCPConn(fd *netFD) *TCPConn { - c := &TCPConn{conn{fd}} - setNoDelay(c.fd, true) - return c +func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn { + setNoDelay(fd, true) + if keepAlive == 0 { + keepAlive = defaultTCPKeepAlive + } + if keepAlive > 0 { + setKeepAlive(fd, true) + setKeepAlivePeriod(fd, keepAlive) + if keepAliveHook != nil { + keepAliveHook(keepAlive) + } + } + return &TCPConn{conn{fd}} } // DialTCP acts like Dial for TCP networks. diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index 435335e92e..d55948f69e 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -42,7 +42,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP if err != nil { return nil, err } - return newTCPConn(fd), nil + return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil } func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil } @@ -52,16 +52,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) { if err != nil { return nil, err } - tc := newTCPConn(fd) - if ln.lc.KeepAlive >= 0 { - setKeepAlive(fd, true) - ka := ln.lc.KeepAlive - if ln.lc.KeepAlive == 0 { - ka = defaultTCPKeepAlive - } - setKeepAlivePeriod(fd, ka) - } - return tc, nil + return newTCPConn(fd, ln.lc.KeepAlive, nil), nil } func (ln *TCPListener) close() error { diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index 463b456173..0b3fa1ae0c 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -107,7 +107,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP if err != nil { return nil, err } - return newTCPConn(fd), nil + return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil } func selfConnect(fd *netFD, err error) bool { @@ -149,16 +149,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) { if err != nil { return nil, err } - tc := newTCPConn(fd) - if ln.lc.KeepAlive >= 0 { - setKeepAlive(fd, true) - ka := ln.lc.KeepAlive - if ln.lc.KeepAlive == 0 { - ka = defaultTCPKeepAlive - } - setKeepAlivePeriod(fd, ka) - } - return tc, nil + return newTCPConn(fd, ln.lc.KeepAlive, nil), nil } func (ln *TCPListener) close() error { diff --git a/src/net/tcpsock_test.go b/src/net/tcpsock_test.go index ae65788a73..990d34706f 100644 --- a/src/net/tcpsock_test.go +++ b/src/net/tcpsock_test.go @@ -808,3 +808,22 @@ func BenchmarkSetReadDeadline(b *testing.B) { deadline = deadline.Add(1) } } + +func TestDialTCPDefaultKeepAlive(t *testing.T) { + ln := newLocalListener(t, "tcp") + defer ln.Close() + + got := time.Duration(-1) + testHookSetKeepAlive = func(d time.Duration) { got = d } + defer func() { testHookSetKeepAlive = func(time.Duration) {} }() + + c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr)) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + if got != defaultTCPKeepAlive { + t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive) + } +}