diff --git a/internal/lsp/implementation.go b/internal/lsp/implementation.go index c4df76a268..0a1877328b 100644 --- a/internal/lsp/implementation.go +++ b/internal/lsp/implementation.go @@ -9,9 +9,7 @@ import ( "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" - "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/span" - "golang.org/x/tools/internal/telemetry/log" ) func (s *Server) implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) { @@ -28,43 +26,5 @@ func (s *Server) implementation(ctx context.Context, params *protocol.Implementa if fh.Identity().Kind != source.Go { return nil, nil } - phs, err := snapshot.PackageHandles(ctx, fh) - if err != nil { - return nil, err - } - var ( - allLocs []protocol.Location - seen = make(map[protocol.Location]bool) - ) - for _, ph := range phs { - ctx := telemetry.Package.With(ctx, ph.ID()) - - ident, err := source.Identifier(ctx, snapshot, fh, params.Position, source.SpecificPackageHandle(ph.ID())) - if err != nil { - if err == source.ErrNoIdentFound { - return nil, err - } - log.Error(ctx, "failed to find Identifer", err) - continue - } - - locs, err := ident.Implementation(ctx) - if err != nil { - if err == source.ErrNotAnInterface { - return nil, err - } - log.Error(ctx, "failed to find Implemenation", err) - continue - } - - for _, loc := range locs { - if seen[loc] { - continue - } - seen[loc] = true - allLocs = append(allLocs, loc) - } - } - - return allLocs, nil + return source.Implementation(ctx, snapshot, fh, params.Position) } diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index 9fe59e4aec..e405f1a33a 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -6,18 +6,22 @@ package source import ( "context" + "fmt" + "go/ast" + "go/token" "go/types" "golang.org/x/tools/internal/lsp/protocol" - "golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/telemetry/log" + "golang.org/x/tools/internal/telemetry/trace" errors "golang.org/x/xerrors" ) -func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Location, error) { - ctx = telemetry.Package.With(ctx, i.pkg.ID()) +func Implementation(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) { + ctx, done := trace.StartSpan(ctx, "source.Implementation") + defer done() - impls, err := i.implementations(ctx) + impls, err := implementations(ctx, s, f, pp) if err != nil { return nil, err } @@ -28,27 +32,21 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio continue } - file, _, err := i.Snapshot.View().FindPosInPackage(impl.pkg, impl.obj.Pos()) - if err != nil { - log.Error(ctx, "Error getting file for object", err) - continue - } - - ident, err := findIdentifier(i.Snapshot, impl.pkg, file, impl.obj.Pos()) - if err != nil { - log.Error(ctx, "Error getting ident for object", err) - continue - } - - decRange, err := ident.Declaration.Range() + rng, err := objToMappedRange(s.View(), impl.pkg, impl.obj) if err != nil { log.Error(ctx, "Error getting range for object", err) continue } + pr, err := rng.Range() + if err != nil { + log.Error(ctx, "Error getting protocol range for object", err) + continue + } + locations = append(locations, protocol.Location{ - URI: protocol.NewURI(ident.Declaration.URI()), - Range: decRange, + URI: protocol.NewURI(rng.URI()), + Range: pr, }) } return locations, nil @@ -56,84 +54,94 @@ func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Locatio var ErrNotAnInterface = errors.New("not an interface or interface method") -func (i *IdentifierInfo) implementations(ctx context.Context) ([]implementation, error) { - var ( - T *types.Interface - method *types.Func - ) - - switch obj := i.Declaration.obj.(type) { - case *types.Func: - method = obj - if recv := obj.Type().(*types.Signature).Recv(); recv != nil { - T, _ = recv.Type().Underlying().(*types.Interface) - } - case *types.TypeName: - T, _ = obj.Type().Underlying().(*types.Interface) - } - - if T == nil { - return nil, ErrNotAnInterface - } - - if T.NumMethods() == 0 { - return nil, nil - } - - // Find all named types, even local types (which can have methods - // due to promotion). - var ( - allNamed []*types.Named - pkgs = make(map[*types.Package]Package) - ) - for _, pkg := range i.Snapshot.KnownPackages(ctx) { - pkgs[pkg.GetTypes()] = pkg - - info := pkg.GetTypesInfo() - for _, obj := range info.Defs { - // We ignore aliases 'type M = N' to avoid duplicate reporting - // of the Named type N. - if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() { - // We skip interface types since we only want concrete - // implementations. - if named, ok := obj.Type().(*types.Named); ok && !isInterface(named) { - allNamed = append(allNamed, named) - } - } - } - } +func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]implementation, error) { var ( impls []implementation - seen = make(map[types.Object]bool) + seen = make(map[token.Position]bool) + fset = s.View().Session().Cache().FileSet() ) - // Find all the named types that implement our interface. - for _, U := range allNamed { - var concrete types.Type = U - if !types.AssignableTo(concrete, T) { - // We also accept T if *T implements our interface. - concrete = types.NewPointer(concrete) - if !types.AssignableTo(concrete, T) { - continue + objs, err := objectsAtProtocolPos(ctx, s, f, pp) + if err != nil { + return nil, err + } + + for _, obj := range objs { + var ( + T *types.Interface + method *types.Func + ) + + switch obj := obj.(type) { + case *types.Func: + method = obj + if recv := obj.Type().(*types.Signature).Recv(); recv != nil { + T, _ = recv.Type().Underlying().(*types.Interface) + } + case *types.TypeName: + T, _ = obj.Type().Underlying().(*types.Interface) + } + + if T == nil { + return nil, ErrNotAnInterface + } + + if T.NumMethods() == 0 { + return nil, nil + } + + // Find all named types, even local types (which can have methods + // due to promotion). + var ( + allNamed []*types.Named + pkgs = make(map[*types.Package]Package) + ) + for _, pkg := range s.KnownPackages(ctx) { + pkgs[pkg.GetTypes()] = pkg + + info := pkg.GetTypesInfo() + for _, obj := range info.Defs { + // We ignore aliases 'type M = N' to avoid duplicate reporting + // of the Named type N. + if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() { + // We skip interface types since we only want concrete + // implementations. + if named, ok := obj.Type().(*types.Named); ok && !isInterface(named) { + allNamed = append(allNamed, named) + } + } } } - var obj types.Object = U.Obj() - if method != nil { - obj = types.NewMethodSet(concrete).Lookup(method.Pkg(), method.Name()).Obj() + // Find all the named types that implement our interface. + for _, U := range allNamed { + var concrete types.Type = U + if !types.AssignableTo(concrete, T) { + // We also accept T if *T implements our interface. + concrete = types.NewPointer(concrete) + if !types.AssignableTo(concrete, T) { + continue + } + } + + var obj types.Object = U.Obj() + if method != nil { + obj = types.NewMethodSet(concrete).Lookup(method.Pkg(), method.Name()).Obj() + } + + pos := fset.Position(obj.Pos()) + if obj == method || seen[pos] { + continue + } + + seen[pos] = true + + impls = append(impls, implementation{ + obj: obj, + pkg: pkgs[obj.Pkg()], + }) } - - if obj == method || seen[obj] { - continue - } - - seen[obj] = true - - impls = append(impls, implementation{ - obj: obj, - pkg: pkgs[obj.Pkg()], - }) } return impls, nil @@ -146,3 +154,99 @@ type implementation struct { // pkg is the Package that contains obj's definition. pkg Package } + +// objectsAtProtocolPos returns all the type.Objects referenced at the given position. +// An object will be returned for every package that the file belongs to. +func objectsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]types.Object, error) { + phs, err := s.PackageHandles(ctx, f) + if err != nil { + return nil, err + } + + var objs []types.Object + + // Check all the packages that the file belongs to. + for _, ph := range phs { + pkg, err := ph.Check(ctx) + if err != nil { + return nil, err + } + + astFile, pos, err := getASTFile(pkg, f, pp) + if err != nil { + return nil, err + } + + path := pathEnclosingIdent(astFile, pos) + if len(path) == 0 { + return nil, ErrNoIdentFound + } + + ident := path[len(path)-1].(*ast.Ident) + + obj := pkg.GetTypesInfo().ObjectOf(ident) + if obj == nil { + return nil, fmt.Errorf("no object for %q", ident.Name) + } + + objs = append(objs, obj) + } + + return objs, nil +} + +func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) { + pgh, err := pkg.File(f.Identity().URI) + if err != nil { + return nil, 0, err + } + + file, m, _, err := pgh.Cached() + if err != nil { + return nil, 0, err + } + + spn, err := m.PointSpan(pos) + if err != nil { + return nil, 0, err + } + + rng, err := spn.Range(m.Converter) + if err != nil { + return nil, 0, err + } + + return file, rng.Start, nil +} + +// pathEnclosingIdent returns the ast path to the node that contains pos. +// It is similar to astutil.PathEnclosingInterval, but simpler, and it +// matches *ast.Ident nodes if pos is equal to node.End(). +func pathEnclosingIdent(f *ast.File, pos token.Pos) []ast.Node { + var ( + path []ast.Node + found bool + ) + + ast.Inspect(f, func(n ast.Node) bool { + if found { + return false + } + + if n == nil { + path = path[:len(path)-1] + return false + } + + switch n := n.(type) { + case *ast.Ident: + found = n.Pos() <= pos && pos <= n.End() + } + + path = append(path, n) + + return !found + }) + + return path +} diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 94159a55af..a2e391535d 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -563,12 +563,7 @@ func (r *runner) Implementation(t *testing.T, spn span.Span, impls []span.Span) if err != nil { t.Fatal(err) } - ident, err := source.Identifier(r.ctx, r.view.Snapshot(), fh, loc.Range.Start, source.WidestCheckPackageHandle) - if err != nil { - t.Fatalf("failed for %v: %v", spn, err) - } - var locs []protocol.Location - locs, err = ident.Implementation(r.ctx) + locs, err := source.Implementation(r.ctx, r.view.Snapshot(), fh, loc.Range.Start) if err != nil { t.Fatalf("failed for %v: %v", spn, err) } diff --git a/internal/lsp/testdata/implementation/other/other_test.go b/internal/lsp/testdata/implementation/other/other_test.go new file mode 100644 index 0000000000..846e0d591d --- /dev/null +++ b/internal/lsp/testdata/implementation/other/other_test.go @@ -0,0 +1,10 @@ +package other + +import ( + "testing" +) + +// This exists so the other.test package comes into existence. + +func TestOther(t *testing.T) { +}