diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go index 3967ad614f..2e62c00d5d 100644 --- a/src/net/http/internal/chunked.go +++ b/src/net/http/internal/chunked.go @@ -220,8 +220,7 @@ type FlushAfterChunkWriter struct { } func parseHexUint(v []byte) (n uint64, err error) { - for _, b := range v { - n <<= 4 + for i, b := range v { switch { case '0' <= b && b <= '9': b = b - '0' @@ -232,6 +231,10 @@ func parseHexUint(v []byte) (n uint64, err error) { default: return 0, errors.New("invalid byte in chunk length") } + if i == 16 { + return 0, errors.New("http chunk length too large") + } + n <<= 4 n |= uint64(b) } return diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go index 7c1c91662f..a136dc99a6 100644 --- a/src/net/http/internal/chunked_test.go +++ b/src/net/http/internal/chunked_test.go @@ -139,19 +139,35 @@ func TestChunkReaderAllocs(t *testing.T) { } func TestParseHexUint(t *testing.T) { - for i := uint64(0); i <= 1234; i++ { - line := []byte(fmt.Sprintf("%x", i)) - got, err := parseHexUint(line) - if err != nil { - t.Fatalf("on %d: %v", i, err) - } - if got != i { - t.Errorf("for input %q = %d; want %d", line, got, i) - } + type testCase struct { + in string + want uint64 + wantErr string } - _, err := parseHexUint([]byte("bogus")) - if err == nil { - t.Error("expected error on bogus input") + tests := []testCase{ + {"x", 0, "invalid byte in chunk length"}, + {"0000000000000000", 0, ""}, + {"0000000000000001", 1, ""}, + {"ffffffffffffffff", 1<<64 - 1, ""}, + {"000000000000bogus", 0, "invalid byte in chunk length"}, + {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted + {"10000000000000000", 0, "http chunk length too large"}, + {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted + } + for i := uint64(0); i <= 1234; i++ { + tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i}) + } + for _, tt := range tests { + got, err := parseHexUint([]byte(tt.in)) + if tt.wantErr != "" { + if !strings.Contains(fmt.Sprint(err), tt.wantErr) { + t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr) + } + } else { + if err != nil || got != tt.want { + t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want) + } + } } }