diff --git a/src/cmd/compile/internal/ssa/loopbce.go b/src/cmd/compile/internal/ssa/loopbce.go index b7dfaa33e3..3dbd7350ae 100644 --- a/src/cmd/compile/internal/ssa/loopbce.go +++ b/src/cmd/compile/internal/ssa/loopbce.go @@ -13,12 +13,14 @@ import ( type indVarFlags uint8 const ( - indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive) - indVarMaxInc // maximum value is inclusive (default: exclusive) + indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive) + indVarMaxInc // maximum value is inclusive (default: exclusive) + indVarCountDown // if set the iteration starts at max and count towards min (default: min towards max) ) type indVar struct { ind *Value // induction variable + nxt *Value // the incremented variable min *Value // minimum value, inclusive/exclusive depends on flags max *Value // maximum value, inclusive/exclusive depends on flags entry *Block // entry block in the loop. @@ -277,6 +279,7 @@ func findIndVar(f *Func) []indVar { if !inclusive { flags |= indVarMinExc } + flags |= indVarCountDown step = -step } if f.pass.debug >= 1 { @@ -285,6 +288,7 @@ func findIndVar(f *Func) []indVar { iv = append(iv, indVar{ ind: ind, + nxt: nxt, min: min, max: max, entry: b.Succs[0].b, diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 38758c3361..91f5fbe765 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -798,6 +798,166 @@ func (ft *factsTable) cleanup(f *Func) { // its negation. If either leads to a contradiction, it can trim that // successor. func prove(f *Func) { + // Find induction variables. Currently, findIndVars + // is limited to one induction variable per block. + var indVars map[*Block]indVar + for _, v := range findIndVar(f) { + ind := v.ind + if len(ind.Args) != 2 { + // the rewrite code assumes there is only ever two parents to loops + panic("unexpected induction with too many parents") + } + + nxt := v.nxt + if !(ind.Uses == 2 && // 2 used by comparison and next + nxt.Uses == 1) { // 1 used by induction + // ind or nxt is used inside the loop, add it for the facts table + if indVars == nil { + indVars = make(map[*Block]indVar) + } + indVars[v.entry] = v + continue + } else { + // Since this induction variable is not used for anything but counting the iterations, + // no point in putting it into the facts table. + } + + // try to rewrite to a downward counting loop checking against start if the + // loop body does not depends on ind or nxt and end is known before the loop. + // This reduce pressure on the register allocator because this do not need + // to use end on each iteration anymore. We compare against the start constant instead. + // That means this code: + // + // loop: + // ind = (Phi (Const [x]) nxt), + // if ind < end + // then goto enter_loop + // else goto exit_loop + // + // enter_loop: + // do something without using ind nor nxt + // nxt = inc + ind + // goto loop + // + // exit_loop: + // + // is rewritten to: + // + // loop: + // ind = (Phi end nxt) + // if (Const [x]) < ind + // then goto enter_loop + // else goto exit_loop + // + // enter_loop: + // do something without using ind nor nxt + // nxt = ind - inc + // goto loop + // + // exit_loop: + // + // this is better because it only require to keep ind then nxt alive while looping, + // while the original form keeps ind then nxt and end alive + start, end := v.min, v.max + if v.flags&indVarCountDown != 0 { + start, end = end, start + } + + if !(start.Op == OpConst8 || start.Op == OpConst16 || start.Op == OpConst32 || start.Op == OpConst64) { + // if start is not a constant we would be winning nothing from inverting the loop + continue + } + if end.Op == OpConst8 || end.Op == OpConst16 || end.Op == OpConst32 || end.Op == OpConst64 { + // TODO: if both start and end are constants we should rewrite such that the comparison + // is against zero and nxt is ++ or -- operation + // That means: + // for i := 2; i < 11; i += 2 { + // should be rewritten to: + // for i := 5; 0 < i; i-- { + continue + } + + header := ind.Block + check := header.Controls[0] + if check == nil { + // we don't know how to rewrite a loop that not simple comparison + continue + } + switch check.Op { + case OpLeq64, OpLeq32, OpLeq16, OpLeq8, + OpLess64, OpLess32, OpLess16, OpLess8: + default: + // we don't know how to rewrite a loop that not simple comparison + continue + } + if !((check.Args[0] == ind && check.Args[1] == end) || + (check.Args[1] == ind && check.Args[0] == end)) { + // we don't know how to rewrite a loop that not simple comparison + continue + } + if end.Block == ind.Block { + // we can't rewrite loops where the condition depends on the loop body + // this simple check is forced to work because if this is true a Phi in ind.Block must exists + continue + } + + // invert the check + check.Args[0], check.Args[1] = check.Args[1], check.Args[0] + + // invert start and end in the loop + for i, v := range check.Args { + if v != end { + continue + } + + check.SetArg(i, start) + goto replacedEnd + } + panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end)) + replacedEnd: + + for i, v := range ind.Args { + if v != start { + continue + } + + ind.SetArg(i, end) + goto replacedStart + } + panic(fmt.Sprintf("unreachable, ind: %v, start: %v, end: %v", ind, start, end)) + replacedStart: + + if nxt.Args[0] != ind { + // unlike additions subtractions are not commutative so be sure we get it right + nxt.Args[0], nxt.Args[1] = nxt.Args[1], nxt.Args[0] + } + + switch nxt.Op { + case OpAdd8: + nxt.Op = OpSub8 + case OpAdd16: + nxt.Op = OpSub16 + case OpAdd32: + nxt.Op = OpSub32 + case OpAdd64: + nxt.Op = OpSub64 + case OpSub8: + nxt.Op = OpAdd8 + case OpSub16: + nxt.Op = OpAdd16 + case OpSub32: + nxt.Op = OpAdd32 + case OpSub64: + nxt.Op = OpAdd64 + default: + panic("unreachable") + } + + if f.pass.debug > 0 { + f.Warnl(ind.Pos, "Inverted loop iteration") + } + } + ft := newFactsTable(f) ft.checkpoint() @@ -933,15 +1093,6 @@ func prove(f *Func) { } } } - // Find induction variables. Currently, findIndVars - // is limited to one induction variable per block. - var indVars map[*Block]indVar - for _, v := range findIndVar(f) { - if indVars == nil { - indVars = make(map[*Block]indVar) - } - indVars[v.entry] = v - } // current node state type walkState int diff --git a/test/codegen/compare_and_branch.go b/test/codegen/compare_and_branch.go index b3feef0eb7..3502a03022 100644 --- a/test/codegen/compare_and_branch.go +++ b/test/codegen/compare_and_branch.go @@ -23,24 +23,25 @@ func si64(x, y chan int64) { } // Signed 64-bit compare-and-branch with 8-bit immediate. -func si64x8() { +func si64x8(doNotOptimize int64) { + // take in doNotOptimize as an argument to avoid the loops being rewritten to count down // s390x:"CGIJ\t[$]12, R[0-9]+, [$]127, " - for i := int64(0); i < 128; i++ { + for i := doNotOptimize; i < 128; i++ { dummy() } // s390x:"CGIJ\t[$]10, R[0-9]+, [$]-128, " - for i := int64(0); i > -129; i-- { + for i := doNotOptimize; i > -129; i-- { dummy() } // s390x:"CGIJ\t[$]2, R[0-9]+, [$]127, " - for i := int64(0); i >= 128; i++ { + for i := doNotOptimize; i >= 128; i++ { dummy() } // s390x:"CGIJ\t[$]4, R[0-9]+, [$]-128, " - for i := int64(0); i <= -129; i-- { + for i := doNotOptimize; i <= -129; i-- { dummy() } } @@ -95,24 +96,25 @@ func si32(x, y chan int32) { } // Signed 32-bit compare-and-branch with 8-bit immediate. -func si32x8() { +func si32x8(doNotOptimize int32) { + // take in doNotOptimize as an argument to avoid the loops being rewritten to count down // s390x:"CIJ\t[$]12, R[0-9]+, [$]127, " - for i := int32(0); i < 128; i++ { + for i := doNotOptimize; i < 128; i++ { dummy() } // s390x:"CIJ\t[$]10, R[0-9]+, [$]-128, " - for i := int32(0); i > -129; i-- { + for i := doNotOptimize; i > -129; i-- { dummy() } // s390x:"CIJ\t[$]2, R[0-9]+, [$]127, " - for i := int32(0); i >= 128; i++ { + for i := doNotOptimize; i >= 128; i++ { dummy() } // s390x:"CIJ\t[$]4, R[0-9]+, [$]-128, " - for i := int32(0); i <= -129; i-- { + for i := doNotOptimize; i <= -129; i-- { dummy() } } diff --git a/test/prove_invert_loop_with_unused_iterators.go b/test/prove_invert_loop_with_unused_iterators.go new file mode 100644 index 0000000000..f278e5aee0 --- /dev/null +++ b/test/prove_invert_loop_with_unused_iterators.go @@ -0,0 +1,10 @@ +// +build amd64 +// errorcheck -0 -d=ssa/prove/debug=1 + +package main + +func invert(b func(), n int) { + for i := 0; i < n; i++ { // ERROR "(Inverted loop iteration|Induction variable: limits \[0,\?\), increment 1)" + b() + } +}