net/http: don't cancel Dials when requests are canceled

Currently, when a Transport creates a new connection for a request,
it uses the request's Context to make the Dial. If a request
times out or is canceled before a Dial completes, the Dial is
canceled.

Change this so that the lifetime of a Dial call is not bound
by the request that originated it.

This change avoids a scenario where a Transport can start and
then cancel many Dial calls in rapid succession:

  - Request starts a Dial.
  - A previous request completes, making its connection available.
  - The new request uses the now-idle connection, and completes.
  - The request Context is canceled, and the Dial is aborted.

Fixes #59017

Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c
Reviewed-on: https://go-review.googlesource.com/c/go/+/576555
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
This commit is contained in:
Damien Neil 2024-04-04 11:01:28 -07:00
parent 4742c52e10
commit 334ce51004
5 changed files with 371 additions and 58 deletions

View File

@ -1938,21 +1938,25 @@ func TestClientCloseIdleConnections(t *testing.T) {
}
}
type testRoundTripper func(*Request) (*Response, error)
func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
return t(req)
}
func TestClientPropagatesTimeoutToContext(t *testing.T) {
errDial := errors.New("not actually dialing")
c := &Client{
Timeout: 5 * time.Second,
Transport: &Transport{
DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
deadline, ok := ctx.Deadline()
if !ok {
t.Error("no deadline")
} else {
t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
}
return nil, errDial
},
},
Transport: testRoundTripper(func(req *Request) (*Response, error) {
ctx := req.Context()
deadline, ok := ctx.Deadline()
if !ok {
t.Error("no deadline")
} else {
t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
}
return nil, errors.New("not actually making a request")
}),
}
c.Get("https://example.tld/")
}

View File

