net: unify TCP keepalive behavior

CL 107196 introduced a default TCP keepalive interval for Dialer and
TCPListener (used by both ListenConfig and ListenTCP). Leaving DialTCP
out was likely an oversight.

DialTCP's documentation says it "acts like Dial". Therefore it's natural
to also expect DialTCP to enable TCP keepalive by default.

This commit addresses this disparity by moving the enablement logic down
to the newTCPConn function, which is used by both dialer and listener.

Fixes #49345
This commit is contained in:
database64128 2022-11-04 16:40:08 +08:00
parent a11cd6f69a
commit 5c2f1cb0fb
No known key found for this signature in database
GPG Key ID: 1CA27546BEDB8B01
7 changed files with 39 additions and 43 deletions

View File

@ -437,21 +437,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
primaries = addrs
}
c, err := sd.dialParallel(ctx, primaries, fallbacks)
if err != nil {
return nil, err
}
if tc, ok := c.(*TCPConn); ok && d.KeepAlive >= 0 {
setKeepAlive(tc.fd, true)
ka := d.KeepAlive
if d.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(tc.fd, ka)
testHookSetKeepAlive(ka)
}
return c, nil
return sd.dialParallel(ctx, primaries, fallbacks)
}
// dialParallel races two copies of dialSerial, giving the first a

View File

@ -100,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) {
switch fd.laddr.(type) {
case *TCPAddr:
return newTCPConn(fd), nil
return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
case *UDPAddr:
return newUDPConn(fd), nil
}

View File

@ -74,7 +74,7 @@ func fileConn(f *os.File) (Conn, error) {
}
switch fd.laddr.(type) {
case *TCPAddr:
return newTCPConn(fd), nil
return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
case *UDPAddr:
return newUDPConn(fd), nil
case *IPAddr:

View File

@ -217,10 +217,19 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
return nil
}
func newTCPConn(fd *netFD) *TCPConn {
c := &TCPConn{conn{fd}}
setNoDelay(c.fd, true)
return c
func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
setNoDelay(fd, true)
if keepAlive == 0 {
keepAlive = defaultTCPKeepAlive
}
if keepAlive > 0 {
setKeepAlive(fd, true)
setKeepAlivePeriod(fd, keepAlive)
if keepAliveHook != nil {
keepAliveHook(keepAlive)
}
}
return &TCPConn{conn{fd}}
}
// DialTCP acts like Dial for TCP networks.

View File

@ -42,7 +42,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
if err != nil {
return nil, err
}
return newTCPConn(fd), nil
return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
}
func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil }
@ -52,16 +52,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil {
return nil, err
}
tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
}
func (ln *TCPListener) close() error {

View File

@ -107,7 +107,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP
if err != nil {
return nil, err
}
return newTCPConn(fd), nil
return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
}
func selfConnect(fd *netFD, err error) bool {
@ -149,16 +149,7 @@ func (ln *TCPListener) accept() (*TCPConn, error) {
if err != nil {
return nil, err
}
tc := newTCPConn(fd)
if ln.lc.KeepAlive >= 0 {
setKeepAlive(fd, true)
ka := ln.lc.KeepAlive
if ln.lc.KeepAlive == 0 {
ka = defaultTCPKeepAlive
}
setKeepAlivePeriod(fd, ka)
}
return tc, nil
return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
}
func (ln *TCPListener) close() error {

View File

@ -808,3 +808,22 @@ func BenchmarkSetReadDeadline(b *testing.B) {
deadline = deadline.Add(1)
}
}
func TestDialTCPDefaultKeepAlive(t *testing.T) {
ln := newLocalListener(t, "tcp")
defer ln.Close()
got := time.Duration(-1)
testHookSetKeepAlive = func(d time.Duration) { got = d }
defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
if err != nil {
t.Fatal(err)
}
defer c.Close()
if got != defaultTCPKeepAlive {
t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
}
}