diff --git a/src/go/token/position.go b/src/go/token/position.go index f5a43aecef..e9f1f5561b 100644 --- a/src/go/token/position.go +++ b/src/go/token/position.go @@ -429,7 +429,7 @@ func (f *File) Position(p Pos) (pos Position) { type FileSet struct { mutex sync.RWMutex // protects the file set base int // base offset for the next file - files []*File // list of files in the order added to the set + tree tree // tree of files in ascending base order last atomic.Pointer[File] // cache of last file looked up } @@ -487,7 +487,7 @@ func (s *FileSet) AddFile(filename string, base, size int) *File { } // add the file to the file set s.base = base - s.files = append(s.files, f) + s.tree.add(f) s.last.Store(f) return f } @@ -518,40 +518,9 @@ func (s *FileSet) AddExistingFiles(files ...*File) { s.mutex.Lock() defer s.mutex.Unlock() - // Merge and sort. - newFiles := append(s.files, files...) - slices.SortFunc(newFiles, func(x, y *File) int { - return cmp.Compare(x.Base(), y.Base()) - }) - - // Reject overlapping files. - // Discard adjacent identical files. - out := newFiles[:0] - for i, file := range newFiles { - if i > 0 { - prev := newFiles[i-1] - if file == prev { - continue - } - if prev.Base()+prev.Size()+1 > file.Base() { - panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)", - prev.Name(), prev.Base(), prev.Base()+prev.Size(), - file.Name(), file.Base(), file.Base()+file.Size())) - } - } - out = append(out, file) - } - newFiles = out - - s.files = newFiles - - // Advance base. - if len(newFiles) > 0 { - last := newFiles[len(newFiles)-1] - newBase := last.Base() + last.Size() + 1 - if s.base < newBase { - s.base = newBase - } + for _, f := range files { + s.tree.add(f) + s.base = max(s.base, f.Base()+f.Size()+1) } } @@ -567,39 +536,26 @@ func (s *FileSet) RemoveFile(file *File) { s.mutex.Lock() defer s.mutex.Unlock() - if i := searchFiles(s.files, file.base); i >= 0 && s.files[i] == file { - last := &s.files[len(s.files)-1] - s.files = slices.Delete(s.files, i, i+1) - *last = nil // don't prolong lifetime when popping last element + pn, _ := s.tree.locate(file.key()) + if *pn != nil && (*pn).file == file { + s.tree.delete(pn) } } -// Iterate calls f for the files in the file set in the order they were added -// until f returns false. -func (s *FileSet) Iterate(f func(*File) bool) { - for i := 0; ; i++ { - var file *File - s.mutex.RLock() - if i < len(s.files) { - file = s.files[i] - } +// Iterate calls yield for the files in the file set in ascending Base +// order until yield returns false. +func (s *FileSet) Iterate(yield func(*File) bool) { + s.mutex.RLock() + defer s.mutex.RUnlock() + + // Unlock around user code. + // The iterator is robust to modification by yield. + // Avoid range here, so we can use defer. + s.tree.all()(func(f *File) bool { s.mutex.RUnlock() - if file == nil || !f(file) { - break - } - } -} - -func searchFiles(a []*File, x int) int { - i, found := slices.BinarySearchFunc(a, x, func(a *File, x int) int { - return cmp.Compare(a.base, x) + defer s.mutex.RLock() + return yield(f) }) - if !found { - // We want the File containing x, but if we didn't - // find x then i is the next one. - i-- - } - return i } func (s *FileSet) file(p Pos) *File { @@ -611,16 +567,12 @@ func (s *FileSet) file(p Pos) *File { s.mutex.RLock() defer s.mutex.RUnlock() - // p is not in last file - search all files - if i := searchFiles(s.files, int(p)); i >= 0 { - f := s.files[i] - // f.base <= int(p) by definition of searchFiles - if int(p) <= f.base+f.size { - // Update cache of last file. A race is ok, - // but an exclusive lock causes heavy contention. - s.last.Store(f) - return f - } + pn, _ := s.tree.locate(key{int(p), int(p)}) + if n := *pn; n != nil { + // Update cache of last file. A race is ok, + // but an exclusive lock causes heavy contention. + s.last.Store(n.file) + return n.file } return nil } diff --git a/src/go/token/position_bench_test.go b/src/go/token/position_bench_test.go index 7bb9de8946..add0783832 100644 --- a/src/go/token/position_bench_test.go +++ b/src/go/token/position_bench_test.go @@ -84,15 +84,15 @@ func BenchmarkFileSet_Position(b *testing.B) { } func BenchmarkFileSet_AddExistingFiles(b *testing.B) { + rng := rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64())) + // Create the "universe" of files. fset := token.NewFileSet() var files []*token.File for range 25000 { files = append(files, fset.AddFile("", -1, 10000)) } - rand.Shuffle(len(files), func(i, j int) { - files[i], files[j] = files[j], files[i] - }) + token.Shuffle(rng, files) // choose returns n random files. choose := func(n int) []*token.File { diff --git a/src/go/token/serialize.go b/src/go/token/serialize.go index 04a48d90f8..2c06f8c25c 100644 --- a/src/go/token/serialize.go +++ b/src/go/token/serialize.go @@ -4,6 +4,8 @@ package token +import "slices" + type serializedFile struct { // fields correspond 1:1 to fields with same (lower-case) name in File Name string @@ -27,18 +29,15 @@ func (s *FileSet) Read(decode func(any) error) error { s.mutex.Lock() s.base = ss.Base - files := make([]*File, len(ss.Files)) - for i := 0; i < len(ss.Files); i++ { - f := &ss.Files[i] - files[i] = &File{ + for _, f := range ss.Files { + s.tree.add(&File{ name: f.Name, base: f.Base, size: f.Size, lines: f.Lines, infos: f.Infos, - } + }) } - s.files = files s.last.Store(nil) s.mutex.Unlock() @@ -51,16 +50,16 @@ func (s *FileSet) Write(encode func(any) error) error { s.mutex.Lock() ss.Base = s.base - files := make([]serializedFile, len(s.files)) - for i, f := range s.files { + var files []serializedFile + for f := range s.tree.all() { f.mutex.Lock() - files[i] = serializedFile{ + files = append(files, serializedFile{ Name: f.name, Base: f.base, Size: f.size, - Lines: append([]int(nil), f.lines...), - Infos: append([]lineInfo(nil), f.infos...), - } + Lines: slices.Clone(f.lines), + Infos: slices.Clone(f.infos), + }) f.mutex.Unlock() } ss.Files = files diff --git a/src/go/token/serialize_test.go b/src/go/token/serialize_test.go index 8d9799547a..5b64c58a82 100644 --- a/src/go/token/serialize_test.go +++ b/src/go/token/serialize_test.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/gob" "fmt" + "slices" "testing" ) @@ -29,12 +30,14 @@ func equal(p, q *FileSet) error { return fmt.Errorf("different bases: %d != %d", p.base, q.base) } - if len(p.files) != len(q.files) { - return fmt.Errorf("different number of files: %d != %d", len(p.files), len(q.files)) + pfiles := slices.Collect(p.tree.all()) + qfiles := slices.Collect(q.tree.all()) + if len(pfiles) != len(qfiles) { + return fmt.Errorf("different number of files: %d != %d", len(pfiles), len(qfiles)) } - for i, f := range p.files { - g := q.files[i] + for i, f := range pfiles { + g := qfiles[i] if f.name != g.name { return fmt.Errorf("different filenames: %q != %q", f.name, g.name) } @@ -88,7 +91,7 @@ func TestSerialization(t *testing.T) { p := NewFileSet() checkSerialize(t, p) // add some files - for i := 0; i < 10; i++ { + for i := range 10 { f := p.AddFile(fmt.Sprintf("file%d", i), p.Base()+i, i*100) checkSerialize(t, p) // add some lines and alternative file infos diff --git a/src/go/token/tree.go b/src/go/token/tree.go new file mode 100644 index 0000000000..5c00dcf2df --- /dev/null +++ b/src/go/token/tree.go @@ -0,0 +1,405 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package token + +// tree is a self-balancing AVL tree; see +// Lewis & Denenberg, Data Structures and Their Algorithms. +// +// An AVL tree is a binary tree in which the difference between the +// heights of a node's two subtrees--the node's "balance factor"--is +// at most one. It is more strictly balanced than a red/black tree, +// and thus favors lookups at the expense of updates, which is the +// appropriate trade-off for FileSet. +// +// Insertion at a node may cause its ancestors' balance factors to +// temporarily reach ±2, requiring rebalancing of each such ancestor +// by a rotation. +// +// Each key is the pos-end range of a single File. +// All Files in the tree must have disjoint ranges. +// +// The implementation is simplified from Russ Cox's github.com/rsc/omap. + +import ( + "fmt" + "iter" +) + +// A tree is a tree-based ordered map: +// each value is a *File, keyed by its Pos range. +// All map entries cover disjoint ranges. +// +// The zero value of tree is an empty map ready to use. +type tree struct { + root *node +} + +type node struct { + // We use the notation (parent left right) in many comments. + parent *node + left *node + right *node + file *File + key key // = file.key(), but improves locality (25% faster) + balance int32 // at most ±2 + height int32 +} + +// A key represents the Pos range of a File. +type key struct{ start, end int } + +func (f *File) key() key { + return key{f.base, f.base + f.size} +} + +// compareKey reports whether x is before y (-1), +// after y (+1), or overlapping y (0). +// This is a total order so long as all +// files in the tree have disjoint ranges. +// +// All files are separated by at least one unit. +// This allows us to use strict < comparisons. +// Use key{p, p} to search for a zero-width position +// even at the start or end of a file. +func compareKey(x, y key) int { + switch { + case x.end < y.start: + return -1 + case y.end < x.start: + return +1 + } + return 0 +} + +// check asserts that each node's height, subtree, and parent link is +// correct. +func (n *node) check(parent *node) { + const debugging = false + if debugging { + if n == nil { + return + } + if n.parent != parent { + panic("bad parent") + } + n.left.check(n) + n.right.check(n) + n.checkBalance() + } +} + +func (n *node) checkBalance() { + lheight, rheight := n.left.safeHeight(), n.right.safeHeight() + balance := rheight - lheight + if balance != n.balance { + panic("bad node.balance") + } + if !(-2 <= balance && balance <= +2) { + panic(fmt.Sprintf("node.balance out of range: %d", balance)) + } + h := 1 + max(lheight, rheight) + if h != n.height { + panic("bad node.height") + } +} + +// locate returns a pointer to the variable that holds the node +// identified by k, along with its parent, if any. If the key is not +// present, it returns a pointer to the node where the key should be +// inserted by a subsequent call to [tree.set]. +func (t *tree) locate(k key) (pos **node, parent *node) { + pos, x := &t.root, t.root + for x != nil { + sign := compareKey(k, x.key) + if sign < 0 { + pos, x, parent = &x.left, x.left, x + } else if sign > 0 { + pos, x, parent = &x.right, x.right, x + } else { + break + } + } + return pos, parent +} + +// all returns an iterator over the tree t. +// If t is modified during the iteration, +// some files may not be visited. +// No file will be visited multiple times. +func (t *tree) all() iter.Seq[*File] { + return func(yield func(*File) bool) { + if t == nil { + return + } + x := t.root + if x != nil { + for x.left != nil { + x = x.left + } + } + for x != nil && yield(x.file) { + if x.height >= 0 { + // still in tree + x = x.next() + } else { + // deleted + x = t.nextAfter(t.locate(x.key)) + } + } + } +} + +// nextAfter returns the node in the key sequence following +// (pos, parent), a result pair from [tree.locate]. +func (t *tree) nextAfter(pos **node, parent *node) *node { + switch { + case *pos != nil: + return (*pos).next() + case parent == nil: + return nil + case pos == &parent.left: + return parent + default: + return parent.next() + } +} + +func (x *node) next() *node { + if x.right == nil { + for x.parent != nil && x.parent.right == x { + x = x.parent + } + return x.parent + } + x = x.right + for x.left != nil { + x = x.left + } + return x +} + +func (t *tree) setRoot(x *node) { + t.root = x + if x != nil { + x.parent = nil + } +} + +func (x *node) setLeft(y *node) { + x.left = y + if y != nil { + y.parent = x + } +} + +func (x *node) setRight(y *node) { + x.right = y + if y != nil { + y.parent = x + } +} + +func (n *node) safeHeight() int32 { + if n == nil { + return -1 + } + return n.height +} + +func (n *node) update() { + lheight, rheight := n.left.safeHeight(), n.right.safeHeight() + n.height = max(lheight, rheight) + 1 + n.balance = rheight - lheight +} + +func (t *tree) replaceChild(parent, old, new *node) { + switch { + case parent == nil: + if t.root != old { + panic("corrupt tree") + } + t.setRoot(new) + case parent.left == old: + parent.setLeft(new) + case parent.right == old: + parent.setRight(new) + default: + panic("corrupt tree") + } +} + +// rebalanceUp visits each excessively unbalanced ancestor +// of x, restoring balance by rotating it. +// +// x is a node that has just been mutated, and so the height and +// balance of x and its ancestors may be stale, but the children of x +// must be in a valid state. +func (t *tree) rebalanceUp(x *node) { + for x != nil { + h := x.height + x.update() + switch x.balance { + case -2: + if x.left.balance == 1 { + t.rotateLeft(x.left) + } + x = t.rotateRight(x) + + case +2: + if x.right.balance == -1 { + t.rotateRight(x.right) + } + x = t.rotateLeft(x) + } + if x.height == h { + // x's height has not changed, so the height + // and balance of its ancestors have not changed; + // no further rebalancing is required. + return + } + x = x.parent + } +} + +// rotateRight rotates the subtree rooted at node y. +// turning (y (x a b) c) into (x a (y b c)). +func (t *tree) rotateRight(y *node) *node { + // p -> (y (x a b) c) + p := y.parent + x := y.left + b := x.right + + x.checkBalance() + y.checkBalance() + + x.setRight(y) + y.setLeft(b) + t.replaceChild(p, y, x) + + y.update() + x.update() + return x +} + +// rotateLeft rotates the subtree rooted at node x. +// turning (x a (y b c)) into (y (x a b) c). +func (t *tree) rotateLeft(x *node) *node { + // p -> (x a (y b c)) + p := x.parent + y := x.right + b := y.left + + x.checkBalance() + y.checkBalance() + + y.setLeft(x) + x.setRight(b) + t.replaceChild(p, x, y) + + x.update() + y.update() + return y +} + +// add inserts file into the tree, if not present. +// It panics if file overlaps with another. +func (t *tree) add(file *File) { + pos, parent := t.locate(file.key()) + if *pos == nil { + t.set(file, pos, parent) // missing; insert + return + } + if prev := (*pos).file; prev != file { + panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)", + prev.Name(), prev.Base(), prev.Base()+prev.Size(), + file.Name(), file.Base(), file.Base()+file.Size())) + } +} + +// set updates the existing node at (pos, parent) if present, or +// inserts a new node if not, so that it refers to file. +func (t *tree) set(file *File, pos **node, parent *node) { + if x := *pos; x != nil { + // This code path isn't currently needed + // because FileSet never updates an existing entry. + // Remove this assertion if things change. + panic("unreachable according to current FileSet requirements") + x.file = file + return + } + x := &node{file: file, key: file.key(), parent: parent, height: -1} + *pos = x + t.rebalanceUp(x) +} + +// delete deletes the node at pos. +func (t *tree) delete(pos **node) { + t.root.check(nil) + + x := *pos + switch { + case x == nil: + // This code path isn't currently needed because FileSet + // only calls delete after a positive locate. + // Remove this assertion if things change. + panic("unreachable according to current FileSet requirements") + return + + case x.left == nil: + if *pos = x.right; *pos != nil { + (*pos).parent = x.parent + } + t.rebalanceUp(x.parent) + + case x.right == nil: + *pos = x.left + x.left.parent = x.parent + t.rebalanceUp(x.parent) + + default: + t.deleteSwap(pos) + } + + x.balance = -100 + x.parent = nil + x.left = nil + x.right = nil + x.height = -1 + t.root.check(nil) +} + +// deleteSwap deletes a node that has two children by replacing +// it by its in-order successor, then triggers a rebalance. +func (t *tree) deleteSwap(pos **node) { + x := *pos + z := t.deleteMin(&x.right) + + *pos = z + unbalanced := z.parent // lowest potentially unbalanced node + if unbalanced == x { + unbalanced = z // (x a (z nil b)) -> (z a b) + } + z.parent = x.parent + z.height = x.height + z.balance = x.balance + z.setLeft(x.left) + z.setRight(x.right) + + t.rebalanceUp(unbalanced) +} + +// deleteMin updates *zpos to the minimum (leftmost) element +// in that subtree. +func (t *tree) deleteMin(zpos **node) (z *node) { + for (*zpos).left != nil { + zpos = &(*zpos).left + } + z = *zpos + *zpos = z.right + if *zpos != nil { + (*zpos).parent = z.parent + } + return z +} diff --git a/src/go/token/tree_test.go b/src/go/token/tree_test.go new file mode 100644 index 0000000000..4bb9f060a1 --- /dev/null +++ b/src/go/token/tree_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package token + +import ( + "math/rand/v2" + "slices" + "testing" +) + +// TestTree provides basic coverage of the AVL tree operations. +func TestTree(t *testing.T) { + // Use a reproducible PRNG. + seed1, seed2 := rand.Uint64(), rand.Uint64() + t.Logf("random seeds: %d, %d", seed1, seed2) + rng := rand.New(rand.NewPCG(seed1, seed2)) + + // Create a number of Files of arbitrary size. + files := make([]*File, 500) + var base int + for i := range files { + base++ + size := 1000 + files[i] = &File{base: base, size: size} + base += size + } + + // Add them all to the tree in random order. + var tr tree + { + files2 := slices.Clone(files) + Shuffle(rng, files2) + for _, f := range files2 { + tr.add(f) + } + } + + // Randomly delete a subset of them. + for range 100 { + i := rng.IntN(len(files)) + file := files[i] + if file == nil { + continue // already deleted + } + files[i] = nil + + pn, _ := tr.locate(file.key()) + if (*pn).file != file { + t.Fatalf("locate returned wrong file") + } + tr.delete(pn) + } + + // Check some position lookups within each file. + for _, file := range files { + if file == nil { + continue // deleted + } + for _, pos := range []int{ + file.base, // start + file.base + file.size/2, // midpoint + file.base + file.size, // end + } { + pn, _ := tr.locate(key{pos, pos}) + if (*pn).file != file { + t.Fatalf("lookup %s@%d returned wrong file %s", + file.name, pos, + (*pn).file.name) + } + } + } + + // Check that the sequence is the same. + files = slices.DeleteFunc(files, func(f *File) bool { return f == nil }) + if !slices.Equal(slices.Collect(tr.all()), files) { + t.Fatalf("incorrect tree.all sequence") + } +} + +func Shuffle[T any](rng *rand.Rand, slice []*T) { + rng.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) +}