@ -86,6 +86,14 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
orig := testHookProxyConnectTimeout
t.Cleanup(func() {
testHookProxyConnectTimeout = orig
})
testHookProxyConnectTimeout = f
}
func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
return &timeoutHandler{
handler: handler,

View File

@ -108,6 +108,7 @@ type Transport struct {
connsPerHostMu sync.Mutex
connsPerHost map[connectMethodKey]int
connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
dialsInProgress wantConnQueue
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
@ -807,6 +808,13 @@ func (t *Transport) CloseIdleConnections() {
pconn.close(errCloseIdleConns)
}
}
t.connsPerHostMu.Lock()
t.dialsInProgress.all(func(w *wantConn) {
if w.cancelCtx != nil && !w.waiting() {
w.cancelCtx()
}
})
t.connsPerHostMu.Unlock()
if t2 := t.h2transport; t2 != nil {
t2.CloseIdleConnections()
}
@ -1116,7 +1124,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.idleConnWait[w.key]
q.cleanFront()
q.cleanFrontNotWaiting()
q.pushBack(w)
t.idleConnWait[w.key] = q
return false
@ -1230,10 +1238,11 @@ type wantConn struct {
beforeDial func()
afterDial func()
mu sync.Mutex // protects ctx, done and sending of the result
ctx context.Context // context for dial, cleared after delivered or canceled
done bool // true after delivered or canceled
result chan connOrError // channel to deliver connection or error
mu sync.Mutex // protects ctx, done and sending of the result
ctx context.Context // context for dial, cleared after delivered or canceled
cancelCtx context.CancelFunc
done bool // true after delivered or canceled
result chan connOrError // channel to deliver connection or error
}
type connOrError struct {
@ -1352,9 +1361,9 @@ func (q *wantConnQueue) peekFront() *wantConn {
return nil
}
// cleanFront pops any wantConns that are no longer waiting from the head of the
// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
func (q *wantConnQueue) cleanFront() (cleaned bool) {
func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
@ -1365,6 +1374,28 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
}
}
// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue.
func (q *wantConnQueue) cleanFrontCanceled() {
for {
w := q.peekFront()
if w == nil || w.cancelCtx != nil {
return
}
q.popFront()
}
}
// all iterates over all wantConns in the queue.
// The caller must not modify the queue while iterating.
func (q *wantConnQueue) all(f func(*wantConn)) {
for _, w := range q.head[q.headPos:] {
f(w)
}
for _, w := range q.tail {
f(w)
}
}
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if t.DialTLSContext != nil {
conn, err = t.DialTLSContext(ctx, network, addr)
@ -1389,10 +1420,18 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
trace.GetConn(cm.addr())
}
// Detach from the request context's cancellation signal.
// The dial should proceed even if the request is canceled,
// because a future request may be able to make use of the connection.
//
// We retain the request context's values.
dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx))
w := &wantConn{
cm: cm,
key: cm.key(),
ctx: ctx,
ctx: dialCtx,
cancelCtx: dialCancel,
result: make(chan connOrError, 1),
beforeDial: testHookPrePendingDial,
afterDial: testHookPostPendingDial,
@ -1470,20 +1509,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
// Once w receives permission to dial, it will do so in a separate goroutine.
func (t *Transport) queueForDial(w *wantConn) {
w.beforeDial()
if t.MaxConnsPerHost <= 0 {
go t.dialConnFor(w)
return
}
t.connsPerHostMu.Lock()
defer t.connsPerHostMu.Unlock()
if t.MaxConnsPerHost <= 0 {
t.startDialConnForLocked(w)
return
}
if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
if t.connsPerHost == nil {
t.connsPerHost = make(map[connectMethodKey]int)
}
t.connsPerHost[w.key] = n + 1
go t.dialConnFor(w)
t.startDialConnForLocked(w)
return
}
@ -1491,11 +1531,24 @@ func (t *Transport) queueForDial(w *wantConn) {
t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.connsPerHostWait[w.key]
q.cleanFront()
q.cleanFrontNotWaiting()
q.pushBack(w)
t.connsPerHostWait[w.key] = q
}
// startDialConnFor calls dialConn in a new goroutine.
// t.connsPerHostMu must be held.
func (t *Transport) startDialConnForLocked(w *wantConn) {
t.dialsInProgress.cleanFrontCanceled()
t.dialsInProgress.pushBack(w)
go func() {
t.dialConnFor(w)
t.connsPerHostMu.Lock()
defer t.connsPerHostMu.Unlock()
w.cancelCtx = nil
}()
}
// dialConnFor dials on behalf of w and delivers the result to w.
// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
@ -1545,7 +1598,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
go t.dialConnFor(w)
t.startDialConnForLocked(w)
done = true
break
}
@ -1626,6 +1679,8 @@ type erringRoundTripper interface {
RoundTripErr() error
}
var testHookProxyConnectTimeout = context.WithTimeout
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
pconn = &persistConn{
t: t,
@ -1742,17 +1797,11 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
Header: hdr,
}
// If there's no done channel (no deadline or cancellation
// from the caller possible), at least set some (long)
// timeout here. This will make sure we don't block forever
// and leak a goroutine if the connection stops replying
// after the TCP connect.
connectCtx := ctx
if ctx.Done() == nil {
newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
connectCtx = newCtx
}
// Set a (long) timeout here to make sure we don't block forever
// and leak a goroutine if the connection stops replying after
// the TCP connect.
connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute)
defer cancel()
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
var (

View File

@ -0,0 +1,235 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http_test
import (
"context"
"io"
"net"
"net/http"
"net/http/httptrace"
"testing"
)
func TestTransportPoolConnReusePriorConnection(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
rt1.finish()
// Second request reuses the first connection.
rt2 := dt.roundTrip()
rt2.wantDone(c1)
rt2.finish()
}
func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
// Second request is made while the first request is still using its connection,
// so it goes on a new connection.
rt2 := dt.roundTrip()
c2 := dt.wantDial()
c2.finish(nil)
rt2.wantDone(c2)
}
func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
// Second request is made while the first request is still using its connection.
// The first connection completes while the second Dial is in progress, so the
// second request uses the first connection.
rt2 := dt.roundTrip()
c2 := dt.wantDial()
rt1.finish()
rt2.wantDone(c1)
// This section is a bit overfitted to the current Transport implementation:
// A third request starts. We have an in-progress dial that was started by rt2,
// but this new request (rt3) is going to ignore it and make a dial of its own.
// rt3 will use the first of these dials that completes.
rt3 := dt.roundTrip()
c3 := dt.wantDial()
c2.finish(nil)
rt3.wantDone(c2)
c3.finish(nil)
}
// A transportDialTester manages a test of a connection's Dials.
type transportDialTester struct {
t *testing.T
cst *clientServerTest
dials chan *transportDialTesterConn // each new conn is sent to this channel
roundTripCount int
dialCount int
}
// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
type transportDialTesterRoundTrip struct {
t *testing.T
roundTripID int // distinguishes RoundTrips in logs
cancel context.CancelFunc // cancels the Request context
reqBody io.WriteCloser // write half of the Request.Body
finished bool
done chan struct{} // closed when RoundTrip returns:w
res *http.Response
err error
conn *transportDialTesterConn
}
// A transportDialTesterConn is a client connection created by the Transport as
// part of a dial test.
type transportDialTesterConn struct {
t *testing.T
connID int // distinguished Dials in logs
ready chan error // sent on to complete the Dial
net.Conn
}
func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
t.Helper()
dt := &transportDialTester{
t: t,
dials: make(chan *transportDialTesterConn),
}
dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write response headers when we receive a request.
http.NewResponseController(w).EnableFullDuplex()
w.WriteHeader(200)
http.NewResponseController(w).Flush()
// Wait for the client to send the request body,
// to synchronize with the rest of the test.
io.ReadAll(r.Body)
}), func(tr *http.Transport) {
tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
c := &transportDialTesterConn{
t: t,
ready: make(chan error),
}
// Notify the test that a Dial has started,
// and wait for the test to notify us that it should complete.
dt.dials <- c
if err := <-c.ready; err != nil {
return nil, err
}
nc, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// Use the *transportDialTesterConn as the net.Conn,
// to let tests associate requests with connections.
c.Conn = nc
return c, err
}
})
return dt
}
// roundTrip starts a RoundTrip.
// It returns immediately, without waiting for the RoundTrip call to complete.
func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
dt.t.Helper()
ctx, cancel := context.WithCancel(context.Background())
pr, pw := io.Pipe()
rt := &transportDialTesterRoundTrip{
t: dt.t,
roundTripID: dt.roundTripCount,
done: make(chan struct{}),
reqBody: pw,
cancel: cancel,
}
dt.roundTripCount++
dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
dt.t.Cleanup(func() {
rt.cancel()
rt.finish()
})
go func() {
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
rt.conn = info.Conn.(*transportDialTesterConn)
},
})
req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
req.Header.Set("Content-Type", "text/plain")
rt.res, rt.err = dt.cst.tr.RoundTrip(req)
dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
close(rt.done)
}()
return rt
}
// wantDone indicates that a RoundTrip should have returned.
func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
rt.t.Helper()
<-rt.done
if rt.err != nil {
rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
}
if rt.conn != c {
rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
}
}
// finish completes a RoundTrip by sending the request body, consuming the response body,
// and closing the response body.
func (rt *transportDialTesterRoundTrip) finish() {
rt.t.Helper()
if rt.finished {
return
}
rt.finished = true
<-rt.done
if rt.err != nil {
return
}
rt.reqBody.Close()
io.ReadAll(rt.res.Body)
rt.res.Body.Close()
rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
}
// wantDial waits for the Transport to start a Dial.
func (dt *transportDialTester) wantDial() *transportDialTesterConn {
c := <-dt.dials
c.connID = dt.dialCount
dt.dialCount++
dt.t.Logf("Dial %v: started", c.connID)
return c
}
// finish completes a Dial.
func (c *transportDialTesterConn) finish(err error) {
c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
c.ready <- err
close(c.ready)
}

