diff --git a/src/pkg/net/http/response_test.go b/src/pkg/net/http/response_test.go index 2f5f77369f..49836ce2cc 100644 --- a/src/pkg/net/http/response_test.go +++ b/src/pkg/net/http/response_test.go @@ -466,7 +466,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { if test.compressed { gzReader, err := gzip.NewReader(resp.Body) checkErr(err, "gzip.NewReader") - resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } rbuf := make([]byte, 2500) diff --git a/src/pkg/net/http/sniff_test.go b/src/pkg/net/http/sniff_test.go index 8ab72ac23f..09665901dc 100644 --- a/src/pkg/net/http/sniff_test.go +++ b/src/pkg/net/http/sniff_test.go @@ -54,6 +54,7 @@ func TestDetectContentType(t *testing.T) { } func TestServerContentType(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] @@ -84,6 +85,8 @@ func TestServerContentType(t *testing.T) { } func TestContentTypeWithCopy(t *testing.T) { + defer checkLeakedTransports(t) + const ( input = "\n\n\t\n" expected = "text/html; charset=utf-8" @@ -116,6 +119,7 @@ func TestContentTypeWithCopy(t *testing.T) { } func TestSniffWriteSize(t *testing.T) { + defer checkLeakedTransports(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) @@ -133,6 +137,11 @@ func TestSniffWriteSize(t *testing.T) { if err != nil { t.Fatalf("size %d: %v", size, err) } - res.Body.Close() + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("size %d: io.Copy of body = %v", size, err) + } + if err := res.Body.Close(); err != nil { + t.Fatalf("size %d: body Close = %v", size, err) + } } } diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 685d7d56c4..7bf08b8ae4 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -17,7 +17,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/url" @@ -592,7 +591,6 @@ func remoteSideClosed(err error) bool { func (pc *persistConn) readLoop() { defer close(pc.closech) alive := true - var lastbody io.ReadCloser // last response body, if any, read on this connection for alive { pb, err := pc.br.Peek(1) @@ -611,13 +609,6 @@ func (pc *persistConn) readLoop() { rc := <-pc.reqch - // Advance past the previous response's body, if the - // caller hasn't done so. - if lastbody != nil { - lastbody.Close() // assumed idempotent - lastbody = nil - } - var resp *Response if err == nil { resp, err = ReadResponse(pc.br, rc.req) @@ -636,7 +627,7 @@ func (pc *persistConn) readLoop() { pc.close() err = zerr } else { - resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } } resp.Body = &bodyEOFSignal{body: resp.Body} @@ -648,8 +639,14 @@ func (pc *persistConn) readLoop() { var waitForBodyRead chan bool if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool, 1) + waitForBodyRead = make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + // Sending false here sets alive to + // false and closes the connection + // below. + waitForBodyRead <- false + return nil + } resp.Body.(*bodyEOFSignal).fn = func(err error) { alive1 := alive if err != nil { @@ -666,15 +663,6 @@ func (pc *persistConn) readLoop() { } if alive && !hasBody { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - if !pc.t.putIdleConn(pc) { alive = false } @@ -868,13 +856,16 @@ func canonicalAddr(url *url.URL) string { // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before its final (error-producing) Read or Close call -// returns. +// returns. If earlyCloseFn is non-nil and Close is called before +// io.EOF is seen, earlyCloseFn is called instead of fn, and its +// return value is the return value from Close. type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards closed, rerr and fn - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) // error will be nil on Read io.EOF + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { @@ -907,6 +898,9 @@ func (es *bodyEOFSignal) Close() error { return nil } es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } err := es.body.Close() es.condfn(err) return err @@ -924,28 +918,7 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } -type readFirstCloseBoth struct { - io.ReadCloser +type readerAndCloser struct { + io.Reader io.Closer } - -func (r *readFirstCloseBoth) Close() error { - if err := r.ReadCloser.Close(); err != nil { - r.Closer.Close() - return err - } - if err := r.Closer.Close(); err != nil { - return err - } - return nil -} - -// discardOnCloseReadCloser consumes all its input on Close. -type discardOnCloseReadCloser struct { - io.ReadCloser -} - -func (d *discardOnCloseReadCloser) Close() error { - io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed - return d.ReadCloser.Close() -} diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 68010e68b3..feaa53d7a5 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -1214,6 +1214,70 @@ func TestTransportCancelRequest(t *testing.T) { } } +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer checkLeakedTransports(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Errorf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) {