mirror of https://github.com/golang/go.git
net/http/httptest: mimic the normal HTTP server's ResponseWriter more closely
Also adds tests, which didn't exist before. Fixes #4188 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/6613062
This commit is contained in:
parent
d9953c9dde
commit
13576e3b65
|
|
@ -17,6 +17,8 @@ type ResponseRecorder struct {
|
||||||
HeaderMap http.Header // the HTTP response headers
|
HeaderMap http.Header // the HTTP response headers
|
||||||
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
||||||
Flushed bool
|
Flushed bool
|
||||||
|
|
||||||
|
wroteHeader bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns an initialized ResponseRecorder.
|
// NewRecorder returns an initialized ResponseRecorder.
|
||||||
|
|
@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder {
|
||||||
return &ResponseRecorder{
|
return &ResponseRecorder{
|
||||||
HeaderMap: make(http.Header),
|
HeaderMap: make(http.Header),
|
||||||
Body: new(bytes.Buffer),
|
Body: new(bytes.Buffer),
|
||||||
|
Code: 200,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4"
|
||||||
|
|
||||||
// Header returns the response headers.
|
// Header returns the response headers.
|
||||||
func (rw *ResponseRecorder) Header() http.Header {
|
func (rw *ResponseRecorder) Header() http.Header {
|
||||||
return rw.HeaderMap
|
m := rw.HeaderMap
|
||||||
|
if m == nil {
|
||||||
|
m = make(http.Header)
|
||||||
|
rw.HeaderMap = m
|
||||||
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write always succeeds and writes to rw.Body, if not nil.
|
// Write always succeeds and writes to rw.Body, if not nil.
|
||||||
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
|
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
}
|
||||||
if rw.Body != nil {
|
if rw.Body != nil {
|
||||||
rw.Body.Write(buf)
|
rw.Body.Write(buf)
|
||||||
}
|
}
|
||||||
if rw.Code == 0 {
|
|
||||||
rw.Code = http.StatusOK
|
|
||||||
}
|
|
||||||
return len(buf), nil
|
return len(buf), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteHeader sets rw.Code.
|
// WriteHeader sets rw.Code.
|
||||||
func (rw *ResponseRecorder) WriteHeader(code int) {
|
func (rw *ResponseRecorder) WriteHeader(code int) {
|
||||||
|
if !rw.wroteHeader {
|
||||||
rw.Code = code
|
rw.Code = code
|
||||||
}
|
}
|
||||||
|
rw.wroteHeader = true
|
||||||
|
}
|
||||||
|
|
||||||
// Flush sets rw.Flushed to true.
|
// Flush sets rw.Flushed to true.
|
||||||
func (rw *ResponseRecorder) Flush() {
|
func (rw *ResponseRecorder) Flush() {
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.WriteHeader(200)
|
||||||
|
}
|
||||||
rw.Flushed = true
|
rw.Flushed = true
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
// Copyright 2012 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package httptest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecorder(t *testing.T) {
|
||||||
|
type checkFunc func(*ResponseRecorder) error
|
||||||
|
check := func(fns ...checkFunc) []checkFunc { return fns }
|
||||||
|
|
||||||
|
hasStatus := func(wantCode int) checkFunc {
|
||||||
|
return func(rec *ResponseRecorder) error {
|
||||||
|
if rec.Code != wantCode {
|
||||||
|
return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hasContents := func(want string) checkFunc {
|
||||||
|
return func(rec *ResponseRecorder) error {
|
||||||
|
if rec.Body.String() != want {
|
||||||
|
return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hasFlush := func(want bool) checkFunc {
|
||||||
|
return func(rec *ResponseRecorder) error {
|
||||||
|
if rec.Flushed != want {
|
||||||
|
return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
h func(w http.ResponseWriter, r *http.Request)
|
||||||
|
checks []checkFunc
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"200 default",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
check(hasStatus(200), hasContents("")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"first code only",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(201)
|
||||||
|
w.WriteHeader(202)
|
||||||
|
w.Write([]byte("hi"))
|
||||||
|
},
|
||||||
|
check(hasStatus(201), hasContents("hi")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"write sends 200",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("hi first"))
|
||||||
|
w.WriteHeader(201)
|
||||||
|
w.WriteHeader(202)
|
||||||
|
},
|
||||||
|
check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"flush",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.(http.Flusher).Flush() // also sends a 200
|
||||||
|
w.WriteHeader(201)
|
||||||
|
},
|
||||||
|
check(hasStatus(200), hasFlush(true)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
|
||||||
|
for _, tt := range tests {
|
||||||
|
h := http.HandlerFunc(tt.h)
|
||||||
|
rec := NewRecorder()
|
||||||
|
h.ServeHTTP(rec, r)
|
||||||
|
for _, check := range tt.checks {
|
||||||
|
if err := check(rec); err != nil {
|
||||||
|
t.Errorf("%s: %v", tt.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue