diff --git a/src/go/go2go/instantiate.go b/src/go/go2go/instantiate.go index f7731df6fa..5eba7355b9 100644 --- a/src/go/go2go/instantiate.go +++ b/src/go/go2go/instantiate.go @@ -289,8 +289,26 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { switch s := s.(type) { case nil: return nil - case *ast.BlockStmt: - return t.instantiateBlockStmt(ta, s) + case *ast.DeclStmt: + decl := t.instantiateDecl(ta, s.Decl) + if decl == s.Decl { + return s + } + return &ast.DeclStmt{ + Decl: decl, + } + case *ast.EmptyStmt: + return s + case *ast.LabeledStmt: + stmt := t.instantiateStmt(ta, s.Stmt) + if stmt == s.Stmt { + return s + } + return &ast.LabeledStmt{ + Label: s.Label, + Colon: s.Colon, + Stmt: stmt, + } case *ast.ExprStmt: x := t.instantiateExpr(ta, s.X) if x == s.X { @@ -299,13 +317,16 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { return &ast.ExprStmt{ X: x, } - case *ast.DeclStmt: - decl := t.instantiateDecl(ta, s.Decl) - if decl == s.Decl { + case *ast.SendStmt: + ch := t.instantiateExpr(ta, s.Chan) + value := t.instantiateExpr(ta, s.Value) + if ch == s.Chan && value == s.Value { return s } - return &ast.DeclStmt{ - Decl: decl, + return &ast.SendStmt{ + Chan: ch, + Arrow: s.Arrow, + Value: value, } case *ast.IncDecStmt: x := t.instantiateExpr(ta, s.X) @@ -329,6 +350,37 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { Tok: s.Tok, Rhs: rhs, } + case *ast.GoStmt: + call := t.instantiateExpr(ta, s.Call).(*ast.CallExpr) + if call == s.Call { + return s + } + return &ast.GoStmt{ + Go: s.Go, + Call: call, + } + case *ast.DeferStmt: + call := t.instantiateExpr(ta, s.Call).(*ast.CallExpr) + if call == s.Call { + return s + } + return &ast.DeferStmt{ + Defer: s.Defer, + Call: call, + } + case *ast.ReturnStmt: + results, changed := t.instantiateExprList(ta, s.Results) + if !changed { + return s + } + return &ast.ReturnStmt{ + Return: s.Return, + Results: results, + } + case *ast.BranchStmt: + return s + case *ast.BlockStmt: + return t.instantiateBlockStmt(ta, s) case *ast.IfStmt: init := t.instantiateStmt(ta, s.Init) cond := t.instantiateExpr(ta, s.Cond) @@ -344,6 +396,65 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { Body: body, Else: els, } + case *ast.CaseClause: + list, listChanged := t.instantiateExprList(ta, s.List) + body, bodyChanged := t.instantiateStmtList(ta, s.Body) + if !listChanged && !bodyChanged { + return s + } + return &ast.CaseClause{ + Case: s.Case, + List: list, + Colon: s.Colon, + Body: body, + } + case *ast.SwitchStmt: + init := t.instantiateStmt(ta, s.Init) + tag := t.instantiateExpr(ta, s.Tag) + body := t.instantiateBlockStmt(ta, s.Body) + if init == s.Init && tag == s.Tag && body == s.Body { + return s + } + return &ast.SwitchStmt{ + Switch: s.Switch, + Init: init, + Tag: tag, + Body: body, + } + case *ast.TypeSwitchStmt: + init := t.instantiateStmt(ta, s.Init) + assign := t.instantiateStmt(ta, s.Assign) + body := t.instantiateBlockStmt(ta, s.Body) + if init == s.Init && assign == s.Assign && body == s.Body { + return s + } + return &ast.TypeSwitchStmt{ + Switch: s.Switch, + Init: init, + Assign: assign, + Body: body, + } + case *ast.CommClause: + comm := t.instantiateStmt(ta, s.Comm) + body, bodyChanged := t.instantiateStmtList(ta, s.Body) + if comm == s.Comm && !bodyChanged { + return s + } + return &ast.CommClause{ + Case: s.Case, + Comm: comm, + Colon: s.Colon, + Body: body, + } + case *ast.SelectStmt: + body := t.instantiateBlockStmt(ta, s.Body) + if body == s.Body { + return s + } + return &ast.SelectStmt{ + Select: s.Select, + Body: body, + } case *ast.ForStmt: init := t.instantiateStmt(ta, s.Init) cond := t.instantiateExpr(ta, s.Cond) @@ -376,15 +487,6 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt { X: x, Body: body, } - case *ast.ReturnStmt: - results, changed := t.instantiateExprList(ta, s.Results) - if !changed { - return s - } - return &ast.ReturnStmt{ - Return: s.Return, - Results: results, - } default: panic(fmt.Sprintf("unimplemented Stmt %T", s)) } @@ -411,6 +513,23 @@ func (t *translator) instantiateBlockStmt(ta *typeArgs, pbs *ast.BlockStmt) *ast } } +// instantiateStmtList instantiates a statement list. +func (t *translator) instantiateStmtList(ta *typeArgs, sl []ast.Stmt) ([]ast.Stmt, bool) { + nsl := make([]ast.Stmt, len(sl)) + changed := false + for i, s := range sl { + ns := t.instantiateStmt(ta, s) + if ns != s { + changed = true + } + nsl[i] = ns + } + if !changed { + return sl, false + } + return nsl, true +} + // instantiateFieldList instantiates a field list. func (t *translator) instantiateFieldList(ta *typeArgs, fl *ast.FieldList) *ast.FieldList { if fl == nil { @@ -517,6 +636,80 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { X: x, Sel: e.Sel, } + case *ast.IndexExpr: + x := t.instantiateExpr(ta, e.X) + index := t.instantiateExpr(ta, e.Index) + if x == e.X && index == e.Index { + return e + } + r = &ast.IndexExpr{ + X: x, + Lbrack: e.Lbrack, + Index: index, + Rbrack: e.Rbrack, + } + case *ast.SliceExpr: + x := t.instantiateExpr(ta, e.X) + low := t.instantiateExpr(ta, e.Low) + high := t.instantiateExpr(ta, e.High) + max := t.instantiateExpr(ta, e.Max) + if x == e.X && low == e.Low && high == e.High && max == e.Max { + return e + } + r = &ast.SliceExpr{ + X: x, + Lbrack: e.Lbrack, + Low: low, + High: high, + Max: max, + Slice3: e.Slice3, + Rbrack: e.Rbrack, + } + case *ast.TypeAssertExpr: + x := t.instantiateExpr(ta, e.X) + typ := t.instantiateExpr(ta, e.Type) + if x == e.X && typ == e.Type { + return e + } + r = &ast.TypeAssertExpr{ + X: x, + Lparen: e.Lparen, + Type: typ, + Rparen: e.Rparen, + } + case *ast.CallExpr: + fun := t.instantiateExpr(ta, e.Fun) + args, argsChanged := t.instantiateExprList(ta, e.Args) + origInferred, haveInferred := t.importer.info.Inferred[e] + var newInferred types.Inferred + inferredChanged := false + if haveInferred { + for _, typ := range origInferred.Targs { + nt := t.instantiateType(ta, typ) + newInferred.Targs = append(newInferred.Targs, nt) + if nt != typ { + inferredChanged = true + } + } + newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature) + if newInferred.Sig != origInferred.Sig { + inferredChanged = true + } + } + if fun == e.Fun && !argsChanged && !inferredChanged { + return e + } + newCall := &ast.CallExpr{ + Fun: fun, + Lparen: e.Lparen, + Args: args, + Ellipsis: e.Ellipsis, + Rparen: e.Rparen, + } + if haveInferred { + t.importer.info.Inferred[newCall] = newInferred + } + r = newCall case *ast.StarExpr: x := t.instantiateExpr(ta, e.X) if x == e.X { @@ -559,68 +752,6 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr { Colon: e.Colon, Value: value, } - case *ast.IndexExpr: - x := t.instantiateExpr(ta, e.X) - index := t.instantiateExpr(ta, e.Index) - if x == e.X && index == e.Index { - return e - } - r = &ast.IndexExpr{ - X: x, - Lbrack: e.Lbrack, - Index: index, - Rbrack: e.Rbrack, - } - case *ast.SliceExpr: - x := t.instantiateExpr(ta, e.X) - low := t.instantiateExpr(ta, e.Low) - high := t.instantiateExpr(ta, e.High) - max := t.instantiateExpr(ta, e.Max) - if x == e.X && low == e.Low && high == e.High && max == e.Max { - return e - } - r = &ast.SliceExpr{ - X: x, - Lbrack: e.Lbrack, - Low: low, - High: high, - Max: max, - Slice3: e.Slice3, - Rbrack: e.Rbrack, - } - case *ast.CallExpr: - fun := t.instantiateExpr(ta, e.Fun) - args, argsChanged := t.instantiateExprList(ta, e.Args) - origInferred, haveInferred := t.importer.info.Inferred[e] - var newInferred types.Inferred - inferredChanged := false - if haveInferred { - for _, typ := range origInferred.Targs { - nt := t.instantiateType(ta, typ) - newInferred.Targs = append(newInferred.Targs, nt) - if nt != typ { - inferredChanged = true - } - } - newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature) - if newInferred.Sig != origInferred.Sig { - inferredChanged = true - } - } - if fun == e.Fun && !argsChanged && !inferredChanged { - return e - } - newCall := &ast.CallExpr{ - Fun: fun, - Lparen: e.Lparen, - Args: args, - Ellipsis: e.Ellipsis, - Rparen: e.Rparen, - } - if haveInferred { - t.importer.info.Inferred[newCall] = newInferred - } - r = newCall case *ast.ArrayType: ln := t.instantiateExpr(ta, e.Len) elt := t.instantiateExpr(ta, e.Elt) diff --git a/src/go/go2go/rewrite.go b/src/go/go2go/rewrite.go index 94db2469bb..2d519a1488 100644 --- a/src/go/go2go/rewrite.go +++ b/src/go/go2go/rewrite.go @@ -330,15 +330,46 @@ func (t *translator) translateStmt(ps *ast.Stmt) { return } switch s := (*ps).(type) { - case *ast.BlockStmt: - t.translateBlockStmt(s) + case *ast.DeclStmt: + d := s.Decl.(*ast.GenDecl) + switch d.Tok { + case token.TYPE: + for i := range d.Specs { + t.translateTypeSpec(&d.Specs[i]) + } + case token.CONST, token.VAR: + for i := range d.Specs { + t.translateValueSpec(&d.Specs[i]) + } + default: + panic(fmt.Sprintf("unknown decl type %v", d.Tok)) + } + case *ast.EmptyStmt: + case *ast.LabeledStmt: + t.translateStmt(&s.Stmt) case *ast.ExprStmt: t.translateExpr(&s.X) + case *ast.SendStmt: + t.translateExpr(&s.Chan) + t.translateExpr(&s.Value) case *ast.IncDecStmt: t.translateExpr(&s.X) case *ast.AssignStmt: t.translateExprList(s.Lhs) t.translateExprList(s.Rhs) + case *ast.GoStmt: + e := ast.Expr(s.Call) + t.translateExpr(&e) + s.Call = e.(*ast.CallExpr) + case *ast.DeferStmt: + e := ast.Expr(s.Call) + t.translateExpr(&e) + s.Call = e.(*ast.CallExpr) + case *ast.ReturnStmt: + t.translateExprList(s.Results) + case *ast.BranchStmt: + case *ast.BlockStmt: + t.translateBlockStmt(s) case *ast.IfStmt: t.translateStmt(&s.Init) t.translateExpr(&s.Cond) @@ -358,6 +389,8 @@ func (t *translator) translateStmt(ps *ast.Stmt) { case *ast.CommClause: t.translateStmt(&s.Comm) t.translateStmtList(s.Body) + case *ast.SelectStmt: + t.translateBlockStmt(s.Body) case *ast.ForStmt: t.translateStmt(&s.Init) t.translateExpr(&s.Cond) @@ -368,22 +401,6 @@ func (t *translator) translateStmt(ps *ast.Stmt) { t.translateExpr(&s.Value) t.translateExpr(&s.X) t.translateBlockStmt(s.Body) - case *ast.DeclStmt: - d := s.Decl.(*ast.GenDecl) - switch d.Tok { - case token.TYPE: - for i := range d.Specs { - t.translateTypeSpec(&d.Specs[i]) - } - case token.CONST, token.VAR: - for i := range d.Specs { - t.translateValueSpec(&d.Specs[i]) - } - default: - panic(fmt.Sprintf("unknown decl type %v", d.Tok)) - } - case *ast.ReturnStmt: - t.translateExprList(s.Results) default: panic(fmt.Sprintf("unimplemented Stmt %T", s)) } @@ -420,13 +437,7 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExprList(e.Elts) case *ast.ParenExpr: t.translateExpr(&e.X) - case *ast.BinaryExpr: - t.translateExpr(&e.X) - t.translateExpr(&e.Y) - case *ast.KeyValueExpr: - t.translateExpr(&e.Key) - t.translateExpr(&e.Value) - case *ast.UnaryExpr: + case *ast.SelectorExpr: t.translateExpr(&e.X) case *ast.IndexExpr: t.translateExpr(&e.X) @@ -436,6 +447,9 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExpr(&e.Low) t.translateExpr(&e.High) t.translateExpr(&e.Max) + case *ast.TypeAssertExpr: + t.translateExpr(&e.X) + t.translateExpr(&e.Type) case *ast.CallExpr: t.translateExprList(e.Args) if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && len(ftyp.TParams()) > 0 { @@ -446,8 +460,14 @@ func (t *translator) translateExpr(pe *ast.Expr) { t.translateExpr(&e.Fun) case *ast.StarExpr: t.translateExpr(&e.X) - case *ast.SelectorExpr: + case *ast.UnaryExpr: t.translateExpr(&e.X) + case *ast.BinaryExpr: + t.translateExpr(&e.X) + t.translateExpr(&e.Y) + case *ast.KeyValueExpr: + t.translateExpr(&e.Key) + t.translateExpr(&e.Value) case *ast.ArrayType: t.translateExpr(&e.Elt) case *ast.StructType: