diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 436c108c96..3de32d27e9 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -728,8 +728,21 @@ func (x nat) trailingZeroBits() uint { return i*_W + uint(bits.TrailingZeros(uint(x[i]))) } +func same(x, y nat) bool { + return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] +} + // z = x << s func (z nat) shl(x nat, s uint) nat { + if s == 0 { + if same(z, x) { + return z + } + if !alias(z, x) { + return z.set(x) + } + } + m := len(x) if m == 0 { return z[:0] @@ -746,6 +759,15 @@ func (z nat) shl(x nat, s uint) nat { // z = x >> s func (z nat) shr(x nat, s uint) nat { + if s == 0 { + if same(z, x) { + return z + } + if !alias(z, x) { + return z.set(x) + } + } + m := len(x) n := m - int(s/_W) if n <= 0 { diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 9bb96b1157..0b94db3476 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -267,6 +267,34 @@ func TestShiftRight(t *testing.T) { } } +func BenchmarkZeroShifts(b *testing.B) { + x := rndNat(800) + + b.Run("Shl", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var z nat + z.shl(x, 0) + } + }) + b.Run("ShlSame", func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.shl(x, 0) + } + }) + + b.Run("Shr", func(b *testing.B) { + for i := 0; i < b.N; i++ { + var z nat + z.shr(x, 0) + } + }) + b.Run("ShrSame", func(b *testing.B) { + for i := 0; i < b.N; i++ { + x.shr(x, 0) + } + }) +} + type modWTest struct { in string dividend string