diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 8f86e16112..26176af07c 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -16,6 +16,10 @@ const ( unknown branch = iota positive negative + // The outedges from a jump table are jumpTable0, + // jumpTable0+1, jumpTable0+2, etc. There could be an + // arbitrary number so we can't list them all here. + jumpTable0 ) // relation represents the set of possible relations between @@ -940,20 +944,31 @@ func prove(f *Func) { // getBranch returns the range restrictions added by p // when reaching b. p is the immediate dominator of b. func getBranch(sdom SparseTree, p *Block, b *Block) branch { - if p == nil || p.Kind != BlockIf { + if p == nil { return unknown } - // If p and p.Succs[0] are dominators it means that every path - // from entry to b passes through p and p.Succs[0]. We care that - // no path from entry to b passes through p.Succs[1]. If p.Succs[0] - // has one predecessor then (apart from the degenerate case), - // there is no path from entry that can reach b through p.Succs[1]. - // TODO: how about p->yes->b->yes, i.e. a loop in yes. - if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 { - return positive - } - if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 { - return negative + switch p.Kind { + case BlockIf: + // If p and p.Succs[0] are dominators it means that every path + // from entry to b passes through p and p.Succs[0]. We care that + // no path from entry to b passes through p.Succs[1]. If p.Succs[0] + // has one predecessor then (apart from the degenerate case), + // there is no path from entry that can reach b through p.Succs[1]. + // TODO: how about p->yes->b->yes, i.e. a loop in yes. + if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 { + return positive + } + if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 { + return negative + } + case BlockJumpTable: + // TODO: this loop can lead to quadratic behavior, as + // getBranch can be called len(p.Succs) times. + for i, e := range p.Succs { + if sdom.IsAncestorEq(e.b, b) && len(e.b.Preds) == 1 { + return jumpTable0 + branch(i) + } + } } return unknown } @@ -984,11 +999,36 @@ func addIndVarRestrictions(ft *factsTable, b *Block, iv indVar) { // branching from Block b in direction br. func addBranchRestrictions(ft *factsTable, b *Block, br branch) { c := b.Controls[0] - switch br { - case negative: + switch { + case br == negative: addRestrictions(b, ft, boolean, nil, c, eq) - case positive: + case br == positive: addRestrictions(b, ft, boolean, nil, c, lt|gt) + case br >= jumpTable0: + idx := br - jumpTable0 + val := int64(idx) + if v, off := isConstDelta(c); v != nil { + // Establish the bound on the underlying value we're switching on, + // not on the offset-ed value used as the jump table index. + c = v + val -= off + } + old, ok := ft.limits[c.ID] + if !ok { + old = noLimit + } + ft.limitStack = append(ft.limitStack, limitFact{c.ID, old}) + if val < old.min || val > old.max || uint64(val) < old.umin || uint64(val) > old.umax { + ft.unsat = true + if b.Func.pass.debug > 2 { + b.Func.Warnl(b.Pos, "block=%s outedge=%d %s=%d unsat", b, idx, c, val) + } + } else { + ft.limits[c.ID] = limit{val, val, uint64(val), uint64(val)} + if b.Func.pass.debug > 2 { + b.Func.Warnl(b.Pos, "block=%s outedge=%d %s=%d", b, idx, c, val) + } + } default: panic("unknown branch") } @@ -1343,10 +1383,14 @@ func removeBranch(b *Block, branch branch) { // attempt to preserve statement marker. b.Pos = b.Pos.WithIsStmt() } - b.Kind = BlockFirst - b.ResetControls() - if branch == positive { - b.swapSuccessors() + if branch == positive || branch == negative { + b.Kind = BlockFirst + b.ResetControls() + if branch == positive { + b.swapSuccessors() + } + } else { + // TODO: figure out how to remove an entry from a jump table } } diff --git a/src/cmd/compile/internal/test/switch_test.go b/src/cmd/compile/internal/test/switch_test.go index 6f7bfcf3d8..30dee6257e 100644 --- a/src/cmd/compile/internal/test/switch_test.go +++ b/src/cmd/compile/internal/test/switch_test.go @@ -77,6 +77,49 @@ func benchmarkSwitch32(b *testing.B, predictable bool) { sink = n } +func BenchmarkSwitchStringPredictable(b *testing.B) { + benchmarkSwitchString(b, true) +} +func BenchmarkSwitchStringUnpredictable(b *testing.B) { + benchmarkSwitchString(b, false) +} +func benchmarkSwitchString(b *testing.B, predictable bool) { + a := []string{ + "foo", + "foo1", + "foo22", + "foo333", + "foo4444", + "foo55555", + "foo666666", + "foo7777777", + } + n := 0 + rng := newRNG() + for i := 0; i < b.N; i++ { + rng = rng.next(predictable) + switch a[rng.value()&7] { + case "foo": + n += 1 + case "foo1": + n += 2 + case "foo22": + n += 3 + case "foo333": + n += 4 + case "foo4444": + n += 5 + case "foo55555": + n += 6 + case "foo666666": + n += 7 + case "foo7777777": + n += 8 + } + } + sink = n +} + // A simple random number generator used to make switches conditionally predictable. type rng uint64 diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index a4003ecea4..6a2dbe1753 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -67,6 +67,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { base.Pos = lno s := exprSwitch{ + pos: lno, exprname: cond, } @@ -113,6 +114,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { // An exprSwitch walks an expression switch. type exprSwitch struct { + pos src.XPos exprname ir.Node // value being switched on done ir.Nodes @@ -183,17 +185,59 @@ func (s *exprSwitch) flush() { } runs = append(runs, cc[start:]) - // Perform two-level binary search. - binarySearch(len(runs), &s.done, - func(i int) ir.Node { - return ir.NewBinaryExpr(base.Pos, ir.OLE, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(runs[i-1]))) - }, - func(i int, nif *ir.IfStmt) { - run := runs[i] - nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(run))) - s.search(run, &nif.Body) - }, - ) + if len(runs) == 1 { + s.search(runs[0], &s.done) + return + } + // We have strings of more than one length. Generate an + // outer switch which switches on the length of the string + // and an inner switch in each case which resolves all the + // strings of the same length. The code looks something like this: + + // goto outerLabel + // len5: + // ... search among length 5 strings ... + // goto endLabel + // len8: + // ... search among length 8 strings ... + // goto endLabel + // ... other lengths ... + // outerLabel: + // switch len(s) { + // case 5: goto len5 + // case 8: goto len8 + // ... other lengths ... + // } + // endLabel: + + outerLabel := typecheck.AutoLabel(".s") + endLabel := typecheck.AutoLabel(".s") + + // Jump around all the individual switches for each length. + s.done.Append(ir.NewBranchStmt(s.pos, ir.OGOTO, outerLabel)) + + var outer exprSwitch + outer.exprname = ir.NewUnaryExpr(s.pos, ir.OLEN, s.exprname) + outer.exprname.SetType(types.Types[types.TINT]) + + for _, run := range runs { + // Target label to jump to when we match this length. + label := typecheck.AutoLabel(".s") + + // Search within this run of same-length strings. + pos := run[0].pos + s.done.Append(ir.NewLabelStmt(pos, label)) + s.search(run, &s.done) + s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel)) + + // Add length case to outer switch. + cas := ir.NewBasicLit(pos, constant.MakeInt64(runLen(run))) + jmp := ir.NewBranchStmt(pos, ir.OGOTO, label) + outer.Add(pos, cas, jmp) + } + s.done.Append(ir.NewLabelStmt(s.pos, outerLabel)) + outer.Emit(&s.done) + s.done.Append(ir.NewLabelStmt(s.pos, endLabel)) return } @@ -278,7 +322,6 @@ func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool { } } out.Append(jt) - // TODO: handle the size portion of string switches using a jump table. return true } @@ -587,6 +630,7 @@ func (s *typeSwitch) flush() { } cc = merged + // TODO: figure out if we could use a jump table using some low bits of the type hashes. binarySearch(len(cc), &s.done, func(i int) ir.Node { return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(int64(cc[i-1].hash)))