mirror of https://github.com/golang/go.git
go/go2go: support type inference
We can now translate code that use type inference when calling generic functions. A couple of test cases were adjusted to use it. Change-Id: I53c2f3dd8f9fcdb44b4a8f592acee1a19ff46f22
This commit is contained in:
parent
4172f069ae
commit
5a5b14a4b3
|
|
@ -20,15 +20,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN()
|
|||
var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
|
||||
|
||||
func TestSortOrderedInts(t *testing.T) {
|
||||
testOrdered(int)(t, ints, sort.Ints)
|
||||
testOrdered(t, ints, sort.Ints)
|
||||
}
|
||||
|
||||
func TestSortOrderedFloat64s(t *testing.T) {
|
||||
testOrdered(float64)(t, float64s, sort.Float64s)
|
||||
testOrdered(t, float64s, sort.Float64s)
|
||||
}
|
||||
|
||||
func TestSortOrderedStrings(t *testing.T) {
|
||||
testOrdered(string)(t, strings, sort.Strings)
|
||||
testOrdered(t, strings, sort.Strings)
|
||||
}
|
||||
|
||||
func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter func([]Elem)) {
|
||||
|
|
@ -36,9 +36,9 @@ func testOrdered(type Elem contracts.Ordered)(t *testing.T, s []Elem, sorter fun
|
|||
copy(s1, s)
|
||||
s2 := make([]Elem, len(s))
|
||||
copy(s2, s)
|
||||
OrderedSlice(Elem)(s1)
|
||||
OrderedSlice(s1)
|
||||
sorter(s2)
|
||||
if !slices.Equal(Elem)(s1, s2) {
|
||||
if !slices.Equal(s1, s2) {
|
||||
t.Fatalf("got %v, want %v", s1, s2)
|
||||
}
|
||||
for i := len(s1) - 1; i > 0; i-- {
|
||||
|
|
@ -85,8 +85,8 @@ func sorter(s1, s2 []int) bool {
|
|||
func TestSortSliceFn(t *testing.T) {
|
||||
c := make([][]int, len(slicesToSort))
|
||||
copy(c, slicesToSort)
|
||||
SliceFn([]int)(c, sorter)
|
||||
if !slices.EqualFn([]int)(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) {
|
||||
SliceFn(c, sorter)
|
||||
if !slices.EqualFn(c, sortedSlices, func(a, b []int) bool { return slices.Equal(int)(a, b) }) {
|
||||
t.Errorf("got %v, want %v", c, sortedSlices)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,9 +49,10 @@ var _ types.ImporterFrom = &Importer{}
|
|||
// The tmpdir will become a GOPATH with translated files.
|
||||
func NewImporter(tmpdir string) *Importer {
|
||||
info := &types.Info{
|
||||
Types: make(map[ast.Expr]types.TypeAndValue),
|
||||
Defs: make(map[*ast.Ident]types.Object),
|
||||
Uses: make(map[*ast.Ident]types.Object),
|
||||
Types: make(map[ast.Expr]types.TypeAndValue),
|
||||
Inferred: make(map[*ast.CallExpr]types.Inferred),
|
||||
Defs: make(map[*ast.Ident]types.Object),
|
||||
Uses: make(map[*ast.Ident]types.Object),
|
||||
}
|
||||
return &Importer{
|
||||
tmpdir: tmpdir,
|
||||
|
|
|
|||
|
|
@ -44,11 +44,7 @@ func typeArgsFromFields(t *translator, astTypes []ast.Expr, typeTypes []types.Ty
|
|||
if !ok {
|
||||
panic(fmt.Sprintf("%v is not a TypeParam", objParam))
|
||||
}
|
||||
var astType ast.Expr
|
||||
if len(astTypes) > 0 {
|
||||
astType = astTypes[i]
|
||||
}
|
||||
ta.add(obj, objParam, astType, typeTypes[i])
|
||||
ta.add(obj, objParam, astTypes[i], typeTypes[i])
|
||||
}
|
||||
}
|
||||
return ta
|
||||
|
|
@ -75,9 +71,7 @@ func typeArgsFromExprs(t *translator, astTypes []ast.Expr, typeTypes []types.Typ
|
|||
|
||||
// add adds mappings for obj to ast and typ.
|
||||
func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) {
|
||||
if ast != nil {
|
||||
ta.toAST[obj] = ast
|
||||
}
|
||||
ta.toAST[obj] = ast
|
||||
ta.toTyp[objParam] = typ
|
||||
}
|
||||
|
||||
|
|
@ -575,16 +569,36 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
|
|||
case *ast.CallExpr:
|
||||
fun := t.instantiateExpr(ta, e.Fun)
|
||||
args, argsChanged := t.instantiateExprList(ta, e.Args)
|
||||
if fun == e.Fun && !argsChanged {
|
||||
origInferred, haveInferred := t.importer.info.Inferred[e]
|
||||
var newInferred types.Inferred
|
||||
inferredChanged := false
|
||||
if haveInferred {
|
||||
for _, typ := range origInferred.Targs {
|
||||
nt := t.instantiateType(ta, typ)
|
||||
newInferred.Targs = append(newInferred.Targs, nt)
|
||||
if nt != typ {
|
||||
inferredChanged = true
|
||||
}
|
||||
}
|
||||
newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature)
|
||||
if newInferred.Sig != origInferred.Sig {
|
||||
inferredChanged = true
|
||||
}
|
||||
}
|
||||
if fun == e.Fun && !argsChanged && !inferredChanged {
|
||||
return e
|
||||
}
|
||||
r = &ast.CallExpr{
|
||||
newCall := &ast.CallExpr{
|
||||
Fun: fun,
|
||||
Lparen: e.Lparen,
|
||||
Args: args,
|
||||
Ellipsis: e.Ellipsis,
|
||||
Rparen: e.Rparen,
|
||||
}
|
||||
if haveInferred {
|
||||
t.importer.info.Inferred[newCall] = newInferred
|
||||
}
|
||||
r = newCall
|
||||
case *ast.FuncType:
|
||||
params := t.instantiateFieldList(ta, e.Params)
|
||||
results := t.instantiateFieldList(ta, e.Results)
|
||||
|
|
|
|||
|
|
@ -488,32 +488,36 @@ func (t *translator) translateFunctionInstantiation(pe *ast.Expr) {
|
|||
qid := t.instantiatedIdent(call)
|
||||
argList, typeList, typeArgs := t.instantiationTypes(call)
|
||||
|
||||
var instIdent *ast.Ident
|
||||
key := qid.String()
|
||||
instantiations := t.instantiations[key]
|
||||
for _, inst := range instantiations {
|
||||
if t.sameTypes(typeList, inst.types) {
|
||||
*pe = inst.decl
|
||||
return
|
||||
instIdent = inst.decl
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
instIdent, err := t.instantiateFunction(qid, argList, typeList)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return
|
||||
}
|
||||
if instIdent == nil {
|
||||
var err error
|
||||
instIdent, err = t.instantiateFunction(qid, argList, typeList)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return
|
||||
}
|
||||
|
||||
n := &instantiation{
|
||||
types: typeList,
|
||||
decl: instIdent,
|
||||
n := &instantiation{
|
||||
types: typeList,
|
||||
decl: instIdent,
|
||||
}
|
||||
t.instantiations[key] = append(instantiations, n)
|
||||
}
|
||||
t.instantiations[key] = append(instantiations, n)
|
||||
|
||||
if typeArgs {
|
||||
*pe = instIdent
|
||||
} else {
|
||||
newCall := *call
|
||||
call.Fun = instIdent
|
||||
newCall.Fun = instIdent
|
||||
*pe = &newCall
|
||||
}
|
||||
}
|
||||
|
|
@ -580,15 +584,9 @@ func (t *translator) instantiatedIdent(call *ast.CallExpr) qualifiedIdent {
|
|||
// It also returns the AST arguments if they are present.
|
||||
// The typeArgs result reports whether the AST arguments are types.
|
||||
func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList []types.Type, typeArgs bool) {
|
||||
if len(call.Args) > 0 {
|
||||
tv, ok := t.importer.info.Types[call.Args[0]]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no type found for argument %v", call.Args[0]))
|
||||
}
|
||||
typeArgs = tv.IsType()
|
||||
}
|
||||
inferred, haveInferred := t.importer.info.Inferred[call]
|
||||
|
||||
if typeArgs {
|
||||
if !haveInferred {
|
||||
argList = call.Args
|
||||
typeList = make([]types.Type, 0, len(argList))
|
||||
for _, arg := range argList {
|
||||
|
|
@ -598,12 +596,13 @@ func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr,
|
|||
typeList = append(typeList, at)
|
||||
}
|
||||
}
|
||||
typeArgs = true
|
||||
} else {
|
||||
params := t.lookupType(call.Fun).(*types.Signature).Params()
|
||||
ln := params.Len()
|
||||
typeList = make([]types.Type, 0, ln)
|
||||
for i := 0; i < ln; i++ {
|
||||
typeList = append(typeList, params.At(i).Type())
|
||||
for _, typ := range inferred.Targs {
|
||||
typeList = append(typeList, typ)
|
||||
arg := ast.NewIdent(typ.String())
|
||||
argList = append(argList, arg)
|
||||
t.setType(arg, typ)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -616,7 +615,7 @@ func (t *translator) sameTypes(a, b []types.Type) bool {
|
|||
return false
|
||||
}
|
||||
for i, x := range a {
|
||||
if x != b[i] {
|
||||
if !types.Identical(x, b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,15 +44,15 @@ var float64s = []float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN()
|
|||
var strings = []string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
|
||||
|
||||
func TestSortOrderedInts() bool {
|
||||
return testOrdered(int)("ints", ints, sort.Ints)
|
||||
return testOrdered("ints", ints, sort.Ints)
|
||||
}
|
||||
|
||||
func TestSortOrderedFloat64s() bool {
|
||||
return testOrdered(float64)("float64s", float64s, sort.Float64s)
|
||||
return testOrdered("float64s", float64s, sort.Float64s)
|
||||
}
|
||||
|
||||
func TestSortOrderedStrings() bool {
|
||||
return testOrdered(string)("strings", strings, sort.Strings)
|
||||
return testOrdered("strings", strings, sort.Strings)
|
||||
}
|
||||
|
||||
func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem)) bool {
|
||||
|
|
@ -60,10 +60,10 @@ func testOrdered(type Elem Ordered)(name string, s []Elem, sorter func([]Elem))
|
|||
copy(s1, s)
|
||||
s2 := make([]Elem, len(s))
|
||||
copy(s2, s)
|
||||
OrderedSlice(Elem)(s1)
|
||||
OrderedSlice(s1)
|
||||
sorter(s2)
|
||||
ok := true
|
||||
if !sliceEq(Elem)(s1, s2) {
|
||||
if !sliceEq(s1, s2) {
|
||||
fmt.Printf("%s: got %v, want %v", name, s1, s2)
|
||||
ok = false
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue