go/go2go: support type inference

We can now translate code that use type inference when calling generic
functions. A couple of test cases were adjusted to use it.

Change-Id: I53c2f3dd8f9fcdb44b4a8f592acee1a19ff46f22
This commit is contained in:
Ian Lance Taylor 2020-03-08 14:40:38 -07:00 committed by Robert Griesemer
parent 4172f069ae
commit 5a5b14a4b3
5 changed files with 65 additions and 51 deletions

View File

@ -20,15 +20,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN()
var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
func TestSortOrderedInts(t *testing.T) {
testOrdered(int)(t, ints, sort.Ints)
testOrdered(t, ints, sort.Ints)
}
func TestSortOrderedFloat64s(t *testing.T) {
testOrdered(float64)(t, float64s, sort.Float64s)
testOrdered(t, float64s, sort.Float64s)
}
func TestSortOrderedStrings(t *testing.T) {
testOrdered(string)(t, strings, sort.Strings)
testOrdered(t, strings, sort.Strings)
}
func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter func([]Elem)) {
@ -36,9 +36,9 @@ func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter fun
copy(s1, s)
s2 := make([]Elem, len(s))
copy(s2, s)
OrderedSlice(Elem)(s1)
OrderedSlice(s1)
sorter(s2)
if !slices.Equal(Elem)(s1, s2) {
if !slices.Equal(s1, s2) {
t.Fatalf("got %v, want %v", s1, s2)
}
for i := len(s1) - 1; i > 0; i-- {
@ -85,8 +85,8 @@ func sorter(s1, s2 []int) bool {
func TestSortSliceFn(t *testing.T) {
c := make([][]int, len(slicesToSort))
copy(c, slicesToSort)
SliceFn([]int)(c, sorter)
if !slices.EqualFn([]int)(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) {
SliceFn(c, sorter)
if !slices.EqualFn(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) {
t.Errorf("got %v, want %v", c, sortedSlices)
}
}

View File

@ -49,9 +49,10 @@ var _ types.ImporterFrom = &Importer{}
// The tmpdir will become a GOPATH with translated files.
func NewImporter(tmpdir string) *Importer {
info := &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Types: make(map[ast.Expr]types.TypeAndValue),
Inferred: make(map[*ast.CallExpr]types.Inferred),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
return &Importer{
tmpdir: tmpdir,

View File

@ -44,11 +44,7 @@ func typeArgsFromFields(t *translator, astTypes []ast.Expr, typeTypes []types.Ty
if !ok {
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
}
var astType ast.Expr
if len(astTypes) > 0 {
astType = astTypes[i]
}
ta.add(obj, objParam, astType, typeTypes[i])
ta.add(obj, objParam, astTypes[i], typeTypes[i])
}
}
return ta
@ -75,9 +71,7 @@ func typeArgsFromExprs(t *translator, astTypes []ast.Expr, typeTypes []types.Typ
// add adds mappings for obj to ast and typ.
func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) {
if ast != nil {
ta.toAST[obj] = ast
}
ta.toAST[obj] = ast
ta.toTyp[objParam] = typ
}
@ -575,16 +569,36 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
case *ast.CallExpr:
fun := t.instantiateExpr(ta, e.Fun)
args, argsChanged := t.instantiateExprList(ta, e.Args)
if fun == e.Fun && !argsChanged {
origInferred, haveInferred := t.importer.info.Inferred[e]
var newInferred types.Inferred
inferredChanged := false
if haveInferred {
for _, typ := range origInferred.Targs {
nt := t.instantiateType(ta, typ)
newInferred.Targs = append(newInferred.Targs, nt)
if nt != typ {
inferredChanged = true
}
}
newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature)
if newInferred.Sig != origInferred.Sig {
inferredChanged = true
}
}
if fun == e.Fun && !argsChanged && !inferredChanged {
return e
}
r = &ast.CallExpr{
newCall := &ast.CallExpr{
Fun: fun,
Lparen: e.Lparen,
Args: args,
Ellipsis: e.Ellipsis,
Rparen: e.Rparen,
}
if haveInferred {
t.importer.info.Inferred[newCall] = newInferred
}
r = newCall
case *ast.FuncType:
params := t.instantiateFieldList(ta, e.Params)
results := t.instantiateFieldList(ta, e.Results)

View File

@ -488,32 +488,36 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) {
qid := t.instantiatedIdent(call)
argList, typeList, typeArgs := t.instantiationTypes(call)
var instIdent *ast.Ident
key := qid.String()
instantiations := t.instantiations[key]
for _, inst := range instantiations {
if t.sameTypes(typeList, inst.types) {
*pe = inst.decl
return
instIdent = inst.decl
break
}
}
instIdent, err := t.instantiateFunction(qid, argList, typeList)
if err != nil {
t.err = err
return
}
if instIdent == nil {
var err error
instIdent, err = t.instantiateFunction(qid, argList, typeList)
if err != nil {
t.err = err
return
}
n := &instantiation{
types: typeList,
decl: instIdent,
n := &instantiation{
types: typeList,
decl: instIdent,
}
t.instantiations[key] = append(instantiations, n)
}
t.instantiations[key] = append(instantiations, n)
if typeArgs {
*pe = instIdent
} else {
newCall := *call
call.Fun = instIdent
newCall.Fun = instIdent
*pe = &newCall
}
}
@ -580,15 +584,9 @@ func (t *translator) instantiatedIdent(call *ast.CallExpr) qualifiedIdent {
// It also returns the AST arguments if they are present.
// The typeArgs result reports whether the AST arguments are types.
func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList []types.Type, typeArgs bool) {
if len(call.Args) > 0 {
tv, ok := t.importer.info.Types[call.Args[0]]
if !ok {
panic(fmt.Sprintf("no type found for argument %v", call.Args[0]))
}
typeArgs = tv.IsType()
}
inferred, haveInferred := t.importer.info.Inferred[call]
if typeArgs {
if !haveInferred {
argList = call.Args
typeList = make([]types.Type, 0, len(argList))
for _, arg := range argList {
@ -598,12 +596,13 @@ func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr,
typeList = append(typeList, at)
}
}
typeArgs = true
} else {
params := t.lookupType(call.Fun).(*types.Signature).Params()
ln := params.Len()
typeList = make([]types.Type, 0, ln)
for i := 0; i < ln; i++ {
typeList = append(typeList, params.At(i).Type())
for _, typ := range inferred.Targs {
typeList = append(typeList, typ)
arg := ast.NewIdent(typ.String())
argList = append(argList, arg)
t.setType(arg, typ)
}
}
@ -616,7 +615,7 @@ func (t *translator) sameTypes(a, b []types.Type) bool {
return false
}
for i, x := range a {
if x != b[i] {
if !types.Identical(x, b[i]) {
return false
}
}

View File

@ -44,15 +44,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN()
var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
func TestSortOrderedInts() bool {
return testOrdered(int)("ints", ints, sort.Ints)
return testOrdered("ints", ints, sort.Ints)
}
func TestSortOrderedFloat64s() bool {
return testOrdered(float64)("float64s", float64s, sort.Float64s)
return testOrdered("float64s", float64s, sort.Float64s)
}
func TestSortOrderedStrings() bool {
return testOrdered(string)("strings", strings, sort.Strings)
return testOrdered("strings", strings, sort.Strings)
}
func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem)) bool {
@ -60,10 +60,10 @@ func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem))
copy(s1, s)
s2 := make([]Elem, len(s))
copy(s2, s)
OrderedSlice(Elem)(s1)
OrderedSlice(s1)
sorter(s2)
ok := true
if !sliceEq(Elem)(s1, s2) {
if !sliceEq(s1, s2) {
fmt.Printf("%s: got %v, want %v", name, s1, s2)
ok = false
}