From a27fdba2771b0ff59adfc5ea08ce9eac8387fc01 Mon Sep 17 00:00:00 2001 From: Muir Manders Date: Sat, 7 Dec 2019 22:07:30 -0800 Subject: [PATCH] internal/lsp: check all package variants in find-implementations We previously only searched for implementations of the object we found in the "widest" package variant. We instead need to search all variants because each variant is type checked separately, and implementations can be located in packages associated with different variants. For example, say you have: -- foo/foo.go -- package foo type Foo int type Fooer interface { Foo() Foo } -- foo/foo_test.go -- package foo func TestFoo(t *testing.T) {} -- bar/bar.go -- package bar import "foo" type impl struct {} func (impl) Foo() foo.Foo { return 0 } When you run find-implementations on the Fooer interface, we previously would start from the (widest) foo.test's Fooer named type. Unfortunately bar imports foo, not foo.test, so bar.impl does not implement foo.test.Fooer. The specific reason is that bar.impl.Foo returns foo.Foo, whereas foo.test.Fooer.Foo returns foo.test.Foo, which are distinct *types.Named objects. Starting our search instead from foo.Fooer resolves this issue. However, we also need to search from foo.test.Fooer so we match any implementations in foo_test.go. Change-Id: I0b0039c98925410751c8f643c8ebd185340e409f Reviewed-on: https://go-review.googlesource.com/c/tools/+/210459 Run-TryBot: Muir Manders TryBot-Result: Gobot Gobot Reviewed-by: Rebecca Stambler --- internal/lsp/definition.go | 4 +-- internal/lsp/hover.go | 2 +- internal/lsp/implementation.go | 42 +++++++++++++++++++++++++-- internal/lsp/references.go | 2 +- internal/lsp/rename.go | 4 +-- internal/lsp/source/identifier.go | 8 +++-- internal/lsp/source/implementation.go | 5 +++- internal/lsp/source/source_test.go | 10 +++---- internal/lsp/source/util.go | 18 +++++++++++- 9 files changed, 77 insertions(+), 18 deletions(-) diff --git a/internal/lsp/definition.go b/internal/lsp/definition.go index 78aae36130..ab2b44223d 100644 --- a/internal/lsp/definition.go +++ b/internal/lsp/definition.go @@ -23,7 +23,7 @@ func (s *Server) definition(ctx context.Context, params *protocol.DefinitionPara if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (s *Server) typeDefinition(ctx context.Context, params *protocol.TypeDefini if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } diff --git a/internal/lsp/hover.go b/internal/lsp/hover.go index 7334d92425..b8ab7025ce 100644 --- a/internal/lsp/hover.go +++ b/internal/lsp/hover.go @@ -26,7 +26,7 @@ func (s *Server) hover(ctx context.Context, params *protocol.HoverParams) (*prot if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, nil } diff --git a/internal/lsp/implementation.go b/internal/lsp/implementation.go index 9af0f7057a..010be7da43 100644 --- a/internal/lsp/implementation.go +++ b/internal/lsp/implementation.go @@ -9,7 +9,9 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/span" + "golang.org/x/tools/internal/telemetry/log" ) func (s *Server) implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) { @@ -23,9 +25,45 @@ func (s *Server) implementation(ctx context.Context, params *protocol.Implementa if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + + phs, err := snapshot.PackageHandles(ctx, snapshot.Handle(ctx, f)) if err != nil { return nil, err } - return ident.Implementation(ctx) + + var ( + allLocs []protocol.Location + seen = make(map[protocol.Location]bool) + ) + for _, ph := range phs { + ctx := telemetry.Package.With(ctx, ph.ID()) + + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.SpecificPackageHandle(ph.ID())) + if err != nil { + if err == source.ErrNoIdentFound { + return nil, err + } + log.Error(ctx, "failed to find Identifer", err) + continue + } + + locs, err := ident.Implementation(ctx) + if err != nil { + if err == source.ErrNotAMethod { + return nil, err + } + log.Error(ctx, "failed to find Implemenation", err) + continue + } + + for _, loc := range locs { + if seen[loc] { + continue + } + seen[loc] = true + allLocs = append(allLocs, loc) + } + } + + return allLocs, nil } diff --git a/internal/lsp/references.go b/internal/lsp/references.go index 2dc37f2bdc..ea6754b5a1 100644 --- a/internal/lsp/references.go +++ b/internal/lsp/references.go @@ -26,7 +26,7 @@ func (s *Server) references(ctx context.Context, params *protocol.ReferenceParam return nil, err } // Find all references to the identifier at the position. - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } diff --git a/internal/lsp/rename.go b/internal/lsp/rename.go index 1db1abdc99..18f9c79464 100644 --- a/internal/lsp/rename.go +++ b/internal/lsp/rename.go @@ -23,7 +23,7 @@ func (s *Server) rename(ctx context.Context, params *protocol.RenameParams) (*pr if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func (s *Server) prepareRename(ctx context.Context, params *protocol.PrepareRena if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, nil // ignore errors } diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go index a942f5e5a5..26d81837d0 100644 --- a/internal/lsp/source/identifier.go +++ b/internal/lsp/source/identifier.go @@ -58,11 +58,11 @@ func (i *IdentifierInfo) DeclarationReferenceInfo() *ReferenceInfo { // Identifier returns identifier information for a position // in a file, accounting for a potentially incomplete selector. -func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position) (*IdentifierInfo, error) { +func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position, selectPackage PackagePolicy) (*IdentifierInfo, error) { ctx, done := trace.StartSpan(ctx, "source.Identifier") defer done() - pkg, pgh, err := getParsedFile(ctx, snapshot, f, WidestCheckPackageHandle) + pkg, pgh, err := getParsedFile(ctx, snapshot, f, selectPackage) if err != nil { return nil, fmt.Errorf("getting file for Identifier: %v", err) } @@ -81,6 +81,8 @@ func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Pos return findIdentifier(snapshot, pkg, file, rng.Start) } +var ErrNoIdentFound = errors.New("no identifier found") + func findIdentifier(snapshot Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { if result, err := identifier(snapshot, pkg, file, pos); err != nil || result != nil { return result, err @@ -90,7 +92,7 @@ func findIdentifier(snapshot Snapshot, pkg Package, file *ast.File, pos token.Po // requesting a completion), use the path to the preceding node. ident, err := identifier(snapshot, pkg, file, pos-1) if ident == nil && err == nil { - err = errors.New("no identifier found") + err = ErrNoIdentFound } return ident, err } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index ed8b16383f..e1578cdb89 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -19,6 +19,7 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/telemetry/log" + errors "golang.org/x/xerrors" ) func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Location, error) { @@ -101,6 +102,8 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio return locations, nil } +var ErrNotAMethod = errors.New("this function is not a method") + func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, error) { var T types.Type var method *types.Func @@ -112,7 +115,7 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, } recv := obj.Type().(*types.Signature).Recv() if recv == nil { - return implementsResult{}, fmt.Errorf("this function is not a method") + return implementsResult{}, ErrNotAMethod } method = obj T = recv.Type() diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index bdb3c66cdc..f30e2914bc 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -499,7 +499,7 @@ func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", d.Src, err) } @@ -562,7 +562,7 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, impls []span.Span) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } @@ -649,7 +649,7 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", src, err) } @@ -693,7 +693,7 @@ func (r *runner) Rename(t *testing.T, spn span.Span, newText string) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Error(err) return @@ -782,7 +782,7 @@ func (r *runner) PrepareRename(t *testing.T, src span.Span, want *source.Prepare t.Fatal(err) } // Find the identifier at the position. - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { if want.Text != "" { // expected an ident. t.Errorf("prepare rename failed for %v: got error: %v", src, err) diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 191e529da5..ada8444419 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -67,7 +67,7 @@ func (s mappedRange) URI() span.URI { // getParsedFile is a convenience function that extracts the Package and ParseGoHandle for a File in a Snapshot. // selectPackage is typically Narrowest/WidestCheckPackageHandle below. -func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage func([]PackageHandle) (PackageHandle, error)) (Package, ParseGoHandle, error) { +func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage PackagePolicy) (Package, ParseGoHandle, error) { fh := snapshot.Handle(ctx, f) phs, err := snapshot.PackageHandles(ctx, fh) if err != nil { @@ -85,6 +85,8 @@ func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage return pkg, pgh, err } +type PackagePolicy func([]PackageHandle) (PackageHandle, error) + // NarrowestCheckPackageHandle picks the "narrowest" package for a given file. // // By "narrowest" package, we mean the package with the fewest number of files @@ -126,6 +128,20 @@ func WidestCheckPackageHandle(handles []PackageHandle) (PackageHandle, error) { return result, nil } +// SpecificPackageHandle creates a PackagePolicy to select a +// particular PackageHandle when you alread know the one you want. +func SpecificPackageHandle(desiredID string) PackagePolicy { + return func(handles []PackageHandle) (PackageHandle, error) { + for _, h := range handles { + if h.ID() == desiredID { + return h, nil + } + } + + return nil, fmt.Errorf("no package handle with expected id %q", desiredID) + } +} + func IsGenerated(ctx context.Context, view View, uri span.URI) bool { f, err := view.GetFile(ctx, uri) if err != nil {