net: unify TCP keepalive behavior

CL 107196 introduced a default TCP keepalive interval for Dialer and TCPListener (used by both ListenConfig and ListenTCP). Leaving DialTCP out was likely an oversight.

DialTCP's documentation says it "acts like Dial". Therefore it's natural to also expect DialTCP to enable TCP keepalive by default.

This commit addresses this disparity by moving the enablement logic down to the newTCPConn function, which is used by both dialer and listener.

Fixes #49345

Change-Id: I99c08b161c468ed0b993d1dbd2bd0d7e803f3826
GitHub-Last-Rev: 5c2f1cb0fb
GitHub-Pull-Request: golang/go#56565
Reviewed-on: https://go-review.googlesource.com/c/go/+/447917
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
This commit is contained in:
database64128 2022-11-10 08:20:29 +00:00 committed by Gopher Robot
parent 531ba0c8aa
commit fbf763fd1d
7 changed files with 39 additions and 43 deletions

View File

@ -437,21 +437,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
primaries = addrs
}
c, err := sd.dialParallel(ctx, primaries, fallbacks)
if err != nil {
return nil, err
}
if tc, ok := c.(*TCPConn); ok && d.KeepAlive >= 0 {
setKeepAlive(tc.fd, true)
ka := d.KeepAlive
if d.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(tc.fd, ka)
testHookSetKeepAlive(ka)
}
return c, nil
return sd.dialParallel(ctx, primaries, fallbacks)
}
// dialParallel races two copies of dialSerial, giving the first a

View File

@ -100,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) {
switch fd.laddr.(type) {
case *TCPAddr:
return newTCPConn(fd), nil
return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
case *UDPAddr:
return newUDPConn(fd), nil
}

View File

@ -74,7 +74,7 @@ func fileConn(f *os.File) (Conn, error) {
}
switch fd.laddr.(type) {
case *TCPAddr:
return newTCPConn(fd), nil
return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
case *UDPAddr:
return newUDPConn(fd), nil
case *IPAddr:

View File

@ -217,10 +217,19 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
return nil
}
func newTCPConn(fd *netFD) *TCPConn {
c := &TCPConn{conn{fd}}
setNoDelay(c.fd, true)
return c
func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
setNoDelay(fd, true)
if keepAlive == 0 {
keepAlive = defaultTCPKeepAlive
}
if keepAlive > 0 {
setKeepAlive(fd, true)
setKeepAlivePeriod(fd, keepAlive)
if keepAliveHook != nil {
keepAliveHook(keepAlive)
}
}
return &TCPConn{conn{fd}}
}
// DialTCP acts like Dial for TCP networks.

View File

@ -42,7 +42,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
if err != nil {
return nil, err
}
return newTCPConn(fd), nil
return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
}
func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil }
@ -52,16 +52,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil {
return nil, err
}
tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
}
func (ln *TCPListener) close() error {

View File

@ -107,7 +107,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
if err != nil {
return nil, err
}
return newTCPConn(fd), nil
return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
}
func selfConnect(fd *netFD, err error) bool {
@ -149,16 +149,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil {
return nil, err
}
tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
}
func (ln *TCPListener) close() error {

View File

@ -808,3 +808,22 @@ func BenchmarkSetReadDeadline(b *testing.B) {
deadline = deadline.Add(1)
}
}
func TestDialTCPDefaultKeepAlive(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
got := time.Duration(-1)
testHookSetKeepAlive = func(d time.Duration) { got = d }
defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
if err != nil {
t.Fatal(err)
}
defer c.Close()
if got != defaultTCPKeepAlive {
t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
}
}