go/go2go: fill in missing ast.Expr and Stmt nodes

Change-Id: I151a923c72fb61c8f31dc6d3688bcd338ab472f7
This commit is contained in:
Ian Lance Taylor 2020-03-17 15:30:47 -07:00 committed by Robert Griesemer
parent c65d33fa3d
commit c165308eb6
2 changed files with 255 additions and 104 deletions

View File

@ -289,8 +289,26 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
switch s := s.(type) {
case nil:
return nil
case *ast.BlockStmt:
return t.instantiateBlockStmt(ta, s)
case *ast.DeclStmt:
decl := t.instantiateDecl(ta, s.Decl)
if decl == s.Decl {
return s
}
return &ast.DeclStmt{
Decl: decl,
}
case *ast.EmptyStmt:
return s
case *ast.LabeledStmt:
stmt := t.instantiateStmt(ta, s.Stmt)
if stmt == s.Stmt {
return s
}
return &ast.LabeledStmt{
Label: s.Label,
Colon: s.Colon,
Stmt: stmt,
}
case *ast.ExprStmt:
x := t.instantiateExpr(ta, s.X)
if x == s.X {
@ -299,13 +317,16 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
return &ast.ExprStmt{
X: x,
}
case *ast.DeclStmt:
decl := t.instantiateDecl(ta, s.Decl)
if decl == s.Decl {
case *ast.SendStmt:
ch := t.instantiateExpr(ta, s.Chan)
value := t.instantiateExpr(ta, s.Value)
if ch == s.Chan && value == s.Value {
return s
}
return &ast.DeclStmt{
Decl: decl,
return &ast.SendStmt{
Chan: ch,
Arrow: s.Arrow,
Value: value,
}
case *ast.IncDecStmt:
x := t.instantiateExpr(ta, s.X)
@ -329,6 +350,37 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
Tok: s.Tok,
Rhs: rhs,
}
case *ast.GoStmt:
call := t.instantiateExpr(ta, s.Call).(*ast.CallExpr)
if call == s.Call {
return s
}
return &ast.GoStmt{
Go: s.Go,
Call: call,
}
case *ast.DeferStmt:
call := t.instantiateExpr(ta, s.Call).(*ast.CallExpr)
if call == s.Call {
return s
}
return &ast.DeferStmt{
Defer: s.Defer,
Call: call,
}
case *ast.ReturnStmt:
results, changed := t.instantiateExprList(ta, s.Results)
if !changed {
return s
}
return &ast.ReturnStmt{
Return: s.Return,
Results: results,
}
case *ast.BranchStmt:
return s
case *ast.BlockStmt:
return t.instantiateBlockStmt(ta, s)
case *ast.IfStmt:
init := t.instantiateStmt(ta, s.Init)
cond := t.instantiateExpr(ta, s.Cond)
@ -344,6 +396,65 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
Body: body,
Else: els,
}
case *ast.CaseClause:
list, listChanged := t.instantiateExprList(ta, s.List)
body, bodyChanged := t.instantiateStmtList(ta, s.Body)
if !listChanged && !bodyChanged {
return s
}
return &ast.CaseClause{
Case: s.Case,
List: list,
Colon: s.Colon,
Body: body,
}
case *ast.SwitchStmt:
init := t.instantiateStmt(ta, s.Init)
tag := t.instantiateExpr(ta, s.Tag)
body := t.instantiateBlockStmt(ta, s.Body)
if init == s.Init && tag == s.Tag && body == s.Body {
return s
}
return &ast.SwitchStmt{
Switch: s.Switch,
Init: init,
Tag: tag,
Body: body,
}
case *ast.TypeSwitchStmt:
init := t.instantiateStmt(ta, s.Init)
assign := t.instantiateStmt(ta, s.Assign)
body := t.instantiateBlockStmt(ta, s.Body)
if init == s.Init && assign == s.Assign && body == s.Body {
return s
}
return &ast.TypeSwitchStmt{
Switch: s.Switch,
Init: init,
Assign: assign,
Body: body,
}
case *ast.CommClause:
comm := t.instantiateStmt(ta, s.Comm)
body, bodyChanged := t.instantiateStmtList(ta, s.Body)
if comm == s.Comm && !bodyChanged {
return s
}
return &ast.CommClause{
Case: s.Case,
Comm: comm,
Colon: s.Colon,
Body: body,
}
case *ast.SelectStmt:
body := t.instantiateBlockStmt(ta, s.Body)
if body == s.Body {
return s
}
return &ast.SelectStmt{
Select: s.Select,
Body: body,
}
case *ast.ForStmt:
init := t.instantiateStmt(ta, s.Init)
cond := t.instantiateExpr(ta, s.Cond)
@ -376,15 +487,6 @@ func (t *translator) instantiateStmt(ta *typeArgs, s ast.Stmt) ast.Stmt {
X: x,
Body: body,
}
case *ast.ReturnStmt:
results, changed := t.instantiateExprList(ta, s.Results)
if !changed {
return s
}
return &ast.ReturnStmt{
Return: s.Return,
Results: results,
}
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
@ -411,6 +513,23 @@ func (t *translator) instantiateBlockStmt(ta *typeArgs, pbs *ast.BlockStmt) *ast
}
}
// instantiateStmtList instantiates a statement list.
func (t *translator) instantiateStmtList(ta *typeArgs, sl []ast.Stmt) ([]ast.Stmt, bool) {
nsl := make([]ast.Stmt, len(sl))
changed := false
for i, s := range sl {
ns := t.instantiateStmt(ta, s)
if ns != s {
changed = true
}
nsl[i] = ns
}
if !changed {
return sl, false
}
return nsl, true
}
// instantiateFieldList instantiates a field list.
func (t *translator) instantiateFieldList(ta *typeArgs, fl *ast.FieldList) *ast.FieldList {
if fl == nil {
@ -517,6 +636,80 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
X: x,
Sel: e.Sel,
}
case *ast.IndexExpr:
x := t.instantiateExpr(ta, e.X)
index := t.instantiateExpr(ta, e.Index)
if x == e.X && index == e.Index {
return e
}
r = &ast.IndexExpr{
X: x,
Lbrack: e.Lbrack,
Index: index,
Rbrack: e.Rbrack,
}
case *ast.SliceExpr:
x := t.instantiateExpr(ta, e.X)
low := t.instantiateExpr(ta, e.Low)
high := t.instantiateExpr(ta, e.High)
max := t.instantiateExpr(ta, e.Max)
if x == e.X && low == e.Low && high == e.High && max == e.Max {
return e
}
r = &ast.SliceExpr{
X: x,
Lbrack: e.Lbrack,
Low: low,
High: high,
Max: max,
Slice3: e.Slice3,
Rbrack: e.Rbrack,
}
case *ast.TypeAssertExpr:
x := t.instantiateExpr(ta, e.X)
typ := t.instantiateExpr(ta, e.Type)
if x == e.X && typ == e.Type {
return e
}
r = &ast.TypeAssertExpr{
X: x,
Lparen: e.Lparen,
Type: typ,
Rparen: e.Rparen,
}
case *ast.CallExpr:
fun := t.instantiateExpr(ta, e.Fun)
args, argsChanged := t.instantiateExprList(ta, e.Args)
origInferred, haveInferred := t.importer.info.Inferred[e]
var newInferred types.Inferred
inferredChanged := false
if haveInferred {
for _, typ := range origInferred.Targs {
nt := t.instantiateType(ta, typ)
newInferred.Targs = append(newInferred.Targs, nt)
if nt != typ {
inferredChanged = true
}
}
newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature)
if newInferred.Sig != origInferred.Sig {
inferredChanged = true
}
}
if fun == e.Fun && !argsChanged && !inferredChanged {
return e
}
newCall := &ast.CallExpr{
Fun: fun,
Lparen: e.Lparen,
Args: args,
Ellipsis: e.Ellipsis,
Rparen: e.Rparen,
}
if haveInferred {
t.importer.info.Inferred[newCall] = newInferred
}
r = newCall
case *ast.StarExpr:
x := t.instantiateExpr(ta, e.X)
if x == e.X {
@ -559,68 +752,6 @@ func (t *translator) instantiateExpr(ta *typeArgs, e ast.Expr) ast.Expr {
Colon: e.Colon,
Value: value,
}
case *ast.IndexExpr:
x := t.instantiateExpr(ta, e.X)
index := t.instantiateExpr(ta, e.Index)
if x == e.X && index == e.Index {
return e
}
r = &ast.IndexExpr{
X: x,
Lbrack: e.Lbrack,
Index: index,
Rbrack: e.Rbrack,
}
case *ast.SliceExpr:
x := t.instantiateExpr(ta, e.X)
low := t.instantiateExpr(ta, e.Low)
high := t.instantiateExpr(ta, e.High)
max := t.instantiateExpr(ta, e.Max)
if x == e.X && low == e.Low && high == e.High && max == e.Max {
return e
}
r = &ast.SliceExpr{
X: x,
Lbrack: e.Lbrack,
Low: low,
High: high,
Max: max,
Slice3: e.Slice3,
Rbrack: e.Rbrack,
}
case *ast.CallExpr:
fun := t.instantiateExpr(ta, e.Fun)
args, argsChanged := t.instantiateExprList(ta, e.Args)
origInferred, haveInferred := t.importer.info.Inferred[e]
var newInferred types.Inferred
inferredChanged := false
if haveInferred {
for _, typ := range origInferred.Targs {
nt := t.instantiateType(ta, typ)
newInferred.Targs = append(newInferred.Targs, nt)
if nt != typ {
inferredChanged = true
}
}
newInferred.Sig = t.instantiateType(ta, origInferred.Sig).(*types.Signature)
if newInferred.Sig != origInferred.Sig {
inferredChanged = true
}
}
if fun == e.Fun && !argsChanged && !inferredChanged {
return e
}
newCall := &ast.CallExpr{
Fun: fun,
Lparen: e.Lparen,
Args: args,
Ellipsis: e.Ellipsis,
Rparen: e.Rparen,
}
if haveInferred {
t.importer.info.Inferred[newCall] = newInferred
}
r = newCall
case *ast.ArrayType:
ln := t.instantiateExpr(ta, e.Len)
elt := t.instantiateExpr(ta, e.Elt)

View File

@ -330,15 +330,46 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
return
}
switch s := (*ps).(type) {
case *ast.BlockStmt:
t.translateBlockStmt(s)
case *ast.DeclStmt:
d := s.Decl.(*ast.GenDecl)
switch d.Tok {
case token.TYPE:
for i := range d.Specs {
t.translateTypeSpec(&d.Specs[i])
}
case token.CONST, token.VAR:
for i := range d.Specs {
t.translateValueSpec(&d.Specs[i])
}
default:
panic(fmt.Sprintf("unknown decl type %v", d.Tok))
}
case *ast.EmptyStmt:
case *ast.LabeledStmt:
t.translateStmt(&s.Stmt)
case *ast.ExprStmt:
t.translateExpr(&s.X)
case *ast.SendStmt:
t.translateExpr(&s.Chan)
t.translateExpr(&s.Value)
case *ast.IncDecStmt:
t.translateExpr(&s.X)
case *ast.AssignStmt:
t.translateExprList(s.Lhs)
t.translateExprList(s.Rhs)
case *ast.GoStmt:
e := ast.Expr(s.Call)
t.translateExpr(&e)
s.Call = e.(*ast.CallExpr)
case *ast.DeferStmt:
e := ast.Expr(s.Call)
t.translateExpr(&e)
s.Call = e.(*ast.CallExpr)
case *ast.ReturnStmt:
t.translateExprList(s.Results)
case *ast.BranchStmt:
case *ast.BlockStmt:
t.translateBlockStmt(s)
case *ast.IfStmt:
t.translateStmt(&s.Init)
t.translateExpr(&s.Cond)
@ -358,6 +389,8 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
case *ast.CommClause:
t.translateStmt(&s.Comm)
t.translateStmtList(s.Body)
case *ast.SelectStmt:
t.translateBlockStmt(s.Body)
case *ast.ForStmt:
t.translateStmt(&s.Init)
t.translateExpr(&s.Cond)
@ -368,22 +401,6 @@ func (t *translator) translateStmt(ps *ast.Stmt) {
t.translateExpr(&s.Value)
t.translateExpr(&s.X)
t.translateBlockStmt(s.Body)
case *ast.DeclStmt:
d := s.Decl.(*ast.GenDecl)
switch d.Tok {
case token.TYPE:
for i := range d.Specs {
t.translateTypeSpec(&d.Specs[i])
}
case token.CONST, token.VAR:
for i := range d.Specs {
t.translateValueSpec(&d.Specs[i])
}
default:
panic(fmt.Sprintf("unknown decl type %v", d.Tok))
}
case *ast.ReturnStmt:
t.translateExprList(s.Results)
default:
panic(fmt.Sprintf("unimplemented Stmt %T", s))
}
@ -420,13 +437,7 @@ func (t *translator) translateExpr(pe *ast.Expr) {
t.translateExprList(e.Elts)
case *ast.ParenExpr:
t.translateExpr(&e.X)
case *ast.BinaryExpr:
t.translateExpr(&e.X)
t.translateExpr(&e.Y)
case *ast.KeyValueExpr:
t.translateExpr(&e.Key)
t.translateExpr(&e.Value)
case *ast.UnaryExpr:
case *ast.SelectorExpr:
t.translateExpr(&e.X)
case *ast.IndexExpr:
t.translateExpr(&e.X)
@ -436,6 +447,9 @@ func (t *translator) translateExpr(pe *ast.Expr) {
t.translateExpr(&e.Low)
t.translateExpr(&e.High)
t.translateExpr(&e.Max)
case *ast.TypeAssertExpr:
t.translateExpr(&e.X)
t.translateExpr(&e.Type)
case *ast.CallExpr:
t.translateExprList(e.Args)
if ftyp, ok := t.lookupType(e.Fun).(*types.Signature); ok && len(ftyp.TParams()) > 0 {
@ -446,8 +460,14 @@ func (t *translator) translateExpr(pe *ast.Expr) {
t.translateExpr(&e.Fun)
case *ast.StarExpr:
t.translateExpr(&e.X)
case *ast.SelectorExpr:
case *ast.UnaryExpr:
t.translateExpr(&e.X)
case *ast.BinaryExpr:
t.translateExpr(&e.X)
t.translateExpr(&e.Y)
case *ast.KeyValueExpr:
t.translateExpr(&e.Key)
t.translateExpr(&e.Value)
case *ast.ArrayType:
t.translateExpr(&e.Elt)
case *ast.StructType: