diff --git a/src/cmd/compile/internal/gc/swt.go b/src/cmd/compile/internal/gc/swt.go index 5089efe08b..4a8e9bceed 100644 --- a/src/cmd/compile/internal/gc/swt.go +++ b/src/cmd/compile/internal/gc/swt.go @@ -10,50 +10,6 @@ import ( "sort" ) -const ( - // expression switch - switchKindExpr = iota // switch a {...} or switch 5 {...} - switchKindTrue // switch true {...} or switch {...} - switchKindFalse // switch false {...} -) - -const ( - binarySearchMin = 4 // minimum number of cases for binary search - integerRangeMin = 2 // minimum size of integer ranges -) - -// An exprSwitch walks an expression switch. -type exprSwitch struct { - exprname *Node // node for the expression being switched on - kind int // kind of switch statement (switchKind*) -} - -// A typeSwitch walks a type switch. -type typeSwitch struct { - hashname *Node // node for the hash of the type of the variable being switched on - facename *Node // node for the concrete type of the variable being switched on - okname *Node // boolean node used for comma-ok type assertions -} - -// A caseClause is a single case clause in a switch statement. -type caseClause struct { - node *Node // points at case statement - ordinal int // position in switch - hash uint32 // hash of a type switch - // isconst indicates whether this case clause is a constant, - // for the purposes of the switch code generation. - // For expression switches, that's generally literals (case 5:, not case x:). - // For type switches, that's concrete types (case time.Time:), not interfaces (case io.Reader:). - isconst bool -} - -// caseClauses are all the case clauses in a switch statement. -type caseClauses struct { - list []caseClause // general cases - defjmp *Node // OGOTO for default case or OBREAK if no default case present - niljmp *Node // OGOTO for nil type case in a type switch -} - // typecheckswitch typechecks a switch statement. func typecheckswitch(n *Node) { typecheckslice(n.Ninit.Slice(), ctxStmt) @@ -71,7 +27,6 @@ func typecheckTypeSwitch(n *Node) { yyerrorl(n.Pos, "cannot type switch on non-interface value %L", n.Left.Right) t = nil } - n.Type = t // TODO(mdempsky): Remove; statements aren't typed. // We don't actually declare the type switch's guarded // declaration itself. So if there are no cases, we won't @@ -212,7 +167,6 @@ func typecheckExprSwitch(n *Node) { t = nil } } - n.Type = t // TODO(mdempsky): Remove; statements aren't typed. var defCase *Node var cs constSet @@ -265,422 +219,267 @@ func typecheckExprSwitch(n *Node) { // walkswitch walks a switch statement. func walkswitch(sw *Node) { - // convert switch {...} to switch true {...} - if sw.Left == nil { - sw.Left = nodbool(true) - sw.Left = typecheck(sw.Left, ctxExpr) - sw.Left = defaultlit(sw.Left, nil) - } - - if sw.Left.Op == OTYPESW { - var s typeSwitch - s.walk(sw) - } else { - var s exprSwitch - s.walk(sw) - } -} - -// walk generates an AST implementing sw. -// sw is an expression switch. -// The AST is generally of the form of a linear -// search using if..goto, although binary search -// is used with long runs of constants. -func (s *exprSwitch) walk(sw *Node) { // Guard against double walk, see #25776. if sw.List.Len() == 0 && sw.Nbody.Len() > 0 { return // Was fatal, but eliminating every possible source of double-walking is hard } - casebody(sw, nil) + if sw.Left != nil && sw.Left.Op == OTYPESW { + walkTypeSwitch(sw) + } else { + walkExprSwitch(sw) + } +} + +// walkExprSwitch generates an AST implementing sw. sw is an +// expression switch. +func walkExprSwitch(sw *Node) { + lno := setlineno(sw) cond := sw.Left sw.Left = nil - s.kind = switchKindExpr - if Isconst(cond, CTBOOL) { - s.kind = switchKindTrue - if !cond.Val().U.(bool) { - s.kind = switchKindFalse - } + // convert switch {...} to switch true {...} + if cond == nil { + cond = nodbool(true) + cond = typecheck(cond, ctxExpr) + cond = defaultlit(cond, nil) } // Given "switch string(byteslice)", - // with all cases being constants (or the default case), + // with all cases being side-effect free, // use a zero-cost alias of the byte slice. - // In theory, we could be more aggressive, - // allowing any side-effect-free expressions in cases, - // but it's a bit tricky because some of that information - // is unavailable due to the introduction of temporaries during order. - // Restricting to constants is simple and probably powerful enough. // Do this before calling walkexpr on cond, // because walkexpr will lower the string // conversion into a runtime call. // See issue 24937 for more discussion. - if cond.Op == OBYTES2STR { - ok := true - for _, cas := range sw.List.Slice() { - if cas.Op != OCASE { - Fatalf("switch string(byteslice) bad op: %v", cas.Op) - } - if cas.Left != nil && !Isconst(cas.Left, CTSTR) { - ok = false - break - } - } - if ok { - cond.Op = OBYTES2STRTMP - } + if cond.Op == OBYTES2STR && allCaseExprsAreSideEffectFree(sw) { + cond.Op = OBYTES2STRTMP } cond = walkexpr(cond, &sw.Ninit) - t := sw.Type - if t == nil { - return + if cond.Op != OLITERAL { + cond = copyexpr(cond, cond.Type, &sw.Nbody) } - // convert the switch into OIF statements - var cas []*Node - if s.kind == switchKindTrue || s.kind == switchKindFalse { - s.exprname = nodbool(s.kind == switchKindTrue) - } else if consttype(cond) > 0 { - // leave constants to enable dead code elimination (issue 9608) - s.exprname = cond - } else { - s.exprname = temp(cond.Type) - cas = []*Node{nod(OAS, s.exprname, cond)} // This gets walk()ed again in walkstmtlist just before end of this function. See #29562. - typecheckslice(cas, ctxStmt) - } - - // Enumerate the cases and prepare the default case. - clauses := s.genCaseClauses(sw.List.Slice()) - sw.List.Set(nil) - cc := clauses.list - - // handle the cases in order - for len(cc) > 0 { - run := 1 - if okforcmp[t.Etype] && cc[0].isconst { - // do binary search on runs of constants - for ; run < len(cc) && cc[run].isconst; run++ { - } - // sort and compile constants - sort.Sort(caseClauseByConstVal(cc[:run])) - } - - a := s.walkCases(cc[:run]) - cas = append(cas, a) - cc = cc[run:] - } - - // handle default case - if nerrors == 0 { - cas = append(cas, clauses.defjmp) - sw.Nbody.Prepend(cas...) - walkstmtlist(sw.Nbody.Slice()) - } -} - -// walkCases generates an AST implementing the cases in cc. -func (s *exprSwitch) walkCases(cc []caseClause) *Node { - if len(cc) < binarySearchMin { - // linear search - var cas []*Node - for _, c := range cc { - n := c.node - lno := setlineno(n) - - a := nod(OIF, nil, nil) - if rng := n.List.Slice(); rng != nil { - // Integer range. - // exprname is a temp or a constant, - // so it is safe to evaluate twice. - // In most cases, this conjunction will be - // rewritten by walkinrange into a single comparison. - low := nod(OGE, s.exprname, rng[0]) - high := nod(OLE, s.exprname, rng[1]) - a.Left = nod(OANDAND, low, high) - } else if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { - a.Left = nod(OEQ, s.exprname, n.Left) // if name == val - } else if s.kind == switchKindTrue { - a.Left = n.Left // if val - } else { - // s.kind == switchKindFalse - a.Left = nod(ONOT, n.Left, nil) // if !val - } - a.Left = typecheck(a.Left, ctxExpr) - a.Left = defaultlit(a.Left, nil) - a.Nbody.Set1(n.Right) // goto l - - cas = append(cas, a) - lineno = lno - } - return liststmt(cas) - } - - // find the middle and recur - half := len(cc) / 2 - a := nod(OIF, nil, nil) - n := cc[half-1].node - var mid *Node - if rng := n.List.Slice(); rng != nil { - mid = rng[1] // high end of range - } else { - mid = n.Left - } - le := nod(OLE, s.exprname, mid) - if Isconst(mid, CTSTR) { - // Search by length and then by value; see caseClauseByConstVal. - lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil)) - leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil)) - a.Left = nod(OOROR, lenlt, nod(OANDAND, leneq, le)) - } else { - a.Left = le - } - a.Left = typecheck(a.Left, ctxExpr) - a.Left = defaultlit(a.Left, nil) - a.Nbody.Set1(s.walkCases(cc[:half])) - a.Rlist.Set1(s.walkCases(cc[half:])) - return a -} - -// casebody builds separate lists of statements and cases. -// It makes labels between cases and statements -// and deals with fallthrough, break, and unreachable statements. -func casebody(sw *Node, typeswvar *Node) { - if sw.List.Len() == 0 { - return - } - - lno := setlineno(sw) - - var cas []*Node // cases - var stat []*Node // statements - var def *Node // defaults - br := nod(OBREAK, nil, nil) - - for _, n := range sw.List.Slice() { - setlineno(n) - if n.Op != OXCASE { - Fatalf("casebody %v", n.Op) - } - n.Op = OCASE - needvar := n.List.Len() != 1 || n.List.First().Op == OLITERAL - - lbl := autolabel(".s") - jmp := nodSym(OGOTO, nil, lbl) - switch n.List.Len() { - case 0: - // default - if def != nil { - yyerrorl(n.Pos, "more than one default case") - } - // reuse original default case - n.Right = jmp - def = n - case 1: - // one case -- reuse OCASE node - n.Left = n.List.First() - n.Right = jmp - n.List.Set(nil) - cas = append(cas, n) - default: - // Expand multi-valued cases and detect ranges of integer cases. - if typeswvar != nil || sw.Left.Type.IsInterface() || !n.List.First().Type.IsInteger() || n.List.Len() < integerRangeMin { - // Can't use integer ranges. Expand each case into a separate node. - for _, n1 := range n.List.Slice() { - cas = append(cas, nod(OCASE, n1, jmp)) - } - break - } - // Find integer ranges within runs of constants. - s := n.List.Slice() - j := 0 - for j < len(s) { - // Find a run of constants. - var run int - for run = j; run < len(s) && Isconst(s[run], CTINT); run++ { - } - if run-j >= integerRangeMin { - // Search for integer ranges in s[j:run]. - // Typechecking is done, so all values are already in an appropriate range. - search := s[j:run] - sort.Sort(constIntNodesByVal(search)) - for beg, end := 0, 1; end <= len(search); end++ { - if end < len(search) && search[end].Int64() == search[end-1].Int64()+1 { - continue - } - if end-beg >= integerRangeMin { - // Record range in List. - c := nod(OCASE, nil, jmp) - c.List.Set2(search[beg], search[end-1]) - cas = append(cas, c) - } else { - // Not large enough for range; record separately. - for _, n := range search[beg:end] { - cas = append(cas, nod(OCASE, n, jmp)) - } - } - beg = end - } - j = run - } - // Advance to next constant, adding individual non-constant - // or as-yet-unhandled constant cases as we go. - for ; j < len(s) && (j < run || !Isconst(s[j], CTINT)); j++ { - cas = append(cas, nod(OCASE, s[j], jmp)) - } - } - } - - stat = append(stat, nodSym(OLABEL, nil, lbl)) - if typeswvar != nil && needvar && n.Rlist.Len() != 0 { - l := []*Node{ - nod(ODCL, n.Rlist.First(), nil), - nod(OAS, n.Rlist.First(), typeswvar), - } - typecheckslice(l, ctxStmt) - stat = append(stat, l...) - } - stat = append(stat, n.Nbody.Slice()...) - - // Search backwards for the index of the fallthrough - // statement. Do not assume it'll be in the last - // position, since in some cases (e.g. when the statement - // list contains autotmp_ variables), one or more OVARKILL - // nodes will be at the end of the list. - fallIndex := len(stat) - 1 - for stat[fallIndex].Op == OVARKILL { - fallIndex-- - } - last := stat[fallIndex] - if last.Op != OFALL { - stat = append(stat, br) - } - } - - stat = append(stat, br) - if def != nil { - cas = append(cas, def) - } - - sw.List.Set(cas) - sw.Nbody.Set(stat) lineno = lno -} -// genCaseClauses generates the caseClauses value for clauses. -func (s *exprSwitch) genCaseClauses(clauses []*Node) caseClauses { - var cc caseClauses - for _, n := range clauses { - if n.Left == nil && n.List.Len() == 0 { - // default case - if cc.defjmp != nil { + s := exprSwitch{ + exprname: cond, + } + + br := nod(OBREAK, nil, nil) + var defaultGoto *Node + var body Nodes + for _, ncase := range sw.List.Slice() { + label := autolabel(".s") + jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label)) + + // Process case dispatch. + if ncase.List.Len() == 0 { + if defaultGoto != nil { Fatalf("duplicate default case not detected during typechecking") } - cc.defjmp = n.Right - continue + defaultGoto = jmp } - c := caseClause{node: n, ordinal: len(cc.list)} - if n.List.Len() > 0 { - c.isconst = true + + for _, n1 := range ncase.List.Slice() { + s.Add(ncase.Pos, n1, jmp) } - switch consttype(n.Left) { - case CTFLT, CTINT, CTRUNE, CTSTR: - c.isconst = true + + // Process body. + body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label))) + body.Append(ncase.Nbody.Slice()...) + if !hasFall(ncase.Nbody.Slice()) { + body.Append(br) } - cc.list = append(cc.list, c) + } + sw.List.Set(nil) + + if defaultGoto == nil { + defaultGoto = br } - if cc.defjmp == nil { - cc.defjmp = nod(OBREAK, nil, nil) - } - return cc + s.Emit(&sw.Nbody) + sw.Nbody.Append(defaultGoto) + sw.Nbody.AppendNodes(&body) + walkstmtlist(sw.Nbody.Slice()) } -// genCaseClauses generates the caseClauses value for clauses. -func (s *typeSwitch) genCaseClauses(clauses []*Node) caseClauses { - var cc caseClauses - for _, n := range clauses { - switch { - case n.Left == nil: - // default case - if cc.defjmp != nil { - Fatalf("duplicate default case not detected during typechecking") - } - cc.defjmp = n.Right - continue - case n.Left.Op == OLITERAL: - // nil case in type switch - if cc.niljmp != nil { - Fatalf("duplicate nil case not detected during typechecking") - } - cc.niljmp = n.Right - continue - } +// An exprSwitch walks an expression switch. +type exprSwitch struct { + exprname *Node // value being switched on - // general case - c := caseClause{ - node: n, - ordinal: len(cc.list), - isconst: !n.Left.Type.IsInterface(), - hash: typehash(n.Left.Type), - } - cc.list = append(cc.list, c) - } - - if cc.defjmp == nil { - cc.defjmp = nod(OBREAK, nil, nil) - } - - return cc + done Nodes + clauses []exprClause } -// walk generates an AST that implements sw, -// where sw is a type switch. -// The AST is generally of the form of a linear -// search using if..goto, although binary search -// is used with long runs of concrete types. -func (s *typeSwitch) walk(sw *Node) { - cond := sw.Left +type exprClause struct { + pos src.XPos + lo, hi *Node + jmp *Node +} + +func (s *exprSwitch) Add(pos src.XPos, expr, jmp *Node) { + c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp} + if okforcmp[s.exprname.Type.Etype] && expr.Op == OLITERAL { + s.clauses = append(s.clauses, c) + return + } + + s.flush() + s.clauses = append(s.clauses, c) + s.flush() +} + +func (s *exprSwitch) Emit(out *Nodes) { + s.flush() + out.AppendNodes(&s.done) +} + +func (s *exprSwitch) flush() { + cc := s.clauses + s.clauses = nil + if len(cc) == 0 { + return + } + + // Caution: If len(cc) == 1, then cc[0] might not an OLITERAL. + // The code below is structured to implicitly handle this case + // (e.g., sort.Slice doesn't need to invoke the less function + // when there's only a single slice element). + + // Sort strings by length and then by value. + // It is much cheaper to compare lengths than values, + // and all we need here is consistency. + // We respect this sorting below. + sort.Slice(cc, func(i, j int) bool { + vi := cc[i].lo.Val() + vj := cc[j].lo.Val() + + if s.exprname.Type.IsString() { + si := vi.U.(string) + sj := vj.U.(string) + if len(si) != len(sj) { + return len(si) < len(sj) + } + return si < sj + } + + return compareOp(vi, OLT, vj) + }) + + // Merge consecutive integer cases. + if s.exprname.Type.IsInteger() { + merged := cc[:1] + for _, c := range cc[1:] { + last := &merged[len(merged)-1] + if last.jmp == c.jmp && last.hi.Int64()+1 == c.lo.Int64() { + last.hi = c.lo + } else { + merged = append(merged, c) + } + } + cc = merged + } + + binarySearch(len(cc), &s.done, + func(i int) *Node { + mid := cc[i-1].hi + + le := nod(OLE, s.exprname, mid) + if s.exprname.Type.IsString() { + // Compare strings by length and then + // by value; see sort.Slice above. + lenlt := nod(OLT, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil)) + leneq := nod(OEQ, nod(OLEN, s.exprname, nil), nod(OLEN, mid, nil)) + le = nod(OOROR, lenlt, nod(OANDAND, leneq, le)) + } + return le + }, + func(i int, out *Nodes) { + c := &cc[i] + + nif := nodl(c.pos, OIF, c.test(s.exprname), nil) + nif.Left = typecheck(nif.Left, ctxExpr) + nif.Left = defaultlit(nif.Left, nil) + nif.Nbody.Set1(c.jmp) + out.Append(nif) + }, + ) +} + +func (c *exprClause) test(exprname *Node) *Node { + // Integer range. + if c.hi != c.lo { + low := nodl(c.pos, OGE, exprname, c.lo) + high := nodl(c.pos, OLE, exprname, c.hi) + return nodl(c.pos, OANDAND, low, high) + } + + // Optimize "switch true { ...}" and "switch false { ... }". + if Isconst(exprname, CTBOOL) && !c.lo.Type.IsInterface() { + if exprname.Val().U.(bool) { + return c.lo + } else { + return nodl(c.pos, ONOT, c.lo, nil) + } + } + + return nodl(c.pos, OEQ, exprname, c.lo) +} + +func allCaseExprsAreSideEffectFree(sw *Node) bool { + // In theory, we could be more aggressive, allowing any + // side-effect-free expressions in cases, but it's a bit + // tricky because some of that information is unavailable due + // to the introduction of temporaries during order. + // Restricting to constants is simple and probably powerful + // enough. + + for _, ncase := range sw.List.Slice() { + if ncase.Op != OXCASE { + Fatalf("switch string(byteslice) bad op: %v", ncase.Op) + } + for _, v := range ncase.List.Slice() { + if v.Op != OLITERAL { + return false + } + } + } + return true +} + +// hasFall reports whether stmts ends with a "fallthrough" statement. +func hasFall(stmts []*Node) bool { + // Search backwards for the index of the fallthrough + // statement. Do not assume it'll be in the last + // position, since in some cases (e.g. when the statement + // list contains autotmp_ variables), one or more OVARKILL + // nodes will be at the end of the list. + + i := len(stmts) - 1 + for i >= 0 && stmts[i].Op == OVARKILL { + i-- + } + return i >= 0 && stmts[i].Op == OFALL +} + +// walkTypeSwitch generates an AST that implements sw, where sw is a +// type switch. +func walkTypeSwitch(sw *Node) { + var s typeSwitch + s.facename = sw.Left.Right sw.Left = nil - if cond == nil { - sw.List.Set(nil) - return - } - if cond.Right == nil { - yyerrorl(sw.Pos, "type switch must have an assignment") - return - } - - cond.Right = walkexpr(cond.Right, &sw.Ninit) - if !cond.Right.Type.IsInterface() { - yyerrorl(sw.Pos, "type switch must be on an interface") - return - } - - var cas []*Node - - // predeclare temporary variables and the boolean var - s.facename = temp(cond.Right.Type) - - a := nod(OAS, s.facename, cond.Right) - a = typecheck(a, ctxStmt) - cas = append(cas, a) - + s.facename = walkexpr(s.facename, &sw.Ninit) + s.facename = copyexpr(s.facename, s.facename.Type, &sw.Nbody) s.okname = temp(types.Types[TBOOL]) - s.okname = typecheck(s.okname, ctxExpr) - s.hashname = temp(types.Types[TUINT32]) - s.hashname = typecheck(s.hashname, ctxExpr) - - // set up labels and jumps - casebody(sw, s.facename) - - clauses := s.genCaseClauses(sw.List.Slice()) - sw.List.Set(nil) - def := clauses.defjmp + // Get interface descriptor word. + // For empty interfaces this will be the type. + // For non-empty interfaces this will be the itab. + itab := nod(OITAB, s.facename, nil) // For empty interfaces, do: // if e._type == nil { @@ -688,230 +487,235 @@ func (s *typeSwitch) walk(sw *Node) { // } // h := e._type.hash // Use a similar strategy for non-empty interfaces. - - // Get interface descriptor word. - // For empty interfaces this will be the type. - // For non-empty interfaces this will be the itab. - itab := nod(OITAB, s.facename, nil) - - // Check for nil first. - i := nod(OIF, nil, nil) - i.Left = nod(OEQ, itab, nodnil()) - if clauses.niljmp != nil { - // Do explicit nil case right here. - i.Nbody.Set1(clauses.niljmp) - } else { - // Jump to default case. - lbl := autolabel(".s") - i.Nbody.Set1(nodSym(OGOTO, nil, lbl)) - // Wrap default case with label. - blk := nod(OBLOCK, nil, nil) - blk.List.Set2(nodSym(OLABEL, nil, lbl), def) - def = blk - } - i.Left = typecheck(i.Left, ctxExpr) - i.Left = defaultlit(i.Left, nil) - cas = append(cas, i) + ifNil := nod(OIF, nil, nil) + ifNil.Left = nod(OEQ, itab, nodnil()) + ifNil.Left = typecheck(ifNil.Left, ctxExpr) + ifNil.Left = defaultlit(ifNil.Left, nil) + // ifNil.Nbody assigned at end. + sw.Nbody.Append(ifNil) // Load hash from type or itab. - h := nodSym(ODOTPTR, itab, nil) - h.Type = types.Types[TUINT32] - h.SetTypecheck(1) - if cond.Right.Type.IsEmptyInterface() { - h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type + dotHash := nodSym(ODOTPTR, itab, nil) + dotHash.Type = types.Types[TUINT32] + dotHash.SetTypecheck(1) + if s.facename.Type.IsEmptyInterface() { + dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime._type } else { - h.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab + dotHash.Xoffset = int64(2 * Widthptr) // offset of hash in runtime.itab } - h.SetBounded(true) // guaranteed not to fault - a = nod(OAS, s.hashname, h) - a = typecheck(a, ctxStmt) - cas = append(cas, a) + dotHash.SetBounded(true) // guaranteed not to fault + s.hashname = copyexpr(dotHash, dotHash.Type, &sw.Nbody) - cc := clauses.list - - // insert type equality check into each case block - for _, c := range cc { - c.node.Right = s.typeone(c.node) - } - - // generate list of if statements, binary search for constant sequences - for len(cc) > 0 { - if !cc[0].isconst { - n := cc[0].node - cas = append(cas, n.Right) - cc = cc[1:] - continue + br := nod(OBREAK, nil, nil) + var defaultGoto, nilGoto *Node + var body Nodes + for _, ncase := range sw.List.Slice() { + var caseVar *Node + if ncase.Rlist.Len() != 0 { + caseVar = ncase.Rlist.First() } - // identify run of constants - var run int - for run = 1; run < len(cc) && cc[run].isconst; run++ { - } + // For single-type cases, we initialize the case + // variable as part of the type assertion; but in + // other cases, we initialize it in the body. + singleType := ncase.List.Len() == 1 && ncase.List.First().Op == OTYPE - // sort by hash - sort.Sort(caseClauseByType(cc[:run])) + label := autolabel(".s") - // for debugging: linear search - if false { - for i := 0; i < run; i++ { - n := cc[i].node - cas = append(cas, n.Right) + jmp := npos(ncase.Pos, nodSym(OGOTO, nil, label)) + if ncase.List.Len() == 0 { // default: + if defaultGoto != nil { + Fatalf("duplicate default case not detected during typechecking") } - continue + defaultGoto = jmp } - // combine adjacent cases with the same hash - var batch []caseClause - for i, j := 0, 0; i < run; i = j { - hash := []*Node{cc[i].node.Right} - for j = i + 1; j < run && cc[i].hash == cc[j].hash; j++ { - hash = append(hash, cc[j].node.Right) + for _, n1 := range ncase.List.Slice() { + if n1.isNil() { // case nil: + if nilGoto != nil { + Fatalf("duplicate nil case not detected during typechecking") + } + nilGoto = jmp + continue + } + + if singleType { + s.Add(n1.Type, caseVar, jmp) + } else { + s.Add(n1.Type, nil, jmp) } - cc[i].node.Right = liststmt(hash) - batch = append(batch, cc[i]) } - // binary search among cases to narrow by hash - cas = append(cas, s.walkCases(batch)) - cc = cc[run:] + body.Append(npos(ncase.Pos, nodSym(OLABEL, nil, label))) + if caseVar != nil && !singleType { + l := []*Node{ + nodl(ncase.Pos, ODCL, caseVar, nil), + nodl(ncase.Pos, OAS, caseVar, s.facename), + } + typecheckslice(l, ctxStmt) + body.Append(l...) + } + body.Append(ncase.Nbody.Slice()...) + body.Append(br) + } + sw.List.Set(nil) + + if defaultGoto == nil { + defaultGoto = br } - // handle default case - if nerrors == 0 { - cas = append(cas, def) - sw.Nbody.Prepend(cas...) - sw.List.Set(nil) - walkstmtlist(sw.Nbody.Slice()) + if nilGoto != nil { + ifNil.Nbody.Set1(nilGoto) + } else { + // TODO(mdempsky): Just use defaultGoto directly. + + // Jump to default case. + label := autolabel(".s") + ifNil.Nbody.Set1(nodSym(OGOTO, nil, label)) + // Wrap default case with label. + blk := nod(OBLOCK, nil, nil) + blk.List.Set2(nodSym(OLABEL, nil, label), defaultGoto) + defaultGoto = blk } + + s.Emit(&sw.Nbody) + sw.Nbody.Append(defaultGoto) + sw.Nbody.AppendNodes(&body) + + walkstmtlist(sw.Nbody.Slice()) } -// typeone generates an AST that jumps to the -// case body if the variable is of type t. -func (s *typeSwitch) typeone(t *Node) *Node { - var name *Node - var init Nodes - if t.Rlist.Len() == 0 { - name = nblank - nblank = typecheck(nblank, ctxExpr|ctxAssign) - } else { - name = t.Rlist.First() - init.Append(nod(ODCL, name, nil)) - a := nod(OAS, name, nil) - a = typecheck(a, ctxStmt) - init.Append(a) - } +// A typeSwitch walks a type switch. +type typeSwitch struct { + // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic: + facename *Node // value being type-switched on + hashname *Node // type hash of the value being type-switched on + okname *Node // boolean used for comma-ok type assertions - a := nod(OAS2, nil, nil) - a.List.Set2(name, s.okname) // name, ok = - b := nod(ODOTTYPE, s.facename, nil) - b.Type = t.Left.Type // interface.(type) - a.Rlist.Set1(b) - a = typecheck(a, ctxStmt) - a = walkexpr(a, &init) - init.Append(a) - - c := nod(OIF, nil, nil) - c.Left = s.okname - c.Nbody.Set1(t.Right) // if ok { goto l } - - init.Append(c) - return init.asblock() + done Nodes + clauses []typeClause } -// walkCases generates an AST implementing the cases in cc. -func (s *typeSwitch) walkCases(cc []caseClause) *Node { - if len(cc) < binarySearchMin { - var cas []*Node - for _, c := range cc { - n := c.node - if !c.isconst { - Fatalf("typeSwitch walkCases") - } +type typeClause struct { + hash uint32 + body Nodes +} + +func (s *typeSwitch) Add(typ *types.Type, caseVar *Node, jmp *Node) { + var body Nodes + if caseVar != nil { + l := []*Node{ + nod(ODCL, caseVar, nil), + nod(OAS, caseVar, nil), + } + typecheckslice(l, ctxStmt) + body.Append(l...) + } else { + caseVar = nblank + } + + // cv, ok = iface.(type) + as := nod(OAS2, nil, nil) + as.List.Set2(caseVar, s.okname) // cv, ok = + dot := nod(ODOTTYPE, s.facename, nil) + dot.Type = typ // iface.(type) + as.Rlist.Set1(dot) + as = typecheck(as, ctxStmt) + as = walkexpr(as, &body) + body.Append(as) + + // if ok { goto label } + nif := nod(OIF, nil, nil) + nif.Left = s.okname + nif.Nbody.Set1(jmp) + body.Append(nif) + + if !typ.IsInterface() { + s.clauses = append(s.clauses, typeClause{ + hash: typehash(typ), + body: body, + }) + return + } + + s.flush() + s.done.AppendNodes(&body) +} + +func (s *typeSwitch) Emit(out *Nodes) { + s.flush() + out.AppendNodes(&s.done) +} + +func (s *typeSwitch) flush() { + cc := s.clauses + s.clauses = nil + if len(cc) == 0 { + return + } + + sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash }) + + // Combine adjacent cases with the same hash. + merged := cc[:1] + for _, c := range cc[1:] { + last := &merged[len(merged)-1] + if last.hash == c.hash { + last.body.AppendNodes(&c.body) + } else { + merged = append(merged, c) + } + } + cc = merged + + binarySearch(len(cc), &s.done, + func(i int) *Node { + return nod(OLE, s.hashname, nodintconst(int64(cc[i-1].hash))) + }, + func(i int, out *Nodes) { + // TODO(mdempsky): Omit hash equality check if + // there's only one type. + c := cc[i] a := nod(OIF, nil, nil) a.Left = nod(OEQ, s.hashname, nodintconst(int64(c.hash))) a.Left = typecheck(a.Left, ctxExpr) a.Left = defaultlit(a.Left, nil) - a.Nbody.Set1(n.Right) - cas = append(cas, a) + a.Nbody.AppendNodes(&c.body) + out.Append(a) + }, + ) +} + +// binarySearch constructs a binary search tree for handling n cases, +// and appends it to out. It's used for efficiently implementing +// switch statements. +// +// less(i) should return a boolean expression. If it evaluates true, +// then cases [0, i) will be tested; otherwise, cases [i, n). +// +// base(i, out) should append statements to out to test the i'th case. +func binarySearch(n int, out *Nodes, less func(i int) *Node, base func(i int, out *Nodes)) { + const binarySearchMin = 4 // minimum number of cases for binary search + + var do func(lo, hi int, out *Nodes) + do = func(lo, hi int, out *Nodes) { + n := hi - lo + if n < binarySearchMin { + for i := lo; i < hi; i++ { + base(i, out) + } + return } - return liststmt(cas) + + half := lo + n/2 + nif := nod(OIF, nil, nil) + nif.Left = less(half) + nif.Left = typecheck(nif.Left, ctxExpr) + nif.Left = defaultlit(nif.Left, nil) + do(lo, half, &nif.Nbody) + do(half, hi, &nif.Rlist) + out.Append(nif) } - // find the middle and recur - half := len(cc) / 2 - a := nod(OIF, nil, nil) - a.Left = nod(OLE, s.hashname, nodintconst(int64(cc[half-1].hash))) - a.Left = typecheck(a.Left, ctxExpr) - a.Left = defaultlit(a.Left, nil) - a.Nbody.Set1(s.walkCases(cc[:half])) - a.Rlist.Set1(s.walkCases(cc[half:])) - return a -} - -// caseClauseByConstVal sorts clauses by constant value to enable binary search. -type caseClauseByConstVal []caseClause - -func (x caseClauseByConstVal) Len() int { return len(x) } -func (x caseClauseByConstVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } -func (x caseClauseByConstVal) Less(i, j int) bool { - // n1 and n2 might be individual constants or integer ranges. - // We have checked for duplicates already, - // so ranges can be safely represented by any value in the range. - n1 := x[i].node - var v1 interface{} - if s := n1.List.Slice(); s != nil { - v1 = s[0].Val().U - } else { - v1 = n1.Left.Val().U - } - - n2 := x[j].node - var v2 interface{} - if s := n2.List.Slice(); s != nil { - v2 = s[0].Val().U - } else { - v2 = n2.Left.Val().U - } - - switch v1 := v1.(type) { - case *Mpflt: - return v1.Cmp(v2.(*Mpflt)) < 0 - case *Mpint: - return v1.Cmp(v2.(*Mpint)) < 0 - case string: - // Sort strings by length and then by value. - // It is much cheaper to compare lengths than values, - // and all we need here is consistency. - // We respect this sorting in exprSwitch.walkCases. - a := v1 - b := v2.(string) - if len(a) != len(b) { - return len(a) < len(b) - } - return a < b - } - - Fatalf("caseClauseByConstVal passed bad clauses %v < %v", x[i].node.Left, x[j].node.Left) - return false -} - -type caseClauseByType []caseClause - -func (x caseClauseByType) Len() int { return len(x) } -func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } -func (x caseClauseByType) Less(i, j int) bool { - c1, c2 := x[i], x[j] - // sort by hash code, then ordinal (for the rare case of hash collisions) - if c1.hash != c2.hash { - return c1.hash < c2.hash - } - return c1.ordinal < c2.ordinal -} - -type constIntNodesByVal []*Node - -func (x constIntNodesByVal) Len() int { return len(x) } -func (x constIntNodesByVal) Swap(i, j int) { x[i], x[j] = x[j], x[i] } -func (x constIntNodesByVal) Less(i, j int) bool { - return x[i].Val().U.(*Mpint).Cmp(x[j].Val().U.(*Mpint)) < 0 + do(0, n, out) } diff --git a/src/cmd/compile/internal/gc/swt_test.go b/src/cmd/compile/internal/gc/swt_test.go deleted file mode 100644 index 2f73ef7b99..0000000000 --- a/src/cmd/compile/internal/gc/swt_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package gc - -import ( - "testing" -) - -func nodrune(r rune) *Node { - v := new(Mpint) - v.SetInt64(int64(r)) - v.Rune = true - return nodlit(Val{v}) -} - -func nodflt(f float64) *Node { - v := newMpflt() - v.SetFloat64(f) - return nodlit(Val{v}) -} - -func TestCaseClauseByConstVal(t *testing.T) { - tests := []struct { - a, b *Node - }{ - // CTFLT - {nodflt(0.1), nodflt(0.2)}, - // CTINT - {nodintconst(0), nodintconst(1)}, - // CTRUNE - {nodrune('a'), nodrune('b')}, - // CTSTR - {nodlit(Val{"ab"}), nodlit(Val{"abc"})}, - {nodlit(Val{"ab"}), nodlit(Val{"xyz"})}, - {nodlit(Val{"abc"}), nodlit(Val{"xyz"})}, - } - for i, test := range tests { - a := caseClause{node: nod(OXXX, test.a, nil)} - b := caseClause{node: nod(OXXX, test.b, nil)} - s := caseClauseByConstVal{a, b} - if less := s.Less(0, 1); !less { - t.Errorf("%d: caseClauseByConstVal(%v, %v) = false", i, test.a, test.b) - } - if less := s.Less(1, 0); less { - t.Errorf("%d: caseClauseByConstVal(%v, %v) = true", i, test.a, test.b) - } - } -}