View File

@ -1626,11 +1626,20 @@ func TestOnProxyConnectResponse(t *testing.T) {
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
// when they're slow to reply to HTTPS CONNECT responses.
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
setParallel(t)
defer afterTest(t)
cancelc := make(chan struct{})
SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-cancelc:
case <-ctx.Done():
}
cancel()
}()
return ctx, cancel
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer afterTest(t)
ln := newLocalListener(t)
defer ln.Close()
@ -1658,7 +1667,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
// Now hang and never write a response; instead, cancel the request and wait
// for the client to close.
// (Prior to Issue 28012 being fixed, we never closed.)
cancel()
close(cancelc)
var buf [1]byte
_, err = br.Read(buf[:])
if err != io.EOF {
@ -1674,7 +1683,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
},
},
}
req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
if err != nil {
t.Fatal(err)
}
@ -3927,9 +3936,13 @@ func testTransportDialTLS(t *testing.T, mode testMode) {
func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
func testTransportDialContext(t *testing.T, mode testMode) {
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ctxKey := "some-key"
ctxValue := "some-value"
var (
mu sync.Mutex // guards following
gotReq bool
gotCtxValue any
)
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
@ -3939,7 +3952,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
c := ts.Client()
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
gotCtxValue = ctx.Value(ctxKey)
mu.Unlock()
return net.Dial(netw, addr)
}
@ -3948,7 +3961,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
@ -3958,8 +3971,8 @@ func testTransportDialContext(t *testing.T, mode testMode) {
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
if got, want := gotCtxValue, ctxValue; got != want {
t.Errorf("got context with value %v, want %v", got, want)
}
}
@ -3967,9 +3980,13 @@ func TestTransportDialTLSContext(t *testing.T) {
run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
}
func testTransportDialTLSContext(t *testing.T, mode testMode) {
var mu sync.Mutex // guards following
var gotReq bool
var receivedContext context.Context
ctxKey := "some-key"
ctxValue := "some-value"
var (
mu sync.Mutex // guards following
gotReq bool
gotCtxValue any
)
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
@ -3979,7 +3996,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
c := ts.Client()
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
mu.Lock()
receivedContext = ctx
gotCtxValue = ctx.Value(ctxKey)
mu.Unlock()
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
if err != nil {
@ -3992,7 +4009,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
if err != nil {
t.Fatal(err)
}
ctx := context.WithValue(context.Background(), "some-key", "some-value")
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
res, err := c.Do(req.WithContext(ctx))
if err != nil {
t.Fatal(err)
@ -4002,8 +4019,8 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
if !gotReq {
t.Error("didn't get request")
}
if receivedContext != ctx {
t.Error("didn't receive correct context")
if got, want := gotCtxValue, ctxValue; got != want {
t.Errorf("got context with value %v, want %v", got, want)
}
}