diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 7c51af1867..4e3948dd91 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -18,6 +18,9 @@ type ResponseRecorder struct { Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + stagingMap http.Header // map that handlers manipulate to set headers + trailerMap http.Header // lazily filled when Trailers() is called + wroteHeader bool } @@ -36,10 +39,10 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - m := rw.HeaderMap + m := rw.stagingMap if m == nil { m = make(http.Header) - rw.HeaderMap = m + rw.stagingMap = m } return m } @@ -59,16 +62,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { str = str[:512] } - _, hasType := rw.HeaderMap["Content-Type"] - hasTE := rw.HeaderMap.Get("Transfer-Encoding") != "" + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" if !hasType && !hasTE { if b == nil { b = []byte(str) } - if rw.HeaderMap == nil { - rw.HeaderMap = make(http.Header) - } - rw.HeaderMap.Set("Content-Type", http.DetectContentType(b)) + m.Set("Content-Type", http.DetectContentType(b)) } rw.WriteHeader(200) @@ -92,11 +94,21 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { return len(str), nil } -// WriteHeader sets rw.Code. +// WriteHeader sets rw.Code. After it is called, changing rw.Header +// will not affect rw.HeaderMap. func (rw *ResponseRecorder) WriteHeader(code int) { - if !rw.wroteHeader { - rw.Code = code - rw.wroteHeader = true + if rw.wroteHeader { + return + } + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + for k, vv := range rw.stagingMap { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + rw.HeaderMap[k] = vv2 } } @@ -107,3 +119,33 @@ func (rw *ResponseRecorder) Flush() { } rw.Flushed = true } + +// Trailers returns any trailers set by the handler. It must be called +// after the handler finished running. +func (rw *ResponseRecorder) Trailers() http.Header { + if rw.trailerMap != nil { + return rw.trailerMap + } + trailers, ok := rw.HeaderMap["Trailer"] + if !ok { + rw.trailerMap = make(http.Header) + return rw.trailerMap + } + rw.trailerMap = make(http.Header, len(trailers)) + for _, k := range trailers { + switch k { + case "Transfer-Encoding", "Content-Length", "Trailer": + // Ignore since forbidden by RFC 2616 14.40. + continue + } + k = http.CanonicalHeaderKey(k) + vv, ok := rw.stagingMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + rw.trailerMap[k] = vv2 + } + return rw.trailerMap +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index c29b6d4cf9..19a37b6c54 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -47,6 +47,37 @@ func TestRecorder(t *testing.T) { return nil } } + hasNotHeaders := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + for _, k := range keys { + _, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected header %s", k) + } + } + return nil + } + } + hasTrailer := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Trailers().Get(key); got != want { + return fmt.Errorf("trailer %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotTrailers := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + trailers := rec.Trailers() + for _, k := range keys { + _, ok := trailers[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected trailer %s", k) + } + } + return nil + } + } tests := []struct { name string @@ -130,6 +161,39 @@ func TestRecorder(t *testing.T) { }, check(hasHeader("Content-Type", "text/html; charset=utf-8")), }, + { + "Header is not changed after write", + func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("Key", "correct") + w.WriteHeader(200) + hdr.Set("Key", "incorrect") + }, + check(hasHeader("Key", "correct")), + }, + { + "Trailer headers are correctly recorded", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Non-Trailer", "correct") + w.Header().Set("Trailer", "Trailer-A") + w.Header().Add("Trailer", "Trailer-B") + w.Header().Add("Trailer", "Trailer-C") + io.WriteString(w, "") + w.Header().Set("Non-Trailer", "incorrect") + w.Header().Set("Trailer-A", "valuea") + w.Header().Set("Trailer-C", "valuec") + w.Header().Set("Trailer-NotDeclared", "should be omitted") + }, + check( + hasStatus(200), + hasHeader("Content-Type", "text/html; charset=utf-8"), + hasHeader("Non-Trailer", "correct"), + hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), + hasTrailer("Trailer-A", "valuea"), + hasTrailer("Trailer-C", "valuec"), + hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + ), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests {