diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go index 6d5ea05c32..9019afb61d 100644 --- a/src/net/http/transfer.go +++ b/src/net/http/transfer.go @@ -335,7 +335,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { var ncopy int64 // Write body. We "unwrap" the body first if it was wrapped in a - // nopCloser. This is to ensure that we can take advantage of + // nopCloser or readTrackingBody. This is to ensure that we can take advantage of // OS-level optimizations in the event that the body is an // *os.File. if t.Body != nil { @@ -413,7 +413,10 @@ func (t *transferWriter) unwrapBody() io.Reader { if reflect.TypeOf(t.Body) == nopCloserType { return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) } - + if r, ok := t.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } return t.Body } @@ -1075,6 +1078,9 @@ func isKnownInMemoryReader(r io.Reader) bool { if reflect.TypeOf(r) == nopCloserType { return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader)) } + if r, ok := r.(*readTrackingBody); ok { + return isKnownInMemoryReader(r.ReadCloser) + } return false } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index b1705d5439..da86b26106 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -511,10 +511,17 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } } + req = setupRewindBody(req) + if altRT := t.alternateRoundTripper(req); altRT != nil { if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { return resp, err } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } } if !isHTTP { req.closeBody() @@ -584,18 +591,59 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { testHookRoundTripRetried() // Rewind the body if we're able to. - if req.GetBody != nil { - newReq := *req - var err error - newReq.Body, err = req.GetBody() - if err != nil { - return nil, err - } - req = &newReq + req, err = rewindBody(req) + if err != nil { + return nil, err } } } +var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") + +type readTrackingBody struct { + io.ReadCloser + didRead bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been read at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead { + return req, nil // nothing to rewind + } + req.closeBody() + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil +} + // shouldRetryRequest reports whether we should retry sending a failed // HTTP request on a new connection. The non-nil input error is the // error from roundTrip. diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index f4014d95bb..5ccb3d14ab 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -6196,3 +6196,29 @@ func (timeoutProto) RoundTrip(req *Request) (*Response, error) { return nil, errors.New("request was not canceled") } } + +type roundTripFunc func(r *Request) (*Response, error) + +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } + +// Issue 32441: body is not reset after ErrSkipAltProtocol +func TestIssue32441(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + t.Error("body length is zero") + } + })) + defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { + // Draining body to trigger failure condition on actual request to server. + if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + t.Error("body length is zero during round trip") + } + return nil, ErrSkipAltProtocol + })) + if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { + t.Error(err) + } +}