diff --git a/internal/lsp/definition.go b/internal/lsp/definition.go index 78aae36130..ab2b44223d 100644 --- a/internal/lsp/definition.go +++ b/internal/lsp/definition.go @@ -23,7 +23,7 @@ func (s *Server) definition(ctx context.Context, params *protocol.DefinitionPara if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (s *Server) typeDefinition(ctx context.Context, params *protocol.TypeDefini if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } diff --git a/internal/lsp/hover.go b/internal/lsp/hover.go index 7334d92425..b8ab7025ce 100644 --- a/internal/lsp/hover.go +++ b/internal/lsp/hover.go @@ -26,7 +26,7 @@ func (s *Server) hover(ctx context.Context, params *protocol.HoverParams) (*prot if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, nil } diff --git a/internal/lsp/implementation.go b/internal/lsp/implementation.go index 9af0f7057a..010be7da43 100644 --- a/internal/lsp/implementation.go +++ b/internal/lsp/implementation.go @@ -9,7 +9,9 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/span" + "golang.org/x/tools/internal/telemetry/log" ) func (s *Server) implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) { @@ -23,9 +25,45 @@ func (s *Server) implementation(ctx context.Context, params *protocol.Implementa if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + + phs, err := snapshot.PackageHandles(ctx, snapshot.Handle(ctx, f)) if err != nil { return nil, err } - return ident.Implementation(ctx) + + var ( + allLocs []protocol.Location + seen = make(map[protocol.Location]bool) + ) + for _, ph := range phs { + ctx := telemetry.Package.With(ctx, ph.ID()) + + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.SpecificPackageHandle(ph.ID())) + if err != nil { + if err == source.ErrNoIdentFound { + return nil, err + } + log.Error(ctx, "failed to find Identifer", err) + continue + } + + locs, err := ident.Implementation(ctx) + if err != nil { + if err == source.ErrNotAMethod { + return nil, err + } + log.Error(ctx, "failed to find Implemenation", err) + continue + } + + for _, loc := range locs { + if seen[loc] { + continue + } + seen[loc] = true + allLocs = append(allLocs, loc) + } + } + + return allLocs, nil } diff --git a/internal/lsp/references.go b/internal/lsp/references.go index 2dc37f2bdc..ea6754b5a1 100644 --- a/internal/lsp/references.go +++ b/internal/lsp/references.go @@ -26,7 +26,7 @@ func (s *Server) references(ctx context.Context, params *protocol.ReferenceParam return nil, err } // Find all references to the identifier at the position. - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } diff --git a/internal/lsp/rename.go b/internal/lsp/rename.go index 1db1abdc99..18f9c79464 100644 --- a/internal/lsp/rename.go +++ b/internal/lsp/rename.go @@ -23,7 +23,7 @@ func (s *Server) rename(ctx context.Context, params *protocol.RenameParams) (*pr if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func (s *Server) prepareRename(ctx context.Context, params *protocol.PrepareRena if err != nil { return nil, err } - ident, err := source.Identifier(ctx, snapshot, f, params.Position) + ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle) if err != nil { return nil, nil // ignore errors } diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go index a942f5e5a5..26d81837d0 100644 --- a/internal/lsp/source/identifier.go +++ b/internal/lsp/source/identifier.go @@ -58,11 +58,11 @@ func (i *IdentifierInfo) DeclarationReferenceInfo() *ReferenceInfo { // Identifier returns identifier information for a position // in a file, accounting for a potentially incomplete selector. -func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position) (*IdentifierInfo, error) { +func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position, selectPackage PackagePolicy) (*IdentifierInfo, error) { ctx, done := trace.StartSpan(ctx, "source.Identifier") defer done() - pkg, pgh, err := getParsedFile(ctx, snapshot, f, WidestCheckPackageHandle) + pkg, pgh, err := getParsedFile(ctx, snapshot, f, selectPackage) if err != nil { return nil, fmt.Errorf("getting file for Identifier: %v", err) } @@ -81,6 +81,8 @@ func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Pos return findIdentifier(snapshot, pkg, file, rng.Start) } +var ErrNoIdentFound = errors.New("no identifier found") + func findIdentifier(snapshot Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) { if result, err := identifier(snapshot, pkg, file, pos); err != nil || result != nil { return result, err @@ -90,7 +92,7 @@ func findIdentifier(snapshot Snapshot, pkg Package, file *ast.File, pos token.Po // requesting a completion), use the path to the preceding node. ident, err := identifier(snapshot, pkg, file, pos-1) if ident == nil && err == nil { - err = errors.New("no identifier found") + err = ErrNoIdentFound } return ident, err } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index ed8b16383f..e1578cdb89 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -19,6 +19,7 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/telemetry/log" + errors "golang.org/x/xerrors" ) func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Location, error) { @@ -101,6 +102,8 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio return locations, nil } +var ErrNotAMethod = errors.New("this function is not a method") + func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, error) { var T types.Type var method *types.Func @@ -112,7 +115,7 @@ func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, } recv := obj.Type().(*types.Signature).Recv() if recv == nil { - return implementsResult{}, fmt.Errorf("this function is not a method") + return implementsResult{}, ErrNotAMethod } method = obj T = recv.Type() diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index bdb3c66cdc..f30e2914bc 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -499,7 +499,7 @@ func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", d.Src, err) } @@ -562,7 +562,7 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, impls []span.Span) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } @@ -649,7 +649,7 @@ func (r *runner) References(t *testing.T, src span.Span, itemList []span.Span) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Fatalf("failed for %v: %v", src, err) } @@ -693,7 +693,7 @@ func (r *runner) Rename(t *testing.T, spn span.Span, newText string) { if err != nil { t.Fatal(err) } - ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { t.Error(err) return @@ -782,7 +782,7 @@ func (r *runner) PrepareRename(t *testing.T, src span.Span, want *source.Prepare t.Fatal(err) } // Find the identifier at the position. - ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start) + ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle) if err != nil { if want.Text != "" { // expected an ident. t.Errorf("prepare rename failed for %v: got error: %v", src, err) diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 191e529da5..ada8444419 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -67,7 +67,7 @@ func (s mappedRange) URI() span.URI { // getParsedFile is a convenience function that extracts the Package and ParseGoHandle for a File in a Snapshot. // selectPackage is typically Narrowest/WidestCheckPackageHandle below. -func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage func([]PackageHandle) (PackageHandle, error)) (Package, ParseGoHandle, error) { +func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage PackagePolicy) (Package, ParseGoHandle, error) { fh := snapshot.Handle(ctx, f) phs, err := snapshot.PackageHandles(ctx, fh) if err != nil { @@ -85,6 +85,8 @@ func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage return pkg, pgh, err } +type PackagePolicy func([]PackageHandle) (PackageHandle, error) + // NarrowestCheckPackageHandle picks the "narrowest" package for a given file. // // By "narrowest" package, we mean the package with the fewest number of files @@ -126,6 +128,20 @@ func WidestCheckPackageHandle(handles []PackageHandle) (PackageHandle, error) { return result, nil } +// SpecificPackageHandle creates a PackagePolicy to select a +// particular PackageHandle when you alread know the one you want. +func SpecificPackageHandle(desiredID string) PackagePolicy { + return func(handles []PackageHandle) (PackageHandle, error) { + for _, h := range handles { + if h.ID() == desiredID { + return h, nil + } + } + + return nil, fmt.Errorf("no package handle with expected id %q", desiredID) + } +} + func IsGenerated(ctx context.Context, view View, uri span.URI) bool { f, err := view.GetFile(ctx, uri) if err != nil {