diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index dd514386ac..e0e2856477 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -323,6 +323,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { cacheKey: cm.String(), conn: conn, reqch: make(chan requestAndChan, 50), + writech: make(chan writeRequest, 50), } switch { @@ -380,6 +381,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn.br = bufio.NewReader(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn) go pconn.readLoop() + go pconn.writeLoop() return pconn, nil } @@ -487,7 +489,8 @@ type persistConn struct { closed bool // whether conn has been closed br *bufio.Reader // from conn bw *bufio.Writer // to conn - reqch chan requestAndChan // written by roundTrip(); read by readLoop() + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop isProxy bool // mutateHeaderFunc is an optional func to modify extra @@ -519,6 +522,7 @@ func remoteSideClosed(err error) bool { } func (pc *persistConn) readLoop() { + defer close(pc.writech) alive := true var lastbody io.ReadCloser // last response body, if any, read on this connection @@ -579,7 +583,7 @@ func (pc *persistConn) readLoop() { if alive && !pc.t.putIdleConn(pc) { alive = false } - if !alive { + if !alive || pc.isBroken() { pc.close() } waitForBodyRead <- true @@ -615,6 +619,23 @@ func (pc *persistConn) readLoop() { } } +func (pc *persistConn) writeLoop() { + for wr := range pc.writech { + if pc.isBroken() { + wr.ch <- errors.New("http: can't write HTTP request on broken connection") + continue + } + err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra) + if err == nil { + err = pc.bw.Flush() + } + if err != nil { + pc.markBroken() + } + wr.ch <- err + } +} + type responseAndError struct { res *Response err error @@ -630,6 +651,15 @@ type requestAndChan struct { addedGzip bool } +// A writeRequest is sent by the readLoop's goroutine to the +// writeLoop's goroutine to write a request while the read loop +// concurrently waits on both the write response and the server's +// reply. +type writeRequest struct { + req *transportRequest + ch chan<- error +} + func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { if pc.mutateHeaderFunc != nil { pc.mutateHeaderFunc(req.extraHeaders()) @@ -652,16 +682,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err pc.numExpectedResponses++ pc.lk.Unlock() - err = req.Request.write(pc.bw, pc.isProxy, req.extra) - if err != nil { - pc.close() - return - } - pc.bw.Flush() + // Write the request concurrently with waiting for a response, + // in case the server decides to reply before reading our full + // request body. + writeErrCh := make(chan error, 1) + pc.writech <- writeRequest{req, writeErrCh} + + resc := make(chan responseAndError, 1) + pc.reqch <- requestAndChan{req.Request, resc, requestedGzip} + + var re responseAndError +WaitResponse: + for { + select { + case err := <-writeErrCh: + if err != nil { + re = responseAndError{nil, err} + break WaitResponse + } + case re = <-resc: + break WaitResponse + } + } - ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} - re := <-ch pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() @@ -669,6 +712,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err return re.res, re.err } +// markBroken marks a connection as broken (so it's not reused). +// It differs from close in that it doesn't close the underlying +// connection for use when it's still being read. +func (pc *persistConn) markBroken() { + pc.lk.Lock() + defer pc.lk.Unlock() + pc.broken = true +} + func (pc *persistConn) close() { pc.lk.Lock() defer pc.lk.Unlock() diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 1312eb8988..c377eff5d1 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -833,6 +833,30 @@ func TestIssue3644(t *testing.T) { } } +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { + const deniedMsg = "sorry, denied." + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) {