mirror of https://github.com/golang/go.git
crypto/internal/fips/bigmod: add support for even moduli
It doesn't need to be fast because we will only use it for RSA key generation / precomputation / validation. Change-Id: If4f5d0d4ac350939b69561e75dec5791db77f68c Reviewed-on: https://go-review.googlesource.com/c/go/+/630515 Reviewed-by: Russ Cox <rsc@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
parent
10fb001c75
commit
0598229d97
|
|
@ -297,19 +297,20 @@ func (x *Nat) sub(y *Nat) (c uint) {
|
|||
|
||||
// Modulus is used for modular arithmetic, precomputing relevant constants.
|
||||
//
|
||||
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
|
||||
// number of bits needed to store their value, and are stored without padding.
|
||||
//
|
||||
// Their actual value is still kept secret.
|
||||
// A Modulus can leak the exact number of bits needed to store its value
|
||||
// and is stored without padding. Its actual value is still kept secret.
|
||||
type Modulus struct {
|
||||
// The underlying natural number for this modulus.
|
||||
//
|
||||
// This will be stored without any padding, and shouldn't alias with any
|
||||
// other natural number being used.
|
||||
nat *Nat
|
||||
leading int // number of leading zeros in the modulus
|
||||
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
||||
rr *Nat // R*R for montgomeryRepresentation
|
||||
leading int // number of leading zeros in the modulus
|
||||
|
||||
// If m is even, the following fields are not set.
|
||||
odd bool
|
||||
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
||||
rr *Nat // R*R for montgomeryRepresentation
|
||||
}
|
||||
|
||||
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
|
||||
|
|
@ -380,17 +381,20 @@ func minusInverseModW(x uint) uint {
|
|||
|
||||
// NewModulus creates a new Modulus from a slice of big-endian bytes.
|
||||
//
|
||||
// The value must be odd. The number of significant bits (and nothing else) is
|
||||
// leaked through timing side-channels.
|
||||
// The number of significant bits and whether the modulus is even is leaked
|
||||
// through timing side-channels.
|
||||
func NewModulus(b []byte) (*Modulus, error) {
|
||||
if len(b) == 0 || b[len(b)-1]&1 != 1 {
|
||||
return nil, errors.New("modulus must be > 0 and odd")
|
||||
}
|
||||
m := &Modulus{}
|
||||
m.nat = NewNat().resetToBytes(b)
|
||||
if len(m.nat.limbs) == 0 {
|
||||
return nil, errors.New("modulus must be > 0")
|
||||
}
|
||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||
m.rr = rr(m)
|
||||
if m.nat.limbs[0]&1 == 1 {
|
||||
m.odd = true
|
||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||
m.rr = rr(m)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
|
|
@ -719,17 +723,71 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
|
|||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||
// takes the result out of Montgomery representation.
|
||||
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||
if m.odd {
|
||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||
// takes the result out of Montgomery representation.
|
||||
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||
}
|
||||
|
||||
n := len(m.nat.limbs)
|
||||
xLimbs := x.limbs[:n]
|
||||
yLimbs := y.limbs[:n]
|
||||
|
||||
switch n {
|
||||
default:
|
||||
// Attempt to use a stack-allocated backing array.
|
||||
T := make([]uint, 0, preallocLimbs*2)
|
||||
if cap(T) < n*2 {
|
||||
T = make([]uint, 0, n*2)
|
||||
}
|
||||
T = T[:n*2]
|
||||
|
||||
// T = x * y
|
||||
for i := 0; i < n; i++ {
|
||||
T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
|
||||
}
|
||||
|
||||
// x = T mod m
|
||||
return x.Mod(&Nat{limbs: T}, m)
|
||||
|
||||
// The following specialized cases follow the exact same algorithm, but
|
||||
// optimized for the sizes most used in RSA. See montgomeryMul for details.
|
||||
case 1024 / _W:
|
||||
const n = 1024 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
for i := 0; i < n; i++ {
|
||||
T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
|
||||
}
|
||||
return x.Mod(&Nat{limbs: T}, m)
|
||||
case 1536 / _W:
|
||||
const n = 1536 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
for i := 0; i < n; i++ {
|
||||
T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
|
||||
}
|
||||
return x.Mod(&Nat{limbs: T}, m)
|
||||
case 2048 / _W:
|
||||
const n = 2048 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
for i := 0; i < n; i++ {
|
||||
T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
|
||||
}
|
||||
return x.Mod(&Nat{limbs: T}, m)
|
||||
}
|
||||
}
|
||||
|
||||
// Exp calculates out = x^e mod m.
|
||||
//
|
||||
// The exponent e is represented in big-endian order. The output will be resized
|
||||
// to the size of m and overwritten. x must already be reduced modulo m.
|
||||
//
|
||||
// m must be odd, or Exp will panic.
|
||||
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||
if !m.odd {
|
||||
panic("bigmod: modulus for Exp must be odd")
|
||||
}
|
||||
|
||||
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
||||
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
||||
// Using bit sizes that don't divide 8 are more complex to implement, but
|
||||
|
|
@ -778,7 +836,12 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
|||
//
|
||||
// The output will be resized to the size of m and overwritten. x must already
|
||||
// be reduced modulo m. This leaks the exponent through timing side-channels.
|
||||
//
|
||||
// m must be odd, or ExpShortVarTime will panic.
|
||||
func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
||||
if !m.odd {
|
||||
panic("bigmod: modulus for ExpShortVarTime must be odd")
|
||||
}
|
||||
// For short exponents, precomputing a table and using a window like in Exp
|
||||
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
||||
// chain, skipping the initial run of zeroes.
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
package bigmod
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
cryptorand "crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
|
|
@ -352,6 +354,56 @@ func TestMulReductions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMul(t *testing.T) {
|
||||
t.Run("small", func(t *testing.T) { testMul(t, 760/8) })
|
||||
t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
|
||||
t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
|
||||
t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
|
||||
}
|
||||
|
||||
func testMul(t *testing.T, n int) {
|
||||
a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
|
||||
cryptorand.Read(a)
|
||||
cryptorand.Read(b)
|
||||
cryptorand.Read(m)
|
||||
|
||||
// Pick the highest as the modulus.
|
||||
if bytes.Compare(a, m) > 0 {
|
||||
a, m = m, a
|
||||
}
|
||||
if bytes.Compare(b, m) > 0 {
|
||||
b, m = m, b
|
||||
}
|
||||
|
||||
M, err := NewModulus(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
A, err := NewNat().SetBytes(a, M)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
B, err := NewNat().SetBytes(b, M)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
A.Mul(B, M)
|
||||
ABytes := A.Bytes(M)
|
||||
|
||||
mBig := new(big.Int).SetBytes(m)
|
||||
aBig := new(big.Int).SetBytes(a)
|
||||
bBig := new(big.Int).SetBytes(b)
|
||||
nBig := new(big.Int).Mul(aBig, bBig)
|
||||
nBig.Mod(nBig, mBig)
|
||||
nBigBytes := make([]byte, len(ABytes))
|
||||
nBig.FillBytes(nBigBytes)
|
||||
|
||||
if !bytes.Equal(ABytes, nBigBytes) {
|
||||
t.Errorf("got %x, want %x", ABytes, nBigBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func natBytes(n *Nat) []byte {
|
||||
return n.Bytes(maxModulus(uint(len(n.limbs))))
|
||||
}
|
||||
|
|
@ -480,7 +532,7 @@ func BenchmarkExp(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestNewModulus(t *testing.T) {
|
||||
expected := "modulus must be > 0 and odd"
|
||||
expected := "modulus must be > 0"
|
||||
_, err := NewModulus([]byte{})
|
||||
if err == nil || err.Error() != expected {
|
||||
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
|
||||
|
|
@ -493,10 +545,6 @@ func TestNewModulus(t *testing.T) {
|
|||
if err == nil || err.Error() != expected {
|
||||
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
|
||||
}
|
||||
_, err = NewModulus([]byte{1, 1, 1, 1, 2})
|
||||
if err == nil || err.Error() != expected {
|
||||
t.Errorf("NewModulus(2) got %q, want %q", err, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestValue(nbits int) []uint {
|
||||
|
|
|
|||
Loading…
Reference in New Issue