diff --git a/src/cmd/compile/internal/ir/node_gen.go b/src/cmd/compile/internal/ir/node_gen.go index 5796544b48..7412969425 100644 --- a/src/cmd/compile/internal/ir/node_gen.go +++ b/src/cmd/compile/internal/ir/node_gen.go @@ -297,21 +297,18 @@ func (n *CommStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *CommStmt) copy() Node { c := *n c.init = c.init.Copy() - c.List = c.List.Copy() c.Body = c.Body.Copy() return &c } func (n *CommStmt) doChildren(do func(Node) error) error { var err error err = maybeDoList(n.init, err, do) - err = maybeDoList(n.List, err, do) err = maybeDo(n.Comm, err, do) err = maybeDoList(n.Body, err, do) return err } func (n *CommStmt) editChildren(edit func(Node) Node) { editList(n.init, edit) - editList(n.List, edit) n.Comm = maybeEdit(n.Comm, edit) editList(n.Body, edit) } diff --git a/src/cmd/compile/internal/ir/stmt.go b/src/cmd/compile/internal/ir/stmt.go index 181a0fd582..0f44acd8b4 100644 --- a/src/cmd/compile/internal/ir/stmt.go +++ b/src/cmd/compile/internal/ir/stmt.go @@ -220,13 +220,12 @@ func editCases(list []*CaseStmt, edit func(Node) Node) { type CommStmt struct { miniStmt - List Nodes // list of expressions for switch, early select - Comm Node // communication case (Exprs[0]) after select is type-checked + Comm Node // communication case Body Nodes } -func NewCommStmt(pos src.XPos, list, body []Node) *CommStmt { - n := &CommStmt{List: list, Body: body} +func NewCommStmt(pos src.XPos, comm Node, body []Node) *CommStmt { + n := &CommStmt{Comm: comm, Body: body} n.pos = pos n.op = OCASE return n @@ -274,11 +273,13 @@ type ForStmt struct { HasBreak bool } -func NewForStmt(pos src.XPos, init []Node, cond, post Node, body []Node) *ForStmt { +func NewForStmt(pos src.XPos, init Node, cond, post Node, body []Node) *ForStmt { n := &ForStmt{Cond: cond, Post: post} n.pos = pos n.op = OFOR - n.init.Set(init) + if init != nil { + n.init = []Node{init} + } n.Body.Set(body) return n } diff --git a/src/cmd/compile/internal/noder/noder.go b/src/cmd/compile/internal/noder/noder.go index ff699cd54d..19a88e21a2 100644 --- a/src/cmd/compile/internal/noder/noder.go +++ b/src/cmd/compile/internal/noder/noder.go @@ -1149,9 +1149,11 @@ func (p *noder) blockStmt(stmt *syntax.BlockStmt) []ir.Node { func (p *noder) ifStmt(stmt *syntax.IfStmt) ir.Node { p.openScope(stmt.Pos()) - init := p.simpleStmt(stmt.Init) + init := p.stmt(stmt.Init) n := ir.NewIfStmt(p.pos(stmt), p.expr(stmt.Cond), p.blockStmt(stmt.Then), nil) - *n.PtrInit() = init + if init != nil { + *n.PtrInit() = []ir.Node{init} + } if stmt.Else != nil { e := p.stmt(stmt.Else) if e.Op() == ir.OBLOCK { @@ -1186,7 +1188,7 @@ func (p *noder) forStmt(stmt *syntax.ForStmt) ir.Node { return n } - n := ir.NewForStmt(p.pos(stmt), p.simpleStmt(stmt.Init), p.expr(stmt.Cond), p.stmt(stmt.Post), p.blockStmt(stmt.Body)) + n := ir.NewForStmt(p.pos(stmt), p.stmt(stmt.Init), p.expr(stmt.Cond), p.stmt(stmt.Post), p.blockStmt(stmt.Body)) p.closeAnotherScope() return n } @@ -1194,9 +1196,11 @@ func (p *noder) forStmt(stmt *syntax.ForStmt) ir.Node { func (p *noder) switchStmt(stmt *syntax.SwitchStmt) ir.Node { p.openScope(stmt.Pos()) - init := p.simpleStmt(stmt.Init) + init := p.stmt(stmt.Init) n := ir.NewSwitchStmt(p.pos(stmt), p.expr(stmt.Tag), nil) - *n.PtrInit() = init + if init != nil { + *n.PtrInit() = []ir.Node{init} + } var tswitch *ir.TypeSwitchGuard if l := n.Tag; l != nil && l.Op() == ir.OTYPESW { @@ -1259,13 +1263,6 @@ func (p *noder) selectStmt(stmt *syntax.SelectStmt) ir.Node { return ir.NewSelectStmt(p.pos(stmt), p.commClauses(stmt.Body, stmt.Rbrace)) } -func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node { - if stmt == nil { - return nil - } - return []ir.Node{p.stmt(stmt)} -} - func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CommStmt { nodes := make([]*ir.CommStmt, len(clauses)) for i, clause := range clauses { @@ -1275,7 +1272,7 @@ func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []* } p.openScope(clause.Pos()) - nodes[i] = ir.NewCommStmt(p.pos(clause), p.simpleStmt(clause.Comm), p.stmts(clause.Body)) + nodes[i] = ir.NewCommStmt(p.pos(clause), p.stmt(clause.Comm), p.stmts(clause.Body)) } if len(clauses) > 0 { p.closeScope(rbrace) diff --git a/src/cmd/compile/internal/typecheck/iexport.go b/src/cmd/compile/internal/typecheck/iexport.go index ef2c4527a9..bf093c60c7 100644 --- a/src/cmd/compile/internal/typecheck/iexport.go +++ b/src/cmd/compile/internal/typecheck/iexport.go @@ -1197,7 +1197,7 @@ func (w *exportWriter) commList(cases []*ir.CommStmt) { w.uint64(uint64(len(cases))) for _, cas := range cases { w.pos(cas.Pos()) - w.stmtList(cas.List) + w.node(cas.Comm) w.stmtList(cas.Body) } } diff --git a/src/cmd/compile/internal/typecheck/iimport.go b/src/cmd/compile/internal/typecheck/iimport.go index ba7ea2f156..af2dd84a38 100644 --- a/src/cmd/compile/internal/typecheck/iimport.go +++ b/src/cmd/compile/internal/typecheck/iimport.go @@ -792,7 +792,7 @@ func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt { func (r *importReader) commList() []*ir.CommStmt { cases := make([]*ir.CommStmt, r.uint64()) for i := range cases { - cases[i] = ir.NewCommStmt(r.pos(), r.stmtList(), r.stmtList()) + cases[i] = ir.NewCommStmt(r.pos(), r.node(), r.stmtList()) } return cases } @@ -1033,7 +1033,9 @@ func (r *importReader) node() ir.Node { case ir.OFOR: pos, init := r.pos(), r.stmtList() cond, post := r.exprsOrNil() - return ir.NewForStmt(pos, init, cond, post, r.stmtList()) + n := ir.NewForStmt(pos, nil, cond, post, r.stmtList()) + n.PtrInit().Set(init) + return n case ir.ORANGE: pos := r.pos() diff --git a/src/cmd/compile/internal/typecheck/stmt.go b/src/cmd/compile/internal/typecheck/stmt.go index 03c3e399eb..bfeea06e83 100644 --- a/src/cmd/compile/internal/typecheck/stmt.go +++ b/src/cmd/compile/internal/typecheck/stmt.go @@ -360,29 +360,23 @@ func tcReturn(n *ir.ReturnStmt) ir.Node { // select func tcSelect(sel *ir.SelectStmt) { - var def ir.Node + var def *ir.CommStmt lno := ir.SetPos(sel) Stmts(sel.Init()) for _, ncase := range sel.Cases { - if len(ncase.List) == 0 { + if ncase.Comm == nil { // default if def != nil { base.ErrorfAt(ncase.Pos(), "multiple defaults in select (first at %v)", ir.Line(def)) } else { def = ncase } - } else if len(ncase.List) > 1 { - base.ErrorfAt(ncase.Pos(), "select cases cannot be lists") } else { - ncase.List[0] = Stmt(ncase.List[0]) - n := ncase.List[0] + n := Stmt(ncase.Comm) ncase.Comm = n - ncase.List.Set(nil) - oselrecv2 := func(dst, recv ir.Node, colas bool) { - n := ir.NewAssignListStmt(n.Pos(), ir.OSELRECV2, nil, nil) - n.Lhs = []ir.Node{dst, ir.BlankNode} - n.Rhs = []ir.Node{recv} - n.Def = colas + oselrecv2 := func(dst, recv ir.Node, def bool) { + n := ir.NewAssignListStmt(n.Pos(), ir.OSELRECV2, []ir.Node{dst, ir.BlankNode}, []ir.Node{recv}) + n.Def = def n.SetTypecheck(1) ncase.Comm = n }