From 73bb221120ac008a2b533d2f82b9d54de36d62a4 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Wed, 4 Mar 2020 14:18:10 -0800 Subject: [PATCH] go/types: add types.Info.Inferred map recording inferred function signatures Change-Id: Idc57c1662fe63edbe8f494961cb4dcda6db61e79 --- src/go/types/api.go | 5 ++++ src/go/types/api_test.go | 65 +++++++++++++++++++++++++++++++++++++--- src/go/types/call.go | 1 + src/go/types/check.go | 8 +++++ 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/src/go/types/api.go b/src/go/types/api.go index 71f93e236d..4283fb9c98 100644 --- a/src/go/types/api.go +++ b/src/go/types/api.go @@ -171,6 +171,11 @@ type Info struct { // qualified identifiers are collected in the Uses map. Types map[ast.Expr]TypeAndValue + // Inferred maps calls of parameterized functions which use + // type inferrence to the inferred signature of the function + // called. + Inferred map[*ast.CallExpr]*Signature + // Defs maps identifiers to the objects they define (including // package names, dots "." of dot-imports, and blank "_" identifiers). // For identifiers that do not denote objects (e.g., the package name diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index fe3950a52d..5f7095b0c9 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -42,7 +42,7 @@ func mustTypecheck(t *testing.T, path, source string, info *Info) string { return pkg.Name() } -func mayTypecheck(t *testing.T, path, source string, info *Info) string { +func mayTypecheck(t *testing.T, path, source string, info *Info) (string, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, source, 0) if f == nil { // ignore errors unless f is nil @@ -52,8 +52,8 @@ func mayTypecheck(t *testing.T, path, source string, info *Info) string { Error: func(err error) {}, Importer: importer.Default(), } - pkg, _ := conf.Check(f.Name.Name, fset, []*ast.File{f}, info) - return pkg.Name() + pkg, err := conf.Check(f.Name.Name, fset, []*ast.File{f}, info) + return pkg.Name(), err } func TestValuesInfo(t *testing.T) { @@ -274,11 +274,17 @@ func TestTypesInfo(t *testing.T) { {`package x3; var x = panic("");`, `panic`, `func(interface{})`}, {`package x4; func _() { panic("") }`, `panic`, `func(interface{})`}, {`package x5; func _() { var x map[string][...]int; x = map[string][...]int{"": {1,2,3}} }`, `x`, `map[string][-1]int`}, + + // parameterized functions + {`package p0; func f(type T)(T); var _ = f(int)`, `f`, `func(type T interface{})(T₁)`}, + {`package p1; func f(type T)(T); var _ = f(int)`, `f(int)`, `func(int)`}, + {`package p2; func f(type T)(T); var _ = f(42)`, `f`, `func(type T interface{})(T₁)`}, + {`package p2; func f(type T)(T); var _ = f(42)`, `f(42)`, `()`}, } for _, test := range tests { info := Info{Types: make(map[ast.Expr]TypeAndValue)} - name := mayTypecheck(t, "TypesInfo", test.src, &info) + name, _ := mayTypecheck(t, "TypesInfo", test.src, &info) // look for expression type var typ Type @@ -300,6 +306,57 @@ func TestTypesInfo(t *testing.T) { } } +func TestInferredInfo(t *testing.T) { + var tests = []struct { + src string + fun string + sig string + }{ + {`package p0; func f(type T)(T); func _() { f(42) }`, `f`, `func(int)`}, + {`package p1; func f(type T)(T) T; func _() { f('@') }`, `f`, `func(rune) rune`}, + {`package p2; func f(type T)(...T) T; func _() { f(0i) }`, `f`, `func(...complex128) complex128`}, + {`package p3; func f(type A, B, C)(A, *B, []C); func _() { f(1.2, new(string), []byte{}) }`, `f`, `func(float64, *string, []byte)`}, + {`package p4; func f(type A, B)(A, *B, ...[]B); func _() { f(1.2, new(byte)) }`, `f`, `func(float64, *byte, ...[]byte)`}, + + // we don't know how to translate these but we can type-check them + {`package q0; type T struct{}; func (T) m(type P)(P); func _(x T) { x.m(42) }`, `x.m`, `func(int)`}, + {`package q1; type T struct{}; func (T) m(type P)(P) P; func _(x T) { x.m(42) }`, `x.m`, `func(int) int`}, + {`package q2; type T struct{}; func (T) m(type P)(...P) P; func _(x T) { x.m(42) }`, `x.m`, `func(...int) int`}, + {`package q3; type T struct{}; func (T) m(type A, B, C)(A, *B, []C); func _(x T) { x.m(1.2, new(string), []byte{}) }`, `x.m`, `func(float64, *string, []byte)`}, + {`package q4; type T struct{}; func (T) m(type A, B)(A, *B, ...[]B); func _(x T) { x.m(1.2, new(byte)) }`, `x.m`, `func(float64, *byte, ...[]byte)`}, + + {`package r0; type T(type P) struct{}; func (_ T(P)) m(type Q)(Q); func _(type P)(x T(P)) { x.m(42) }`, `x.m`, `func(int)`}, + {`package r1; type T interface{ m(type P)(P) }; func _(x T) { x.m(4.2) }`, `x.m`, `func(float64)`}, + } + + for _, test := range tests { + info := Info{Inferred: make(map[*ast.CallExpr]*Signature)} + name, err := mayTypecheck(t, "InferredInfo", test.src, &info) + if err != nil { + t.Errorf("package %s: %v", name, err) + continue + } + + // look for inferred signature + var sig *Signature + for call, typ := range info.Inferred { + if ExprString(call.Fun) == test.fun { + sig = typ + break + } + } + if sig == nil { + t.Errorf("package %s: no signature found for %s", name, test.fun) + continue + } + + // check that signature is correct + if got := sig.String(); got != test.sig { + t.Errorf("package %s: got %s; want %s", name, got, test.sig) + } + } +} + func TestImplicitsInfo(t *testing.T) { testenv.MustHaveGoBuild(t) diff --git a/src/go/types/call.go b/src/go/types/call.go index 44a0dea1e5..808855e9f8 100644 --- a/src/go/types/call.go +++ b/src/go/types/call.go @@ -324,6 +324,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, args []*oper // compute result signature rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature) assert(rsig.tparams == nil) // signature is not generic anymore + check.recordInferred(call, rsig) // Optimization: Only if the parameter list was adjusted do we // need to compute it from the adjusted list; otherwise we can diff --git a/src/go/types/check.go b/src/go/types/check.go index 9be79f2503..2fc96bf3dd 100644 --- a/src/go/types/check.go +++ b/src/go/types/check.go @@ -396,6 +396,14 @@ func (check *Checker) recordCommaOkTypes(x ast.Expr, a [2]Type) { } } +func (check *Checker) recordInferred(call *ast.CallExpr, sig *Signature) { + assert(call != nil) + assert(sig != nil) + if m := check.Inferred; m != nil { + m[call] = sig + } +} + func (check *Checker) recordDef(id *ast.Ident, obj Object) { assert(id != nil) if m := check.Defs; m != nil {