From ea5e1dc8bc30f53cf48f1eff344a001537f1ba9d Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Mon, 24 Jan 2022 11:16:17 -0500 Subject: [PATCH] internal/typeparams: add a helper to return the origin method With instantiated types, method objects are no longer unique: they may be instantiations of methods with generic receiver. However, some use-cases require finding the canonical method representing the method in the source. For these use-cases, provide an OriginMethod helper. For golang/go#50447 Change-Id: I6f8af3fb5c5eeefb11f8f3bdba54cd6692ca389f Reviewed-on: https://go-review.googlesource.com/c/tools/+/380554 Trust: Robert Findley Run-TryBot: Robert Findley TryBot-Result: Gopher Robot gopls-CI: kokoro Reviewed-by: Tim King --- internal/typeparams/common.go | 32 ++++++++ internal/typeparams/common_test.go | 128 ++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 3 deletions(-) diff --git a/internal/typeparams/common.go b/internal/typeparams/common.go index 1222764b6a..53b0696a1d 100644 --- a/internal/typeparams/common.go +++ b/internal/typeparams/common.go @@ -77,3 +77,35 @@ func IsTypeParam(t types.Type) bool { _, ok := t.(*TypeParam) return ok } + +// OriginMethod returns the origin method associated with the method fn. +// For methods on a non-generic receiver base type, this is just +// fn. However, for methods with a generic receiver, OriginMethod returns the +// corresponding method in the method set of the origin type. +// +// As a special case, if fn is not a method (has no receiver), OriginMethod +// returns fn. +func OriginMethod(fn *types.Func) *types.Func { + recv := fn.Type().(*types.Signature).Recv() + if recv == nil { + + return fn + } + base := recv.Type() + p, isPtr := base.(*types.Pointer) + if isPtr { + base = p.Elem() + } + named, isNamed := base.(*types.Named) + if !isNamed { + // Receiver is a *types.Interface. + return fn + } + if ForNamed(named).Len() == 0 { + // Receiver base has no type parameters, so we can avoid the lookup below. + return fn + } + orig := NamedTypeOrigin(named) + gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name()) + return gfn.(*types.Func) +} diff --git a/internal/typeparams/common_test.go b/internal/typeparams/common_test.go index 1bd15d794b..da084d173f 100644 --- a/internal/typeparams/common_test.go +++ b/internal/typeparams/common_test.go @@ -6,16 +6,20 @@ package typeparams_test import ( "go/ast" + "go/parser" + "go/token" + "go/types" "testing" - "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/testenv" + . "golang.org/x/tools/internal/typeparams" ) func TestGetIndexExprData(t *testing.T) { x := &ast.Ident{} i := &ast.Ident{} - want := &typeparams.IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2} + want := &IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2} tests := map[ast.Node]bool{ &ast.IndexExpr{X: x, Lbrack: 1, Index: i, Rbrack: 2}: true, want: true, @@ -23,7 +27,7 @@ func TestGetIndexExprData(t *testing.T) { } for n, isIndexExpr := range tests { - X, lbrack, indices, rbrack := typeparams.UnpackIndexExpr(n) + X, lbrack, indices, rbrack := UnpackIndexExpr(n) if got := X != nil; got != isIndexExpr { t.Errorf("UnpackIndexExpr(%v) = %v, _, _, _; want nil: %t", n, x, !isIndexExpr) } @@ -35,3 +39,121 @@ func TestGetIndexExprData(t *testing.T) { } } } + +func TestOriginMethodRecursive(t *testing.T) { + testenv.NeedsGo1Point(t, 18) + src := `package p + +type N[A any] int + +func (r N[B]) m() { r.m(); r.n() } + +func (r *N[C]) n() { } +` + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", src, 0) + if err != nil { + t.Fatal(err) + } + info := types.Info{ + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + var conf types.Config + if _, err := conf.Check("p", fset, []*ast.File{f}, &info); err != nil { + t.Fatal(err) + } + + // Collect objects from types.Info. + var m, n *types.Func // the 'origin' methods in Info.Defs + var mm, mn *types.Func // the methods used in the body of m + + for _, decl := range f.Decls { + fdecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + def := info.Defs[fdecl.Name].(*types.Func) + switch fdecl.Name.Name { + case "m": + m = def + ast.Inspect(fdecl.Body, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + sel := call.Fun.(*ast.SelectorExpr) + use := info.Uses[sel.Sel].(*types.Func) + switch sel.Sel.Name { + case "m": + mm = use + case "n": + mn = use + } + } + return true + }) + case "n": + n = def + } + } + + tests := []struct { + name string + input, want *types.Func + }{ + {"declared m", m, m}, + {"declared n", n, n}, + {"used m", mm, m}, + {"used n", mn, n}, + } + + for _, test := range tests { + if got := OriginMethod(test.input); got != test.want { + t.Errorf("OriginMethod(%q) = %v, want %v", test.name, test.input, test.want) + } + } +} + +func TestOriginMethodUses(t *testing.T) { + testenv.NeedsGo1Point(t, 18) + + tests := []string{ + `type T interface { m() }; func _(t T) { t.m() }`, + `type T[P any] interface { m() P }; func _[A any](t T[A]) { t.m() }`, + `type T[P any] interface { m() P }; func _(t T[int]) { t.m() }`, + `type T[P any] int; func (r T[A]) m() { r.m() }`, + `type T[P any] int; func (r *T[A]) m() { r.m() }`, + `type T[P any] int; func (r *T[A]) m() {}; func _(t T[int]) { t.m() }`, + `type T[P any] int; func (r *T[A]) m() {}; func _[A any](t T[A]) { t.m() }`, + } + + for _, src := range tests { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", "package p; "+src, 0) + if err != nil { + t.Fatal(err) + } + info := types.Info{ + Uses: make(map[*ast.Ident]types.Object), + } + var conf types.Config + pkg, err := conf.Check("p", fset, []*ast.File{f}, &info) + if err != nil { + t.Fatal(err) + } + + T := pkg.Scope().Lookup("T").Type() + obj, _, _ := types.LookupFieldOrMethod(T, true, pkg, "m") + m := obj.(*types.Func) + + ast.Inspect(f, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + sel := call.Fun.(*ast.SelectorExpr) + use := info.Uses[sel.Sel].(*types.Func) + orig := OriginMethod(use) + if orig != m { + t.Errorf("%s:\nUses[%v] = %v, want %v", src, types.ExprString(sel), use, m) + } + } + return true + }) + } +}