crypto/internal/fips/bigmod: add support for even moduli

It doesn't need to be fast because we will only use it for RSA key
generation / precomputation / validation.

Change-Id: If4f5d0d4ac350939b69561e75dec5791db77f68c
Reviewed-on: https://go-review.googlesource.com/c/go/+/630515
Reviewed-by: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Filippo Valsorda 2024-11-19 12:57:55 +01:00 committed by Gopher Robot
parent 10fb001c75
commit 0598229d97
2 changed files with 134 additions and 23 deletions

View File

@ -297,19 +297,20 @@ func (x *Nat) sub(y *Nat) (c uint) {
// Modulus is used for modular arithmetic, precomputing relevant constants.
//
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
// number of bits needed to store their value, and are stored without padding.
//
// Their actual value is still kept secret.
// A Modulus can leak the exact number of bits needed to store its value
// and is stored without padding. Its actual value is still kept secret.
type Modulus struct {
// The underlying natural number for this modulus.
//
// This will be stored without any padding, and shouldn't alias with any
// other natural number being used.
nat *Nat
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
leading int // number of leading zeros in the modulus
// If m is even, the following fields are not set.
odd bool
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
@ -380,17 +381,20 @@ func minusInverseModW(x uint) uint {
// NewModulus creates a new Modulus from a slice of big-endian bytes.
//
// The value must be odd. The number of significant bits (and nothing else) is
// leaked through timing side-channels.
// The number of significant bits and whether the modulus is even is leaked
// through timing side-channels.
func NewModulus(b []byte) (*Modulus, error) {
if len(b) == 0 || b[len(b)-1]&1 != 1 {
return nil, errors.New("modulus must be > 0 and odd")
}
m := &Modulus{}
m.nat = NewNat().resetToBytes(b)
if len(m.nat.limbs) == 0 {
return nil, errors.New("modulus must be > 0")
}
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
if m.nat.limbs[0]&1 == 1 {
m.odd = true
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
}
return m, nil
}
@ -719,17 +723,71 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
if m.odd {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
n := len(m.nat.limbs)
xLimbs := x.limbs[:n]
yLimbs := y.limbs[:n]
switch n {
default:
// Attempt to use a stack-allocated backing array.
T := make([]uint, 0, preallocLimbs*2)
if cap(T) < n*2 {
T = make([]uint, 0, n*2)
}
T = T[:n*2]
// T = x * y
for i := 0; i < n; i++ {
T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
}
// x = T mod m
return x.Mod(&Nat{limbs: T}, m)
// The following specialized cases follow the exact same algorithm, but
// optimized for the sizes most used in RSA. See montgomeryMul for details.
case 1024 / _W:
const n = 1024 / _W // compiler hint
T := make([]uint, n*2)
for i := 0; i < n; i++ {
T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
}
return x.Mod(&Nat{limbs: T}, m)
case 1536 / _W:
const n = 1536 / _W // compiler hint
T := make([]uint, n*2)
for i := 0; i < n; i++ {
T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
}
return x.Mod(&Nat{limbs: T}, m)
case 2048 / _W:
const n = 2048 / _W // compiler hint
T := make([]uint, n*2)
for i := 0; i < n; i++ {
T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
}
return x.Mod(&Nat{limbs: T}, m)
}
}
// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
// to the size of m and overwritten. x must already be reduced modulo m.
//
// m must be odd, or Exp will panic.
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
if !m.odd {
panic("bigmod: modulus for Exp must be odd")
}
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement, but
@ -778,7 +836,12 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
//
// The output will be resized to the size of m and overwritten. x must already
// be reduced modulo m. This leaks the exponent through timing side-channels.
//
// m must be odd, or ExpShortVarTime will panic.
func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
if !m.odd {
panic("bigmod: modulus for ExpShortVarTime must be odd")
}
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes.

View File

@ -5,6 +5,8 @@
package bigmod
import (
"bytes"
cryptorand "crypto/rand"
"fmt"
"math/big"
"math/bits"
@ -352,6 +354,56 @@ func TestMulReductions(t *testing.T) {
}
}
func TestMul(t *testing.T) {
t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
}
func testMul(t *testing.T, n int) {
a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
cryptorand.Read(a)
cryptorand.Read(b)
cryptorand.Read(m)
// Pick the highest as the modulus.
if bytes.Compare(a, m) > 0 {
a, m = m, a
}
if bytes.Compare(b, m) > 0 {
b, m = m, b
}
M, err := NewModulus(m)
if err != nil {
t.Fatal(err)
}
A, err := NewNat().SetBytes(a, M)
if err != nil {
t.Fatal(err)
}
B, err := NewNat().SetBytes(b, M)
if err != nil {
t.Fatal(err)
}
A.Mul(B, M)
ABytes := A.Bytes(M)
mBig := new(big.Int).SetBytes(m)
aBig := new(big.Int).SetBytes(a)
bBig := new(big.Int).SetBytes(b)
nBig := new(big.Int).Mul(aBig, bBig)
nBig.Mod(nBig, mBig)
nBigBytes := make([]byte, len(ABytes))
nBig.FillBytes(nBigBytes)
if !bytes.Equal(ABytes, nBigBytes) {
t.Errorf("got %x, want %x", ABytes, nBigBytes)
}
}
func natBytes(n *Nat) []byte {
return n.Bytes(maxModulus(uint(len(n.limbs))))
}
@ -480,7 +532,7 @@ func BenchmarkExp(b *testing.B) {
}
func TestNewModulus(t *testing.T) {
expected := "modulus must be > 0 and odd"
expected := "modulus must be > 0"
_, err := NewModulus([]byte{})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
@ -493,10 +545,6 @@ func TestNewModulus(t *testing.T) {
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
}
_, err = NewModulus([]byte{1, 1, 1, 1, 2})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(2) got %q, want %q", err, expected)
}
}
func makeTestValue(nbits int) []uint {