mirror of https://github.com/golang/go.git
net/http: fix cancelation of requests with a readTrackingBody wrapper
Use the original *Request in the reqCanceler map, not the transient
wrapper created to handle body rewinding.
Change the key of reqCanceler to a struct{*Request}, to make it more
difficult to accidentally use the wrong request as the key.
Fixes #40453.
Change-Id: I4e61ee9ff2c794fb4c920a3a66c9a0458693d757
Reviewed-on: https://go-review.googlesource.com/c/go/+/245357
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
This commit is contained in:
parent
f92337422e
commit
f235275097
|
|
@ -100,7 +100,7 @@ type Transport struct {
|
|||
idleLRU connLRU
|
||||
|
||||
reqMu sync.Mutex
|
||||
reqCanceler map[*Request]func(error)
|
||||
reqCanceler map[cancelKey]func(error)
|
||||
|
||||
altMu sync.Mutex // guards changing altProto only
|
||||
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
|
||||
|
|
@ -273,6 +273,13 @@ type Transport struct {
|
|||
ForceAttemptHTTP2 bool
|
||||
}
|
||||
|
||||
// A cancelKey is the key of the reqCanceler map.
|
||||
// We wrap the *Request in this type since we want to use the original request,
|
||||
// not any transient one created by roundTrip.
|
||||
type cancelKey struct {
|
||||
req *Request
|
||||
}
|
||||
|
||||
func (t *Transport) writeBufferSize() int {
|
||||
if t.WriteBufferSize > 0 {
|
||||
return t.WriteBufferSize
|
||||
|
|
@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
|
|||
// optional extra headers to write and stores any error to return
|
||||
// from roundTrip.
|
||||
type transportRequest struct {
|
||||
*Request // original request, not to be mutated
|
||||
extra Header // extra headers to write, or nil
|
||||
trace *httptrace.ClientTrace // optional
|
||||
*Request // original request, not to be mutated
|
||||
extra Header // extra headers to write, or nil
|
||||
trace *httptrace.ClientTrace // optional
|
||||
cancelKey cancelKey
|
||||
|
||||
mu sync.Mutex // guards err
|
||||
err error // first setError value for mapRoundTripError to consider
|
||||
|
|
@ -512,6 +520,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
|||
}
|
||||
|
||||
origReq := req
|
||||
cancelKey := cancelKey{origReq}
|
||||
req = setupRewindBody(req)
|
||||
|
||||
if altRT := t.alternateRoundTripper(req); altRT != nil {
|
||||
|
|
@ -546,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
|||
}
|
||||
|
||||
// treq gets modified by roundTrip, so we need to recreate for each retry.
|
||||
treq := &transportRequest{Request: req, trace: trace}
|
||||
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
|
||||
cm, err := t.connectMethodForRequest(treq)
|
||||
if err != nil {
|
||||
req.closeBody()
|
||||
|
|
@ -559,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
|||
// to send it requests.
|
||||
pconn, err := t.getConn(treq, cm)
|
||||
if err != nil {
|
||||
t.setReqCanceler(req, nil)
|
||||
t.setReqCanceler(cancelKey, nil)
|
||||
req.closeBody()
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -567,7 +576,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
|||
var resp *Response
|
||||
if pconn.alt != nil {
|
||||
// HTTP/2 path.
|
||||
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
|
||||
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
|
||||
resp, err = pconn.alt.RoundTrip(req)
|
||||
} else {
|
||||
resp, err = pconn.roundTrip(treq)
|
||||
|
|
@ -753,14 +762,14 @@ func (t *Transport) CloseIdleConnections() {
|
|||
// cancelable context instead. CancelRequest cannot cancel HTTP/2
|
||||
// requests.
|
||||
func (t *Transport) CancelRequest(req *Request) {
|
||||
t.cancelRequest(req, errRequestCanceled)
|
||||
t.cancelRequest(cancelKey{req}, errRequestCanceled)
|
||||
}
|
||||
|
||||
// Cancel an in-flight request, recording the error value.
|
||||
func (t *Transport) cancelRequest(req *Request, err error) {
|
||||
func (t *Transport) cancelRequest(key cancelKey, err error) {
|
||||
t.reqMu.Lock()
|
||||
cancel := t.reqCanceler[req]
|
||||
delete(t.reqCanceler, req)
|
||||
cancel := t.reqCanceler[key]
|
||||
delete(t.reqCanceler, key)
|
||||
t.reqMu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel(err)
|
||||
|
|
@ -1093,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
|
|||
return removed
|
||||
}
|
||||
|
||||
func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
|
||||
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
|
||||
t.reqMu.Lock()
|
||||
defer t.reqMu.Unlock()
|
||||
if t.reqCanceler == nil {
|
||||
t.reqCanceler = make(map[*Request]func(error))
|
||||
t.reqCanceler = make(map[cancelKey]func(error))
|
||||
}
|
||||
if fn != nil {
|
||||
t.reqCanceler[r] = fn
|
||||
t.reqCanceler[key] = fn
|
||||
} else {
|
||||
delete(t.reqCanceler, r)
|
||||
delete(t.reqCanceler, key)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1110,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
|
|||
// for the request, we don't set the function and return false.
|
||||
// Since CancelRequest will clear the canceler, we can use the return value to detect if
|
||||
// the request was canceled since the last setReqCancel call.
|
||||
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
|
||||
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
|
||||
t.reqMu.Lock()
|
||||
defer t.reqMu.Unlock()
|
||||
_, ok := t.reqCanceler[r]
|
||||
_, ok := t.reqCanceler[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if fn != nil {
|
||||
t.reqCanceler[r] = fn
|
||||
t.reqCanceler[key] = fn
|
||||
} else {
|
||||
delete(t.reqCanceler, r)
|
||||
delete(t.reqCanceler, key)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
@ -1324,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi
|
|||
// set request canceler to some non-nil function so we
|
||||
// can detect whether it was cleared between now and when
|
||||
// we enter roundTrip
|
||||
t.setReqCanceler(req, func(error) {})
|
||||
t.setReqCanceler(treq.cancelKey, func(error) {})
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
cancelc := make(chan error, 1)
|
||||
t.setReqCanceler(req, func(err error) { cancelc <- err })
|
||||
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
|
||||
|
||||
// Queue for permission to dial.
|
||||
t.queueForDial(w)
|
||||
|
|
@ -2078,7 +2087,7 @@ func (pc *persistConn) readLoop() {
|
|||
}
|
||||
|
||||
if !hasBody || bodyWritable {
|
||||
pc.t.setReqCanceler(rc.req, nil)
|
||||
pc.t.setReqCanceler(rc.cancelKey, nil)
|
||||
|
||||
// Put the idle conn back into the pool before we send the response
|
||||
// so if they process it quickly and make another request, they'll
|
||||
|
|
@ -2151,7 +2160,7 @@ func (pc *persistConn) readLoop() {
|
|||
// reading the response body. (or for cancellation or death)
|
||||
select {
|
||||
case bodyEOF := <-waitForBodyRead:
|
||||
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
|
||||
pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
|
||||
alive = alive &&
|
||||
bodyEOF &&
|
||||
!pc.sawEOF &&
|
||||
|
|
@ -2165,7 +2174,7 @@ func (pc *persistConn) readLoop() {
|
|||
pc.t.CancelRequest(rc.req)
|
||||
case <-rc.req.Context().Done():
|
||||
alive = false
|
||||
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
|
||||
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
|
||||
case <-pc.closech:
|
||||
alive = false
|
||||
}
|
||||
|
|
@ -2408,9 +2417,10 @@ type responseAndError struct {
|
|||
}
|
||||
|
||||
type requestAndChan struct {
|
||||
_ incomparable
|
||||
req *Request
|
||||
ch chan responseAndError // unbuffered; always send in select on callerGone
|
||||
_ incomparable
|
||||
req *Request
|
||||
cancelKey cancelKey
|
||||
ch chan responseAndError // unbuffered; always send in select on callerGone
|
||||
|
||||
// whether the Transport (as opposed to the user client code)
|
||||
// added the Accept-Encoding gzip header. If the Transport
|
||||
|
|
@ -2472,7 +2482,7 @@ var (
|
|||
|
||||
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
|
||||
testHookEnterRoundTrip()
|
||||
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
|
||||
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
|
||||
pc.t.putOrCloseIdleConn(pc)
|
||||
return nil, errRequestCanceled
|
||||
}
|
||||
|
|
@ -2524,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
|||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
pc.t.setReqCanceler(req.Request, nil)
|
||||
pc.t.setReqCanceler(req.cancelKey, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -2540,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
|||
resc := make(chan responseAndError)
|
||||
pc.reqch <- requestAndChan{
|
||||
req: req.Request,
|
||||
cancelKey: req.cancelKey,
|
||||
ch: resc,
|
||||
addedGzip: requestedGzip,
|
||||
continueCh: continueCh,
|
||||
|
|
@ -2591,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
|
|||
}
|
||||
return re.res, nil
|
||||
case <-cancelChan:
|
||||
pc.t.CancelRequest(req.Request)
|
||||
pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
|
||||
cancelChan = nil
|
||||
case <-ctxDoneChan:
|
||||
pc.t.cancelRequest(req.Request, req.Context().Err())
|
||||
pc.t.cancelRequest(req.cancelKey, req.Context().Err())
|
||||
cancelChan = nil
|
||||
ctxDoneChan = nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2364,6 +2364,50 @@ func TestTransportCancelRequest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in -short mode")
|
||||
}
|
||||
unblockc := make(chan bool)
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
<-unblockc
|
||||
}))
|
||||
defer ts.Close()
|
||||
defer close(unblockc)
|
||||
|
||||
c := ts.Client()
|
||||
tr := c.Transport.(*Transport)
|
||||
|
||||
donec := make(chan bool)
|
||||
req, _ := NewRequest("GET", ts.URL, body)
|
||||
go func() {
|
||||
defer close(donec)
|
||||
c.Do(req)
|
||||
}()
|
||||
start := time.Now()
|
||||
timeout := 10 * time.Second
|
||||
for time.Since(start) < timeout {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
tr.CancelRequest(req)
|
||||
select {
|
||||
case <-donec:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
t.Errorf("Do of canceled request has not returned after %v", timeout)
|
||||
}
|
||||
|
||||
func TestTransportCancelRequestInDo(t *testing.T) {
|
||||
testTransportCancelRequestInDo(t, nil)
|
||||
}
|
||||
|
||||
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
|
||||
testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
|
||||
}
|
||||
|
||||
func TestTransportCancelRequestInDial(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
if testing.Short() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue