diff --git a/src/math/big/arith_test.go b/src/math/big/arith_test.go index 64225bbd53..feffa1bc95 100644 --- a/src/math/big/arith_test.go +++ b/src/math/big/arith_test.go @@ -368,9 +368,12 @@ func TestShiftOverlap(t *testing.T) { } func TestIssue31084(t *testing.T) { + stk := getStack() + defer stk.free() + // compute 10^n via 5^n << n. const n = 165 - p := nat(nil).expNN(nat{5}, nat{n}, nil, false) + p := nat(nil).expNN(stk, nat{5}, nat{n}, nil, false) p = p.shl(p, n) got := string(p.utoa(10)) want := "1" + strings.Repeat("0", n) diff --git a/src/math/big/float.go b/src/math/big/float.go index e1d20d8bb4..2c5234a4ce 100644 --- a/src/math/big/float.go +++ b/src/math/big/float.go @@ -1327,9 +1327,9 @@ func (z *Float) umul(x, y *Float) { e := int64(x.exp) + int64(y.exp) if x == y { - z.mant = z.mant.sqr(x.mant) + z.mant = z.mant.sqr(nil, x.mant) } else { - z.mant = z.mant.mul(x.mant, y.mant) + z.mant = z.mant.mul(nil, x.mant, y.mant) } z.setExpAndRound(e-fnorm(z.mant), 0) } @@ -1363,8 +1363,10 @@ func (z *Float) uquo(x, y *Float) { d := len(xadj) - len(y.mant) // divide + stk := getStack() + defer stk.free() var r nat - z.mant, r = z.mant.div(nil, xadj, y.mant) + z.mant, r = z.mant.div(stk, nil, xadj, y.mant) e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W // The result is long enough to include (at least) the rounding bit. diff --git a/src/math/big/int.go b/src/math/big/int.go index 0b710c6968..cb7221250d 100644 --- a/src/math/big/int.go +++ b/src/math/big/int.go @@ -181,16 +181,20 @@ func (z *Int) Sub(x, y *Int) *Int { // Mul sets z to the product x*y and returns z. func (z *Int) Mul(x, y *Int) *Int { + return z.mul(nil, x, y) +} + +func (z *Int) mul(stk *stack, x, y *Int) *Int { // x * y == x * y // x * (-y) == -(x * y) // (-x) * y == -(x * y) // (-x) * (-y) == x * y if x == y { - z.abs = z.abs.sqr(x.abs) + z.abs = z.abs.sqr(stk, x.abs) z.neg = false return z } - z.abs = z.abs.mul(x.abs, y.abs) + z.abs = z.abs.mul(stk, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign return z } @@ -213,7 +217,7 @@ func (z *Int) MulRange(a, b int64) *Int { a, b = -b, -a } - z.abs = z.abs.mulRange(uint64(a), uint64(b)) + z.abs = z.abs.mulRange(nil, uint64(a), uint64(b)) z.neg = neg return z } @@ -264,7 +268,7 @@ func (z *Int) Binomial(n, k int64) *Int { // If y == 0, a division-by-zero run-time panic occurs. // Quo implements truncated division (like Go); see [Int.QuoRem] for more details. func (z *Int) Quo(x, y *Int) *Int { - z.abs, _ = z.abs.div(nil, x.abs, y.abs) + z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign return z } @@ -273,7 +277,7 @@ func (z *Int) Quo(x, y *Int) *Int { // If y == 0, a division-by-zero run-time panic occurs. // Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details. func (z *Int) Rem(x, y *Int) *Int { - _, z.abs = nat(nil).div(z.abs, x.abs, y.abs) + _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg // 0 has no sign return z } @@ -290,7 +294,7 @@ func (z *Int) Rem(x, y *Int) *Int { // (See Daan Leijen, “Division and Modulus for Computer Scientists”.) // See [Int.DivMod] for Euclidean division and modulus (unlike Go). func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) { - z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs) + z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs) z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign return z, r } @@ -589,7 +593,7 @@ func (z *Int) exp(x, y, m *Int, slow bool) *Int { mWords = m.abs // m.abs may be nil for m == 0 } - z.abs = z.abs.expNN(xWords, yWords, mWords, slow) + z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow) z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign if z.neg && len(mWords) > 0 { // make modulus result positive @@ -1298,6 +1302,6 @@ func (z *Int) Sqrt(x *Int) *Int { panic("square root of negative number") } z.neg = false - z.abs = z.abs.sqrt(x.abs) + z.abs = z.abs.sqrt(nil, x.abs) return z } diff --git a/src/math/big/nat.go b/src/math/big/nat.go index 541da229d6..ec75c8f6fd 100644 --- a/src/math/big/nat.go +++ b/src/math/big/nat.go @@ -17,6 +17,7 @@ import ( "internal/byteorder" "math/bits" "math/rand" + "slices" "sync" ) @@ -262,9 +263,9 @@ var karatsubaThreshold = 40 // computed by calibrate_test.go // karatsuba multiplies x and y and leaves the result in z. // Both x and y must have the same length n and n must be a -// power of 2. The result vector z must have len(z) >= 6*n. -// The (non-normalized) result is placed in z[0 : 2*n]. -func karatsuba(z, x, y nat) { +// power of 2. The result vector z must have len(z) == len(x)+len(y). +// The (non-normalized) result is placed in z. +func karatsuba(stk *stack, z, x, y nat) { n := len(y) // Switch to basic multiplication if numbers are odd or small. @@ -304,29 +305,19 @@ func karatsuba(z, x, y nat) { x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 - // z is used for the result and temporary storage: - // - // 6*n 5*n 4*n 3*n 2*n 1*n 0*n - // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] - // - // For each recursive call of karatsuba, an unused slice of - // z is passed in that has (at least) half the length of the - // caller's z. - // compute z0 and z2 with the result "in place" in z - karatsuba(z, x0, y0) // z0 = x0*y0 - karatsuba(z[n:], x1, y1) // z2 = x1*y1 + karatsuba(stk, z, x0, y0) // z0 = x0*y0 + karatsuba(stk, z[n:], x1, y1) // z2 = x1*y1 - // compute xd (or the negative value if underflow occurs) + // compute xd, yd (or the negative value if underflow occurs) s := 1 // sign of product xd*yd - xd := z[2*n : 2*n+n2] + defer stk.restore(stk.save()) + xd := stk.nat(n2) + yd := stk.nat(n2) if subVV(xd, x1, x0) != 0 { // x1-x0 s = -s subVV(xd, x0, x1) // x0-x1 } - - // compute yd (or the negative value if underflow occurs) - yd := z[2*n+n2 : 3*n] if subVV(yd, y0, y1) != 0 { // y0-y1 s = -s subVV(yd, y1, y0) // y1-y0 @@ -334,12 +325,12 @@ func karatsuba(z, x, y nat) { // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 - p := z[n*3:] - karatsuba(p, xd, yd) + p := stk.nat(2 * n2) + karatsuba(stk, p, xd, yd) // save original z2:z0 // (ok to use upper half of z since we're done recurring) - r := z[n*4:] + r := stk.nat(n * 2) copy(r, z[:n*2]) // add up all partial products @@ -396,13 +387,15 @@ func karatsubaLen(n, threshold int) int { return n << i } -func (z nat) mul(x, y nat) nat { +// mul sets z = x*y, using stk for temporary storage. +// The caller may pass stk == nil to request that mul obtain and release one itself. +func (z nat) mul(stk *stack, x, y nat) nat { m := len(x) n := len(y) switch { case m < n: - return z.mul(y, x) + return z.mul(stk, y, x) case m == 0 || n == 0: return z[:0] case n == 1: @@ -432,12 +425,16 @@ func (z nat) mul(x, y nat) nat { k := karatsubaLen(n, karatsubaThreshold) // k <= n + if stk == nil { + stk = getStack() + defer stk.free() + } + // multiply x0 and y0 via Karatsuba - x0 := x[0:k] // x0 is not normalized - y0 := y[0:k] // y0 is not normalized - z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y - karatsuba(z, x0, y0) - z = z[0 : m+n] // z has final length but may be incomplete + x0 := x[0:k] // x0 is not normalized + y0 := y[0:k] // y0 is not normalized + z = z.make(m + n) // enough space for full result of x*y + karatsuba(stk, z, x0, y0) clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) // If xh != 0 or yh != 0, add the missing terms to z. For @@ -454,13 +451,13 @@ func (z nat) mul(x, y nat) nat { // be a larger valid threshold contradicting the assumption about k. // if k < n || m != n { - tp := getNat(3 * k) - t := *tp + defer stk.restore(stk.save()) + t := stk.nat(3 * k) // add x0*y1*b x0 := x0.norm() - y1 := y[k:] // y1 is normalized because y is - t = t.mul(x0, y1) // update t so we don't lose t's underlying array + y1 := y[k:] // y1 is normalized because y is + t = t.mul(stk, x0, y1) // update t so we don't lose t's underlying array addAt(z, t, k) // add xi*y0< 0, len(z) == 2*len(x) // The (non-normalized) result is placed in z. -func basicSqr(z, x nat) { +func basicSqr(stk *stack, z, x nat) { n := len(x) - tp := getNat(2 * n) - t := *tp // temporary variable to hold the products + defer stk.restore(stk.save()) + t := stk.nat(2 * n) clear(t) z[1], z[0] = mulWW(x[0], x[0]) // the initial square for i := 1; i < n; i++ { @@ -502,38 +497,37 @@ func basicSqr(z, x nat) { } t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products addVV(z, z, t) // combine the result - putNat(tp) } // karatsubaSqr squares x and leaves the result in z. -// len(x) must be a power of 2 and len(z) >= 6*len(x). -// The (non-normalized) result is placed in z[0 : 2*len(x)]. +// len(x) must be a power of 2 and len(z) == 2*len(x). +// The (non-normalized) result is placed in z. // // The algorithm and the layout of z are the same as for karatsuba. -func karatsubaSqr(z, x nat) { +func karatsubaSqr(stk *stack, z, x nat) { n := len(x) if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { - basicSqr(z[:2*n], x) + basicSqr(stk, z[:2*n], x) return } n2 := n >> 1 x1, x0 := x[n2:], x[0:n2] - karatsubaSqr(z, x0) - karatsubaSqr(z[n:], x1) + karatsubaSqr(stk, z, x0) + karatsubaSqr(stk, z[n:], x1) // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 - xd := z[2*n : 2*n+n2] + defer stk.restore(stk.save()) + p := stk.nat(2 * n2) + r := stk.nat(n * 2) + xd := r[:n2] if subVV(xd, x1, x0) != 0 { subVV(xd, x0, x1) } - p := z[n*3:] - karatsubaSqr(p, xd) - - r := z[n*4:] + karatsubaSqr(stk, p, xd) copy(r, z[:n*2]) karatsubaAdd(z[n2:], r, n) @@ -547,8 +541,9 @@ func karatsubaSqr(z, x nat) { var basicSqrThreshold = 20 // computed by calibrate_test.go var karatsubaSqrThreshold = 260 // computed by calibrate_test.go -// z = x*x -func (z nat) sqr(x nat) nat { +// sqr sets z = x*x, using stk for temporary storage. +// The caller may pass stk == nil to request that sqr obtain and release one itself. +func (z nat) sqr(stk *stack, x nat) nat { n := len(x) switch { case n == 0: @@ -563,15 +558,20 @@ func (z nat) sqr(x nat) nat { if alias(z, x) { z = nil // z is an alias for x - cannot reuse } + z = z.make(2 * n) if n < basicSqrThreshold { - z = z.make(2 * n) basicMul(z, x, x) return z.norm() } + + if stk == nil { + stk = getStack() + defer stk.free() + } + if n < karatsubaSqrThreshold { - z = z.make(2 * n) - basicSqr(z, x) + basicSqr(stk, z, x) return z.norm() } @@ -583,22 +583,18 @@ func (z nat) sqr(x nat) nat { k := karatsubaLen(n, karatsubaSqrThreshold) x0 := x[0:k] - z = z.make(max(6*k, 2*n)) - karatsubaSqr(z, x0) // z = x0^2 - z = z[0 : 2*n] + karatsubaSqr(stk, z, x0) // z = x0^2 clear(z[2*k:]) if k < n { - tp := getNat(2 * k) - t := *tp + t := stk.nat(2 * k) x0 := x0.norm() x1 := x[k:] - t = t.mul(x0, x1) + t = t.mul(stk, x0, x1) addAt(z, t, k) addAt(z, t, k) // z = 2*x1*x0*b + x0^2 - t = t.sqr(x1) + t = t.sqr(stk, x1) addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 - putNat(tp) } return z.norm() @@ -606,7 +602,8 @@ func (z nat) sqr(x nat) nat { // mulRange computes the product of all the unsigned integers in the // range [a, b] inclusively. If a > b (empty range), the result is 1. -func (z nat) mulRange(a, b uint64) nat { +// The caller may pass stk == nil to request that mulRange obtain and release one itself. +func (z nat) mulRange(stk *stack, a, b uint64) nat { switch { case a == 0: // cut long ranges short (optimization) @@ -616,35 +613,80 @@ func (z nat) mulRange(a, b uint64) nat { case a == b: return z.setUint64(a) case a+1 == b: - return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) + return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b)) } + + if stk == nil { + stk = getStack() + defer stk.free() + } + m := a + (b-a)/2 // avoid overflow - return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) + return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b)) } -// getNat returns a *nat of len n. The contents may not be zero. -// The pool holds *nat to avoid allocation when converting to interface{}. -func getNat(n int) *nat { - var z *nat - if v := natPool.Get(); v != nil { - z = v.(*nat) +// A stack provides temporary storage for complex calculations +// such as multiplication and division. +// The stack is a simple slice of words, extended as needed +// to hold all the temporary storage for a calculation. +// In general, if a function takes a *stack, it expects a non-nil *stack. +// However, certain functions may allow passing a nil *stack instead, +// so that they can handle trivial stack-free cases without forcing the +// caller to obtain and free a stack that will be unused. These functions +// document that they accept a nil *stack in their doc comments. +type stack struct { + w []Word +} + +var stackPool sync.Pool + +// getStack returns a temporary stack. +// The caller must call [stack.free] to give up use of the stack when finished. +func getStack() *stack { + s, _ := stackPool.Get().(*stack) + if s == nil { + s = new(stack) } - if z == nil { - z = new(nat) - } - *z = z.make(n) + return s +} + +// free returns the stack for use by another calculation. +func (s *stack) free() { + s.w = s.w[:0] + stackPool.Put(s) +} + +// save returns the current stack pointer. +// A future call to restore with the same value +// frees any temporaries allocated on the stack after the call to save. +func (s *stack) save() int { + return len(s.w) +} + +// restore restores the stack pointer to n. +// It is almost always invoked as +// +// defer stk.restore(stk.save()) +// +// which makes sure to pop any temporaries allocated in the current function +// from the stack before returning. +func (s *stack) restore(n int) { + s.w = s.w[:n] +} + +// nat returns a nat of n words, allocated on the stack. +func (s *stack) nat(n int) nat { + nr := (n + 3) &^ 3 // round up to multiple of 4 + off := len(s.w) + s.w = slices.Grow(s.w, nr) + s.w = s.w[:off+nr] + x := s.w[off : off+n : off+n] if n > 0 { - (*z)[0] = 0xfedcb // break code expecting zero + x[0] = 0xfedcb } - return z + return x } -func putNat(x *nat) { - natPool.Put(x) -} - -var natPool sync.Pool - // bitLen returns the length of x in bits. // Unlike most methods, it works even if x is not normalized. func (x nat) bitLen() int { @@ -930,7 +972,8 @@ func (z nat) random(rand *rand.Rand, limit nat, n int) nat { // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; // otherwise it sets z to x**y. The result is the value of z. -func (z nat) expNN(x, y, m nat, slow bool) nat { +// The caller may pass stk == nil to request that expNN obtain and release one itself. +func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat { if alias(z, x) || alias(z, y) { // We cannot allow in-place modification of x or y. z = nil @@ -961,12 +1004,17 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // x > 1 // x**1 == x - if len(y) == 1 && y[0] == 1 { - if len(m) != 0 { - return z.rem(x, m) - } + if len(y) == 1 && y[0] == 1 && len(m) == 0 { return z.set(x) } + if stk == nil { + stk = getStack() + defer stk.free() + } + if len(y) == 1 && y[0] == 1 { // len(m) > 0 + return z.rem(stk, x, m) + } + // y > 1 if len(m) != 0 { @@ -980,12 +1028,12 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // instance of each of the first two cases). if len(y) > 1 && !slow { if m[0]&1 == 1 { - return z.expNNMontgomery(x, y, m) + return z.expNNMontgomery(stk, x, y, m) } if logM, ok := m.isPow2(); ok { - return z.expNNWindowed(x, y, logM) + return z.expNNWindowed(stk, x, y, logM) } - return z.expNNMontgomeryEven(x, y, m) + return z.expNNMontgomeryEven(stk, x, y, m) } } @@ -1006,16 +1054,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // otherwise the arguments would alias. var zz, r nat for j := 0; j < w; j++ { - zz = zz.sqr(z) + zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { - zz = zz.mul(z, x) + zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { - zz, r = zz.div(r, z, m) + zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } @@ -1026,16 +1074,16 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { v = y[i] for j := 0; j < _W; j++ { - zz = zz.sqr(z) + zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { - zz = zz.mul(z, x) + zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { - zz, r = zz.div(r, z, m) + zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } @@ -1054,7 +1102,7 @@ func (z nat) expNN(x, y, m nat, slow bool) nat { // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf -func (z nat) expNNMontgomeryEven(x, y, m nat) nat { +func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat { // Split m = m₁ × m₂ where m₁ = 2ⁿ n := m.trailingZeroBits() m1 := nat(nil).shl(natOne, n) @@ -1066,8 +1114,8 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat { // (We are using the math/big convention for names here, // where the computation is z = x**y mod m, so its parts are z1 and z2. // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) - z1 := nat(nil).expNN(x, y, m1, false) - z2 := nat(nil).expNN(x, y, m2, false) + z1 := nat(nil).expNN(stk, x, y, m1, false) + z2 := nat(nil).expNN(stk, x, y, m2, false) // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, // which uses only a single modInverse (and an easy one at that). @@ -1086,18 +1134,18 @@ func (z nat) expNNMontgomeryEven(x, y, m nat) nat { // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). m2inv := nat(nil).modInverse(m2, m1) - z2 = z2.mul(z1, m2inv) + z2 = z2.mul(stk, z1, m2inv) z2 = z2.trunc(z2, n) // Reuse z1 for p * m2. - z = z.add(z, z1.mul(z2, m2)) + z = z.add(z, z1.mul(stk, z2, m2)) return z } // expNNWindowed calculates x**y mod m using a fixed, 4-bit window, // where m = 2**logM. -func (z nat) expNNWindowed(x, y nat, logM uint) nat { +func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat { if len(y) <= 1 { panic("big: misuse of expNNWindowed") } @@ -1112,23 +1160,23 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat { // zz is used to avoid allocating in mul as otherwise // the arguments would alias. + defer stk.restore(stk.save()) w := int((logM + _W - 1) / _W) - zzp := getNat(w) - zz := *zzp + zz := stk.nat(w) const n = 4 // powers[i] contains x^i. - var powers [1 << n]*nat + var powers [1 << n]nat for i := range powers { - powers[i] = getNat(w) + powers[i] = stk.nat(w) } - *powers[0] = powers[0].set(natOne) - *powers[1] = powers[1].trunc(x, logM) + powers[0] = powers[0].set(natOne) + powers[1] = powers[1].trunc(x, logM) for i := 2; i < 1<>(_W-n)]) + zz = zz.mul(stk, z, powers[yi>>(_W-n)]) zz, z = z, zz z = z.trunc(z, logM) @@ -1185,24 +1233,18 @@ func (z nat) expNNWindowed(x, y nat, logM uint) nat { } } - *zzp = zz - putNat(zzp) - for i := range powers { - putNat(powers[i]) - } - return z.norm() } // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. // Uses Montgomery representation. -func (z nat) expNNMontgomery(x, y, m nat) nat { +func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat { numWords := len(m) // We want the lengths of x and m to be equal. // It is OK if x >= m as long as len(x) == len(m). if len(x) > numWords { - _, x = nat(nil).div(nil, x, m) + _, x = nat(nil).div(stk, nil, x, m) // Note: now len(x) <= numWords, not guaranteed ==. } if len(x) < numWords { @@ -1225,7 +1267,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { // RR = 2**(2*_W*len(m)) mod m RR := nat(nil).setWord(1) zz := nat(nil).shl(RR, uint(2*numWords*_W)) - _, RR = nat(nil).div(RR, zz, m) + _, RR = nat(nil).div(stk, RR, zz, m) if len(RR) < numWords { zz = zz.make(numWords) copy(zz, RR) @@ -1280,7 +1322,7 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { // The div is not expected to be reached. zz = zz.sub(zz, m) if zz.cmp(m) >= 0 { - _, zz = nat(nil).div(nil, zz, m) + _, zz = nat(nil).div(stk, nil, zz, m) } } @@ -1349,7 +1391,8 @@ func (z nat) setBytes(buf []byte) nat { } // sqrt sets z = ⌊√x⌋ -func (z nat) sqrt(x nat) nat { +// The caller may pass stk == nil to request that sqrt obtain and release one itself. +func (z nat) sqrt(stk *stack, x nat) nat { if x.cmp(natOne) <= 0 { return z.set(x) } @@ -1357,6 +1400,11 @@ func (z nat) sqrt(x nat) nat { z = nil } + if stk == nil { + stk = getStack() + defer stk.free() + } + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). // https://members.loria.fr/PZimmermann/mca/pub226.html @@ -1367,7 +1415,7 @@ func (z nat) sqrt(x nat) nat { z1 = z1.setUint64(1) z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x for n := 0; ; n++ { - z2, _ = z2.div(nil, x, z1) + z2, _ = z2.div(stk, nil, x, z1) z2 = z2.add(z2, z1) z2 = z2.shr(z2, 1) if z2.cmp(z1) >= 0 { diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go index 46231f7976..1811dccfe3 100644 --- a/src/math/big/nat_test.go +++ b/src/math/big/nat_test.go @@ -42,6 +42,7 @@ func TestCmp(t *testing.T) { } type funNN func(z, x, y nat) nat +type funSNN func(z nat, stk *stack, x, y nat) nat type argNN struct { z, x, y nat } @@ -112,6 +113,15 @@ func testFunNN(t *testing.T, msg string, f funNN, a argNN) { } } +func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) { + stk := getStack() + defer stk.free() + z := f(nil, stk, a.x, a.y) + if z.cmp(a.z) != 0 { + t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z) + } +} + func TestFunNN(t *testing.T) { for _, a := range sumNN { arg := a @@ -129,10 +139,10 @@ func TestFunNN(t *testing.T) { for _, a := range prodNN { arg := a - testFunNN(t, "mul", nat.mul, arg) + testFunSNN(t, "mul", nat.mul, arg) arg = argNN{a.z, a.y, a.x} - testFunNN(t, "mul symmetric", nat.mul, arg) + testFunSNN(t, "mul symmetric", nat.mul, arg) } } @@ -163,8 +173,11 @@ var mulRangesN = []struct { } func TestMulRangeN(t *testing.T) { + stk := getStack() + defer stk.free() + for i, r := range mulRangesN { - prod := string(nat(nil).mulRange(r.a, r.b).utoa(10)) + prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10)) if prod != r.prod { t.Errorf("#%d: got %s; want %s", i, prod, r.prod) } @@ -185,11 +198,14 @@ func allocBytes(f func()) uint64 { // does not cause deep recursion and in turn allocate too much memory. // Test case for issue 3807. func TestMulUnbalanced(t *testing.T) { + stk := getStack() + defer stk.free() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) x := rndNat(50000) y := rndNat(40) allocSize := allocBytes(func() { - nat(nil).mul(x, y) + nat(nil).mul(stk, x, y) }) inputSize := uint64(len(x)+len(y)) * _S if ratio := allocSize / uint64(inputSize); ratio > 10 { @@ -214,12 +230,15 @@ func rndNat1(n int) nat { } func BenchmarkMul(b *testing.B) { + stk := getStack() + defer stk.free() + mulx := rndNat(1e4) muly := rndNat(1e4) b.ResetTimer() for i := 0; i < b.N; i++ { var z nat - z.mul(mulx, muly) + z.mul(stk, mulx, muly) } } @@ -230,7 +249,7 @@ func benchmarkNatMul(b *testing.B, nwords int) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - z.mul(x, y) + z.mul(nil, x, y) } } @@ -444,6 +463,9 @@ var montgomeryTests = []struct { } func TestMontgomery(t *testing.T) { + stk := getStack() + defer stk.free() + one := NewInt(1) _B := new(Int).Lsh(one, _W) for i, test := range montgomeryTests { @@ -458,11 +480,11 @@ func TestMontgomery(t *testing.T) { } if x.cmp(m) > 0 { - _, r := nat(nil).div(nil, x, m) + _, r := nat(nil).div(stk, nil, x, m) t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16)) } if y.cmp(m) > 0 { - _, r := nat(nil).div(nil, x, m) + _, r := nat(nil).div(stk, nil, x, m) t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16)) } @@ -538,6 +560,9 @@ var expNNTests = []struct { } func TestExpNN(t *testing.T) { + stk := getStack() + defer stk.free() + for i, test := range expNNTests { x := natFromString(test.x) y := natFromString(test.y) @@ -548,7 +573,7 @@ func TestExpNN(t *testing.T) { m = natFromString(test.m) } - z := nat(nil).expNN(x, y, m, false) + z := nat(nil).expNN(stk, x, y, m, false) if z.cmp(out) != 0 { t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10)) } @@ -572,6 +597,9 @@ func FuzzExpMont(f *testing.F) { } func BenchmarkExp3Power(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 3 for _, y := range []Word{ 0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000, @@ -579,7 +607,7 @@ func BenchmarkExp3Power(b *testing.B) { b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) { var z nat for i := 0; i < b.N; i++ { - z.expWW(x, y) + z.expWW(stk, x, y) } }) } @@ -712,10 +740,13 @@ func TestSticky(t *testing.T) { } func testSqr(t *testing.T, x nat) { + stk := getStack() + defer stk.free() + got := make(nat, 2*len(x)) want := make(nat, 2*len(x)) - got = got.sqr(x) - want = want.mul(x, x) + got = got.sqr(stk, x) + want = want.mul(stk, x, x) if got.cmp(want) != 0 { t.Errorf("basicSqr(%v), got %v, want %v", x, got, want) } @@ -741,7 +772,7 @@ func benchmarkNatSqr(b *testing.B, nwords int) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - z.sqr(x) + z.sqr(nil, x) } } @@ -830,6 +861,9 @@ func BenchmarkNatSetBytes(b *testing.B) { } func TestNatDiv(t *testing.T) { + stk := getStack() + defer stk.free() + sizes := []int{ 1, 2, 5, 8, 15, 25, 40, 65, 100, 200, 500, 800, 1500, 2500, 4000, 6500, 10000, @@ -849,11 +883,11 @@ func TestNatDiv(t *testing.T) { c = c.norm() } // compute x = a*b+c - x := nat(nil).mul(a, b) + x := nat(nil).mul(stk, a, b) x = x.add(x, c) var q, r nat - q, r = q.div(r, x, b) + q, r = q.div(stk, r, x, b) if q.cmp(a) != 0 { t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10)) } @@ -868,6 +902,9 @@ func TestNatDiv(t *testing.T) { // the inaccurate estimate of the first word's quotient // happens at the very beginning of the loop. func TestIssue37499(t *testing.T) { + stk := getStack() + defer stk.free() + // Choose u and v such that v is slightly larger than u >> N. // This tricks divBasic into choosing 1 as the first word // of the quotient. This works in both 32-bit and 64-bit settings. @@ -875,7 +912,7 @@ func TestIssue37499(t *testing.T) { v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c") q := nat(nil).make(8) - q.divBasic(u, v) + q.divBasic(stk, u, v) q = q.norm() if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" { t.Fatalf("incorrect quotient: %s", s) @@ -886,8 +923,11 @@ func TestIssue37499(t *testing.T) { // where the first division loop is never entered, and correcting // the remainder takes exactly two iterations in the final loop. func TestIssue42552(t *testing.T) { + stk := getStack() + defer stk.free() + u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000") v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") q := nat(nil).make(16) - q.div(q, u, v) + q.div(stk, q, u, v) } diff --git a/src/math/big/natconv.go b/src/math/big/natconv.go index ce94f2cf72..8a47ec9f9c 100644 --- a/src/math/big/natconv.go +++ b/src/math/big/natconv.go @@ -321,17 +321,20 @@ func (x nat) itoa(neg bool, base int) []byte { } } else { + stk := getStack() + defer stk.free() + bb, ndigits := maxPow(b) // construct table of successive squares of bb*leafSize to use in subdivisions // result (table != nil) <=> (len(x) > leafSize > 0) - table := divisors(len(x), b, ndigits, bb) + table := divisors(stk, len(x), b, ndigits, bb) // preserve x, create local copy for use by convertWords q := nat(nil).set(x) // convert q to string s in base b - q.convertWords(s, b, ndigits, bb, table) + q.convertWords(stk, s, b, ndigits, bb, table) // strip leading zeros // (x != 0; thus s must contain at least one non-zero digit @@ -365,7 +368,7 @@ func (x nat) itoa(neg bool, base int) []byte { // range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and // ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for // specific hardware. -func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) { +func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) { // split larger blocks recursively if table != nil { // len(q) > leafSize > 0 @@ -386,12 +389,12 @@ func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []diviso } // split q into the two digit number (q'*bbb + r) to form independent subblocks - q, r = q.div(r, q, table[index].bbb) + q, r = q.div(stk, r, q, table[index].bbb) // convert subblocks and collect results in s[:h] and s[h:] h := len(s) - table[index].ndigits - r.convertWords(s[h:], b, ndigits, bb, table[0:index]) - s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1]) + r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index]) + s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1]) } } @@ -451,12 +454,12 @@ var cacheBase10 struct { } // expWW computes x**y -func (z nat) expWW(x, y Word) nat { - return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false) +func (z nat) expWW(stk *stack, x, y Word) nat { + return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false) } // construct table of powers of bb*leafSize to use in subdivisions. -func divisors(m int, b Word, ndigits int, bb Word) []divisor { +func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor { // only compute table when recursive conversion is enabled and x is large if leafSize == 0 || m <= leafSize { return nil @@ -484,10 +487,10 @@ func divisors(m int, b Word, ndigits int, bb Word) []divisor { for i := 0; i < k; i++ { if table[i].ndigits == 0 { if i == 0 { - table[0].bbb = nat(nil).expWW(bb, Word(leafSize)) + table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize)) table[0].ndigits = ndigits * leafSize } else { - table[i].bbb = nat(nil).sqr(table[i-1].bbb) + table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb) table[i].ndigits = 2 * table[i-1].ndigits } diff --git a/src/math/big/natconv_test.go b/src/math/big/natconv_test.go index d390272108..66300e412b 100644 --- a/src/math/big/natconv_test.go +++ b/src/math/big/natconv_test.go @@ -350,6 +350,9 @@ func BenchmarkStringPiParallel(b *testing.B) { } func BenchmarkScan(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 10 for _, base := range []int{2, 8, 10, 16} { for _, y := range []Word{10, 100, 1000, 10000, 100000} { @@ -359,7 +362,7 @@ func BenchmarkScan(b *testing.B) { b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { b.StopTimer() var z nat - z = z.expWW(x, y) + z = z.expWW(stk, x, y) s := z.utoa(base) if t := itoa(z, base); !bytes.Equal(s, t) { @@ -376,6 +379,9 @@ func BenchmarkScan(b *testing.B) { } func BenchmarkString(b *testing.B) { + stk := getStack() + defer stk.free() + const x = 10 for _, base := range []int{2, 8, 10, 16} { for _, y := range []Word{10, 100, 1000, 10000, 100000} { @@ -385,7 +391,7 @@ func BenchmarkString(b *testing.B) { b.Run(fmt.Sprintf("%d/Base%d", y, base), func(b *testing.B) { b.StopTimer() var z nat - z = z.expWW(x, y) + z = z.expWW(stk, x, y) z.utoa(base) // warm divisor cache b.StartTimer() @@ -416,9 +422,11 @@ func LeafSizeHelper(b *testing.B, base, size int) { for d := 1; d <= 10000; d *= 10 { b.StopTimer() + stk := getStack() var z nat - z = z.expWW(Word(base), Word(d)) // build target number - _ = z.utoa(base) // warm divisor cache + z = z.expWW(stk, Word(base), Word(d)) // build target number + _ = z.utoa(base) // warm divisor cache + stk.free() b.StartTimer() for i := 0; i < b.N; i++ { @@ -443,13 +451,16 @@ func resetTable(table []divisor) { } func TestStringPowers(t *testing.T) { + stk := getStack() + defer stk.free() + var p Word for b := 2; b <= 16; b++ { for p = 0; p <= 512; p++ { if testing.Short() && p > 10 { break } - x := nat(nil).expWW(Word(b), p) + x := nat(nil).expWW(stk, Word(b), p) xs := x.utoa(b) xs2 := itoa(x, b) if !bytes.Equal(xs, xs2) { diff --git a/src/math/big/natdiv.go b/src/math/big/natdiv.go index 2e66e3425c..b514e2ce21 100644 --- a/src/math/big/natdiv.go +++ b/src/math/big/natdiv.go @@ -502,30 +502,24 @@ import "math/bits" // rem returns r such that r = u%v. // It uses z as the storage for r. -func (z nat) rem(u, v nat) (r nat) { +func (z nat) rem(stk *stack, u, v nat) (r nat) { if alias(z, u) { z = nil } - qp := getNat(0) - q, r := qp.div(z, u, v) - *qp = q - putNat(qp) + defer stk.restore(stk.save()) + q := stk.nat(len(u) - (len(v) - 1)) + _, r = q.div(stk, z, u, v) return r } // div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v. // It uses z and z2 as the storage for q and r. -func (z nat) div(z2, u, v nat) (q, r nat) { +// The caller may pass stk == nil to request that div obtain and release one itself. +func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) { if len(v) == 0 { panic("division by zero") } - if u.cmp(v) < 0 { - q = z[:0] - r = z2.set(u) - return - } - if len(v) == 1 { // Short division: long optimized for a single-word divisor. // In that case, the 2-by-1 guess is all we need at each step. @@ -535,7 +529,18 @@ func (z nat) div(z2, u, v nat) (q, r nat) { return } - q, r = z.divLarge(z2, u, v) + if u.cmp(v) < 0 { + q = z[:0] + r = z2.set(u) + return + } + + if stk == nil { + stk = getStack() + defer stk.free() + } + + q, r = z.divLarge(stk, z2, u, v) return } @@ -589,7 +594,7 @@ func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { // It uses z and u as the storage for q and r. // The caller must ensure that len(vIn) ≥ 2 (use divW otherwise) // and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise). -func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { +func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) { n := len(vIn) m := len(uIn) - n @@ -597,9 +602,9 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // vIn is treated as a read-only input (it may be in use by another // goroutine), so we must make a copy. // uIn is copied to u. + defer stk.restore(stk.save()) shift := nlz(vIn[n-1]) - vp := getNat(n) - v := *vp + v := stk.nat(n) shlVU(v, vIn, shift) u = u.make(len(uIn) + 1) u[len(uIn)] = shlVU(u[:len(uIn)], uIn, shift) @@ -613,11 +618,10 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // Use basic or recursive long division depending on size. if n < divRecursiveThreshold { - q.divBasic(u, v) + q.divBasic(stk, u, v) } else { - q.divRecursive(u, v) + q.divRecursive(stk, u, v) } - putNat(vp) q = q.norm() @@ -631,12 +635,12 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // divBasic implements long division as described above. // It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r. // q must be large enough to hold ⌊u/v⌋. -func (q nat) divBasic(u, v nat) { +func (q nat) divBasic(stk *stack, u, v nat) { n := len(v) m := len(u) - n - qhatvp := getNat(n + 1) - qhatv := *qhatvp + defer stk.restore(stk.save()) + qhatv := stk.nat(n + 1) // Set up for divWW below, precomputing reciprocal argument. vn1 := v[n-1] @@ -707,8 +711,6 @@ func (q nat) divBasic(u, v nat) { } q[j] = qhat } - - putNat(qhatvp) } // greaterThan reports whether the two digit numbers x1 x2 > y1 y2. @@ -727,24 +729,9 @@ const divRecursiveThreshold = 100 // z must be large enough to hold ⌊u/v⌋. // This function is just for allocating and freeing temporaries // around divRecursiveStep, the real implementation. -func (z nat) divRecursive(u, v nat) { - // Recursion depth is (much) less than 2 log₂(len(v)). - // Allocate a slice of temporaries to be reused across recursion, - // plus one extra temporary not live across the recursion. - recDepth := 2 * bits.Len(uint(len(v))) - tmp := getNat(3 * len(v)) - temps := make([]*nat, recDepth) - +func (z nat) divRecursive(stk *stack, u, v nat) { clear(z) - z.divRecursiveStep(u, v, 0, tmp, temps) - - // Free temporaries. - for _, n := range temps { - if n != nil { - putNat(n) - } - } - putNat(tmp) + z.divRecursiveStep(stk, u, v, 0) } // divRecursiveStep is the actual implementation of recursive division. @@ -752,7 +739,7 @@ func (z nat) divRecursive(u, v nat) { // z must be large enough to hold ⌊u/v⌋. // It uses temps[depth] (allocating if needed) as a temporary live across // the recursive call. It also uses tmp, but not live across the recursion. -func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { +func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { // u is a subsection of the original and may have leading zeros. // TODO(rsc): The v = v.norm() is useless and should be removed. // We know (and require) that v's top digit is ≥ B/2. @@ -766,7 +753,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // Fall back to basic division if the problem is now small enough. n := len(v) if n < divRecursiveThreshold { - z.divBasic(u, v) + z.divBasic(stk, u, v) return } @@ -785,11 +772,8 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { B := n / 2 // Allocate a nat for qhat below. - if temps[depth] == nil { - temps[depth] = getNat(n) // TODO(rsc): Can be just B+1. - } else { - *temps[depth] = temps[depth].make(B + 1) - } + defer stk.restore(stk.save()) + qhat0 := stk.nat(B + 1) // Compute each wide digit of the quotient. // @@ -816,9 +800,9 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { uu := u[j-B:] // Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n]. - qhat := *temps[depth] + qhat := qhat0 clear(qhat) - qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps) + qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1) qhat = qhat.norm() // Extend to a 3-by-2 quotient and remainder. @@ -833,9 +817,10 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u. // But we can do the subtraction directly, as in the comment above // and in long division, because we know that q̂ is wrong by at most one. - qhatv := tmp.make(3 * n) + mark := stk.save() + qhatv := stk.nat(3 * n) clear(qhatv) - qhatv = qhatv.mul(qhat, v[:s]) + qhatv = qhatv.mul(stk, qhat, v[:s]) for i := 0; i < 2; i++ { e := qhatv.cmp(uu.norm()) if e <= 0 { @@ -857,6 +842,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { } addAt(z, qhat, j-B) j -= B + stk.restore(mark) } // TODO(rsc): Rewrite loop as described above and delete all this code. @@ -864,13 +850,13 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { // Now u < (v< 0 { diff --git a/src/math/big/prime.go b/src/math/big/prime.go index 26688bbd64..bba5a07685 100644 --- a/src/math/big/prime.go +++ b/src/math/big/prime.go @@ -75,7 +75,9 @@ func (x *Int) ProbablyPrime(n int) bool { return false } - return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas() + stk := getStack() + defer stk.free() + return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk) } // probablyPrimeMillerRabin reports whether n passes reps rounds of the @@ -83,7 +85,7 @@ func (x *Int) ProbablyPrime(n int) bool { // If force2 is true, one of the rounds is forced to use base 2. // See Handbook of Applied Cryptography, p. 139, Algorithm 4.24. // The number n is known to be non-zero. -func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool { +func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool { nm1 := nat(nil).sub(n, natOne) // determine q, k such that nm1 = q << k k := nm1.trailingZeroBits() @@ -103,13 +105,13 @@ NextRandom: x = x.random(rand, nm3, nm3Len) x = x.add(x, natTwo) } - y = y.expNN(x, q, n, false) + y = y.expNN(stk, x, q, n, false) if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } for j := uint(1); j < k; j++ { - y = y.sqr(y) - quotient, y = quotient.div(y, y, n) + y = y.sqr(stk, y) + quotient, y = quotient.div(stk, y, y, n) if y.cmp(nm1) == 0 { continue NextRandom } @@ -147,7 +149,7 @@ NextRandom: // // Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed. // Springer, 2005. -func (n nat) probablyPrimeLucas() bool { +func (n nat) probablyPrimeLucas(stk *stack) bool { // Discard 0, 1. if len(n) == 0 || n.cmp(natOne) == 0 { return false @@ -193,8 +195,8 @@ func (n nat) probablyPrimeLucas() bool { // We'll never find (d/n) = -1 if n is a square. // If n is a non-square we expect to find a d in just a few attempts on average. // After 40 attempts, take a moment to check if n is indeed a square. - t1 = t1.sqrt(n) - t1 = t1.sqr(t1) + t1 = t1.sqrt(stk, n) + t1 = t1.sqr(stk, t1) if t1.cmp(n) == 0 { return false } @@ -254,25 +256,25 @@ func (n nat) probablyPrimeLucas() bool { if s.bit(uint(i)) != 0 { // k' = 2k+1 // V(k') = V(2k+1) = V(k) V(k+1) - P. - t1 = t1.mul(vk, vk1) + t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) // V(k'+1) = V(2k+2) = V(k+1)² - 2. - t1 = t1.sqr(vk1) + t1 = t1.sqr(stk, vk1) t1 = t1.add(t1, nm2) - t2, vk1 = t2.div(vk1, t1, n) + t2, vk1 = t2.div(stk, vk1, t1, n) } else { // k' = 2k // V(k'+1) = V(2k+1) = V(k) V(k+1) - P. - t1 = t1.mul(vk, vk1) + t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) - t2, vk1 = t2.div(vk1, t1, n) + t2, vk1 = t2.div(stk, vk1, t1, n) // V(k') = V(2k) = V(k)² - 2 - t1 = t1.sqr(vk) + t1 = t1.sqr(stk, vk) t1 = t1.add(t1, nm2) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) } } @@ -285,7 +287,7 @@ func (n nat) probablyPrimeLucas() bool { // // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n, // or P V(k) - 2 V(k+1) == 0 mod n. - t1 := t1.mul(vk, natP) + t1 := t1.mul(stk, vk, natP) t2 := t2.shl(vk1, 1) if t1.cmp(t2) < 0 { t1, t2 = t2, t1 @@ -294,7 +296,7 @@ func (n nat) probablyPrimeLucas() bool { t3 := vk1 // steal vk1, no longer needed below vk1 = nil _ = vk1 - t2, t3 = t2.div(t3, t1, n) + t2, t3 = t2.div(stk, t3, t1, n) if len(t3) == 0 { return true } @@ -312,9 +314,9 @@ func (n nat) probablyPrimeLucas() bool { } // k' = 2k // V(k') = V(2k) = V(k)² - 2 - t1 = t1.sqr(vk) + t1 = t1.sqr(stk, vk) t1 = t1.sub(t1, natTwo) - t2, vk = t2.div(vk, t1, n) + t2, vk = t2.div(stk, vk, t1, n) } return false } diff --git a/src/math/big/prime_test.go b/src/math/big/prime_test.go index 8596e33a13..2b1995bcb2 100644 --- a/src/math/big/prime_test.go +++ b/src/math/big/prime_test.go @@ -159,6 +159,9 @@ func TestProbablyPrime(t *testing.T) { } func BenchmarkProbablyPrime(b *testing.B) { + stk := getStack() + defer stk.free() + p, _ := new(Int).SetString("203956878356401977405765866929034577280193993314348263094772646453283062722701277632936616063144088173312372882677123879538709400158306567338328279154499698366071906766440037074217117805690872792848149112022286332144876183376326512083574821647933992961249917319836219304274280243803104015000563790123", 10) for _, n := range []int{0, 1, 5, 10, 20} { b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { @@ -170,26 +173,32 @@ func BenchmarkProbablyPrime(b *testing.B) { b.Run("Lucas", func(b *testing.B) { for i := 0; i < b.N; i++ { - p.abs.probablyPrimeLucas() + p.abs.probablyPrimeLucas(stk) } }) b.Run("MillerRabinBase2", func(b *testing.B) { for i := 0; i < b.N; i++ { - p.abs.probablyPrimeMillerRabin(1, true) + p.abs.probablyPrimeMillerRabin(stk, 1, true) } }) } func TestMillerRabinPseudoprimes(t *testing.T) { + stk := getStack() + defer stk.free() + testPseudoprimes(t, "probablyPrimeMillerRabin", - func(n nat) bool { return n.probablyPrimeMillerRabin(1, true) && !n.probablyPrimeLucas() }, + func(n nat) bool { return n.probablyPrimeMillerRabin(stk, 1, true) && !n.probablyPrimeLucas(stk) }, // https://oeis.org/A001262 []int{2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, 52633, 65281, 74665, 80581, 85489, 88357, 90751}) } func TestLucasPseudoprimes(t *testing.T) { + stk := getStack() + defer stk.free() + testPseudoprimes(t, "probablyPrimeLucas", - func(n nat) bool { return n.probablyPrimeLucas() && !n.probablyPrimeMillerRabin(1, true) }, + func(n nat) bool { return n.probablyPrimeLucas(stk) && !n.probablyPrimeMillerRabin(stk, 1, true) }, // https://oeis.org/A217719 []int{989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, 72389, 73919, 75077}) } diff --git a/src/math/big/rat.go b/src/math/big/rat.go index e58433ecea..ac94056a83 100644 --- a/src/math/big/rat.go +++ b/src/math/big/rat.go @@ -74,7 +74,7 @@ func (z *Rat) SetFloat64(f float64) *Rat { // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. -func quotToFloat32(a, b nat) (f float32, exact bool) { +func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) { const ( // float size in bits Fsize = 32 @@ -121,7 +121,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) { // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat - q, r := q.div(a2, a2, b2) // (recycle a2) + q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low32(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half @@ -172,7 +172,7 @@ func quotToFloat32(a, b nat) (f float32, exact bool) { // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. -func quotToFloat64(a, b nat) (f float64, exact bool) { +func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) { const ( // float size in bits Fsize = 64 @@ -219,7 +219,7 @@ func quotToFloat64(a, b nat) (f float64, exact bool) { // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat - q, r := q.div(a2, a2, b2) // (recycle a2) + q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low64(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half @@ -275,7 +275,9 @@ func (x *Rat) Float32() (f float32, exact bool) { if len(b) == 0 { b = natOne } - f, exact = quotToFloat32(x.a.abs, b) + stk := getStack() + defer stk.free() + f, exact = quotToFloat32(stk, x.a.abs, b) if x.a.neg { f = -f } @@ -291,7 +293,9 @@ func (x *Rat) Float64() (f float64, exact bool) { if len(b) == 0 { b = natOne } - f, exact = quotToFloat64(x.a.abs, b) + stk := getStack() + defer stk.free() + f, exact = quotToFloat64(stk, x.a.abs, b) if x.a.neg { f = -f } @@ -437,12 +441,14 @@ func (z *Rat) norm() *Rat { z.b.abs = z.b.abs.setWord(1) default: // z is fraction; normalize numerator and denominator + stk := getStack() + defer stk.free() neg := z.a.neg z.a.neg = false z.b.neg = false if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 { - z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs) - z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs) + z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs) + z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs) } z.a.neg = neg } @@ -452,7 +458,7 @@ func (z *Rat) norm() *Rat { // mulDenom sets z to the denominator product x*y (by taking into // account that 0 values for x or y must be interpreted as 1) and // returns z. -func mulDenom(z, x, y nat) nat { +func mulDenom(stk *stack, z, x, y nat) nat { switch { case len(x) == 0 && len(y) == 0: return z.setWord(1) @@ -461,17 +467,17 @@ func mulDenom(z, x, y nat) nat { case len(y) == 0: return z.set(x) } - return z.mul(x, y) + return z.mul(stk, x, y) } // scaleDenom sets z to the product x*f. // If f == 0 (zero value of denominator), z is set to (a copy of) x. -func (z *Int) scaleDenom(x *Int, f nat) { +func (z *Int) scaleDenom(stk *stack, x *Int, f nat) { if len(f) == 0 { z.Set(x) return } - z.abs = z.abs.mul(x.abs, f) + z.abs = z.abs.mul(stk, x.abs, f) z.neg = x.neg } @@ -481,58 +487,73 @@ func (z *Int) scaleDenom(x *Int, f nat) { // - +1 if x > y. func (x *Rat) Cmp(y *Rat) int { var a, b Int - a.scaleDenom(&x.a, y.b.abs) - b.scaleDenom(&y.a, x.b.abs) + stk := getStack() + defer stk.free() + a.scaleDenom(stk, &x.a, y.b.abs) + b.scaleDenom(stk, &y.a, x.b.abs) return a.Cmp(&b) } // Add sets z to the sum x+y and returns z. func (z *Rat) Add(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + var a1, a2 Int - a1.scaleDenom(&x.a, y.b.abs) - a2.scaleDenom(&y.a, x.b.abs) + a1.scaleDenom(stk, &x.a, y.b.abs) + a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Add(&a1, &a2) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Sub sets z to the difference x-y and returns z. func (z *Rat) Sub(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + var a1, a2 Int - a1.scaleDenom(&x.a, y.b.abs) - a2.scaleDenom(&y.a, x.b.abs) + a1.scaleDenom(stk, &x.a, y.b.abs) + a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Sub(&a1, &a2) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Mul sets z to the product x*y and returns z. func (z *Rat) Mul(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + if x == y { // a squared Rat is positive and can't be reduced (no need to call norm()) z.a.neg = false - z.a.abs = z.a.abs.sqr(x.a.abs) + z.a.abs = z.a.abs.sqr(stk, x.a.abs) if len(x.b.abs) == 0 { z.b.abs = z.b.abs.setWord(1) } else { - z.b.abs = z.b.abs.sqr(x.b.abs) + z.b.abs = z.b.abs.sqr(stk, x.b.abs) } return z } - z.a.Mul(&x.a, &y.a) - z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs) + + z.a.mul(stk, &x.a, &y.a) + z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Quo sets z to the quotient x/y and returns z. // If y == 0, Quo panics. func (z *Rat) Quo(x, y *Rat) *Rat { + stk := getStack() + defer stk.free() + if len(y.a.abs) == 0 { panic("division by zero") } var a, b Int - a.scaleDenom(&x.a, y.b.abs) - b.scaleDenom(&y.a, x.b.abs) + a.scaleDenom(stk, &x.a, y.b.abs) + b.scaleDenom(stk, &y.a, x.b.abs) z.a.abs = a.abs z.b.abs = b.abs z.a.neg = a.neg != b.neg diff --git a/src/math/big/ratconv.go b/src/math/big/ratconv.go index 12f9888c37..84602ff455 100644 --- a/src/math/big/ratconv.go +++ b/src/math/big/ratconv.go @@ -163,6 +163,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) { } // exp consumed - not needed anymore + stk := getStack() + defer stk.free() + // apply exp5 contributions // (start with exp5 so the numbers to multiply are smaller) if exp5 != 0 { @@ -178,9 +181,9 @@ func (z *Rat) SetString(s string) (*Rat, bool) { if n > 1e6 { return nil, false // avoid excessively large exponents } - pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs + pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs if exp5 > 0 { - z.a.abs = z.a.abs.mul(z.a.abs, pow5) + z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5) z.b.abs = z.b.abs.setWord(1) } else { z.b.abs = pow5 @@ -343,15 +346,17 @@ func (x *Rat) FloatString(prec int) string { } // x.b.abs != 0 - q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs) + stk := getStack() + defer stk.free() + q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs) p := natOne if prec > 0 { - p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false) + p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false) } - r = r.mul(r, p) - r, r2 := r.div(nat(nil), r, x.b.abs) + r = r.mul(stk, r, p) + r, r2 := r.div(stk, nat(nil), r, x.b.abs) // see if we need to round up r2 = r2.add(r2, r2) @@ -398,6 +403,9 @@ func (x *Rat) FloatString(prec int) string { // 1/4 2 true 0.25 // 1/6 1 false 0.2 (0.166... rounded) func (x *Rat) FloatPrec() (n int, exact bool) { + stk := getStack() + defer stk.free() + // Determine q and largest p2, p5 such that d = q·2^p2·5^p5. // The results n, exact are: // @@ -425,11 +433,11 @@ func (x *Rat) FloatPrec() (n int, exact bool) { f := nat{1220703125} // == 5^fp (must fit into a uint32 Word) var t, r nat // temporaries for { - if _, r = t.div(r, q, f); len(r) != 0 { + if _, r = t.div(stk, r, q, f); len(r) != 0 { break // f doesn't divide q evenly } tab = append(tab, f) - f = nat(nil).sqr(f) // nat(nil) to ensure a new f for each table entry + f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry } // Factor q using the table entries, if any. @@ -441,7 +449,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) { // The same reasoning applies to the subsequent factors. var p5 uint for i := len(tab) - 1; i >= 0; i-- { - if t, r = t.div(r, q, tab[i]); len(r) == 0 { + if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 { p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i) q = q.set(t) } @@ -449,7 +457,7 @@ func (x *Rat) FloatPrec() (n int, exact bool) { // If fp != 1, we may still have multiples of 5 left. for { - if t, r = t.div(r, q, natFive); len(r) != 0 { + if t, r = t.div(stk, r, q, natFive); len(r) != 0 { break } p5++