From 74ba054b38398cb87216fbd373d1971ec8504568 Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Mon, 2 Mar 2020 17:35:07 -0800 Subject: [PATCH] go/go2go: start to handle inferred types Also support for statements and inc/dec statements. Change-Id: I9af474065d3b433f582422b739fc49ef266d2751 --- src/go/go2go/instantiate.go | 35 +++++++++++++++- src/go/go2go/rewrite.go | 81 ++++++++++++++++++++++++++----------- 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index 54cd970067..0da462479a 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -46,7 +46,11 @@ func typeArgsFromFields(t *translator, info *types.Info, astTypes []ast.Expr, ty if !ok { panic(fmt.Sprintf("%v is not a TypeParam", objParam)) } - ta.add(obj, objParam, astTypes[i], typeTypes[i]) + var astType ast.Expr + if len(astTypes) > 0 { + astType = astTypes[i] + } + ta.add(obj, objParam, astType, typeTypes[i]) } } return ta @@ -73,7 +77,9 @@ func typeArgsFromExprs(t *translator, info *types.Info, astTypes []ast.Expr, typ // add adds mappings for obj to ast and typ. func (ta *typeArgs) add(obj types.Object, objParam *types.TypeParam, ast ast.Expr, typ types.Type) { - ta.toAST[obj] = ast + if ast != nil { + ta.toAST[obj] = ast + } ta.toTyp[objParam] = typ } @@ -325,6 +331,16 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { return &ast.DeclStmt{ Decl: decl, } + case *ast.IncDecStmt: + x := t.instantiateExpr(ta, s.X) + if x == s.X { + return s + } + return &ast.IncDecStmt{ + X: x, + TokPos: s.TokPos, + Tok: s.Tok, + } case *ast.AssignStmt: lhs, lchanged := t.instantiateExprList(ta, s.Lhs) rhs, rchanged := t.instantiateExprList(ta, s.Rhs) @@ -352,6 +368,21 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { Body: body, Else: els, } + case *ast.ForStmt: + init := t.instantiateStmt(ta, s.Init) + cond := t.instantiateExpr(ta, s.Cond) + post := t.instantiateStmt(ta, s.Post) + body := t.instantiateBlockStmt(ta, s.Body) + if init == s.Init && cond == s.Cond && post == s.Post && body == s.Body { + return s + } + return &ast.ForStmt{ + For: s.For, + Init: init, + Cond: cond, + Post: post, + Body: body, + } case *ast.RangeStmt: key := t.instantiateExpr(ta, s.Key) value := t.instantiateExpr(ta, s.Value) diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index c2d117dd72..981f92df7a 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -279,6 +279,8 @@ func (t *translator) translateStmt(ps *ast.Stmt) { t.translateBlockStmt(s) case *ast.ExprStmt: t.translateExpr(&s.X) + case *ast.IncDecStmt: + t.translateExpr(&s.X) case *ast.AssignStmt: t.translateExprList(s.Lhs) t.translateExprList(s.Rhs) @@ -287,6 +289,11 @@ func (t *translator) translateStmt(ps *ast.Stmt) { t.translateExpr(&s.Cond) t.translateBlockStmt(s.Body) t.translateStmt(&s.Else) + case *ast.ForStmt: + t.translateStmt(&s.Init) + t.translateExpr(&s.Cond) + t.translateStmt(&s.Post) + t.translateBlockStmt(s.Body) case *ast.RangeStmt: t.translateExpr(&s.Key) t.translateExpr(&s.Value) @@ -393,36 +400,35 @@ func (t *translator) translateField(f *ast.Field) { func (t *translator) translateFunctionInstantiation(pe *ast.Expr) { call := (*pe).(*ast.CallExpr) qid := t.instantiatedIdent(call) - types := make([]types.Type, 0, len(call.Args)) - for _, arg := range call.Args { - if at := t.lookupType(arg); at == nil { - panic(fmt.Sprintf("no type found for %T %v", arg, arg)) - } else { - types = append(types, at) - } - } + argList, typeList, typeArgs := t.instantiationTypes(call) instantiations := t.instantiations[qid] for _, inst := range instantiations { - if t.sameTypes(types, inst.types) { + if t.sameTypes(typeList, inst.types) { *pe = inst.decl return } } - instIdent, err := t.instantiateFunction(qid, call.Args, types) + instIdent, err := t.instantiateFunction(qid, argList, typeList) if err != nil { t.err = err return } n := &instantiation{ - types: types, + types: typeList, decl: instIdent, } t.instantiations[qid] = append(instantiations, n) - *pe = instIdent + if typeArgs { + *pe = instIdent + } else { + newCall := *call + call.Fun = instIdent + *pe = &newCall + } } // translateTypeInstantiation translates an instantiated type to Go 1. @@ -430,32 +436,27 @@ func (t *translator) translateTypeInstantiation(pe *ast.Expr) { call := (*pe).(*ast.CallExpr) qid := t.instantiatedIdent(call) typ := t.lookupType(call.Fun).(*types.Named) - - types := make([]types.Type, 0, len(call.Args)) - for _, arg := range call.Args { - if at := t.lookupType(arg); at == nil { - panic(fmt.Sprintf("no type found for %T %v", arg, arg)) - } else { - types = append(types, at) - } + argList, typeList, typeArgs := t.instantiationTypes(call) + if !typeArgs { + panic("no type arguments for type") } instantiations := t.typeInstantiations[typ] for _, inst := range instantiations { - if t.sameTypes(types, inst.types) { + if t.sameTypes(typeList, inst.types) { *pe = inst.decl return } } - instIdent, instType, err := t.instantiateTypeDecl(qid, typ, call.Args, types) + instIdent, instType, err := t.instantiateTypeDecl(qid, typ, argList, typeList) if err != nil { t.err = err return } n := &typeInstantiation{ - types: types, + types: typeList, decl: instIdent, typ: instType, } @@ -488,6 +489,40 @@ func (t *translator) instantiatedIdent(call *ast.CallExpr) qualifiedIdent { panic(fmt.Sprintf("instantiated object %v is not an identifier", call.Fun)) } +// instantiationTypes returns the type arguments of an instantiation. +// It also returns the AST arguments if they are present. +// The typeArgs result reports whether the AST arguments are types. +func (t *translator) instantiationTypes(call *ast.CallExpr) (argList []ast.Expr, typeList []types.Type, typeArgs bool) { + if len(call.Args) > 0 { + tv, ok := t.info.Types[call.Args[0]] + if !ok { + panic(fmt.Sprintf("no type found for argument %v", call.Args[0])) + } + typeArgs = tv.IsType() + } + + if typeArgs { + argList = call.Args + typeList = make([]types.Type, 0, len(argList)) + for _, arg := range argList { + if at := t.lookupType(arg); at == nil { + panic(fmt.Sprintf("no type found for %T %v", arg, arg)) + } else { + typeList = append(typeList, at) + } + } + } else { + params := t.lookupType(call.Fun).(*types.Signature).Params() + ln := params.Len() + typeList = make([]types.Type, 0, ln) + for i := 0; i < ln; i++ { + typeList = append(typeList, params.At(i).Type()) + } + } + + return +} + // sameTypes reports whether two type slices are the same. func (t *translator) sameTypes(a, b []types.Type) bool { if len(a) != len(b) {