diff --git a/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 b/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 index 2f108a1165..47715ed23f 100644 --- a/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 +++ b/src/cmd/go2go/testdata/go2path/src/gsort/gsort_test.go2 @@ -20,15 +20,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN() var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} func TestSortOrderedInts(t *testing.T) { - testOrdered(int)(t, ints, sort.Ints) + testOrdered(t, ints, sort.Ints) } func TestSortOrderedFloat64s(t *testing.T) { - testOrdered(float64)(t, float64s, sort.Float64s) + testOrdered(t, float64s, sort.Float64s) } func TestSortOrderedStrings(t *testing.T) { - testOrdered(string)(t, strings, sort.Strings) + testOrdered(t, strings, sort.Strings) } func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter func([]Elem)) { @@ -36,9 +36,9 @@ func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter fun copy(s1, s) s2 := make([]Elem, len(s)) copy(s2, s) - OrderedSlice(Elem)(s1) + OrderedSlice(s1) sorter(s2) - if !slices.Equal(Elem)(s1, s2) { + if !slices.Equal(s1, s2) { t.Fatalf("got %v, want %v", s1, s2) } for i := len(s1) - 1; i > 0; i-- { @@ -85,8 +85,8 @@ func sorter(s1, s2 []int) bool { func TestSortSliceFn(t *testing.T) { c := make([][]int, len(slicesToSort)) copy(c, slicesToSort) - SliceFn([]int)(c, sorter) - if !slices.EqualFn([]int)(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) { + SliceFn(c, sorter) + if !slices.EqualFn(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) { t.Errorf("got %v, want %v", c, sortedSlices) } } diff --git a/src/go/go2go/importer.go b/src/go/go2go/importer.go index 4baa50e0c5..5ee69d572c 100644 --- a/src/go/go2go/importer.go +++ b/src/go/go2go/importer.go @@ -49,9 +49,10 @@ var _ types.ImporterFrom = &Importer{} // The tmpdir will become a GOPATH with translated files. func NewImporter(tmpdir string) *Importer { info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), + Types: make(map[ast.Expr]types.TypeAndValue), + Inferred: make(map[*ast.CallExpr]types.Inferred), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), } return &Importer{ tmpdir: tmpdir, diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index 3f5ff7a69b..98ebaf84c6 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -44,11 +44,7 @@ func typeArgsFromFields(t *translator, astTypes []ast.Expr, typeTypes []types.Ty if !ok { panic(fmt.Sprintf("%v is not a TypeParam", objParam)) } - var astType ast.Expr - if len(astTypes) > 0 { - astType = astTypes[i] - } - ta.add(obj, objParam, astType, typeTypes[i]) + ta.add(obj, objParam, astTypes[i], typeTypes[i]) } } return ta @@ -75,9 +71,7 @@ func typeArgsFromExprs(t *translator, astTypes []ast.Expr, typeTypes []types.Typ // add adds mappings for obj to ast and typ. func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) { - if ast != nil { - ta.toAST[obj] = ast - } + ta.toAST[obj] = ast ta.toTyp[objParam] = typ } @@ -575,16 +569,36 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { case *ast.CallExpr: fun := t.instantiateExpr(ta, e.Fun) args, argsChanged := t.instantiateExprList(ta, e.Args) - if fun == e.Fun && !argsChanged { + origInferred, haveInferred := t.importer.info.Inferred[e] + var newInferred types.Inferred + inferredChanged := false + if haveInferred { + for _, typ := range origInferred.Targs { + nt := t.instantiateType(ta, typ) + newInferred.Targs = append(newInferred.Targs, nt) + if nt != typ { + inferredChanged = true + } + } + newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature) + if newInferred.Sig != origInferred.Sig { + inferredChanged = true + } + } + if fun == e.Fun && !argsChanged && !inferredChanged { return e } - r = &ast.CallExpr{ + newCall := &ast.CallExpr{ Fun: fun, Lparen: e.Lparen, Args: args, Ellipsis: e.Ellipsis, Rparen: e.Rparen, } + if haveInferred { + t.importer.info.Inferred[newCall] = newInferred + } + r = newCall case *ast.FuncType: params := t.instantiateFieldList(ta, e.Params) results := t.instantiateFieldList(ta, e.Results) diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 7660c847fd..127c1c6fe1 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -488,32 +488,36 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) { qid := t.instantiatedIdent(call) argList, typeList, typeArgs := t.instantiationTypes(call) + var instIdent *ast.Ident key := qid.String() instantiations := t.instantiations[key] for _, inst := range instantiations { if t.sameTypes(typeList, inst.types) { - *pe = inst.decl - return + instIdent = inst.decl + break } } - instIdent, err := t.instantiateFunction(qid, argList, typeList) - if err != nil { - t.err = err - return - } + if instIdent == nil { + var err error + instIdent, err = t.instantiateFunction(qid, argList, typeList) + if err != nil { + t.err = err + return + } - n := &instantiation{ - types: typeList, - decl: instIdent, + n := &instantiation{ + types: typeList, + decl: instIdent, + } + t.instantiations[key] = append(instantiations, n) } - t.instantiations[key] = append(instantiations, n) if typeArgs { *pe = instIdent } else { newCall := *call - call.Fun = instIdent + newCall.Fun = instIdent *pe = &newCall } } @@ -580,15 +584,9 @@ func (t *translator) instantiatedIdent(call *ast.CallExpr) qualifiedIdent { // It also returns the AST arguments if they are present. // The typeArgs result reports whether the AST arguments are types. func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList []types.Type, typeArgs bool) { - if len(call.Args) > 0 { - tv, ok := t.importer.info.Types[call.Args[0]] - if !ok { - panic(fmt.Sprintf("no type found for argument %v", call.Args[0])) - } - typeArgs = tv.IsType() - } + inferred, haveInferred := t.importer.info.Inferred[call] - if typeArgs { + if !haveInferred { argList = call.Args typeList = make([]types.Type, 0, len(argList)) for _, arg := range argList { @@ -598,12 +596,13 @@ func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList = append(typeList, at) } } + typeArgs = true } else { - params := t.lookupType(call.Fun).(*types.Signature).Params() - ln := params.Len() - typeList = make([]types.Type, 0, ln) - for i := 0; i < ln; i++ { - typeList = append(typeList, params.At(i).Type()) + for _, typ := range inferred.Targs { + typeList = append(typeList, typ) + arg := ast.NewIdent(typ.String()) + argList = append(argList, arg) + t.setType(arg, typ) } } @@ -616,7 +615,7 @@ func (t *translator) sameTypes(a, b []types.Type) bool { return false } for i, x := range a { - if x != b[i] { + if !types.Identical(x, b[i]) { return false } } diff --git a/test/gen/g006.go2 b/test/gen/g006.go2 index fadc672808..be19f81fe1 100644 --- a/test/gen/g006.go2 +++ b/test/gen/g006.go2 @@ -44,15 +44,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN() var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} func TestSortOrderedInts() bool { - return testOrdered(int)("ints", ints, sort.Ints) + return testOrdered("ints", ints, sort.Ints) } func TestSortOrderedFloat64s() bool { - return testOrdered(float64)("float64s", float64s, sort.Float64s) + return testOrdered("float64s", float64s, sort.Float64s) } func TestSortOrderedStrings() bool { - return testOrdered(string)("strings", strings, sort.Strings) + return testOrdered("strings", strings, sort.Strings) } func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem)) bool { @@ -60,10 +60,10 @@ func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem)) copy(s1, s) s2 := make([]Elem, len(s)) copy(s2, s) - OrderedSlice(Elem)(s1) + OrderedSlice(s1) sorter(s2) ok := true - if !sliceEq(Elem)(s1, s2) { + if !sliceEq(s1, s2) { fmt.Printf("%s: got %v, want %v", name, s1, s2) ok = false }