mirror of https://github.com/golang/go.git
net/http: fix cancelation of requests with a readTrackingBody wrapper
Use the original *Request in the reqCanceler map, not the transient
wrapper created to handle body rewinding.
Change the key of reqCanceler to a struct{*Request}, to make it more
difficult to accidentally use the wrong request as the key.
Fixes #40453.
Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757
Reviewed-on: https://go-review.googlesource.com/c/go/+/245357
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
This commit is contained in:
parent
f92337422e
commit
f235275097
|
|
@ -100,7 +100,7 @@ type Transport struct {
|
||||||
idleLRU connLRU
|
idleLRU connLRU
|
||||||
|
|
||||||
reqMu sync.Mutex
|
reqMu sync.Mutex
|
||||||
reqCanceler map[*Request]func(error)
|
reqCanceler map[cancelKey]func(error)
|
||||||
|
|
||||||
altMu sync.Mutex // guards changing altProto only
|
altMu sync.Mutex // guards changing altProto only
|
||||||
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
|
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
|
||||||
|
|
@ -273,6 +273,13 @@ type Transport struct {
|
||||||
ForceAttemptHTTP2 bool
|
ForceAttemptHTTP2 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A cancelKey is the key of the reqCanceler map.
|
||||||
|
// We wrap the *Request in this type since we want to use the original request,
|
||||||
|
// not any transient one created by roundTrip.
|
||||||
|
type cancelKey struct {
|
||||||
|
req *Request
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) writeBufferSize() int {
|
func (t *Transport) writeBufferSize() int {
|
||||||
if t.WriteBufferSize > 0 {
|
if t.WriteBufferSize > 0 {
|
||||||
return t.WriteBufferSize
|
return t.WriteBufferSize
|
||||||
|
|
@ -436,6 +443,7 @@ type transportRequest struct {
|
||||||
*Request // original request, not to be mutated
|
*Request // original request, not to be mutated
|
||||||
extra Header // extra headers to write, or nil
|
extra Header // extra headers to write, or nil
|
||||||
trace *httptrace.ClientTrace // optional
|
trace *httptrace.ClientTrace // optional
|
||||||
|
cancelKey cancelKey
|
||||||
|
|
||||||
mu sync.Mutex // guards err
|
mu sync.Mutex // guards err
|
||||||
err error // first setError value for mapRoundTripError to consider
|
err error // first setError value for mapRoundTripError to consider
|
||||||
|
|
@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
origReq := req
|
origReq := req
|
||||||
|
cancelKey := cancelKey{origReq}
|
||||||
req = setupRewindBody(req)
|
req = setupRewindBody(req)
|
||||||
|
|
||||||
if altRT := t.alternateRoundTripper(req); altRT != nil {
|
if altRT := t.alternateRoundTripper(req); altRT != nil {
|
||||||
|
|
@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// treq gets modified by roundTrip, so we need to recreate for each retry.
|
// treq gets modified by roundTrip, so we need to recreate for each retry.
|
||||||
treq := &transportRequest{Request: req, trace: trace}
|
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
|
||||||
cm, err := t.connectMethodForRequest(treq)
|
cm, err := t.connectMethodForRequest(treq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.closeBody()
|
req.closeBody()
|
||||||
|
|
@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||||
// to send it requests.
|
// to send it requests.
|
||||||
pconn, err := t.getConn(treq, cm)
|
pconn, err := t.getConn(treq, cm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.setReqCanceler(req, nil)
|
t.setReqCanceler(cancelKey, nil)
|
||||||
req.closeBody()
|
req.closeBody()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||||
var resp *Response
|
var resp *Response
|
||||||
if pconn.alt != nil {
|
if pconn.alt != nil {
|
||||||
// HTTP/2 path.
|
// HTTP/2 path.
|
||||||
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
|
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
|
||||||
resp, err = pconn.alt.RoundTrip(req)
|
resp, err = pconn.alt.RoundTrip(req)
|
||||||
} else {
|
} else {
|
||||||
resp, err = pconn.roundTrip(treq)
|
resp, err = pconn.roundTrip(treq)
|
||||||
|
|
@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() {
|
||||||
// cancelable context instead. CancelRequest cannot cancel HTTP/2
|
// cancelable context instead. CancelRequest cannot cancel HTTP/2
|
||||||
// requests.
|
// requests.
|
||||||
func (t *Transport) CancelRequest(req *Request) {
|
func (t *Transport) CancelRequest(req *Request) {
|
||||||
t.cancelRequest(req, errRequestCanceled)
|
t.cancelRequest(cancelKey{req}, errRequestCanceled)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cancel an in-flight request, recording the error value.
|
// Cancel an in-flight request, recording the error value.
|
||||||
func (t *Transport) cancelRequest(req *Request, err error) {
|
func (t *Transport) cancelRequest(key cancelKey, err error) {
|
||||||
t.reqMu.Lock()
|
t.reqMu.Lock()
|
||||||
cancel := t.reqCanceler[req]
|
cancel := t.reqCanceler[key]
|
||||||
delete(t.reqCanceler, req)
|
delete(t.reqCanceler, key)
|
||||||
t.reqMu.Unlock()
|
t.reqMu.Unlock()
|
||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
cancel(err)
|
cancel(err)
|
||||||
|
|
@ -1093,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
|
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
|
||||||
t.reqMu.Lock()
|
t.reqMu.Lock()
|
||||||
defer t.reqMu.Unlock()
|
defer t.reqMu.Unlock()
|
||||||
if t.reqCanceler == nil {
|
if t.reqCanceler == nil {
|
||||||
t.reqCanceler = make(map[*Request]func(error))
|
t.reqCanceler = make(map[cancelKey]func(error))
|
||||||
}
|
}
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
t.reqCanceler[r] = fn
|
t.reqCanceler[key] = fn
|
||||||
} else {
|
} else {
|
||||||
delete(t.reqCanceler, r)
|
delete(t.reqCanceler, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1110,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
|
||||||
// for the request, we don't set the function and return false.
|
// for the request, we don't set the function and return false.
|
||||||
// Since CancelRequest will clear the canceler, we can use the return value to detect if
|
// Since CancelRequest will clear the canceler, we can use the return value to detect if
|
||||||
// the request was canceled since the last setReqCancel call.
|
// the request was canceled since the last setReqCancel call.
|
||||||
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
|
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
|
||||||
t.reqMu.Lock()
|
t.reqMu.Lock()
|
||||||
defer t.reqMu.Unlock()
|
defer t.reqMu.Unlock()
|
||||||
_, ok := t.reqCanceler[r]
|
_, ok := t.reqCanceler[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
t.reqCanceler[r] = fn
|
t.reqCanceler[key] = fn
|
||||||
} else {
|
} else {
|
||||||
delete(t.reqCanceler, r)
|
delete(t.reqCanceler, key)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
|
||||||
// set request canceler to some non-nil function so we
|
// set request canceler to some non-nil function so we
|
||||||
// can detect whether it was cleared between now and when
|
// can detect whether it was cleared between now and when
|
||||||
// we enter roundTrip
|
// we enter roundTrip
|
||||||
t.setReqCanceler(req, func(error) {})
|
t.setReqCanceler(treq.cancelKey, func(error) {})
|
||||||
return pc, nil
|
return pc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cancelc := make(chan error, 1)
|
cancelc := make(chan error, 1)
|
||||||
t.setReqCanceler(req, func(err error) { cancelc <- err })
|
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
|
||||||
|
|
||||||
// Queue for permission to dial.
|
// Queue for permission to dial.
|
||||||
t.queueForDial(w)
|
t.queueForDial(w)
|
||||||
|
|
@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasBody || bodyWritable {
|
if !hasBody || bodyWritable {
|
||||||
pc.t.setReqCanceler(rc.req, nil)
|
pc.t.setReqCanceler(rc.cancelKey, nil)
|
||||||
|
|
||||||
// Put the idle conn back into the pool before we send the response
|
// Put the idle conn back into the pool before we send the response
|
||||||
// so if they process it quickly and make another request, they'll
|
// so if they process it quickly and make another request, they'll
|
||||||
|
|
@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() {
|
||||||
// reading the response body. (or for cancellation or death)
|
// reading the response body. (or for cancellation or death)
|
||||||
select {
|
select {
|
||||||
case bodyEOF := <-waitForBodyRead:
|
case bodyEOF := <-waitForBodyRead:
|
||||||
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
|
pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
|
||||||
alive = alive &&
|
alive = alive &&
|
||||||
bodyEOF &&
|
bodyEOF &&
|
||||||
!pc.sawEOF &&
|
!pc.sawEOF &&
|
||||||
|
|
@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() {
|
||||||
pc.t.CancelRequest(rc.req)
|
pc.t.CancelRequest(rc.req)
|
||||||
case <-rc.req.Context().Done():
|
case <-rc.req.Context().Done():
|
||||||
alive = false
|
alive = false
|
||||||
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
|
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
|
||||||
case <-pc.closech:
|
case <-pc.closech:
|
||||||
alive = false
|
alive = false
|
||||||
}
|
}
|
||||||
|
|
@ -2410,6 +2419,7 @@ type responseAndError struct {
|
||||||
type requestAndChan struct {
|
type requestAndChan struct {
|
||||||
_ incomparable
|
_ incomparable
|
||||||
req *Request
|
req *Request
|
||||||
|
cancelKey cancelKey
|
||||||
ch chan responseAndError // unbuffered; always send in select on callerGone
|
ch chan responseAndError // unbuffered; always send in select on callerGone
|
||||||
|
|
||||||
// whether the Transport (as opposed to the user client code)
|
// whether the Transport (as opposed to the user client code)
|
||||||
|
|
@ -2472,7 +2482,7 @@ var (
|
||||||
|
|
||||||
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
|
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
|
||||||
testHookEnterRoundTrip()
|
testHookEnterRoundTrip()
|
||||||
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
|
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
|
||||||
pc.t.putOrCloseIdleConn(pc)
|
pc.t.putOrCloseIdleConn(pc)
|
||||||
return nil, errRequestCanceled
|
return nil, errRequestCanceled
|
||||||
}
|
}
|
||||||
|
|
@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pc.t.setReqCanceler(req.Request, nil)
|
pc.t.setReqCanceler(req.cancelKey, nil)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
||||||
resc := make(chan responseAndError)
|
resc := make(chan responseAndError)
|
||||||
pc.reqch <- requestAndChan{
|
pc.reqch <- requestAndChan{
|
||||||
req: req.Request,
|
req: req.Request,
|
||||||
|
cancelKey: req.cancelKey,
|
||||||
ch: resc,
|
ch: resc,
|
||||||
addedGzip: requestedGzip,
|
addedGzip: requestedGzip,
|
||||||
continueCh: continueCh,
|
continueCh: continueCh,
|
||||||
|
|
@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
||||||
}
|
}
|
||||||
return re.res, nil
|
return re.res, nil
|
||||||
case <-cancelChan:
|
case <-cancelChan:
|
||||||
pc.t.CancelRequest(req.Request)
|
pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
|
||||||
cancelChan = nil
|
cancelChan = nil
|
||||||
case <-ctxDoneChan:
|
case <-ctxDoneChan:
|
||||||
pc.t.cancelRequest(req.Request, req.Context().Err())
|
pc.t.cancelRequest(req.cancelKey, req.Context().Err())
|
||||||
cancelChan = nil
|
cancelChan = nil
|
||||||
ctxDoneChan = nil
|
ctxDoneChan = nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
|
||||||
|
setParallel(t)
|
||||||
|
defer afterTest(t)
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping test in -short mode")
|
||||||
|
}
|
||||||
|
unblockc := make(chan bool)
|
||||||
|
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
|
<-unblockc
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
defer close(unblockc)
|
||||||
|
|
||||||
|
c := ts.Client()
|
||||||
|
tr := c.Transport.(*Transport)
|
||||||
|
|
||||||
|
donec := make(chan bool)
|
||||||
|
req, _ := NewRequest("GET", ts.URL, body)
|
||||||
|
go func() {
|
||||||
|
defer close(donec)
|
||||||
|
c.Do(req)
|
||||||
|
}()
|
||||||
|
start := time.Now()
|
||||||
|
timeout := 10 * time.Second
|
||||||
|
for time.Since(start) < timeout {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
tr.CancelRequest(req)
|
||||||
|
select {
|
||||||
|
case <-donec:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf("Do of canceled request has not returned after %v", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportCancelRequestInDo(t *testing.T) {
|
||||||
|
testTransportCancelRequestInDo(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
|
||||||
|
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransportCancelRequestInDial(t *testing.T) {
|
func TestTransportCancelRequestInDial(t *testing.T) {
|
||||||
defer afterTest(t)
|
defer afterTest(t)
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue