diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 1100878794..63dfc04d40 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -278,7 +278,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Type = obj.TYPE_REG p.To.Reg = rd case ssa.OpRISCV64ADD, ssa.OpRISCV64SUB, ssa.OpRISCV64SUBW, ssa.OpRISCV64XOR, ssa.OpRISCV64OR, ssa.OpRISCV64AND, - ssa.OpRISCV64SLL, ssa.OpRISCV64SRA, ssa.OpRISCV64SRL, + ssa.OpRISCV64SLL, ssa.OpRISCV64SRA, ssa.OpRISCV64SRL, ssa.OpRISCV64SRLW, ssa.OpRISCV64SLT, ssa.OpRISCV64SLTU, ssa.OpRISCV64MUL, ssa.OpRISCV64MULW, ssa.OpRISCV64MULH, ssa.OpRISCV64MULHU, ssa.OpRISCV64DIV, ssa.OpRISCV64DIVU, ssa.OpRISCV64DIVW, ssa.OpRISCV64DIVUW, ssa.OpRISCV64REM, ssa.OpRISCV64REMU, ssa.OpRISCV64REMW, @@ -356,7 +356,7 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Type = obj.TYPE_REG p.To.Reg = v.Reg() case ssa.OpRISCV64ADDI, ssa.OpRISCV64ADDIW, ssa.OpRISCV64XORI, ssa.OpRISCV64ORI, ssa.OpRISCV64ANDI, - ssa.OpRISCV64SLLI, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRLI, ssa.OpRISCV64SLTI, + ssa.OpRISCV64SLLI, ssa.OpRISCV64SRAI, ssa.OpRISCV64SRLI, ssa.OpRISCV64SRLIW, ssa.OpRISCV64SLTI, ssa.OpRISCV64SLTIU: p := s.Prog(v.Op.Asm()) p.From.Type = obj.TYPE_CONST diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index e498218c60..837c0e18f6 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -150,8 +150,9 @@ (Lsh32x(64|32|16|8) x y) && shiftIsBounded(v) => (SLL x y) (Lsh64x(64|32|16|8) x y) && shiftIsBounded(v) => (SLL x y) -// SRL only considers the bottom 6 bits of y. If y > 64, the result should -// always be 0. See Lsh above for a detailed description. +// SRL only considers the bottom 6 bits of y, similarly SRLW only considers the +// bottom 5 bits of y. Ensure that the result is always zero if the shift exceeds +// the maximum value. See Lsh above for a detailed description. (Rsh8Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt8to64 y)))) (Rsh8Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt16to64 y)))) (Rsh8Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt32to64 y)))) @@ -160,10 +161,10 @@ (Rsh16Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt16to64 y)))) (Rsh16Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt32to64 y)))) (Rsh16Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] y))) -(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] y))) +(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) +(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) +(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) +(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) (Rsh64Ux8 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt8to64 y)))) (Rsh64Ux16 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt16to64 y)))) (Rsh64Ux32 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt32to64 y)))) @@ -707,6 +708,10 @@ // But for now, this is enough to get rid of lots of them. (MOVDnop (MOVDconst [c])) => (MOVDconst [c]) +// Avoid unnecessary zero extension when right shifting. +(SRL (MOVWUreg x) y) => (SRLW x y) +(SRLI [x] (MOVWUreg y)) => (SRLIW [int64(x&31)] y) + // Fold constant into immediate instructions where possible. (ADD (MOVDconst [val]) x) && is32Bit(val) && !t.IsPtr() => (ADDI [val] x) (AND (MOVDconst [val]) x) && is32Bit(val) => (ANDI [val] x) @@ -714,6 +719,7 @@ (XOR (MOVDconst [val]) x) && is32Bit(val) => (XORI [val] x) (SLL x (MOVDconst [val])) => (SLLI [int64(val&63)] x) (SRL x (MOVDconst [val])) => (SRLI [int64(val&63)] x) +(SRLW x (MOVDconst [val])) => (SRLIW [int64(val&31)] x) (SRA x (MOVDconst [val])) => (SRAI [int64(val&63)] x) (SLT x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTI [val] x) (SLTU x (MOVDconst [val])) && val >= -2048 && val <= 2047 => (SLTIU [val] x) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 741769f036..47ba20a66b 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -207,12 +207,14 @@ func init() { {name: "MOVDnop", argLength: 1, reg: regInfo{inputs: []regMask{gpMask}, outputs: []regMask{gpMask}}, resultInArg0: true}, // nop, return arg0 in same register // Shift ops - {name: "SLL", argLength: 2, reg: gp21, asm: "SLL"}, // arg0 << (aux1 & 63) - {name: "SRA", argLength: 2, reg: gp21, asm: "SRA"}, // arg0 >> (aux1 & 63), signed - {name: "SRL", argLength: 2, reg: gp21, asm: "SRL"}, // arg0 >> (aux1 & 63), unsigned - {name: "SLLI", argLength: 1, reg: gp11, asm: "SLLI", aux: "Int64"}, // arg0 << auxint, shift amount 0-63 - {name: "SRAI", argLength: 1, reg: gp11, asm: "SRAI", aux: "Int64"}, // arg0 >> auxint, signed, shift amount 0-63 - {name: "SRLI", argLength: 1, reg: gp11, asm: "SRLI", aux: "Int64"}, // arg0 >> auxint, unsigned, shift amount 0-63 + {name: "SLL", argLength: 2, reg: gp21, asm: "SLL"}, // arg0 << (aux1 & 63) + {name: "SRA", argLength: 2, reg: gp21, asm: "SRA"}, // arg0 >> (aux1 & 63), signed + {name: "SRL", argLength: 2, reg: gp21, asm: "SRL"}, // arg0 >> (aux1 & 63), unsigned + {name: "SRLW", argLength: 2, reg: gp21, asm: "SRLW"}, // arg0 >> (aux1 & 31), unsigned + {name: "SLLI", argLength: 1, reg: gp11, asm: "SLLI", aux: "Int64"}, // arg0 << auxint, shift amount 0-63 + {name: "SRAI", argLength: 1, reg: gp11, asm: "SRAI", aux: "Int64"}, // arg0 >> auxint, signed, shift amount 0-63 + {name: "SRLI", argLength: 1, reg: gp11, asm: "SRLI", aux: "Int64"}, // arg0 >> auxint, unsigned, shift amount 0-63 + {name: "SRLIW", argLength: 1, reg: gp11, asm: "SRLIW", aux: "Int64"}, // arg0 >> auxint, unsigned, shift amount 0-31 // Bitwise ops {name: "XOR", argLength: 2, reg: gp21, asm: "XOR", commutative: true}, // arg0 ^ arg1 diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 56404830eb..2e4b376cb0 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2383,9 +2383,11 @@ const ( OpRISCV64SLL OpRISCV64SRA OpRISCV64SRL + OpRISCV64SRLW OpRISCV64SLLI OpRISCV64SRAI OpRISCV64SRLI + OpRISCV64SRLIW OpRISCV64XOR OpRISCV64XORI OpRISCV64OR @@ -31968,6 +31970,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SRLW", + argLen: 2, + asm: riscv.ASRLW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + {1, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "SLLI", auxType: auxInt64, @@ -32010,6 +32026,20 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "SRLIW", + auxType: auxInt64, + argLen: 1, + asm: riscv.ASRLIW, + reg: regInfo{ + inputs: []inputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + outputs: []outputInfo{ + {0, 1006632944}, // X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 X19 X20 X21 X22 X23 X24 X25 X26 X28 X29 X30 + }, + }, + }, { name: "XOR", argLen: 2, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 1ca03a58a9..02629de3ae 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -545,6 +545,8 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64SRL(v) case OpRISCV64SRLI: return rewriteValueRISCV64_OpRISCV64SRLI(v) + case OpRISCV64SRLW: + return rewriteValueRISCV64_OpRISCV64SRLW(v) case OpRISCV64SUB: return rewriteValueRISCV64_OpRISCV64SUB(v) case OpRISCV64SUBW: @@ -6293,6 +6295,20 @@ func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + // match: (SRL (MOVWUreg x) y) + // result: (SRLW x y) + for { + t := v.Type + if v_0.Op != OpRISCV64MOVWUreg { + break + } + x := v_0.Args[0] + y := v_1 + v.reset(OpRISCV64SRLW) + v.Type = t + v.AddArg2(x, y) + return true + } // match: (SRL x (MOVDconst [val])) // result: (SRLI [int64(val&63)] x) for { @@ -6310,6 +6326,21 @@ func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { } func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { v_0 := v.Args[0] + // match: (SRLI [x] (MOVWUreg y)) + // result: (SRLIW [x] y) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + y := v_0.Args[0] + v.reset(OpRISCV64SRLIW) + v.Type = t + v.AuxInt = int64ToAuxInt(x) + v.AddArg(y) + return true + } // match: (SRLI [x] (MOVDconst [y])) // result: (MOVDconst [int64(uint64(y) >> uint32(x))]) for { @@ -6324,6 +6355,24 @@ func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SRLW(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SRLW x (MOVDconst [val])) + // result: (SRLIW [int64(val&31)] x) + for { + x := v_0 + if v_1.Op != OpRISCV64MOVDconst { + break + } + val := auxIntToInt64(v_1.AuxInt) + v.reset(OpRISCV64SRLIW) + v.AuxInt = int64ToAuxInt(int64(val & 31)) + v.AddArg(x) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SUB(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -6940,7 +6989,7 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux16 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt16to64 y)))) + // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) for { t := v.Type x := v_0 @@ -6955,7 +7004,7 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { v0.AddArg2(v1, y) v2 := b.NewValue0(v.Pos, OpNeg32, t) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) @@ -6987,7 +7036,7 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux32 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt32to64 y)))) + // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) for { t := v.Type x := v_0 @@ -7002,7 +7051,7 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { v0.AddArg2(v1, y) v2 := b.NewValue0(v.Pos, OpNeg32, t) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) @@ -7034,7 +7083,7 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux64 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] y))) + // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) for { t := v.Type x := v_0 @@ -7049,7 +7098,7 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { v0.AddArg2(v1, y) v2 := b.NewValue0(v.Pos, OpNeg32, t) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v3.AddArg(y) v2.AddArg(v3) v.AddArg2(v0, v2) @@ -7079,7 +7128,7 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux8 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [64] (ZeroExt8to64 y)))) + // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) for { t := v.Type x := v_0 @@ -7094,7 +7143,7 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { v0.AddArg2(v1, y) v2 := b.NewValue0(v.Pos, OpNeg32, t) v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(64) + v3.AuxInt = int64ToAuxInt(32) v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) v4.AddArg(y) v3.AddArg(v4) diff --git a/test/codegen/shift.go b/test/codegen/shift.go index d34ff9b428..302560d5b0 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -18,7 +18,7 @@ func lshConst64x64(v int64) int64 { func rshConst64Ux64(v uint64) uint64 { // ppc64x:"SRD" - // riscv64:"SRLI",-"AND",-"SLTIU" + // riscv64:"SRLI\t",-"AND",-"SLTIU" return v >> uint64(33) } @@ -36,7 +36,7 @@ func lshConst32x64(v int32) int32 { func rshConst32Ux64(v uint32) uint32 { // ppc64x:"SRW" - // riscv64:"SRLI",-"AND",-"SLTIU", -"MOVW" + // riscv64:"SRLIW",-"AND",-"SLTIU", -"MOVW" return v >> uint64(29) } @@ -54,7 +54,7 @@ func lshConst64x32(v int64) int64 { func rshConst64Ux32(v uint64) uint64 { // ppc64x:"SRD" - // riscv64:"SRLI",-"AND",-"SLTIU" + // riscv64:"SRLI\t",-"AND",-"SLTIU" return v >> uint32(33) } @@ -79,7 +79,7 @@ func lshMask64x64(v int64, s uint64) int64 { func rshMask64Ux64(v uint64, s uint64) uint64 { // arm64:"LSR",-"AND",-"CSEL" // ppc64x:"ANDCC",-"ORN",-"ISEL" - // riscv64:"SRL",-"AND\t",-"SLTIU" + // riscv64:"SRL\t",-"AND\t",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } @@ -103,11 +103,16 @@ func lshMask32x64(v int32, s uint64) int32 { func rshMask32Ux64(v uint32, s uint64) uint32 { // arm64:"LSR",-"AND" // ppc64x:"ISEL",-"ORN" - // riscv64:"SRL",-"AND\t",-"SLTIU" + // riscv64:"SRLW","SLTIU","NEG","AND\t",-"SRL\t" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } +func rsh5Mask32Ux64(v uint32, s uint64) uint32 { + // riscv64:"SRLW",-"AND\t",-"SLTIU",-"SRL\t" + return v >> (s & 31) +} + func rshMask32x64(v int32, s uint64) int32 { // arm64:"ASR",-"AND" // ppc64x:"ISEL",-"ORN" @@ -127,7 +132,7 @@ func lshMask64x32(v int64, s uint32) int64 { func rshMask64Ux32(v uint64, s uint32) uint64 { // arm64:"LSR",-"AND",-"CSEL" // ppc64x:"ANDCC",-"ORN" - // riscv64:"SRL",-"AND\t",-"SLTIU" + // riscv64:"SRL\t",-"AND\t",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> (s & 63) } @@ -149,7 +154,7 @@ func lshMask64x32Ext(v int64, s int32) int64 { func rshMask64Ux32Ext(v uint64, s int32) uint64 { // ppc64x:"ANDCC",-"ORN",-"ISEL" - // riscv64:"SRL",-"AND\t",-"SLTIU" + // riscv64:"SRL\t",-"AND\t",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" return v >> uint(s&63) } @@ -206,7 +211,7 @@ func lshGuarded64(v int64, s uint) int64 { func rshGuarded64U(v uint64, s uint) uint64 { if s < 64 { - // riscv64:"SRL",-"AND",-"SLTIU" + // riscv64:"SRL\t",-"AND",-"SLTIU" // s390x:-"RISBGZ",-"AND",-"LOCGR" // wasm:-"Select",-".*LtU" // arm64:"LSR",-"CSEL"