cmd/compile: add additional arm64 bit field rules

Get rid of TODO in prove pass.
We currently avoid marking shifts of constants as bounded, where
bounded means we don't have to worry about <0 or >=bitwidth shifts.
We do this because it causes different rule applications during lowering
which cause some codegen tests to fail.

Add some new rules which ensure that we get the right final instruction
sequence regardless of the ordering. Then we can remove this special case.

Change-Id: I4e962d4f09992b42ab47e123de5ded3b8b8fb205
Reviewed-on: https://go-review.googlesource.com/c/go/+/602935
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
khr@golang.org 2024-08-04 05:41:38 -07:00 committed by Keith Randall
parent 8aa2eed8fb
commit 7273509466
4 changed files with 152 additions and 12 deletions

View File

@ -1789,6 +1789,14 @@
(SRAconst [sc] (SBFIZ [bfc] x)) && sc >= bfc.getARM64BFlsb()
&& sc < bfc.getARM64BFlsb()+bfc.getARM64BFwidth()
=> (SBFX [armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFlsb()+bfc.getARM64BFwidth()-sc)] x)
(SBFX [bfc] s:(SLLconst [sc] x))
&& s.Uses == 1
&& sc <= bfc.getARM64BFlsb()
=> (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x)
(SBFX [bfc] s:(SLLconst [sc] x))
&& s.Uses == 1
&& sc > bfc.getARM64BFlsb()
=> (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x)
// ubfiz
// (x << lc) >> rc
@ -1833,11 +1841,25 @@
(UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), arm64BFWidth(c, 0)))] x)
(UBFX [bfc] (ANDconst [c] x)) && isARM64BFMask(0, c, 0) && bfc.getARM64BFlsb() + bfc.getARM64BFwidth() <= arm64BFWidth(c, 0) =>
(UBFX [bfc] x)
// merge ubfx and zerso-extension into ubfx
// merge ubfx and zero-extension into ubfx
(MOVWUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 32 => (UBFX [bfc] x)
(MOVHUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 16 => (UBFX [bfc] x)
(MOVBUreg (UBFX [bfc] x)) && bfc.getARM64BFwidth() <= 8 => (UBFX [bfc] x)
// Extracting bits from across a zero-extension boundary.
(UBFX [bfc] e:(MOVWUreg x))
&& e.Uses == 1
&& bfc.getARM64BFlsb() < 32
=> (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x)
(UBFX [bfc] e:(MOVHUreg x))
&& e.Uses == 1
&& bfc.getARM64BFlsb() < 16
=> (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x)
(UBFX [bfc] e:(MOVBUreg x))
&& e.Uses == 1
&& bfc.getARM64BFlsb() < 8
=> (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x)
// ubfiz/ubfx combinations: merge shifts into bitfield ops
(SRLconst [sc] (UBFX [bfc] x)) && sc < bfc.getARM64BFwidth()
=> (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth()-sc)] x)
@ -1868,6 +1890,12 @@
(OR (UBFIZ [bfc] x) (ANDconst [ac] y))
&& ac == ^((1<<uint(bfc.getARM64BFwidth())-1) << uint(bfc.getARM64BFlsb()))
=> (BFI [bfc] y x)
(ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y))
&& xc == ^(yc << s) // opposite masks
&& yc & (yc+1) == 0 // power of 2 minus 1
&& yc > 0 // not 0, not all 64 bits (there are better rewrites in that case)
&& s+log64(yc+1) <= 64 // shifted mask doesn't overflow
=> (BFI [armBFAuxInt(s, log64(yc+1))] x y)
(ORshiftRL [rc] (ANDconst [ac] x) (SLLconst [lc] y))
&& lc > rc && ac == ^((1<<uint(64-lc)-1) << uint64(lc-rc))
=> (BFI [armBFAuxInt(lc-rc, 64-lc)] x y)

View File

