diff --git a/src/net/http/transport.go b/src/net/http/transport.go index ed7c2a52c2..c980e727a6 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -2398,8 +2398,6 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr trace.GotFirstResponseByte() } } - num1xx := 0 // number of informational 1xx headers received - const max1xxResponses = 5 // arbitrary bound on number of informational responses continueCh := rc.continueCh for { @@ -2419,15 +2417,18 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr // treat 101 as a terminal status, see issue 26161 is1xxNonTerminal := is1xx && resCode != StatusSwitchingProtocols if is1xxNonTerminal { - num1xx++ - if num1xx > max1xxResponses { - return nil, errors.New("net/http: too many 1xx informational responses") - } - pc.readLimit = pc.maxHeaderResponseSize() // reset the limit if trace != nil && trace.Got1xxResponse != nil { if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil { return nil, err } + // If the 1xx response was delivered to the user, + // then they're responsible for limiting the number of + // responses. Reset the header limit. + // + // If the user didn't examine the 1xx response, then we + // limit the size of all headers (including both 1xx + // and the final response) to maxHeaderResponseSize. + pc.readLimit = pc.maxHeaderResponseSize() // reset the limit } continue } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index b76b8dfcff..30a7e5eabd 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -3258,29 +3258,67 @@ func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { } } -func TestTransportLimits1xxResponses(t *testing.T) { - run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) -} +func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) } func testTransportLimits1xxResponses(t *testing.T, mode testMode) { cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - conn, buf, _ := w.(Hijacker).Hijack() + w.Header().Add("X-Header", strings.Repeat("a", 100)) for i := 0; i < 10; i++ { - buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) + w.WriteHeader(123) } - buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) - buf.Flush() - conn.Close() + w.WriteHeader(204) })) cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + cst.tr.MaxResponseHeaderBytes = 1000 res, err := cst.c.Get(cst.ts.URL) - if res != nil { - defer res.Body.Close() + if err == nil { + res.Body.Close() + t.Fatalf("RoundTrip succeeded; want error") } - got := fmt.Sprint(err) - wantSub := "too many 1xx informational responses" - if !strings.Contains(got, wantSub) { - t.Errorf("Get error = %v; want substring %q", err, wantSub) + for _, want := range []string{ + "response headers exceeded", + "too many 1xx", + } { + if strings.Contains(err.Error(), want) { + return + } + } + t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err) +} + +func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) { + run(t, testTransportDoesNotLimitDelivered1xxResponses) +} +func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skip until x/net/http2 updated") + } + const num1xx = 10 + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Add("X-Header", strings.Repeat("a", 100)) + for i := 0; i < 10; i++ { + w.WriteHeader(123) + } + w.WriteHeader(204) + })) + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + cst.tr.MaxResponseHeaderBytes = 1000 + + got1xx := 0 + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + got1xx++ + return nil + }, + }) + req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if got1xx != num1xx { + t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx) } }