diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go index 1eceb8a0d6..030a7f7625 100644 --- a/internal/lsp/source/implementation.go +++ b/internal/lsp/source/implementation.go @@ -46,8 +46,10 @@ func Implementation(ctx context.Context, s Snapshot, f FileHandle, pp protocol.P return locations, nil } -var ErrNotAnInterface = errors.New("not an interface or interface method") +var ErrNotAType = errors.New("not a type name or method") +// implementations returns the concrete implementations of the specified +// interface, or the interfaces implemented by the specified concrete type. func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) { var ( impls []qualifiedObject @@ -62,25 +64,25 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. for _, qo := range qos { var ( - T *types.Interface - method *types.Func + queryType types.Type + queryMethod *types.Func ) switch obj := qo.obj.(type) { case *types.Func: - method = obj + queryMethod = obj if recv := obj.Type().(*types.Signature).Recv(); recv != nil { - T, _ = recv.Type().Underlying().(*types.Interface) + queryType = ensurePointer(recv.Type()) } case *types.TypeName: - T, _ = obj.Type().Underlying().(*types.Interface) + queryType = ensurePointer(obj.Type()) } - if T == nil { - return nil, ErrNotAnInterface + if queryType == nil { + return nil, ErrNotAType } - if T.NumMethods() == 0 { + if types.NewMethodSet(queryType).Len() == 0 { return nil, nil } @@ -108,42 +110,48 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. if !ok || obj.IsAlias() { continue } - named, ok := obj.Type().(*types.Named) - // We skip interface types since we only want concrete - // implementations. - if !ok || isInterface(named) { - continue + if named, ok := obj.Type().(*types.Named); ok { + allNamed = append(allNamed, named) } - allNamed = append(allNamed, named) } } - // Find all the named types that implement our interface. - for _, U := range allNamed { - var concrete types.Type = U - if !types.AssignableTo(concrete, T) { - // We also accept T if *T implements our interface. - concrete = types.NewPointer(concrete) - if !types.AssignableTo(concrete, T) { + // Find all the named types that match our query. + for _, named := range allNamed { + var ( + candObj types.Object = named.Obj() + candType = ensurePointer(named) + ) + + if !concreteImplementsIntf(candType, queryType) { + continue + } + + ms := types.NewMethodSet(candType) + if ms.Len() == 0 { + // Skip empty interfaces. + continue + } + + // If client queried a method, look up corresponding candType method. + if queryMethod != nil { + sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name()) + if sel == nil { continue } + candObj = sel.Obj() } - var obj types.Object = U.Obj() - if method != nil { - obj = types.NewMethodSet(concrete).Lookup(method.Pkg(), method.Name()).Obj() - } - - pos := fset.Position(obj.Pos()) - if obj == method || seen[pos] { + pos := fset.Position(candObj.Pos()) + if candObj == queryMethod || seen[pos] { continue } seen[pos] = true impls = append(impls, qualifiedObject{ - obj: obj, - pkg: pkgs[obj.Pkg()], + obj: candObj, + pkg: pkgs[candObj.Pkg()], }) } } @@ -151,6 +159,35 @@ func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol. return impls, nil } +// concreteImplementsIntf returns true if a is an interface type implemented by +// concrete type b, or vice versa. +func concreteImplementsIntf(a, b types.Type) bool { + aIsIntf, bIsIntf := isInterface(a), isInterface(b) + + // Make sure exactly one is an interface type. + if aIsIntf == bIsIntf { + return false + } + + // Rearrange if needed so "a" is the concrete type. + if aIsIntf { + a, b = b, a + } + + return types.AssignableTo(a, b) +} + +// ensurePointer wraps T in a *types.Pointer if T is a named, non-interface +// type. This is useful to make sure you consider a named type's full method +// set. +func ensurePointer(T types.Type) types.Type { + if _, ok := T.(*types.Named); ok && !isInterface(T) { + return types.NewPointer(T) + } + + return T +} + type qualifiedObject struct { obj types.Object diff --git a/internal/lsp/testdata/lsp/primarymod/implementation/implementation.go b/internal/lsp/testdata/lsp/primarymod/implementation/implementation.go index 5c18004262..c3229121a3 100644 --- a/internal/lsp/testdata/lsp/primarymod/implementation/implementation.go +++ b/internal/lsp/testdata/lsp/primarymod/implementation/implementation.go @@ -2,32 +2,30 @@ package implementation import "golang.org/x/tools/internal/lsp/implementation/other" -type ImpP struct{} //@ImpP +type ImpP struct{} //@ImpP,implementations("ImpP", Laugher, OtherLaugher) -func (*ImpP) Laugh() { //@mark(LaughP, "Laugh") +func (*ImpP) Laugh() { //@mark(LaughP, "Laugh"),implementations("Laugh", Laugh, OtherLaugh) } -type ImpS struct{} //@ImpS +type ImpS struct{} //@ImpS,implementations("ImpS", Laugher, OtherLaugher) -func (ImpS) Laugh() { //@mark(LaughS, "Laugh") +func (ImpS) Laugh() { //@mark(LaughS, "Laugh"),implementations("Laugh", Laugh, OtherLaugh) } -type ImpI interface { - Laugh() //@implementations("Laugh", LaughP, OtherLaughP, LaughS, OtherLaughS) +type Laugher interface { //@Laugher,implementations("Laugher", ImpP, OtherImpP, ImpS, OtherImpS) + Laugh() //@Laugh,implementations("Laugh", LaughP, OtherLaughP, LaughS, OtherLaughS) } -type Laugher interface { //@implementations("Laugher", ImpP, OtherImpP, ImpS, OtherImpS) - Laugh() //@implementations("Laugh", LaughP, OtherLaughP, LaughS, OtherLaughS) -} - -type Foo struct { +type Foo struct { //@implementations("Foo", Joker) other.Foo } -type U interface { - U() //@implementations("U", ImpU) +type Joker interface { //@Joker + Joke() //@Joke,implementations("Joke", ImpJoker) } -type cryer int +type cryer int //@implementations("cryer", Cryer) -func (cryer) Cry(other.CryType) {} //@mark(CryImpl, "Cry") +func (cryer) Cry(other.CryType) {} //@mark(CryImpl, "Cry"),implementations("Cry", Cry) + +type Empty interface{} //@implementations("Empty") diff --git a/internal/lsp/testdata/lsp/primarymod/implementation/other/other.go b/internal/lsp/testdata/lsp/primarymod/implementation/other/other.go index f6dff0a3c4..aff825e91e 100644 --- a/internal/lsp/testdata/lsp/primarymod/implementation/other/other.go +++ b/internal/lsp/testdata/lsp/primarymod/implementation/other/other.go @@ -10,20 +10,18 @@ type ImpS struct{} //@mark(OtherImpS, "ImpS") func (ImpS) Laugh() { //@mark(OtherLaughS, "Laugh") } -type ImpI interface { - Laugh() +type ImpI interface { //@mark(OtherLaugher, "ImpI") + Laugh() //@mark(OtherLaugh, "Laugh") } -type Foo struct { +type Foo struct { //@implementations("Foo", Joker) } -func (Foo) U() { //@mark(ImpU, "U") +func (Foo) Joke() { //@mark(ImpJoker, "Joke"),implementations("Joke", Joke) } type CryType int -const Sob CryType = 1 - -type Cryer interface { - Cry(CryType) //@implementations("Cry", CryImpl) +type Cryer interface { //@Cryer + Cry(CryType) //@Cry,implementations("Cry", CryImpl) } diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden index f1ef410dfa..f5e0cac9ad 100644 --- a/internal/lsp/testdata/lsp/summary.txt.golden +++ b/internal/lsp/testdata/lsp/summary.txt.golden @@ -24,5 +24,5 @@ FuzzyWorkspaceSymbolsCount = 3 CaseSensitiveWorkspaceSymbolsCount = 2 SignaturesCount = 23 LinksCount = 8 -ImplementationsCount = 5 +ImplementationsCount = 14