diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 866b093e4f..686c1cee05 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -466,16 +466,30 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, m tests.Implementat if len(locs) != len(m.Implementations) { t.Fatalf("got %d locations for implementation, expected %d", len(locs), len(m.Implementations)) } + + var results []span.Span for i := range locs { locURI := span.NewURI(locs[i].URI) lm, err := r.data.Mapper(locURI) if err != nil { t.Fatal(err) } - if imp, err := lm.Span(locs[i]); err != nil { + imp, err := lm.Span(locs[i]) + if err != nil { t.Fatalf("failed for %v: %v", locs[i], err) - } else if imp != m.Implementations[i] { - t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, imp, m.Implementations[i]) + } + results = append(results, imp) + } + // Sort results and expected to make tests deterministic. + sort.SliceStable(results, func(i, j int) bool { + return span.Compare(results[i], results[j]) == -1 + }) + sort.SliceStable(m.Implementations, func(i, j int) bool { + return span.Compare(m.Implementations[i], m.Implementations[j]) == -1 + }) + for i := range results { + if results[i] != m.Implementations[i] { + t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, results[i], m.Implementations[i]) } } } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index 7702b8e2c6..61e80cfb9c 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -12,8 +12,8 @@ package source import ( "context" "fmt" + "go/token" "go/types" - "sort" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/internal/lsp/protocol" @@ -27,10 +27,16 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio var objs []types.Object pkgs := map[types.Object]Package{} + // To ensure that we get one object per position, the seen map tracks object positions. + var seen map[token.Pos]bool if res.toMethod != nil { + seen = make(map[token.Pos]bool, len(res.toMethod)) // If we looked up a method, results are in toMethod. for _, s := range res.toMethod { + if seen[s.Obj().Pos()] { + continue + } // Determine package of receiver. recv := s.Recv() if p, ok := recv.(*types.Pointer); ok { @@ -42,8 +48,10 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio } // Add object to objs. objs = append(objs, s.Obj()) + seen[s.Obj().Pos()] = true } } else { + seen = make(map[token.Pos]bool, len(res.to)) // Otherwise, the results are in to. for _, t := range res.to { // We'll provide implementations that are named types and pointers to named types. @@ -51,15 +59,18 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio t = p.Elem() } if n, ok := t.(*types.Named); ok { + if seen[n.Obj().Pos()] { + continue + } pkg := res.pkgs[n] - objs = append(objs, n.Obj()) pkgs[n.Obj()] = pkg + objs = append(objs, n.Obj()) + seen[n.Obj().Pos()] = true } } } var locations []protocol.Location - for _, obj := range objs { pkg := pkgs[obj] if pkgs[obj] == nil || len(pkg.CompiledGoFiles()) == 0 { @@ -86,9 +97,9 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio Range: decRange, }) } - return locations, nil } + func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, error) { var T types.Type var method *types.Func @@ -174,12 +185,6 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, } } } - - // Sort types (arbitrarily) to ensure test determinism. - sort.Sort(typesByString(to)) - sort.Sort(typesByString(from)) - sort.Sort(typesByString(fromPtr)) - var toMethod []*types.Selection // contain nils if method != nil { for _, t := range to { @@ -187,7 +192,6 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, types.NewMethodSet(t).Lookup(method.Pkg(), method.Name())) } } - return implementsResult{pkgs, to, from, fromPtr, toMethod}, nil } @@ -199,9 +203,3 @@ type implementsResult struct { fromPtr []types.Type // named interfaces assignable only from *T toMethod []*types.Selection } - -type typesByString []types.Type - -func (p typesByString) Len() int { return len(p) } -func (p typesByString) Less(i, j int) bool { return p[i].String() < p[j].String() } -func (p typesByString) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index c9a56fa099..cda30b4db7 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -573,16 +573,29 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, m tests.Implementat if len(locs) != len(m.Implementations) { t.Fatalf("got %d locations for implementation, expected %d", len(locs), len(m.Implementations)) } + var results []span.Span for i := range locs { locURI := span.NewURI(locs[i].URI) lm, err := r.data.Mapper(locURI) if err != nil { t.Fatal(err) } - if imp, err := lm.Span(locs[i]); err != nil { + imp, err := lm.Span(locs[i]) + if err != nil { t.Fatalf("failed for %v: %v", locs[i], err) - } else if imp != m.Implementations[i] { - t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, imp, m.Implementations[i]) + } + results = append(results, imp) + } + // Sort results and expected to make tests deterministic. + sort.SliceStable(results, func(i, j int) bool { + return span.Compare(results[i], results[j]) == -1 + }) + sort.SliceStable(m.Implementations, func(i, j int) bool { + return span.Compare(m.Implementations[i], m.Implementations[j]) == -1 + }) + for i := range results { + if results[i] != m.Implementations[i] { + t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, results[i], m.Implementations[i]) } } }