mirror of https://github.com/golang/go.git
internal/typeparams: add a helper to return the origin method
With instantiated types, method objects are no longer unique: they may be instantiations of methods with generic receiver. However, some use-cases require finding the canonical method representing the method in the source. For these use-cases, provide an OriginMethod helper. For golang/go#50447 Change-Id: I6f8af3fb5c5eeefb11f8f3bdba54cd6692ca389f Reviewed-on: https://go-review.googlesource.com/c/tools/+/380554 Trust: Robert Findley <rfindley@google.com> Run-TryBot: Robert Findley <rfindley@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> gopls-CI: kokoro <noreply+kokoro@google.com> Reviewed-by: Tim King <taking@google.com>
This commit is contained in:
parent
a739c97304
commit
ea5e1dc8bc
|
|
@ -77,3 +77,35 @@ func IsTypeParam(t types.Type) bool {
|
|||
_, ok := t.(*TypeParam)
|
||||
return ok
|
||||
}
|
||||
|
||||
// OriginMethod returns the origin method associated with the method fn.
|
||||
// For methods on a non-generic receiver base type, this is just
|
||||
// fn. However, for methods with a generic receiver, OriginMethod returns the
|
||||
// corresponding method in the method set of the origin type.
|
||||
//
|
||||
// As a special case, if fn is not a method (has no receiver), OriginMethod
|
||||
// returns fn.
|
||||
func OriginMethod(fn *types.Func) *types.Func {
|
||||
recv := fn.Type().(*types.Signature).Recv()
|
||||
if recv == nil {
|
||||
|
||||
return fn
|
||||
}
|
||||
base := recv.Type()
|
||||
p, isPtr := base.(*types.Pointer)
|
||||
if isPtr {
|
||||
base = p.Elem()
|
||||
}
|
||||
named, isNamed := base.(*types.Named)
|
||||
if !isNamed {
|
||||
// Receiver is a *types.Interface.
|
||||
return fn
|
||||
}
|
||||
if ForNamed(named).Len() == 0 {
|
||||
// Receiver base has no type parameters, so we can avoid the lookup below.
|
||||
return fn
|
||||
}
|
||||
orig := NamedTypeOrigin(named)
|
||||
gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name())
|
||||
return gfn.(*types.Func)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,16 +6,20 @@ package typeparams_test
|
|||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/tools/internal/typeparams"
|
||||
"golang.org/x/tools/internal/testenv"
|
||||
. "golang.org/x/tools/internal/typeparams"
|
||||
)
|
||||
|
||||
func TestGetIndexExprData(t *testing.T) {
|
||||
x := &ast.Ident{}
|
||||
i := &ast.Ident{}
|
||||
|
||||
want := &typeparams.IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2}
|
||||
want := &IndexListExpr{X: x, Lbrack: 1, Indices: []ast.Expr{i}, Rbrack: 2}
|
||||
tests := map[ast.Node]bool{
|
||||
&ast.IndexExpr{X: x, Lbrack: 1, Index: i, Rbrack: 2}: true,
|
||||
want: true,
|
||||
|
|
@ -23,7 +27,7 @@ func TestGetIndexExprData(t *testing.T) {
|
|||
}
|
||||
|
||||
for n, isIndexExpr := range tests {
|
||||
X, lbrack, indices, rbrack := typeparams.UnpackIndexExpr(n)
|
||||
X, lbrack, indices, rbrack := UnpackIndexExpr(n)
|
||||
if got := X != nil; got != isIndexExpr {
|
||||
t.Errorf("UnpackIndexExpr(%v) = %v, _, _, _; want nil: %t", n, x, !isIndexExpr)
|
||||
}
|
||||
|
|
@ -35,3 +39,121 @@ func TestGetIndexExprData(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginMethodRecursive(t *testing.T) {
|
||||
testenv.NeedsGo1Point(t, 18)
|
||||
src := `package p
|
||||
|
||||
type N[A any] int
|
||||
|
||||
func (r N[B]) m() { r.m(); r.n() }
|
||||
|
||||
func (r *N[C]) n() { }
|
||||
`
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, "p.go", src, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
info := types.Info{
|
||||
Defs: make(map[*ast.Ident]types.Object),
|
||||
Uses: make(map[*ast.Ident]types.Object),
|
||||
}
|
||||
var conf types.Config
|
||||
if _, err := conf.Check("p", fset, []*ast.File{f}, &info); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Collect objects from types.Info.
|
||||
var m, n *types.Func // the 'origin' methods in Info.Defs
|
||||
var mm, mn *types.Func // the methods used in the body of m
|
||||
|
||||
for _, decl := range f.Decls {
|
||||
fdecl, ok := decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
def := info.Defs[fdecl.Name].(*types.Func)
|
||||
switch fdecl.Name.Name {
|
||||
case "m":
|
||||
m = def
|
||||
ast.Inspect(fdecl.Body, func(n ast.Node) bool {
|
||||
if call, ok := n.(*ast.CallExpr); ok {
|
||||
sel := call.Fun.(*ast.SelectorExpr)
|
||||
use := info.Uses[sel.Sel].(*types.Func)
|
||||
switch sel.Sel.Name {
|
||||
case "m":
|
||||
mm = use
|
||||
case "n":
|
||||
mn = use
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
case "n":
|
||||
n = def
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input, want *types.Func
|
||||
}{
|
||||
{"declared m", m, m},
|
||||
{"declared n", n, n},
|
||||
{"used m", mm, m},
|
||||
{"used n", mn, n},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if got := OriginMethod(test.input); got != test.want {
|
||||
t.Errorf("OriginMethod(%q) = %v, want %v", test.name, test.input, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginMethodUses(t *testing.T) {
|
||||
testenv.NeedsGo1Point(t, 18)
|
||||
|
||||
tests := []string{
|
||||
`type T interface { m() }; func _(t T) { t.m() }`,
|
||||
`type T[P any] interface { m() P }; func _[A any](t T[A]) { t.m() }`,
|
||||
`type T[P any] interface { m() P }; func _(t T[int]) { t.m() }`,
|
||||
`type T[P any] int; func (r T[A]) m() { r.m() }`,
|
||||
`type T[P any] int; func (r *T[A]) m() { r.m() }`,
|
||||
`type T[P any] int; func (r *T[A]) m() {}; func _(t T[int]) { t.m() }`,
|
||||
`type T[P any] int; func (r *T[A]) m() {}; func _[A any](t T[A]) { t.m() }`,
|
||||
}
|
||||
|
||||
for _, src := range tests {
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, "p.go", "package p; "+src, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
info := types.Info{
|
||||
Uses: make(map[*ast.Ident]types.Object),
|
||||
}
|
||||
var conf types.Config
|
||||
pkg, err := conf.Check("p", fset, []*ast.File{f}, &info)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
T := pkg.Scope().Lookup("T").Type()
|
||||
obj, _, _ := types.LookupFieldOrMethod(T, true, pkg, "m")
|
||||
m := obj.(*types.Func)
|
||||
|
||||
ast.Inspect(f, func(n ast.Node) bool {
|
||||
if call, ok := n.(*ast.CallExpr); ok {
|
||||
sel := call.Fun.(*ast.SelectorExpr)
|
||||
use := info.Uses[sel.Sel].(*types.Func)
|
||||
orig := OriginMethod(use)
|
||||
if orig != m {
|
||||
t.Errorf("%s:\nUses[%v] = %v, want %v", src, types.ExprString(sel), use, m)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue