diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 332f5841b7..22338188e5 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -278,7 +278,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Type = obj.TYPE_REG p.To.Reg = rd case ssa.OpRISCV64ADD, ssa.OpRISCV64SUB, ssa.OpRISCV64SUBW, ssa.OpRISCV64XOR, ssa.OpRISCV64OR, ssa.OpRISCV64AND, - ssa.OpRISCV64SLL, ssa.OpRISCV64SRA, ssa.OpRISCV64SRL, ssa.OpRISCV64SRLW, + ssa.OpRISCV64SLL, ssa.OpRISCV64SRA, ssa.OpRISCV64SRAW, ssa.OpRISCV64SRL, ssa.OpRISCV64SRLW, ssa.OpRISCV64SLT, ssa.OpRISCV64SLTU, ssa.OpRISCV64MUL, ssa.OpRISCV64MULW, ssa.OpRISCV64MULH, ssa.OpRISCV64MULHU, ssa.OpRISCV64DIV, ssa.OpRISCV64DIVU, ssa.OpRISCV64DIVW, ssa.OpRISCV64DIVUW, ssa.OpRISCV64REM, ssa.OpRISCV64REMU, ssa.OpRISCV64REMW, @@ -356,7 +356,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Type = obj.TYPE_REG p.To.Reg = v.Reg() case ssa.OpRISCV64ADDI, ssa.OpRISCV64ADDIW, ssa.OpRISCV64XORI, ssa.OpRISCV64ORI, ssa.OpRISCV64ANDI, - ssa.OpRISCV64SLLI, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, + ssa.OpRISCV64SLLI, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRAIW, ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, ssa.OpRISCV64SLTIU: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_CONST diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 4cacabb236..9afe5995ae 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -175,16 +175,19 @@ (Rsh32Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt32to64 x) y) (Rsh64Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL x y) -// SRA only considers the bottom 6 bits of y. If y > 64, the result should -// be either 0 or -1 based on the sign bit. +// SRA only considers the bottom 6 bits of y, similarly SRAW only considers the +// bottom 5 bits. If y is greater than the maximum value (either 63 or 31 +// depending on the instruction), the result of the shift should be either 0 +// or -1 based on the sign bit of x. // -// We implement this by performing the max shift (-1) if y >= 64. +// We implement this by performing the max shift (-1) if y > the maximum value. // // We OR (uint64(y < 64) - 1) into y before passing it to SRA. This leaves -// us with -1 (0xffff...) if y >= 64. +// us with -1 (0xffff...) if y >= 64. Similarly, we OR (uint64(y < 32) - 1) into y +// before passing it to SRAW. // // We don't need to sign-extend the OR result, as it will be at minimum 8 bits, -// more than the 6 bits SRA cares about. +// more than the 5 or 6 bits SRAW and SRA care about. (Rsh8x8 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) (Rsh8x16 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) (Rsh8x32 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) @@ -193,10 +196,10 @@ (Rsh16x16 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) (Rsh16x32 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) (Rsh16x64 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh32x8 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh32x16 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh32x32 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh32x64 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh32x8 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) +(Rsh32x16 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) +(Rsh32x32 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) +(Rsh32x64 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) (Rsh64x8 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) (Rsh64x16 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) (Rsh64x32 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) @@ -706,9 +709,11 @@ // But for now, this is enough to get rid of lots of them. (MOVDnop (MOVDconst [c])) => (MOVDconst [c]) -// Avoid unnecessary zero extension when right shifting. +// Avoid unnecessary zero and sign extension when right shifting. (SRL (MOVWUreg x) y) => (SRLW x y) (SRLI [x] (MOVWUreg y)) => (SRLIW [int64(x&31)] y) +(SRA (MOVWreg x) y) => (SRAW x y) +(SRAI [x] (MOVWreg y)) => (SRAIW [int64(x&31)] y) // Fold constant into immediate instructions where possible. (ADD (MOVDconst [val]) x) && is32Bit(val) && !t.IsPtr() => (ADDI [val] x) @@ -719,6 +724,7 @@ (SRL x (MOVDconst [val])) => (SRLI [int64(val&63)] x) (SRLW x (MOVDconst [val])) => (SRLIW [int64(val&31)] x) (SRA x (MOVDconst [val])) => (SRAI [int64(val&63)] x) +(SRAW x (MOVDconst [val])) => (SRAIW [int64(val&31)] x) (SLT x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTI [val] x) (SLTU x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTIU [val] x) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 360eff6bcf..93f20f8a99 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -209,10 +209,12 @@ func init() { // Shift ops {name: "SLL", argLength: 2, reg: gp21, asm: "SLL"}, // arg0 << (aux1 & 63) {name: "SRA", argLength: 2, reg: gp21, asm: "SRA"}, // arg0 >> (aux1 & 63), signed + {name: "SRAW", argLength: 2, reg: gp21, asm: "SRAW"}, // arg0 >> (aux1 & 31), signed {name: "SRL", argLength: 2, reg: gp21, asm: "SRL"}, // arg0 >> (aux1 & 63), unsigned {name: "SRLW", argLength: 2, reg: gp21, asm: "SRLW"}, // arg0 >> (aux1 & 31), unsigned {name: "SLLI", argLength: 1, reg: gp11, asm: "SLLI", aux: "Int64"}, // arg0 << auxint, shift amount 0-63 {name: "SRAI", argLength: 1, reg: gp11, asm: "SRAI", aux: "Int64"}, // arg0 >> auxint, signed, shift amount 0-63 + {name: "SRAIW", argLength: 1, reg: gp11, asm: "SRAIW", aux: "Int64"}, // arg0 >> auxint, signed, shift amount 0-31 {name: "SRLI", argLength: 1, reg: gp11, asm: "SRLI", aux: "Int64"}, // arg0 >> auxint, unsigned, shift amount 0-63 {name: "SRLIW", argLength: 1, reg: gp11, asm: "SRLIW", aux: "Int64"}, // arg0 >> auxint, unsigned, shift amount 0-31 diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index b2af30a37d..6b2320f44c 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2382,10 +2382,12 @@ const ( OpRISCV64MOVDnop OpRISCV64SLL OpRISCV64SRA + OpRISCV64SRAW OpRISCV64SRL OpRISCV64SRLW OpRISCV64SLLI OpRISCV64SRAI + OpRISCV64SRAIW OpRISCV64SRLI OpRISCV64SRLIW OpRISCV64XOR @@ -31953,6 +31955,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SRAW", + argLen: 2, + asm: riscv.ASRAW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "SRL", argLen: 2, @@ -32009,6 +32025,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SRAIW", + auxType: auxInt64, + argLen: 1, + asm: riscv.ASRAIW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "SRLI", auxType: auxInt64, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 7eed0f1700..6009c41f2d 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -538,6 +538,8 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64SRA(v) case OpRISCV64SRAI: return rewriteValueRISCV64_OpRISCV64SRAI(v) + case OpRISCV64SRAW: + return rewriteValueRISCV64_OpRISCV64SRAW(v) case OpRISCV64SRL: return rewriteValueRISCV64_OpRISCV64SRL(v) case OpRISCV64SRLI: @@ -6258,6 +6260,20 @@ func rewriteValueRISCV64_OpRISCV64SNEZ(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + // match: (SRA (MOVWreg x) y) + // result: (SRAW x y) + for { + t := v.Type + if v_0.Op != OpRISCV64MOVWreg { + break + } + x := v_0.Args[0] + y := v_1 + v.reset(OpRISCV64SRAW) + v.Type = t + v.AddArg2(x, y) + return true + } // match: (SRA x (MOVDconst [val])) // result: (SRAI [int64(val&63)] x) for { @@ -6275,6 +6291,21 @@ func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { } func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { v_0 := v.Args[0] + // match: (SRAI [x] (MOVWreg y)) + // result: (SRAIW [int64(x&31)] y) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWreg { + break + } + y := v_0.Args[0] + v.reset(OpRISCV64SRAIW) + v.Type = t + v.AuxInt = int64ToAuxInt(int64(x & 31)) + v.AddArg(y) + return true + } // match: (SRAI [x] (MOVDconst [y])) // result: (MOVDconst [int64(y) >> uint32(x)]) for { @@ -6289,6 +6320,24 @@ func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SRAW(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SRAW x (MOVDconst [val])) + // result: (SRAIW [int64(val&31)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64SRAIW) + v.AuxInt = int64ToAuxInt(int64(val & 31)) + v.AddArg(x) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -7172,7 +7221,7 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x16 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) + // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) for { t := v.Type x := v_0 @@ -7188,7 +7237,7 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) v2.AuxInt = int64ToAuxInt(-1) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) @@ -7221,7 +7270,7 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x32 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) + // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) for { t := v.Type x := v_0 @@ -7237,7 +7286,7 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) v2.AuxInt = int64ToAuxInt(-1) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) @@ -7270,7 +7319,7 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x64 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) + // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) for { t := v.Type x := v_0 @@ -7286,7 +7335,7 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) v2.AuxInt = int64ToAuxInt(-1) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v3.AddArg(y) v2.AddArg(v3) v1.AddArg2(y, v2) @@ -7317,7 +7366,7 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x8 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) + // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) for { t := v.Type x := v_0 @@ -7333,7 +7382,7 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) v2.AuxInt = int64ToAuxInt(-1) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) diff --git a/test/codegen/shift.go b/test/codegen/shift.go index bf8b633905..32cfaffae0 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -24,7 +24,7 @@ func rshConst64Ux64(v uint64) uint64 { func rshConst64x64(v int64) int64 { // ppc64x:"SRAD" - // riscv64:"SRAI",-"OR",-"SLTIU" + // riscv64:"SRAI\t",-"OR",-"SLTIU" return v >> uint64(33) } @@ -42,7 +42,7 @@ func rshConst32Ux64(v uint32) uint32 { func rshConst32x64(v int32) int32 { // ppc64x:"SRAW" - // riscv64:"SRAI",-"OR",-"SLTIU", -"MOVW" + // riscv64:"SRAIW",-"OR",-"SLTIU", -"MOVW" return v >> uint64(29) } @@ -60,7 +60,7 @@ func rshConst64Ux32(v uint64) uint64 { func rshConst64x32(v int64) int64 { // ppc64x:"SRAD" - // riscv64:"SRAI",-"OR",-"SLTIU" + // riscv64:"SRAI\t",-"OR",-"SLTIU" return v >> uint32(33) } @@ -87,7 +87,7 @@ func rshMask64Ux64(v uint64, s uint64) uint64 { func rshMask64x64(v int64, s uint64) int64 { // arm64:"ASR",-"AND",-"CSEL" // ppc64x:"RLDICL",-"ORN",-"ISEL" - // riscv64:"SRA",-"OR",-"SLTIU" + // riscv64:"SRA\t",-"OR",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } @@ -116,11 +116,16 @@ func rsh5Mask32Ux64(v uint32, s uint64) uint32 { func rshMask32x64(v int32, s uint64) int32 { // arm64:"ASR",-"AND" // ppc64x:"ISEL",-"ORN" - // riscv64:"SRA",-"OR",-"SLTIU" + // riscv64:"SRAW","OR","SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } +func rsh5Mask32x64(v int32, s uint64) int32 { + // riscv64:"SRAW",-"OR",-"SLTIU" + return v >> (s & 31) +} + func lshMask64x32(v int64, s uint32) int64 { // arm64:"LSL",-"AND" // ppc64x:"RLDICL",-"ORN" @@ -140,7 +145,7 @@ func rshMask64Ux32(v uint64, s uint32) uint64 { func rshMask64x32(v int64, s uint32) int64 { // arm64:"ASR",-"AND",-"CSEL" // ppc64x:"RLDICL",-"ORN",-"ISEL" - // riscv64:"SRA",-"OR",-"SLTIU" + // riscv64:"SRA\t",-"OR",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } @@ -161,7 +166,7 @@ func rshMask64Ux32Ext(v uint64, s int32) uint64 { func rshMask64x32Ext(v int64, s int32) int64 { // ppc64x:"RLDICL",-"ORN",-"ISEL" - // riscv64:"SRA",-"OR",-"SLTIU" + // riscv64:"SRA\t",-"OR",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> uint(s&63) } @@ -222,7 +227,7 @@ func rshGuarded64U(v uint64, s uint) uint64 { func rshGuarded64(v int64, s uint) int64 { if s < 64 { - // riscv64:"SRA",-"OR",-"SLTIU" + // riscv64:"SRA\t",-"OR",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" // wasm:-"Select",-".*LtU" // arm64:"ASR",-"CSEL"