net/http: fix cancelation of requests with a readTrackingBody wrapper

Use the original *Request in the reqCanceler map, not the transient
wrapper created to handle body rewinding.

Change the key of reqCanceler to a struct{*Request}, to make it more
difficult to accidentally use the wrong request as the key.

Fixes #40453.

Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757
Reviewed-on: https://go-review.googlesource.com/c/go/+/245357
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
This commit is contained in:
Damien Neil 2020-07-28 12:49:52 -07:00
parent f92337422e
commit f235275097
2 changed files with 86 additions and 31 deletions

View File

@ -100,7 +100,7 @@ type Transport struct {
idleLRU connLRU idleLRU connLRU
reqMu sync.Mutex reqMu sync.Mutex
reqCanceler map[*Request]func(error) reqCanceler map[cancelKey]func(error)
altMu sync.Mutex // guards changing altProto only altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@ -273,6 +273,13 @@ type Transport struct {
ForceAttemptHTTP2 bool ForceAttemptHTTP2 bool
} }
// A cancelKey is the key of the reqCanceler map.
// We wrap the *Request in this type since we want to use the original request,
// not any transient one created by roundTrip.
type cancelKey struct {
req *Request
}
func (t *Transport) writeBufferSize() int { func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 { if t.WriteBufferSize > 0 {
return t.WriteBufferSize return t.WriteBufferSize
@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
// optional extra headers to write and stores any error to return // optional extra headers to write and stores any error to return
// from roundTrip. // from roundTrip.
type transportRequest struct { type transportRequest struct {
*Request // original request, not to be mutated *Request // original request, not to be mutated
extra Header // extra headers to write, or nil extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional trace *httptrace.ClientTrace // optional
cancelKey cancelKey
mu sync.Mutex // guards err mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider err error // first setError value for mapRoundTripError to consider
@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
} }
origReq := req origReq := req
cancelKey := cancelKey{origReq}
req = setupRewindBody(req) req = setupRewindBody(req)
if altRT := t.alternateRoundTripper(req); altRT != nil { if altRT := t.alternateRoundTripper(req); altRT != nil {
@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
} }
// treq gets modified by roundTrip, so we need to recreate for each retry. // treq gets modified by roundTrip, so we need to recreate for each retry.
treq := &transportRequest{Request: req, trace: trace} treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
cm, err := t.connectMethodForRequest(treq) cm, err := t.connectMethodForRequest(treq)
if err != nil { if err != nil {
req.closeBody() req.closeBody()
@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
// to send it requests. // to send it requests.
pconn, err := t.getConn(treq, cm) pconn, err := t.getConn(treq, cm)
if err != nil { if err != nil {
t.setReqCanceler(req, nil) t.setReqCanceler(cancelKey, nil)
req.closeBody() req.closeBody()
return nil, err return nil, err
} }
@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response var resp *Response
if pconn.alt != nil { if pconn.alt != nil {
// HTTP/2 path. // HTTP/2 path.
t.setReqCanceler(req, nil) // not cancelable with CancelRequest t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req) resp, err = pconn.alt.RoundTrip(req)
} else { } else {
resp, err = pconn.roundTrip(treq) resp, err = pconn.roundTrip(treq)
@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2 // cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests. // requests.
func (t *Transport) CancelRequest(req *Request) { func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(req, errRequestCanceled) t.cancelRequest(cancelKey{req}, errRequestCanceled)
} }
// Cancel an in-flight request, recording the error value. // Cancel an in-flight request, recording the error value.
func (t *Transport) cancelRequest(req *Request, err error) { func (t *Transport) cancelRequest(key cancelKey, err error) {
t.reqMu.Lock() t.reqMu.Lock()
cancel := t.reqCanceler[req] cancel := t.reqCanceler[key]
delete(t.reqCanceler, req) delete(t.reqCanceler, key)
t.reqMu.Unlock() t.reqMu.Unlock()
if cancel != nil { if cancel != nil {
cancel(err) cancel(err)
@ -1093,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
return removed return removed
} }
func (t *Transport) setReqCanceler(r *Request, fn func(error)) { func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
t.reqMu.Lock() t.reqMu.Lock()
defer t.reqMu.Unlock() defer t.reqMu.Unlock()
if t.reqCanceler == nil { if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]func(error)) t.reqCanceler = make(map[cancelKey]func(error))
} }
if fn != nil { if fn != nil {
t.reqCanceler[r] = fn t.reqCanceler[key] = fn
} else { } else {
delete(t.reqCanceler, r) delete(t.reqCanceler, key)
} }
} }
@ -1110,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
// for the request, we don't set the function and return false. // for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if // Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call. // the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
t.reqMu.Lock() t.reqMu.Lock()
defer t.reqMu.Unlock() defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r] _, ok := t.reqCanceler[key]
if !ok { if !ok {
return false return false
} }
if fn != nil { if fn != nil {
t.reqCanceler[r] = fn t.reqCanceler[key] = fn
} else { } else {
delete(t.reqCanceler, r) delete(t.reqCanceler, key)
} }
return true return true
} }
@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
// set request canceler to some non-nil function so we // set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when // can detect whether it was cleared between now and when
// we enter roundTrip // we enter roundTrip
t.setReqCanceler(req, func(error) {}) t.setReqCanceler(treq.cancelKey, func(error) {})
return pc, nil return pc, nil
} }
cancelc := make(chan error, 1) cancelc := make(chan error, 1)
t.setReqCanceler(req, func(err error) { cancelc <- err }) t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
// Queue for permission to dial. // Queue for permission to dial.
t.queueForDial(w) t.queueForDial(w)
@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() {
} }
if !hasBody || bodyWritable { if !hasBody || bodyWritable {
pc.t.setReqCanceler(rc.req, nil) pc.t.setReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response // Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll // so if they process it quickly and make another request, they'll
@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death) // reading the response body. (or for cancellation or death)
select { select {
case bodyEOF := <-waitForBodyRead: case bodyEOF := <-waitForBodyRead:
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive && alive = alive &&
bodyEOF && bodyEOF &&
!pc.sawEOF && !pc.sawEOF &&
@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req) pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done(): case <-rc.req.Context().Done():
alive = false alive = false
pc.t.cancelRequest(rc.req, rc.req.Context().Err()) pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
case <-pc.closech: case <-pc.closech:
alive = false alive = false
} }
@ -2408,9 +2417,10 @@ type responseAndError struct {
} }
type requestAndChan struct { type requestAndChan struct {
_ incomparable _ incomparable
req *Request req *Request
ch chan responseAndError // unbuffered; always send in select on callerGone cancelKey cancelKey
ch chan responseAndError // unbuffered; always send in select on callerGone
// whether the Transport (as opposed to the user client code) // whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport // added the Accept-Encoding gzip header. If the Transport
@ -2472,7 +2482,7 @@ var (
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip() testHookEnterRoundTrip()
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
pc.t.putOrCloseIdleConn(pc) pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled return nil, errRequestCanceled
} }
@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
defer func() { defer func() {
if err != nil { if err != nil {
pc.t.setReqCanceler(req.Request, nil) pc.t.setReqCanceler(req.cancelKey, nil)
} }
}() }()
@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
resc := make(chan responseAndError) resc := make(chan responseAndError)
pc.reqch <- requestAndChan{ pc.reqch <- requestAndChan{
req: req.Request, req: req.Request,
cancelKey: req.cancelKey,
ch: resc, ch: resc,
addedGzip: requestedGzip, addedGzip: requestedGzip,
continueCh: continueCh, continueCh: continueCh,
@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
} }
return re.res, nil return re.res, nil
case <-cancelChan: case <-cancelChan:
pc.t.CancelRequest(req.Request) pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil cancelChan = nil
case <-ctxDoneChan: case <-ctxDoneChan:
pc.t.cancelRequest(req.Request, req.Context().Err()) pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil cancelChan = nil
ctxDoneChan = nil ctxDoneChan = nil
} }

View File

@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) {
} }
} }
func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
}))
defer ts.Close()
defer close(unblockc)
c := ts.Client()
tr := c.Transport.(*Transport)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
go func() {
defer close(donec)
c.Do(req)
}()
start := time.Now()
timeout := 10 * time.Second
for time.Since(start) < timeout {
time.Sleep(100 * time.Millisecond)
tr.CancelRequest(req)
select {
case <-donec:
return
default:
}
}
t.Errorf("Do of canceled request has not returned after %v", timeout)
}
func TestTransportCancelRequestInDo(t *testing.T) {
testTransportCancelRequestInDo(t, nil)
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
}
func TestTransportCancelRequestInDial(t *testing.T) { func TestTransportCancelRequestInDial(t *testing.T) {
defer afterTest(t) defer afterTest(t)
if testing.Short() { if testing.Short() {