diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 5044306a87..f73172189e 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "net/url" "reflect" + "regexp" "strings" "testing" ) @@ -406,8 +407,7 @@ func TestWriteResponse(t *testing.T) { t.Errorf("#%d: %v", i, err) continue } - bout := bytes.NewBuffer(nil) - err = resp.Write(bout) + err = resp.Write(ioutil.Discard) if err != nil { t.Errorf("#%d: %v", i, err) continue @@ -506,6 +506,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) { rest, err := ioutil.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { + g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { + return fmt.Sprintf("x(repeated x%d)", len(match)) + }) fatalf("remainder = %q, expected %q", g, e) } } diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index e4d9b340be..8f382fa6ea 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -2090,6 +2090,64 @@ func TestNoContentTypeOnNotModified(t *testing.T) { } } +// Issue 6995 +// A server Handler can receive a Request, and then turn around and +// give a copy of that Request.Body out to the Transport (e.g. any +// proxy). So then two people own that Request.Body (both the server +// and the http client), and both think they can close it on failure. +// Therefore, all incoming server requests Bodies need to be thread-safe. +func TestTransportAndServerSharedBodyRace(t *testing.T) { + defer afterTest(t) + + const bodySize = 1 << 20 + + unblockBackend := make(chan bool) + backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + io.CopyN(rw, req.Body, bodySize/2) + <-unblockBackend + })) + defer backend.Close() + + backendRespc := make(chan *Response, 1) + proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.RequestURI == "/foo" { + rw.Write([]byte("bar")) + return + } + req2, _ := NewRequest("POST", backend.URL, req.Body) + req2.ContentLength = bodySize + + bresp, err := DefaultClient.Do(req2) + if err != nil { + t.Errorf("Proxy outbound request: %v", err) + return + } + _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4) + if err != nil { + t.Errorf("Proxy copy error: %v", err) + return + } + backendRespc <- bresp // to close later + + // Try to cause a race: Both the DefaultTransport and the proxy handler's Server + // will try to read/close req.Body (aka req2.Body) + DefaultTransport.(*Transport).CancelRequest(req2) + rw.Write([]byte("OK")) + })) + defer proxy.Close() + + req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize)) + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatalf("Original request: %v", err) + } + + // Cleanup, so we don't leak goroutines. + res.Body.Close() + close(unblockBackend) + (<-backendRespc).Body.Close() +} + func TestResponseWriterWriteStringAllocs(t *testing.T) { ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/s" { diff --git a/src/pkg/net/http/transfer.go b/src/pkg/net/http/transfer.go index bacd83732d..4a2bda19fa 100644 --- a/src/pkg/net/http/transfer.go +++ b/src/pkg/net/http/transfer.go @@ -14,6 +14,7 @@ import ( "net/textproto" "strconv" "strings" + "sync" ) // transferWriter inspects the fields of a user-supplied Request or Response, @@ -331,17 +332,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { if noBodyExpected(t.RequestMethod) { t.Body = eofReader } else { - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: t.Body = eofReader case realLength > 0: - t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} + t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header if t.Close { // Close semantics (i.e. HTTP/1.0) - t.Body = &body{Reader: r, closing: t.Close} + t.Body = &body{src: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) t.Body = eofReader @@ -514,11 +515,13 @@ func fixTrailer(header Header, te []string) (Header, error) { // Close ensures that the body has been fully read // and then reads the trailer if necessary. type body struct { - io.Reader + src io.Reader hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? - closed bool + + mu sync.Mutex // guards closed, and calls to Read and Close + closed bool } // ErrBodyReadAfterClose is returned when reading a Request or Response @@ -528,10 +531,17 @@ type body struct { var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body") func (b *body) Read(p []byte) (n int, err error) { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return 0, ErrBodyReadAfterClose } - n, err = b.Reader.Read(p) + return b.readLocked(p) +} + +// Must hold b.mu. +func (b *body) readLocked(p []byte) (n int, err error) { + n, err = b.src.Read(p) if err == io.EOF { // Chunked case. Read the trailer. @@ -543,7 +553,7 @@ func (b *body) Read(p []byte) (n int, err error) { } else { // If the server declared the Content-Length, our body is a LimitedReader // and we need to check whether this EOF arrived early. - if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 { + if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 { err = io.ErrUnexpectedEOF } } @@ -618,6 +628,8 @@ func (b *body) readTrailer() error { } func (b *body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() if b.closed { return nil } @@ -629,12 +641,25 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, b) + _, err = io.Copy(ioutil.Discard, bodyLocked{b}) } b.closed = true return err } +// bodyLocked is a io.Reader reading from a *body when its mutex is +// already held. +type bodyLocked struct { + b *body +} + +func (bl bodyLocked) Read(p []byte) (n int, err error) { + if bl.b.closed { + return 0, ErrBodyReadAfterClose + } + return bl.b.readLocked(p) +} + // parseContentLength trims whitespace from s and returns -1 if no value // is set, or the value if it's >= 0. func parseContentLength(cl string) (int64, error) { diff --git a/src/pkg/net/http/transfer_test.go b/src/pkg/net/http/transfer_test.go index 8627a374c8..fb5ef37a0f 100644 --- a/src/pkg/net/http/transfer_test.go +++ b/src/pkg/net/http/transfer_test.go @@ -12,9 +12,9 @@ import ( func TestBodyReadBadTrailer(t *testing.T) { b := &body{ - Reader: strings.NewReader("foobar"), - hdr: true, // force reading the trailer - r: bufio.NewReader(strings.NewReader("")), + src: strings.NewReader("foobar"), + hdr: true, // force reading the trailer + r: bufio.NewReader(strings.NewReader("")), } buf := make([]byte, 7) n, err := b.Read(buf[:3])