@ -2008,15 +2008,6 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
lim := ft.limits[by.ID]
bits := 8 * v.Args[0].Type.Size()
if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) {
if by.isGenericIntConst() {
// TODO: get rid of this block.
// Currently this causes lowering to happen
// in different orders, which causes some
// problems for codegen tests for arm64
// where rule application order results
// in different final instructions.
break
}
v.AuxInt = 1 // see shiftIsBounded
if b.Func.pass.debug > 0 && !by.isGenericIntConst() {
b.Func.Warnl(v.Pos, "Proved %v bounded", v.Op)

View File

@ -352,6 +352,8 @@ func rewriteValueARM64(v *Value) bool {
return rewriteValueARM64_OpARM64RORW(v)
case OpARM64SBCSflags:
return rewriteValueARM64_OpARM64SBCSflags(v)
case OpARM64SBFX:
return rewriteValueARM64_OpARM64SBFX(v)
case OpARM64SLL:
return rewriteValueARM64_OpARM64SLL(v)
case OpARM64SLLconst:
@ -15203,6 +15205,29 @@ func rewriteValueARM64_OpARM64ORshiftLL(v *Value) bool {
v.AddArg2(x2, x)
return true
}
// match: (ORshiftLL [s] (ANDconst [xc] x) (ANDconst [yc] y))
// cond: xc == ^(yc << s) && yc & (yc+1) == 0 && yc > 0 && s+log64(yc+1) <= 64
// result: (BFI [armBFAuxInt(s, log64(yc+1))] x y)
for {
s := auxIntToInt64(v.AuxInt)
if v_0.Op != OpARM64ANDconst {
break
}
xc := auxIntToInt64(v_0.AuxInt)
x := v_0.Args[0]
if v_1.Op != OpARM64ANDconst {
break
}
yc := auxIntToInt64(v_1.AuxInt)
y := v_1.Args[0]
if !(xc == ^(yc<<s) && yc&(yc+1) == 0 && yc > 0 && s+log64(yc+1) <= 64) {
break
}
v.reset(OpARM64BFI)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(s, log64(yc+1)))
v.AddArg2(x, y)
return true
}
// match: (ORshiftLL [sc] (UBFX [bfc] x) (SRLconst [sc] y))
// cond: sc == bfc.getARM64BFwidth()
// result: (BFXIL [bfc] y x)
@ -15546,6 +15571,48 @@ func rewriteValueARM64_OpARM64SBCSflags(v *Value) bool {
}
return false
}
func rewriteValueARM64_OpARM64SBFX(v *Value) bool {
v_0 := v.Args[0]
// match: (SBFX [bfc] s:(SLLconst [sc] x))
// cond: s.Uses == 1 && sc <= bfc.getARM64BFlsb()
// result: (SBFX [armBFAuxInt(bfc.getARM64BFlsb() - sc, bfc.getARM64BFwidth())] x)
for {
bfc := auxIntToArm64BitField(v.AuxInt)
s := v_0
if s.Op != OpARM64SLLconst {
break
}
sc := auxIntToInt64(s.AuxInt)
x := s.Args[0]
if !(s.Uses == 1 && sc <= bfc.getARM64BFlsb()) {
break
}
v.reset(OpARM64SBFX)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb()-sc, bfc.getARM64BFwidth()))
v.AddArg(x)
return true
}
// match: (SBFX [bfc] s:(SLLconst [sc] x))
// cond: s.Uses == 1 && sc > bfc.getARM64BFlsb()
// result: (SBFIZ [armBFAuxInt(sc - bfc.getARM64BFlsb(), bfc.getARM64BFwidth() - (sc-bfc.getARM64BFlsb()))] x)
for {
bfc := auxIntToArm64BitField(v.AuxInt)
s := v_0
if s.Op != OpARM64SLLconst {
break
}
sc := auxIntToInt64(s.AuxInt)
x := s.Args[0]
if !(s.Uses == 1 && sc > bfc.getARM64BFlsb()) {
break
}
v.reset(OpARM64SBFIZ)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(sc-bfc.getARM64BFlsb(), bfc.getARM64BFwidth()-(sc-bfc.getARM64BFlsb())))
v.AddArg(x)
return true
}
return false
}
func rewriteValueARM64_OpARM64SLL(v *Value) bool {
v_1 := v.Args[1]
v_0 := v.Args[0]
@ -16946,6 +17013,60 @@ func rewriteValueARM64_OpARM64UBFX(v *Value) bool {
v.AddArg(x)
return true
}
// match: (UBFX [bfc] e:(MOVWUreg x))
// cond: e.Uses == 1 && bfc.getARM64BFlsb() < 32
// result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb()))] x)
for {
bfc := auxIntToArm64BitField(v.AuxInt)
e := v_0
if e.Op != OpARM64MOVWUreg {
break
}
x := e.Args[0]
if !(e.Uses == 1 && bfc.getARM64BFlsb() < 32) {
break
}
v.reset(OpARM64UBFX)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 32-bfc.getARM64BFlsb())))
v.AddArg(x)
return true
}
// match: (UBFX [bfc] e:(MOVHUreg x))
// cond: e.Uses == 1 && bfc.getARM64BFlsb() < 16
// result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb()))] x)
for {
bfc := auxIntToArm64BitField(v.AuxInt)
e := v_0
if e.Op != OpARM64MOVHUreg {
break
}
x := e.Args[0]
if !(e.Uses == 1 && bfc.getARM64BFlsb() < 16) {
break
}
v.reset(OpARM64UBFX)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 16-bfc.getARM64BFlsb())))
v.AddArg(x)
return true
}
// match: (UBFX [bfc] e:(MOVBUreg x))
// cond: e.Uses == 1 && bfc.getARM64BFlsb() < 8
// result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb()))] x)
for {
bfc := auxIntToArm64BitField(v.AuxInt)
e := v_0
if e.Op != OpARM64MOVBUreg {
break
}
x := e.Args[0]
if !(e.Uses == 1 && bfc.getARM64BFlsb() < 8) {
break
}
v.reset(OpARM64UBFX)
v.AuxInt = arm64BitFieldToAuxInt(armBFAuxInt(bfc.getARM64BFlsb(), min(bfc.getARM64BFwidth(), 8-bfc.getARM64BFlsb())))
v.AddArg(x)
return true
}
// match: (UBFX [bfc] (SRLconst [sc] x))
// cond: sc+bfc.getARM64BFwidth()+bfc.getARM64BFlsb() < 64
// result: (UBFX [armBFAuxInt(bfc.getARM64BFlsb()+sc, bfc.getARM64BFwidth())] x)

View File

@ -233,9 +233,9 @@ func CmpToZero(a, b, d int32, e, f int64, deOptC0, deOptC1 bool) int32 {
// arm:`AND`,-`TST`
// 386:`ANDL`
c6 := a&d >= 0
// arm64:`TST\sR[0-9]+<<3,\sR[0-9]+`
// For arm64, could be TST+BGE or AND+TBZ
c7 := e&(f<<3) < 0
// arm64:`CMN\sR[0-9]+<<3,\sR[0-9]+`
// For arm64, could be CMN+BPL or ADD+TBZ
c8 := e+(f<<3) < 0
// arm64:`TST\sR[0-9],\sR[0-9]+`
c9 := e&(-19) < 0