mirror of https://github.com/golang/go.git
net/http: don't cancel Dials when requests are canceled
Currently, when a Transport creates a new connection for a request, it uses the request's Context to make the Dial. If a request times out or is canceled before a Dial completes, the Dial is canceled. Change this so that the lifetime of a Dial call is not bound by the request that originated it. This change avoids a scenario where a Transport can start and then cancel many Dial calls in rapid succession: - Request starts a Dial. - A previous request completes, making its connection available. - The new request uses the now-idle connection, and completes. - The request Context is canceled, and the Dial is aborted. Fixes #59017 Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c Reviewed-on: https://go-review.googlesource.com/c/go/+/576555 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Cherry Mui <cherryyz@google.com>
This commit is contained in:
parent
4742c52e10
commit
334ce51004
|
|
@ -1938,21 +1938,25 @@ func TestClientCloseIdleConnections(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type testRoundTripper func(*Request) (*Response, error)
|
||||
|
||||
func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
|
||||
return t(req)
|
||||
}
|
||||
|
||||
func TestClientPropagatesTimeoutToContext(t *testing.T) {
|
||||
errDial := errors.New("not actually dialing")
|
||||
c := &Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &Transport{
|
||||
DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("no deadline")
|
||||
} else {
|
||||
t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
|
||||
}
|
||||
return nil, errDial
|
||||
},
|
||||
},
|
||||
Transport: testRoundTripper(func(req *Request) (*Response, error) {
|
||||
ctx := req.Context()
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("no deadline")
|
||||
} else {
|
||||
t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
|
||||
}
|
||||
return nil, errors.New("not actually making a request")
|
||||
}),
|
||||
}
|
||||
c.Get("https://example.tld/")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,14 @@ func SetPendingDialHooks(before, after func()) {
|
|||
|
||||
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
|
||||
|
||||
func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
|
||||
orig := testHookProxyConnectTimeout
|
||||
t.Cleanup(func() {
|
||||
testHookProxyConnectTimeout = orig
|
||||
})
|
||||
testHookProxyConnectTimeout = f
|
||||
}
|
||||
|
||||
func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
|
||||
return &timeoutHandler{
|
||||
handler: handler,
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ type Transport struct {
|
|||
connsPerHostMu sync.Mutex
|
||||
connsPerHost map[connectMethodKey]int
|
||||
connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
|
||||
dialsInProgress wantConnQueue
|
||||
|
||||
// Proxy specifies a function to return a proxy for a given
|
||||
// Request. If the function returns a non-nil error, the
|
||||
|
|
@ -807,6 +808,13 @@ func (t *Transport) CloseIdleConnections() {
|
|||
pconn.close(errCloseIdleConns)
|
||||
}
|
||||
}
|
||||
t.connsPerHostMu.Lock()
|
||||
t.dialsInProgress.all(func(w *wantConn) {
|
||||
if w.cancelCtx != nil && !w.waiting() {
|
||||
w.cancelCtx()
|
||||
}
|
||||
})
|
||||
t.connsPerHostMu.Unlock()
|
||||
if t2 := t.h2transport; t2 != nil {
|
||||
t2.CloseIdleConnections()
|
||||
}
|
||||
|
|
@ -1116,7 +1124,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
|
|||
t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
|
||||
}
|
||||
q := t.idleConnWait[w.key]
|
||||
q.cleanFront()
|
||||
q.cleanFrontNotWaiting()
|
||||
q.pushBack(w)
|
||||
t.idleConnWait[w.key] = q
|
||||
return false
|
||||
|
|
@ -1230,10 +1238,11 @@ type wantConn struct {
|
|||
beforeDial func()
|
||||
afterDial func()
|
||||
|
||||
mu sync.Mutex // protects ctx, done and sending of the result
|
||||
ctx context.Context // context for dial, cleared after delivered or canceled
|
||||
done bool // true after delivered or canceled
|
||||
result chan connOrError // channel to deliver connection or error
|
||||
mu sync.Mutex // protects ctx, done and sending of the result
|
||||
ctx context.Context // context for dial, cleared after delivered or canceled
|
||||
cancelCtx context.CancelFunc
|
||||
done bool // true after delivered or canceled
|
||||
result chan connOrError // channel to deliver connection or error
|
||||
}
|
||||
|
||||
type connOrError struct {
|
||||
|
|
@ -1352,9 +1361,9 @@ func (q *wantConnQueue) peekFront() *wantConn {
|
|||
return nil
|
||||
}
|
||||
|
||||
// cleanFront pops any wantConns that are no longer waiting from the head of the
|
||||
// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the
|
||||
// queue, reporting whether any were popped.
|
||||
func (q *wantConnQueue) cleanFront() (cleaned bool) {
|
||||
func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) {
|
||||
for {
|
||||
w := q.peekFront()
|
||||
if w == nil || w.waiting() {
|
||||
|
|
@ -1365,6 +1374,28 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
|
|||
}
|
||||
}
|
||||
|
||||
// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue.
|
||||
func (q *wantConnQueue) cleanFrontCanceled() {
|
||||
for {
|
||||
w := q.peekFront()
|
||||
if w == nil || w.cancelCtx != nil {
|
||||
return
|
||||
}
|
||||
q.popFront()
|
||||
}
|
||||
}
|
||||
|
||||
// all iterates over all wantConns in the queue.
|
||||
// The caller must not modify the queue while iterating.
|
||||
func (q *wantConnQueue) all(f func(*wantConn)) {
|
||||
for _, w := range q.head[q.headPos:] {
|
||||
f(w)
|
||||
}
|
||||
for _, w := range q.tail {
|
||||
f(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
if t.DialTLSContext != nil {
|
||||
conn, err = t.DialTLSContext(ctx, network, addr)
|
||||
|
|
@ -1389,10 +1420,18 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
|
|||
trace.GetConn(cm.addr())
|
||||
}
|
||||
|
||||
// Detach from the request context's cancellation signal.
|
||||
// The dial should proceed even if the request is canceled,
|
||||
// because a future request may be able to make use of the connection.
|
||||
//
|
||||
// We retain the request context's values.
|
||||
dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
|
||||
w := &wantConn{
|
||||
cm: cm,
|
||||
key: cm.key(),
|
||||
ctx: ctx,
|
||||
ctx: dialCtx,
|
||||
cancelCtx: dialCancel,
|
||||
result: make(chan connOrError, 1),
|
||||
beforeDial: testHookPrePendingDial,
|
||||
afterDial: testHookPostPendingDial,
|
||||
|
|
@ -1470,20 +1509,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
|
|||
// Once w receives permission to dial, it will do so in a separate goroutine.
|
||||
func (t *Transport) queueForDial(w *wantConn) {
|
||||
w.beforeDial()
|
||||
if t.MaxConnsPerHost <= 0 {
|
||||
go t.dialConnFor(w)
|
||||
return
|
||||
}
|
||||
|
||||
t.connsPerHostMu.Lock()
|
||||
defer t.connsPerHostMu.Unlock()
|
||||
|
||||
if t.MaxConnsPerHost <= 0 {
|
||||
t.startDialConnForLocked(w)
|
||||
return
|
||||
}
|
||||
|
||||
if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
|
||||
if t.connsPerHost == nil {
|
||||
t.connsPerHost = make(map[connectMethodKey]int)
|
||||
}
|
||||
t.connsPerHost[w.key] = n + 1
|
||||
go t.dialConnFor(w)
|
||||
t.startDialConnForLocked(w)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1491,11 +1531,24 @@ func (t *Transport) queueForDial(w *wantConn) {
|
|||
t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
|
||||
}
|
||||
q := t.connsPerHostWait[w.key]
|
||||
q.cleanFront()
|
||||
q.cleanFrontNotWaiting()
|
||||
q.pushBack(w)
|
||||
t.connsPerHostWait[w.key] = q
|
||||
}
|
||||
|
||||
// startDialConnFor calls dialConn in a new goroutine.
|
||||
// t.connsPerHostMu must be held.
|
||||
func (t *Transport) startDialConnForLocked(w *wantConn) {
|
||||
t.dialsInProgress.cleanFrontCanceled()
|
||||
t.dialsInProgress.pushBack(w)
|
||||
go func() {
|
||||
t.dialConnFor(w)
|
||||
t.connsPerHostMu.Lock()
|
||||
defer t.connsPerHostMu.Unlock()
|
||||
w.cancelCtx = nil
|
||||
}()
|
||||
}
|
||||
|
||||
// dialConnFor dials on behalf of w and delivers the result to w.
|
||||
// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
|
||||
// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
|
||||
|
|
@ -1545,7 +1598,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) {
|
|||
for q.len() > 0 {
|
||||
w := q.popFront()
|
||||
if w.waiting() {
|
||||
go t.dialConnFor(w)
|
||||
t.startDialConnForLocked(w)
|
||||
done = true
|
||||
break
|
||||
}
|
||||
|
|
@ -1626,6 +1679,8 @@ type erringRoundTripper interface {
|
|||
RoundTripErr() error
|
||||
}
|
||||
|
||||
var testHookProxyConnectTimeout = context.WithTimeout
|
||||
|
||||
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
|
||||
pconn = &persistConn{
|
||||
t: t,
|
||||
|
|
@ -1742,17 +1797,11 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
|
|||
Header: hdr,
|
||||
}
|
||||
|
||||
// If there's no done channel (no deadline or cancellation
|
||||
// from the caller possible), at least set some (long)
|
||||
// timeout here. This will make sure we don't block forever
|
||||
// and leak a goroutine if the connection stops replying
|
||||
// after the TCP connect.
|
||||
connectCtx := ctx
|
||||
if ctx.Done() == nil {
|
||||
newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
||||
defer cancel()
|
||||
connectCtx = newCtx
|
||||
}
|
||||
// Set a (long) timeout here to make sure we don't block forever
|
||||
// and leak a goroutine if the connection stops replying after
|
||||
// the TCP connect.
|
||||
connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,235 @@
|
|||
// Copyright 2024 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTransportPoolConnReusePriorConnection(t *testing.T) {
|
||||
dt := newTransportDialTester(t, http1Mode)
|
||||
|
||||
// First request creates a new connection.
|
||||
rt1 := dt.roundTrip()
|
||||
c1 := dt.wantDial()
|
||||
c1.finish(nil)
|
||||
rt1.wantDone(c1)
|
||||
rt1.finish()
|
||||
|
||||
// Second request reuses the first connection.
|
||||
rt2 := dt.roundTrip()
|
||||
rt2.wantDone(c1)
|
||||
rt2.finish()
|
||||
}
|
||||
|
||||
func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
|
||||
dt := newTransportDialTester(t, http1Mode)
|
||||
|
||||
// First request creates a new connection.
|
||||
rt1 := dt.roundTrip()
|
||||
c1 := dt.wantDial()
|
||||
c1.finish(nil)
|
||||
rt1.wantDone(c1)
|
||||
|
||||
// Second request is made while the first request is still using its connection,
|
||||
// so it goes on a new connection.
|
||||
rt2 := dt.roundTrip()
|
||||
c2 := dt.wantDial()
|
||||
c2.finish(nil)
|
||||
rt2.wantDone(c2)
|
||||
}
|
||||
|
||||
func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
|
||||
dt := newTransportDialTester(t, http1Mode)
|
||||
|
||||
// First request creates a new connection.
|
||||
rt1 := dt.roundTrip()
|
||||
c1 := dt.wantDial()
|
||||
c1.finish(nil)
|
||||
rt1.wantDone(c1)
|
||||
|
||||
// Second request is made while the first request is still using its connection.
|
||||
// The first connection completes while the second Dial is in progress, so the
|
||||
// second request uses the first connection.
|
||||
rt2 := dt.roundTrip()
|
||||
c2 := dt.wantDial()
|
||||
rt1.finish()
|
||||
rt2.wantDone(c1)
|
||||
|
||||
// This section is a bit overfitted to the current Transport implementation:
|
||||
// A third request starts. We have an in-progress dial that was started by rt2,
|
||||
// but this new request (rt3) is going to ignore it and make a dial of its own.
|
||||
// rt3 will use the first of these dials that completes.
|
||||
rt3 := dt.roundTrip()
|
||||
c3 := dt.wantDial()
|
||||
c2.finish(nil)
|
||||
rt3.wantDone(c2)
|
||||
|
||||
c3.finish(nil)
|
||||
}
|
||||
|
||||
// A transportDialTester manages a test of a connection's Dials.
|
||||
type transportDialTester struct {
|
||||
t *testing.T
|
||||
cst *clientServerTest
|
||||
|
||||
dials chan *transportDialTesterConn // each new conn is sent to this channel
|
||||
|
||||
roundTripCount int
|
||||
dialCount int
|
||||
}
|
||||
|
||||
// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
|
||||
type transportDialTesterRoundTrip struct {
|
||||
t *testing.T
|
||||
|
||||
roundTripID int // distinguishes RoundTrips in logs
|
||||
cancel context.CancelFunc // cancels the Request context
|
||||
reqBody io.WriteCloser // write half of the Request.Body
|
||||
finished bool
|
||||
|
||||
done chan struct{} // closed when RoundTrip returns:w
|
||||
res *http.Response
|
||||
err error
|
||||
conn *transportDialTesterConn
|
||||
}
|
||||
|
||||
// A transportDialTesterConn is a client connection created by the Transport as
|
||||
// part of a dial test.
|
||||
type transportDialTesterConn struct {
|
||||
t *testing.T
|
||||
|
||||
connID int // distinguished Dials in logs
|
||||
ready chan error // sent on to complete the Dial
|
||||
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
|
||||
t.Helper()
|
||||
dt := &transportDialTester{
|
||||
t: t,
|
||||
dials: make(chan *transportDialTesterConn),
|
||||
}
|
||||
dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Write response headers when we receive a request.
|
||||
http.NewResponseController(w).EnableFullDuplex()
|
||||
w.WriteHeader(200)
|
||||
http.NewResponseController(w).Flush()
|
||||
// Wait for the client to send the request body,
|
||||
// to synchronize with the rest of the test.
|
||||
io.ReadAll(r.Body)
|
||||
}), func(tr *http.Transport) {
|
||||
tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c := &transportDialTesterConn{
|
||||
t: t,
|
||||
ready: make(chan error),
|
||||
}
|
||||
// Notify the test that a Dial has started,
|
||||
// and wait for the test to notify us that it should complete.
|
||||
dt.dials <- c
|
||||
if err := <-c.ready; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nc, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use the *transportDialTesterConn as the net.Conn,
|
||||
// to let tests associate requests with connections.
|
||||
c.Conn = nc
|
||||
return c, err
|
||||
}
|
||||
})
|
||||
return dt
|
||||
}
|
||||
|
||||
// roundTrip starts a RoundTrip.
|
||||
// It returns immediately, without waiting for the RoundTrip call to complete.
|
||||
func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
|
||||
dt.t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pr, pw := io.Pipe()
|
||||
rt := &transportDialTesterRoundTrip{
|
||||
t: dt.t,
|
||||
roundTripID: dt.roundTripCount,
|
||||
done: make(chan struct{}),
|
||||
reqBody: pw,
|
||||
cancel: cancel,
|
||||
}
|
||||
dt.roundTripCount++
|
||||
dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
|
||||
dt.t.Cleanup(func() {
|
||||
rt.cancel()
|
||||
rt.finish()
|
||||
})
|
||||
go func() {
|
||||
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
rt.conn = info.Conn.(*transportDialTesterConn)
|
||||
},
|
||||
})
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
rt.res, rt.err = dt.cst.tr.RoundTrip(req)
|
||||
dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
|
||||
close(rt.done)
|
||||
}()
|
||||
return rt
|
||||
}
|
||||
|
||||
// wantDone indicates that a RoundTrip should have returned.
|
||||
func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
|
||||
rt.t.Helper()
|
||||
<-rt.done
|
||||
if rt.err != nil {
|
||||
rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
|
||||
}
|
||||
if rt.conn != c {
|
||||
rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
|
||||
}
|
||||
}
|
||||
|
||||
// finish completes a RoundTrip by sending the request body, consuming the response body,
|
||||
// and closing the response body.
|
||||
func (rt *transportDialTesterRoundTrip) finish() {
|
||||
rt.t.Helper()
|
||||
|
||||
if rt.finished {
|
||||
return
|
||||
}
|
||||
rt.finished = true
|
||||
|
||||
<-rt.done
|
||||
|
||||
if rt.err != nil {
|
||||
return
|
||||
}
|
||||
rt.reqBody.Close()
|
||||
io.ReadAll(rt.res.Body)
|
||||
rt.res.Body.Close()
|
||||
rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
|
||||
}
|
||||
|
||||
// wantDial waits for the Transport to start a Dial.
|
||||
func (dt *transportDialTester) wantDial() *transportDialTesterConn {
|
||||
c := <-dt.dials
|
||||
c.connID = dt.dialCount
|
||||
dt.dialCount++
|
||||
dt.t.Logf("Dial %v: started", c.connID)
|
||||
return c
|
||||
}
|
||||
|
||||
// finish completes a Dial.
|
||||
func (c *transportDialTesterConn) finish(err error) {
|
||||
c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
|
||||
c.ready <- err
|
||||
close(c.ready)
|
||||
}
|
||||
|
|
@ -1626,11 +1626,20 @@ func TestOnProxyConnectResponse(t *testing.T) {
|
|||
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
|
||||
// when they're slow to reply to HTTPS CONNECT responses.
|
||||
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
cancelc := make(chan struct{})
|
||||
SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
select {
|
||||
case <-cancelc:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
return ctx, cancel
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
defer afterTest(t)
|
||||
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
|
|
@ -1658,7 +1667,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
|
|||
// Now hang and never write a response; instead, cancel the request and wait
|
||||
// for the client to close.
|
||||
// (Prior to Issue 28012 being fixed, we never closed.)
|
||||
cancel()
|
||||
close(cancelc)
|
||||
var buf [1]byte
|
||||
_, err = br.Read(buf[:])
|
||||
if err != io.EOF {
|
||||
|
|
@ -1674,7 +1683,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
|
||||
req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -3927,9 +3936,13 @@ func testTransportDialTLS(t *testing.T, mode testMode) {
|
|||
|
||||
func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
|
||||
func testTransportDialContext(t *testing.T, mode testMode) {
|
||||
var mu sync.Mutex // guards following
|
||||
var gotReq bool
|
||||
var receivedContext context.Context
|
||||
ctxKey := "some-key"
|
||||
ctxValue := "some-value"
|
||||
var (
|
||||
mu sync.Mutex // guards following
|
||||
gotReq bool
|
||||
gotCtxValue any
|
||||
)
|
||||
|
||||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
mu.Lock()
|
||||
|
|
@ -3939,7 +3952,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
|
|||
c := ts.Client()
|
||||
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
mu.Lock()
|
||||
receivedContext = ctx
|
||||
gotCtxValue = ctx.Value(ctxKey)
|
||||
mu.Unlock()
|
||||
return net.Dial(netw, addr)
|
||||
}
|
||||
|
|
@ -3948,7 +3961,7 @@ func testTransportDialContext(t *testing.T, mode testMode) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), "some-key", "some-value")
|
||||
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
|
||||
res, err := c.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -3958,8 +3971,8 @@ func testTransportDialContext(t *testing.T, mode testMode) {
|
|||
if !gotReq {
|
||||
t.Error("didn't get request")
|
||||
}
|
||||
if receivedContext != ctx {
|
||||
t.Error("didn't receive correct context")
|
||||
if got, want := gotCtxValue, ctxValue; got != want {
|
||||
t.Errorf("got context with value %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3967,9 +3980,13 @@ func TestTransportDialTLSContext(t *testing.T) {
|
|||
run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
|
||||
}
|
||||
func testTransportDialTLSContext(t *testing.T, mode testMode) {
|
||||
var mu sync.Mutex // guards following
|
||||
var gotReq bool
|
||||
var receivedContext context.Context
|
||||
ctxKey := "some-key"
|
||||
ctxValue := "some-value"
|
||||
var (
|
||||
mu sync.Mutex // guards following
|
||||
gotReq bool
|
||||
gotCtxValue any
|
||||
)
|
||||
|
||||
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
mu.Lock()
|
||||
|
|
@ -3979,7 +3996,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
|
|||
c := ts.Client()
|
||||
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
mu.Lock()
|
||||
receivedContext = ctx
|
||||
gotCtxValue = ctx.Value(ctxKey)
|
||||
mu.Unlock()
|
||||
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
|
||||
if err != nil {
|
||||
|
|
@ -3992,7 +4009,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), "some-key", "some-value")
|
||||
ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
|
||||
res, err := c.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -4002,8 +4019,8 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) {
|
|||
if !gotReq {
|
||||
t.Error("didn't get request")
|
||||
}
|
||||
if receivedContext != ctx {
|
||||
t.Error("didn't receive correct context")
|
||||
if got, want := gotCtxValue, ctxValue; got != want {
|
||||
t.Errorf("got context with value %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue