diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index a0d9d9e205..c187b1cd07 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -632,22 +632,20 @@ func Test304Responses(t *testing.T) { } } -// TestHeadResponses verifies that responses to HEAD requests don't -// declare that they're chunking in their response headers, aren't -// allowed to produce output, and don't set a Content-Type since -// the real type of the body data cannot be inferred. +// TestHeadResponses verifies that all MIME type sniffing and Content-Length +// counting of GET requests also happens on HEAD requests. func TestHeadResponses(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := w.Write([]byte("Ignored body")) - if err != ErrBodyNotAllowed { - t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) + _, err := w.Write([]byte("")) + if err != nil { + t.Errorf("ResponseWriter.Write: %v", err) } // Also exercise the ReaderFrom path - _, err = io.Copy(w, strings.NewReader("Ignored body")) - if err != ErrBodyNotAllowed { - t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err) + _, err = io.Copy(w, strings.NewReader("789a")) + if err != nil { + t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) defer ts.Close() @@ -658,9 +656,11 @@ func TestHeadResponses(t *testing.T) { if len(res.TransferEncoding) > 0 { t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) } - ct := res.Header.Get("Content-Type") - if ct != "" { - t.Errorf("expected no Content-Type; got %s", ct) + if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" { + t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct) + } + if v := res.ContentLength; v != 10 { + t.Errorf("Content-Length: %d; want 10", v) } body, err := ioutil.ReadAll(res.Body) if err != nil { diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index 4e8f6dce2e..5b93a61125 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -246,6 +246,10 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if !cw.wroteHeader { cw.writeHeader(p) } + if cw.res.req.Method == "HEAD" { + // Eat writes. + return len(p), nil + } if cw.chunking { _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) if err != nil { @@ -704,6 +708,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res + isHEAD := w.req.Method == "HEAD" // header is written out to w.conn.buf below. Depending on the // state of the handler, we either own the map or not. If we @@ -735,7 +740,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { + if w.handlerDone && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } @@ -752,7 +757,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // Check for a explicit (and valid) Content-Length header. hasCL := w.contentLength != -1 - if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) { _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { setHeader.connection = "keep-alive" @@ -793,7 +798,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } else { // If no content type, apply sniffing algorithm to body. _, haveType := header["Content-Type"] - if !haveType && w.req.Method != "HEAD" { + if !haveType { setHeader.contentType = DetectContentType(p) } } @@ -905,7 +910,7 @@ func (w *response) bodyAllowed() bool { if !w.wroteHeader { panic("") } - return w.status != StatusNotModified && w.req.Method != "HEAD" + return w.status != StatusNotModified } // The Life Of A Write is like this: @@ -983,7 +988,7 @@ func (w *response) finishRequest() { w.req.MultipartForm.RemoveAll() } - if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { + if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. w.closeAfterReply = true } diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index 48a8c441f7..df01a65667 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -470,6 +470,7 @@ func TestTransportHeadResponses(t *testing.T) { res, err := c.Head(ts.URL) if err != nil { t.Errorf("error on loop %d: %v", i, err) + continue } if e, g := "123", res.Header.Get("Content-Length"); e != g { t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) @@ -477,6 +478,11 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } + if all, err := ioutil.ReadAll(res.Body); err != nil { + t.Errorf("loop %d: Body ReadAll: %v", i, err) + } else if len(all) != 0 { + t.Errorf("Bogus body %q", all) + } } }