net/http: add MaxBytesHandler

Fixes #39567

Change-Id: I226089b678a6a13d7ce69f360a23fc5bd297d550
GitHub-Last-Rev: 6435fd5881
GitHub-Pull-Request: golang/go#48104
Reviewed-on: https://go-review.googlesource.com/c/go/+/346569
Trust: Damien Neil <dneil@google.com>
Trust: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Carl Johnson 2021-08-31 20:35:35 +00:00 committed by Damien Neil
parent 36dbf7f7e6
commit 55e6e825d4
2 changed files with 69 additions and 0 deletions

View File

@ -6682,3 +6682,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
}
}
}
func TestMaxBytesHandler(t *testing.T) {
setParallel(t)
defer afterTest(t)
for _, maxSize := range []int64{100, 1_000, 1_000_000} {
for _, requestSize := range []int64{100, 1_000, 1_000_000} {
t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
func(t *testing.T) {
testMaxBytesHandler(t, maxSize, requestSize)
})
}
}
}
func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
var (
handlerN int64
handlerErr error
)
echo := HandlerFunc(func(w ResponseWriter, r *Request) {
var buf bytes.Buffer
handlerN, handlerErr = io.Copy(&buf, r.Body)
io.Copy(w, &buf)
})
ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
defer ts.Close()
c := ts.Client()
var buf strings.Builder
body := strings.NewReader(strings.Repeat("a", int(requestSize)))
res, err := c.Post(ts.URL, "text/plain", body)
if err != nil {
t.Errorf("unexpected connection error: %v", err)
} else {
_, err = io.Copy(&buf, res.Body)
res.Body.Close()
if err != nil {
t.Errorf("unexpected read error: %v", err)
}
}
if handlerN > maxSize {
t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
}
if requestSize > maxSize && handlerErr == nil {
t.Error("expected error on handler side; got nil")
}
if requestSize <= maxSize {
if handlerErr != nil {
t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
}
if handlerN != requestSize {
t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
}
}
if buf.Len() != int(handlerN) {
t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
}
}

View File

@ -3610,3 +3610,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
}
return false
}
// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
func MaxBytesHandler(h Handler, n int64) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
r2 := *r
r2.Body = MaxBytesReader(w, r.Body, n)
h.ServeHTTP(w, &r2)
})
}