This commit is contained in:
Eric Gusmão 2025-06-20 15:31:56 -04:00 committed by GitHub
commit 28ed628748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 2 deletions

View File

@ -9,6 +9,7 @@ package http
import (
"bytes"
"internal/testenv"
"io"
"io/fs"
"net/url"
"os"
@ -189,6 +190,45 @@ func TestNoUnicodeStrings(t *testing.T) {
}
}
type requestTooLargerResponseWriter struct {
called bool
}
func (rw *requestTooLargerResponseWriter) Header() Header { return Header{} }
func (rw *requestTooLargerResponseWriter) Write(b []byte) (int, error) { return len(b), nil }
func (rw *requestTooLargerResponseWriter) WriteHeader(statusCode int) {}
func (rw *requestTooLargerResponseWriter) requestTooLarge() {
rw.called = true
}
type wrapper struct {
ResponseWriter
}
func (w *wrapper) Unwrap() ResponseWriter {
return w.ResponseWriter
}
func TestMaxBytesReaderUnwrapTriggersRequestTooLarge(t *testing.T) {
body := strings.NewReader("123456")
limit := int64(5)
innerRw := &requestTooLargerResponseWriter{}
wrappedRw := &wrapper{ResponseWriter: innerRw}
l := MaxBytesReader(wrappedRw, io.NopCloser(body), limit)
buf := make([]byte, 10)
_, err := l.Read(buf)
if _, ok := err.(*MaxBytesError); !ok {
t.Errorf("expected MaxBytesError, got %T", err)
}
if !innerRw.called {
t.Errorf("expected requestTooLarge to be called, but it wasn't")
}
}
func TestProtocols(t *testing.T) {
var p Protocols
if p.HTTP1() {

View File

@ -1243,9 +1243,23 @@ func (l *maxBytesReader) Read(p []byte) (n int, err error) {
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
// Unwrap the ResponseWriter wrappers until we find one that implements
// the server-only requestTooLarger interface, then call requestTooLarge().
// This ensures that even if the ResponseWriter is wrapped by a custom implementation,
// the underlying server writer can be notified when the request body is too large.
rw := l.w
for {
if res, ok := rw.(requestTooLarger); ok {
res.requestTooLarge()
break
}
unwrapper, ok := rw.(rwUnwrapper)
if !ok {
break
}
rw = unwrapper.Unwrap()
}
l.err = &MaxBytesError{l.i}
return n, l.err
}