diff --git a/src/net/http/request.go b/src/net/http/request.go index aca55b1ca7..ff21f19942 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -1124,6 +1124,9 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro // MaxBytesReader prevents clients from accidentally or maliciously // sending a large request and wasting server resources. func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + if n < 0 { // Treat negative limits as equivalent to 0. + n = 0 + } return &maxBytesReader{w: w, r: r, n: n} } diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 29297b0e7b..f09c63ed7e 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -850,6 +850,92 @@ func TestMaxBytesReaderStickyError(t *testing.T) { } } +// Issue 45101: maxBytesReader's Read panicked when n < -1. This test +// also ensures that Read treats negative limits as equivalent to 0. +func TestMaxBytesReaderDifferentLimits(t *testing.T) { + const testStr = "1234" + tests := [...]struct { + limit int64 + lenP int + wantN int + wantErr bool + }{ + 0: { + limit: -123, + lenP: 0, + wantN: 0, + wantErr: false, // Ensure we won't return an error when the limit is negative, but we don't need to read. + }, + 1: { + limit: -100, + lenP: 32 * 1024, + wantN: 0, + wantErr: true, + }, + 2: { + limit: -2, + lenP: 1, + wantN: 0, + wantErr: true, + }, + 3: { + limit: -1, + lenP: 2, + wantN: 0, + wantErr: true, + }, + 4: { + limit: 0, + lenP: 3, + wantN: 0, + wantErr: true, + }, + 5: { + limit: 1, + lenP: 4, + wantN: 1, + wantErr: true, + }, + 6: { + limit: 2, + lenP: 5, + wantN: 2, + wantErr: true, + }, + 7: { + limit: 3, + lenP: 2, + wantN: 2, + wantErr: false, + }, + 8: { + limit: int64(len(testStr)), + lenP: len(testStr), + wantN: len(testStr), + wantErr: false, + }, + 9: { + limit: 100, + lenP: 6, + wantN: len(testStr), + wantErr: false, + }, + } + for i, tt := range tests { + rc := MaxBytesReader(nil, io.NopCloser(strings.NewReader(testStr)), tt.limit) + + n, err := rc.Read(make([]byte, tt.lenP)) + + if n != tt.wantN { + t.Errorf("%d. n: %d, want n: %d", i, n, tt.wantN) + } + + if (err != nil) != tt.wantErr { + t.Errorf("%d. error: %v", i, err) + } + } +} + func TestWithContextDeepCopiesURL(t *testing.T) { req, err := NewRequest("POST", "https://golang.org/", nil) if err != nil {