diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 44d5515705..bbac2bf448 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -513,6 +513,22 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper { return altProto[req.URL.Scheme] } +func validateHeaders(hdrs Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("field value for %q", k) + } + } + } + return "" +} + // roundTrip implements a RoundTripper over HTTP. func (t *Transport) roundTrip(req *Request) (*Response, error) { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) @@ -530,18 +546,16 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { scheme := req.URL.Scheme isHTTP := scheme == "http" || scheme == "https" if isHTTP { - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - req.closeBody() - return nil, fmt.Errorf("net/http: invalid header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - req.closeBody() - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("net/http: invalid header field value for %q", k) - } - } + // Validate the outgoing headers. + if err := validateHeaders(req.Header); err != "" { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid header %s", err) + } + + // Validate the outgoing trailers too. + if err := validateHeaders(req.Trailer); err != "" { + req.closeBody() + return nil, fmt.Errorf("net/http: invalid trailer %s", err) } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index d3f43cfd9a..204133f130 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -7031,3 +7031,42 @@ func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) { return true }) } + +func TestValidateClientRequestTrailers(t *testing.T) { + run(t, testValidateClientRequestTrailers) +} + +func testValidateClientRequestTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + rw.Write([]byte("Hello")) + })).ts + + cases := []struct { + trailer Header + wantErr string + }{ + {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`}, + {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`}, + } + + for i, tt := range cases { + testName := fmt.Sprintf("%s%d", mode, i) + t.Run(testName, func(t *testing.T) { + req, err := NewRequest("GET", cst.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Trailer = tt.trailer + res, err := cst.Client().Do(req) + if err == nil { + t.Fatal("Expected an error") + } + if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { + t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w) + } + if res != nil { + t.Fatal("Unexpected non-nil response") + } + }) + } +}