diff --git a/internal/lsp/analysis/stubmethods/stubmethods.go b/internal/lsp/analysis/stubmethods/stubmethods.go index c2a4138fad..f9dc69a965 100644 --- a/internal/lsp/analysis/stubmethods/stubmethods.go +++ b/internal/lsp/analysis/stubmethods/stubmethods.go @@ -105,11 +105,59 @@ func GetStubInfo(ti *types.Info, path []ast.Node, pos token.Pos) *StubInfo { return si case *ast.AssignStmt: return fromAssignStmt(ti, n, pos) + case *ast.CallExpr: + // Note that some call expressions don't carry the interface type + // because they don't point to a function or method declaration elsewhere. + // For eaxmple, "var Interface = (*Concrete)(nil)". In that case, continue + // this loop to encounter other possibilities such as *ast.ValueSpec or others. + si := fromCallExpr(ti, pos, n) + if si != nil { + return si + } } } return nil } +// fromCallExpr tries to find an *ast.CallExpr's function declaration and +// analyzes a function call's signature against the passed in parameter to deduce +// the concrete and interface types. +func fromCallExpr(ti *types.Info, pos token.Pos, ce *ast.CallExpr) *StubInfo { + paramIdx := -1 + for i, p := range ce.Args { + if pos >= p.Pos() && pos <= p.End() { + paramIdx = i + break + } + } + if paramIdx == -1 { + return nil + } + p := ce.Args[paramIdx] + concObj, pointer := concreteType(p, ti) + if concObj == nil || concObj.Obj().Pkg() == nil { + return nil + } + tv, ok := ti.Types[ce.Fun] + if !ok { + return nil + } + sig, ok := tv.Type.(*types.Signature) + if !ok { + return nil + } + sigVar := sig.Params().At(paramIdx) + iface := ifaceObjFromType(sigVar.Type()) + if iface == nil { + return nil + } + return &StubInfo{ + Concrete: concObj, + Pointer: pointer, + Interface: iface, + } +} + // fromReturnStmt analyzes a "return" statement to extract // a concrete type that is trying to be returned as an interface type. // @@ -290,8 +338,11 @@ func ifaceType(n ast.Expr, ti *types.Info) types.Object { if !ok { return nil } - typ := tv.Type - named, ok := typ.(*types.Named) + return ifaceObjFromType(tv.Type) +} + +func ifaceObjFromType(t types.Type) types.Object { + named, ok := t.(*types.Named) if !ok { return nil } diff --git a/internal/lsp/testdata/stub/stub_call_expr.go b/internal/lsp/testdata/stub/stub_call_expr.go new file mode 100644 index 0000000000..775b0e5545 --- /dev/null +++ b/internal/lsp/testdata/stub/stub_call_expr.go @@ -0,0 +1,13 @@ +package stub + +func main() { + check(&callExpr{}) //@suggestedfix("&", "refactor.rewrite") +} + +func check(err error) { + if err != nil { + panic(err) + } +} + +type callExpr struct{} diff --git a/internal/lsp/testdata/stub/stub_call_expr.go.golden b/internal/lsp/testdata/stub/stub_call_expr.go.golden new file mode 100644 index 0000000000..2d12f8651f --- /dev/null +++ b/internal/lsp/testdata/stub/stub_call_expr.go.golden @@ -0,0 +1,20 @@ +-- suggestedfix_stub_call_expr_4_8 -- +package stub + +func main() { + check(&callExpr{}) //@suggestedfix("&", "refactor.rewrite") +} + +func check(err error) { + if err != nil { + panic(err) + } +} + +type callExpr struct{} + +// Error implements error +func (*callExpr) Error() string { + panic("unimplemented") +} + diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden index 89f9579963..9e1d84d1d5 100644 --- a/internal/lsp/testdata/summary.txt.golden +++ b/internal/lsp/testdata/summary.txt.golden @@ -13,7 +13,7 @@ FoldingRangesCount = 2 FormatCount = 6 ImportCount = 8 SemanticTokenCount = 3 -SuggestedFixCount = 62 +SuggestedFixCount = 63 FunctionExtractionCount = 25 MethodExtractionCount = 6 DefinitionsCount = 95 diff --git a/internal/lsp/testdata/summary_go1.18.txt.golden b/internal/lsp/testdata/summary_go1.18.txt.golden index 1b4891eaf9..1c6ad922c3 100644 --- a/internal/lsp/testdata/summary_go1.18.txt.golden +++ b/internal/lsp/testdata/summary_go1.18.txt.golden @@ -13,7 +13,7 @@ FoldingRangesCount = 2 FormatCount = 6 ImportCount = 8 SemanticTokenCount = 3 -SuggestedFixCount = 63 +SuggestedFixCount = 64 FunctionExtractionCount = 25 MethodExtractionCount = 6 DefinitionsCount = 108