diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 9afe5995ae..fc206c42d3 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -153,27 +153,27 @@ // SRL only considers the bottom 6 bits of y, similarly SRLW only considers the // bottom 5 bits of y. Ensure that the result is always zero if the shift exceeds // the maximum value. See Lsh above for a detailed description. -(Rsh8Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh8Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh8Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh8Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] y))) -(Rsh16Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh16Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh16Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh16Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] y))) -(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) -(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) -(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) -(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) -(Rsh64Ux8 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh64Ux16 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh64Ux32 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh64Ux64 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] y))) +(Rsh8Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh8Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh8Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh8Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] y))) +(Rsh16Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh16Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh16Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh16Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] y))) +(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) +(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) +(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) +(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] y))) +(Rsh64Ux8 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh64Ux16 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh64Ux32 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh64Ux64 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] y))) -(Rsh8Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt8to64 x) y) -(Rsh16Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt16to64 x) y) -(Rsh32Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt32to64 x) y) -(Rsh64Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL x y) +(Rsh8Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt8to64 x) y) +(Rsh16Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt16to64 x) y) +(Rsh32Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRLW x y) +(Rsh64Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL x y) // SRA only considers the bottom 6 bits of y, similarly SRAW only considers the // bottom 5 bits. If y is greater than the maximum value (either 63 or 31 @@ -188,27 +188,27 @@ // // We don't need to sign-extend the OR result, as it will be at minimum 8 bits, // more than the 5 or 6 bits SRAW and SRA care about. -(Rsh8x8 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh8x16 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh8x32 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh8x64 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh16x8 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh16x16 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh16x32 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh16x64 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh32x8 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) -(Rsh32x16 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) -(Rsh32x32 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) -(Rsh32x64 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) -(Rsh64x8 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh64x16 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh64x32 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh64x64 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh8x8 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh8x16 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh8x32 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh8x64 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh16x8 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh16x16 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh16x32 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh16x64 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh32x8 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) +(Rsh32x16 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) +(Rsh32x32 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) +(Rsh32x64 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] y)))) +(Rsh64x8 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh64x16 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh64x32 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh64x64 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh8x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt8to64 x) y) -(Rsh16x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt16to64 x) y) -(Rsh32x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt32to64 x) y) -(Rsh64x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA x y) +(Rsh8x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt8to64 x) y) +(Rsh16x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt16to64 x) y) +(Rsh32x(64|32|16|8) x y) && shiftIsBounded(v) => (SRAW x y) +(Rsh64x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA x y) // Rotates. (RotateLeft8 x (MOVDconst [c])) => (Or8 (Lsh8x64 x (MOVDconst [c&7])) (Rsh8Ux64 x (MOVDconst [-c&7]))) @@ -710,10 +710,18 @@ (MOVDnop (MOVDconst [c])) => (MOVDconst [c]) // Avoid unnecessary zero and sign extension when right shifting. -(SRL (MOVWUreg x) y) => (SRLW x y) -(SRLI [x] (MOVWUreg y)) => (SRLIW [int64(x&31)] y) -(SRA (MOVWreg x) y) => (SRAW x y) -(SRAI [x] (MOVWreg y)) => (SRAIW [int64(x&31)] y) +(SRAI [x] (MOVWreg y)) && x >= 0 && x <= 31 => (SRAIW [int64(x)] y) +(SRLI [x] (MOVWUreg y)) && x >= 0 && x <= 31 => (SRLIW [int64(x)] y) + +// Replace right shifts that exceed size of signed type. +(SRAI [x] (MOVBreg y)) && x >= 8 => (SRAI [63] (SLLI [56] y)) +(SRAI [x] (MOVHreg y)) && x >= 16 => (SRAI [63] (SLLI [48] y)) +(SRAI [x] (MOVWreg y)) && x >= 32 => (SRAIW [31] y) + +// Eliminate right shifts that exceed size of unsigned type. +(SRLI [x] (MOVBUreg y)) && x >= 8 => (MOVDconst [0]) +(SRLI [x] (MOVHUreg y)) && x >= 16 => (MOVDconst [0]) +(SRLI [x] (MOVWUreg y)) && x >= 32 => (MOVDconst [0]) // Fold constant into immediate instructions where possible. (ADD (MOVDconst [val]) x) && is32Bit(val) && !t.IsPtr() => (ADDI [val] x) diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 6009c41f2d..52ddca1c7d 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -6260,20 +6260,6 @@ func rewriteValueRISCV64_OpRISCV64SNEZ(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] - // match: (SRA (MOVWreg x) y) - // result: (SRAW x y) - for { - t := v.Type - if v_0.Op != OpRISCV64MOVWreg { - break - } - x := v_0.Args[0] - y := v_1 - v.reset(OpRISCV64SRAW) - v.Type = t - v.AddArg2(x, y) - return true - } // match: (SRA x (MOVDconst [val])) // result: (SRAI [int64(val&63)] x) for { @@ -6291,8 +6277,10 @@ func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { } func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { v_0 := v.Args[0] + b := v.Block // match: (SRAI [x] (MOVWreg y)) - // result: (SRAIW [int64(x&31)] y) + // cond: x >= 0 && x <= 31 + // result: (SRAIW [int64(x)] y) for { t := v.Type x := auxIntToInt64(v.AuxInt) @@ -6300,9 +6288,71 @@ func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { break } y := v_0.Args[0] + if !(x >= 0 && x <= 31) { + break + } v.reset(OpRISCV64SRAIW) v.Type = t - v.AuxInt = int64ToAuxInt(int64(x & 31)) + v.AuxInt = int64ToAuxInt(int64(x)) + v.AddArg(y) + return true + } + // match: (SRAI [x] (MOVBreg y)) + // cond: x >= 8 + // result: (SRAI [63] (SLLI [56] y)) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBreg { + break + } + y := v_0.Args[0] + if !(x >= 8) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(63) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(56) + v0.AddArg(y) + v.AddArg(v0) + return true + } + // match: (SRAI [x] (MOVHreg y)) + // cond: x >= 16 + // result: (SRAI [63] (SLLI [48] y)) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHreg { + break + } + y := v_0.Args[0] + if !(x >= 16) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(63) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(48) + v0.AddArg(y) + v.AddArg(v0) + return true + } + // match: (SRAI [x] (MOVWreg y)) + // cond: x >= 32 + // result: (SRAIW [31] y) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWreg { + break + } + y := v_0.Args[0] + if !(x >= 32) { + break + } + v.reset(OpRISCV64SRAIW) + v.AuxInt = int64ToAuxInt(31) v.AddArg(y) return true } @@ -6341,20 +6391,6 @@ func rewriteValueRISCV64_OpRISCV64SRAW(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] - // match: (SRL (MOVWUreg x) y) - // result: (SRLW x y) - for { - t := v.Type - if v_0.Op != OpRISCV64MOVWUreg { - break - } - x := v_0.Args[0] - y := v_1 - v.reset(OpRISCV64SRLW) - v.Type = t - v.AddArg2(x, y) - return true - } // match: (SRL x (MOVDconst [val])) // result: (SRLI [int64(val&63)] x) for { @@ -6373,7 +6409,8 @@ func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { v_0 := v.Args[0] // match: (SRLI [x] (MOVWUreg y)) - // result: (SRLIW [int64(x&31)] y) + // cond: x >= 0 && x <= 31 + // result: (SRLIW [int64(x)] y) for { t := v.Type x := auxIntToInt64(v.AuxInt) @@ -6381,12 +6418,66 @@ func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { break } y := v_0.Args[0] + if !(x >= 0 && x <= 31) { + break + } v.reset(OpRISCV64SRLIW) v.Type = t - v.AuxInt = int64ToAuxInt(int64(x & 31)) + v.AuxInt = int64ToAuxInt(int64(x)) v.AddArg(y) return true } + // match: (SRLI [x] (MOVBUreg y)) + // cond: x >= 8 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBUreg { + break + } + if !(x >= 8) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } + // match: (SRLI [x] (MOVHUreg y)) + // cond: x >= 16 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHUreg { + break + } + if !(x >= 16) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } + // match: (SRLI [x] (MOVWUreg y)) + // cond: x >= 32 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + if !(x >= 32) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } // match: (SRLI [x] (MOVDconst [y])) // result: (MOVDconst [int64(uint64(y) >> uint32(x))]) for { @@ -7035,7 +7126,7 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux16 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) for { t := v.Type x := v_0 @@ -7044,33 +7135,29 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux16 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7082,7 +7169,7 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux32 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) for { t := v.Type x := v_0 @@ -7091,33 +7178,29 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux32 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7126,10 +7209,9 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block - typ := &b.Func.Config.Types // match: (Rsh32Ux64 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] y))) for { t := v.Type x := v_0 @@ -7138,31 +7220,27 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v3.AddArg(y) - v2.AddArg(v3) - v.AddArg2(v0, v2) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v2.AddArg(y) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux64 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7174,7 +7252,7 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux8 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) for { t := v.Type x := v_0 @@ -7183,33 +7261,29 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux8 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7221,7 +7295,7 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x16 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) for { t := v.Type x := v_0 @@ -7229,36 +7303,32 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x16 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7270,7 +7340,7 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x32 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) for { t := v.Type x := v_0 @@ -7278,36 +7348,32 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x32 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7316,10 +7382,9 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block - typ := &b.Func.Config.Types // match: (Rsh32x64 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] y)))) for { t := v.Type x := v_0 @@ -7327,34 +7392,30 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v3.AddArg(y) - v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v2.AddArg(y) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x64 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7366,7 +7427,7 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x8 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) for { t := v.Type x := v_0 @@ -7374,36 +7435,32 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x8 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false diff --git a/src/cmd/compile/internal/test/testdata/arith_test.go b/src/cmd/compile/internal/test/testdata/arith_test.go index 2b8cd9fad3..cd7b5bc2c4 100644 --- a/src/cmd/compile/internal/test/testdata/arith_test.go +++ b/src/cmd/compile/internal/test/testdata/arith_test.go @@ -268,6 +268,70 @@ func testOverflowConstShift(t *testing.T) { } } +//go:noinline +func rsh64x64ConstOverflow8(x int8) int64 { + return int64(x) >> 9 +} + +//go:noinline +func rsh64x64ConstOverflow16(x int16) int64 { + return int64(x) >> 17 +} + +//go:noinline +func rsh64x64ConstOverflow32(x int32) int64 { + return int64(x) >> 33 +} + +func testArithRightShiftConstOverflow(t *testing.T) { + allSet := int64(-1) + if got, want := rsh64x64ConstOverflow8(0x7f), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow16(0x7fff), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow32(0x7ffffff), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow32 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow8(int8(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow16(int16(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow32(int32(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow32 failed: got %v, want %v", got, want) + } +} + +//go:noinline +func rsh64Ux64ConstOverflow8(x uint8) uint64 { + return uint64(x) >> 9 +} + +//go:noinline +func rsh64Ux64ConstOverflow16(x uint16) uint64 { + return uint64(x) >> 17 +} + +//go:noinline +func rsh64Ux64ConstOverflow32(x uint32) uint64 { + return uint64(x) >> 33 +} + +func testRightShiftConstOverflow(t *testing.T) { + if got, want := rsh64Ux64ConstOverflow8(0xff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64Ux64ConstOverflow16(0xffff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64Ux64ConstOverflow32(0xffffffff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow32 failed: got %v, want %v", got, want) + } +} + // test64BitConstMult tests that rewrite rules don't fold 64 bit constants // into multiply instructions. func test64BitConstMult(t *testing.T) { @@ -918,6 +982,8 @@ func TestArithmetic(t *testing.T) { testShiftCX(t) testSubConst(t) testOverflowConstShift(t) + testArithRightShiftConstOverflow(t) + testRightShiftConstOverflow(t) testArithConstShift(t) testArithRshConst(t) testLargeConst(t) diff --git a/test/codegen/shift.go b/test/codegen/shift.go index 32cfaffae0..50d60426d0 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -22,12 +22,42 @@ func rshConst64Ux64(v uint64) uint64 { return v >> uint64(33) } +func rshConst64Ux64Overflow32(v uint32) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 32 +} + +func rshConst64Ux64Overflow16(v uint16) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 16 +} + +func rshConst64Ux64Overflow8(v uint8) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 8 +} + func rshConst64x64(v int64) int64 { // ppc64x:"SRAD" // riscv64:"SRAI\t",-"OR",-"SLTIU" return v >> uint64(33) } +func rshConst64x64Overflow32(v int32) int64 { + // riscv64:"SRAIW",-"SLLI",-"SRAI\t" + return int64(v) >> 32 +} + +func rshConst64x64Overflow16(v int16) int64 { + // riscv64:"SLLI","SRAI",-"SRAIW" + return int64(v) >> 16 +} + +func rshConst64x64Overflow8(v int8) int64 { + // riscv64:"SLLI","SRAI",-"SRAIW" + return int64(v) >> 8 +} + func lshConst32x64(v int32) int32 { // ppc64x:"SLW" // riscv64:"SLLI",-"AND",-"SLTIU", -"MOVW"