diff --git a/src/net/http/routing_index.go b/src/net/http/routing_index.go new file mode 100644 index 0000000000..9ac42c997d --- /dev/null +++ b/src/net/http/routing_index.go @@ -0,0 +1,124 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import "math" + +// A routingIndex optimizes conflict detection by indexing patterns. +// +// The basic idea is to rule out patterns that cannot conflict with a given +// pattern because they have a different literal in a corresponding segment. +// See the comments in [routingIndex.possiblyConflictingPatterns] for more details. +type routingIndex struct { + // map from a particular segment position and value to all registered patterns + // with that value in that position. + // For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c" + // but not "/a", "b/a", "/a/c" or "/a/{x}". + segments map[routingIndexKey][]*pattern + // All patterns that end in a multi wildcard (including trailing slash). + // We do not try to be clever about indexing multi patterns, because there + // are unlikely to be many of them. + multis []*pattern +} + +type routingIndexKey struct { + pos int // 0-based segment position + s string // literal, or empty for wildcard +} + +func (idx *routingIndex) addPattern(pat *pattern) { + if pat.lastSegment().multi { + idx.multis = append(idx.multis, pat) + } else { + if idx.segments == nil { + idx.segments = map[routingIndexKey][]*pattern{} + } + for pos, seg := range pat.segments { + key := routingIndexKey{pos: pos, s: ""} + if !seg.wild { + key.s = seg.s + } + idx.segments[key] = append(idx.segments[key], pat) + } + } +} + +// possiblyConflictingPatterns calls f on all patterns that might conflict with +// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately +// with that error. +// +// To be correct, possiblyConflictingPatterns must include all patterns that +// might conflict. But it may also include patterns that cannot conflict. +// For instance, an implementation that returns all registered patterns is correct. +// We use this fact throughout, simplifying the implementation by returning more +// patterns that we might need to. +func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) { + // Terminology: + // dollar pattern: one ending in "{$}" + // multi pattern: one ending in a trailing slash or "{x...}" wildcard + // ordinary pattern: neither of the above + + // apply f to all the pats, stopping on error. + apply := func(pats []*pattern) error { + if err != nil { + return err + } + for _, p := range pats { + err = f(p) + if err != nil { + return err + } + } + return nil + } + + // Our simple indexing scheme doesn't try to prune multi patterns; assume + // any of them can match the argument. + if err := apply(idx.multis); err != nil { + return err + } + if pat.lastSegment().s == "/" { + // All paths that a dollar pattern matches end in a slash; no paths that + // an ordinary pattern matches do. So only other dollar or multi + // patterns can conflict with a dollar pattern. Furthermore, conflicting + // dollar patterns must have the {$} in the same position. + return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}]) + } + // For ordinary and multi patterns, the only conflicts can be with a multi, + // or a pattern that has the same literal or a wildcard at some literal + // position. + // We could intersect all the possible matches at each position, but we + // do something simpler: we find the position with the fewest patterns. + var lmin, wmin []*pattern + min := math.MaxInt + hasLit := false + for i, seg := range pat.segments { + if seg.multi { + break + } + if !seg.wild { + hasLit = true + lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}] + wpats := idx.segments[routingIndexKey{s: "", pos: i}] + if sum := len(lpats) + len(wpats); sum < min { + lmin = lpats + wmin = wpats + min = sum + } + } + } + if hasLit { + apply(lmin) + apply(wmin) + return err + } + + // This pattern is all wildcards. + // Check it against everything. + for _, pats := range idx.segments { + apply(pats) + } + return err +} diff --git a/src/net/http/routing_index_test.go b/src/net/http/routing_index_test.go new file mode 100644 index 0000000000..7030fc8a67 --- /dev/null +++ b/src/net/http/routing_index_test.go @@ -0,0 +1,207 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "fmt" + "slices" + "sort" + "strings" + "testing" +) + +func TestIndex(t *testing.T) { + pats := []string{"HEAD /", "/a"} + + var patterns []*pattern + var idx routingIndex + for _, p := range pats { + pat := mustParsePattern(t, p) + patterns = append(patterns, pat) + idx.addPattern(pat) + } + + compare := func(pat *pattern) { + t.Helper() + got := indexConflicts(pat, &idx) + want := trueConflicts(pat, patterns) + if !slices.Equal(got, want) { + t.Errorf("%q:\ngot %q\nwant %q", pat, got, want) + } + } + + compare(mustParsePattern(t, "GET /foo")) + compare(mustParsePattern(t, "GET /{x}")) +} + +// This test works by comparing possiblyConflictingPatterns with +// an exhaustive loop through all patterns. +func FuzzIndex(f *testing.F) { + inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}", + "/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"} + + var patterns []*pattern + var idx routingIndex + + // compare takes a fatalf function because fuzzing doesn't like + // it when the fuzz function calls f.Fatalf. + compare := func(pat *pattern, fatalf func(string, ...any)) { + got := indexConflicts(pat, &idx) + want := trueConflicts(pat, patterns) + if !slices.Equal(got, want) { + fatalf("%q:\ngot %q\nwant %q", pat, got, want) + } + } + + for _, p := range inits { + pat, err := parsePattern(p) + if err != nil { + f.Fatal(err) + } + compare(pat, f.Fatalf) + patterns = append(patterns, pat) + idx.addPattern(pat) + f.Add(bytesFromPattern(pat)) + } + + f.Fuzz(func(t *testing.T, pb []byte) { + pat := bytesToPattern(pb) + if pat == nil { + return + } + compare(pat, t.Fatalf) + }) +} + +func trueConflicts(pat *pattern, pats []*pattern) []string { + var s []string + for _, p := range pats { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + } + sort.Strings(s) + return s +} + +func indexConflicts(pat *pattern, idx *routingIndex) []string { + var s []string + idx.possiblyConflictingPatterns(pat, func(p *pattern) error { + if pat.conflictsWith(p) { + s = append(s, p.String()) + } + return nil + }) + sort.Strings(s) + return slices.Compact(s) +} + +// TODO: incorporate host and method; make encoding denser. +func bytesToPattern(bs []byte) *pattern { + if len(bs) == 0 { + return nil + } + var sb strings.Builder + wc := 0 + for _, b := range bs[:len(bs)-1] { + sb.WriteByte('/') + switch b & 0x3 { + case 0: + fmt.Fprintf(&sb, "{x%d}", wc) + wc++ + case 1: + sb.WriteString("a") + case 2: + sb.WriteString("b") + case 3: + sb.WriteString("c") + } + } + sb.WriteByte('/') + switch bs[len(bs)-1] & 0x7 { + case 0: + fmt.Fprintf(&sb, "{x%d}", wc) + case 1: + sb.WriteString("a") + case 2: + sb.WriteString("b") + case 3: + sb.WriteString("c") + case 4, 5: + fmt.Fprintf(&sb, "{x%d...}", wc) + default: + sb.WriteString("{$}") + } + pat, err := parsePattern(sb.String()) + if err != nil { + panic(err) + } + return pat +} + +func bytesFromPattern(p *pattern) []byte { + var bs []byte + for _, s := range p.segments { + var b byte + switch { + case s.multi: + b = 4 + case s.wild: + b = 0 + case s.s == "/": + b = 7 + case s.s == "a": + b = 1 + case s.s == "b": + b = 2 + case s.s == "c": + b = 3 + default: + panic("bad pattern") + } + bs = append(bs, b) + } + return bs +} + +func TestBytesPattern(t *testing.T) { + tests := []struct { + bs []byte + pat string + }{ + {[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"}, + {[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"}, + {[]byte{4, 4}, "/{x0}/{x1...}"}, + {[]byte{6, 7}, "/b/{$}"}, + } + t.Run("To", func(t *testing.T) { + for _, test := range tests { + p := bytesToPattern(test.bs) + got := p.String() + if got != test.pat { + t.Errorf("%v: got %q, want %q", test.bs, got, test.pat) + } + } + }) + t.Run("From", func(t *testing.T) { + for _, test := range tests { + p, err := parsePattern(test.pat) + if err != nil { + t.Fatal(err) + } + got := bytesFromPattern(p) + var want []byte + for _, b := range test.bs[:len(test.bs)-1] { + want = append(want, b%4) + + } + want = append(want, test.bs[len(test.bs)-1]%8) + if !bytes.Equal(got, want) { + t.Errorf("%s: got %v, want %v", test.pat, got, want) + } + } + }) +} diff --git a/src/net/http/server.go b/src/net/http/server.go index bc5bcb9a71..629d8d3c62 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -2347,7 +2347,8 @@ func RedirectHandler(url string, code int) Handler { type ServeMux struct { mu sync.RWMutex tree routingNode - patterns []*pattern + index routingIndex + patterns []*pattern // TODO(jba): remove if possible } // NewServeMux allocates and returns a new ServeMux. @@ -2624,8 +2625,8 @@ func (mux *ServeMux) register(pattern string, handler Handler) { } } -func (mux *ServeMux) registerErr(pattern string, handler Handler) error { - if pattern == "" { +func (mux *ServeMux) registerErr(patstr string, handler Handler) error { + if patstr == "" { return errors.New("http: invalid pattern") } if handler == nil { @@ -2635,9 +2636,9 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error { return errors.New("http: nil handler") } - pat, err := parsePattern(pattern) + pat, err := parsePattern(patstr) if err != nil { - return fmt.Errorf("parsing %q: %w", pattern, err) + return fmt.Errorf("parsing %q: %w", patstr, err) } // Get the caller's location, for better conflict error messages. @@ -2652,16 +2653,17 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error { mux.mu.Lock() defer mux.mu.Unlock() // Check for conflict. - // This makes a quadratic number of calls to conflictsWith: we check - // each pattern against every other pattern. - // TODO(jba): add indexing to speed this up. - for _, pat2 := range mux.patterns { + if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error { if pat.conflictsWith(pat2) { return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)", pat, pat.loc, pat2, pat2.loc) } + return nil + }); err != nil { + return err } mux.tree.addPattern(pat, handler) + mux.index.addPattern(pat) mux.patterns = append(mux.patterns, pat) return nil } diff --git a/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da b/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da new file mode 100644 index 0000000000..06a7336a8d --- /dev/null +++ b/src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("101$") diff --git a/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 b/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 new file mode 100644 index 0000000000..520bff177b --- /dev/null +++ b/src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("1010")