mirror of https://github.com/golang/go.git
math/big: specialize Karatsuba implementation for squaring
Currently we use three different algorithms for squaring: 1. basic multiplication for small numbers 2. basic squaring for medium numbers 3. Karatsuba multiplication for large numbers Change 3. to a version of Karatsuba multiplication specialized for x == y. Increasing the performance of 3. lets us lower the threshold between 2. and 3. Adapt TestCalibrate to the change that 3. isn't independent of the threshold between 1. and 2. any more. Fixes #23221. benchstat old.txt new.txt name old time/op new time/op delta NatSqr/1-4 29.6ns ± 7% 29.5ns ± 5% ~ (p=0.103 n=50+50) NatSqr/2-4 51.9ns ± 1% 51.9ns ± 1% ~ (p=0.693 n=42+49) NatSqr/3-4 64.3ns ± 1% 64.1ns ± 0% -0.26% (p=0.000 n=46+43) NatSqr/5-4 93.5ns ± 2% 93.1ns ± 1% -0.39% (p=0.000 n=48+49) NatSqr/8-4 131ns ± 1% 131ns ± 1% ~ (p=0.870 n=46+49) NatSqr/10-4 175ns ± 1% 175ns ± 1% +0.38% (p=0.000 n=49+47) NatSqr/20-4 426ns ± 1% 429ns ± 1% +0.84% (p=0.000 n=46+48) NatSqr/30-4 702ns ± 2% 699ns ± 1% -0.38% (p=0.011 n=46+44) NatSqr/50-4 1.44µs ± 2% 1.43µs ± 1% -0.54% (p=0.010 n=48+48) NatSqr/80-4 2.85µs ± 1% 2.87µs ± 1% +0.68% (p=0.000 n=47+47) NatSqr/100-4 4.06µs ± 1% 4.07µs ± 1% +0.29% (p=0.000 n=46+45) NatSqr/200-4 13.4µs ± 1% 13.5µs ± 1% +0.73% (p=0.000 n=48+48) NatSqr/300-4 28.5µs ± 1% 28.2µs ± 1% -1.22% (p=0.000 n=46+48) NatSqr/500-4 81.9µs ± 1% 67.0µs ± 1% -18.25% (p=0.000 n=48+48) NatSqr/800-4 161µs ± 1% 140µs ± 1% -13.29% (p=0.000 n=47+48) NatSqr/1000-4 245µs ± 1% 207µs ± 1% -15.17% (p=0.000 n=49+49) go test -v -calibrate --run TestCalibrate ... Calibrating threshold between basicSqr(x) and karatsubaSqr(x) Looking for a timing difference for x between 200 - 500 words by 10 step words = 200 deltaT = -980ns ( -7%) is karatsubaSqr(x) better: false words = 210 deltaT = -773ns ( -5%) is karatsubaSqr(x) better: false words = 220 deltaT = -695ns ( -4%) is karatsubaSqr(x) better: false words = 230 deltaT = -570ns ( -3%) is karatsubaSqr(x) better: false words = 240 deltaT = -458ns ( -2%) is karatsubaSqr(x) better: false words = 250 deltaT = -63ns ( 0%) is karatsubaSqr(x) better: false words = 260 deltaT = 118ns ( 0%) is karatsubaSqr(x) better: true threshold found words = 270 deltaT = 377ns ( 1%) is karatsubaSqr(x) better: true words = 280 deltaT = 765ns ( 3%) is karatsubaSqr(x) better: true words = 290 deltaT = 673ns ( 2%) is karatsubaSqr(x) better: true words = 300 deltaT = 502ns ( 1%) is karatsubaSqr(x) better: true words = 310 deltaT = 629ns ( 2%) is karatsubaSqr(x) better: true words = 320 deltaT = 1.011µs ( 3%) is karatsubaSqr(x) better: true words = 330 deltaT = 1.36µs ( 4%) is karatsubaSqr(x) better: true words = 340 deltaT = 3.001µs ( 8%) is karatsubaSqr(x) better: true words = 350 deltaT = 3.178µs ( 8%) is karatsubaSqr(x) better: true ... Change-Id: I6f13c23d94d042539ac28e77fd2618cdc37a429e Reviewed-on: https://go-review.googlesource.com/105075 Run-TryBot: Robert Griesemer <gri@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
parent
723f4286b9
commit
1dc20e9124
|
|
@ -28,24 +28,32 @@ import (
|
|||
|
||||
var calibrate = flag.Bool("calibrate", false, "run calibration test")
|
||||
|
||||
func TestCalibrate(t *testing.T) {
|
||||
if *calibrate {
|
||||
computeKaratsubaThresholds()
|
||||
const (
|
||||
sqrModeMul = "mul(x, x)"
|
||||
sqrModeBasic = "basicSqr(x)"
|
||||
sqrModeKaratsuba = "karatsubaSqr(x)"
|
||||
)
|
||||
|
||||
// compute basicSqrThreshold where overhead becomes negligible
|
||||
minSqr := computeSqrThreshold(10, 30, 1, 3)
|
||||
// compute karatsubaSqrThreshold where karatsuba is faster
|
||||
maxSqr := computeSqrThreshold(300, 500, 10, 3)
|
||||
if minSqr != 0 {
|
||||
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
|
||||
} else {
|
||||
fmt.Println("no basicSqrThreshold found")
|
||||
}
|
||||
if maxSqr != 0 {
|
||||
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
|
||||
} else {
|
||||
fmt.Println("no karatsubaSqrThreshold found")
|
||||
}
|
||||
func TestCalibrate(t *testing.T) {
|
||||
if !*calibrate {
|
||||
return
|
||||
}
|
||||
|
||||
computeKaratsubaThresholds()
|
||||
|
||||
// compute basicSqrThreshold where overhead becomes negligible
|
||||
minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
|
||||
// compute karatsubaSqrThreshold where karatsuba is faster
|
||||
maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
|
||||
if minSqr != 0 {
|
||||
fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
|
||||
} else {
|
||||
fmt.Println("no basicSqrThreshold found")
|
||||
}
|
||||
if maxSqr != 0 {
|
||||
fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
|
||||
} else {
|
||||
fmt.Println("no karatsubaSqrThreshold found")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -109,16 +117,17 @@ func computeKaratsubaThresholds() {
|
|||
}
|
||||
}
|
||||
|
||||
func measureBasicSqr(words, nruns int, enable bool) time.Duration {
|
||||
func measureSqr(words, nruns int, mode string) time.Duration {
|
||||
// more runs for better statistics
|
||||
initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
|
||||
|
||||
if enable {
|
||||
// set thresholds to use basicSqr at this number of words
|
||||
switch mode {
|
||||
case sqrModeMul:
|
||||
basicSqrThreshold = words + 1
|
||||
case sqrModeBasic:
|
||||
basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
|
||||
} else {
|
||||
// set thresholds to disable basicSqr for any number of words
|
||||
basicSqrThreshold, karatsubaSqrThreshold = -1, -1
|
||||
case sqrModeKaratsuba:
|
||||
karatsubaSqrThreshold = words - 1
|
||||
}
|
||||
|
||||
var testval int64
|
||||
|
|
@ -133,18 +142,18 @@ func measureBasicSqr(words, nruns int, enable bool) time.Duration {
|
|||
return time.Duration(testval)
|
||||
}
|
||||
|
||||
func computeSqrThreshold(from, to, step, nruns int) int {
|
||||
fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
|
||||
func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
|
||||
fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
|
||||
fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
|
||||
var initPos bool
|
||||
var threshold int
|
||||
for i := from; i <= to; i += step {
|
||||
baseline := measureBasicSqr(i, nruns, false)
|
||||
testval := measureBasicSqr(i, nruns, true)
|
||||
baseline := measureSqr(i, nruns, lower)
|
||||
testval := measureSqr(i, nruns, upper)
|
||||
pos := baseline > testval
|
||||
delta := baseline - testval
|
||||
percent := delta * 100 / baseline
|
||||
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
|
||||
fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
|
||||
if i == from {
|
||||
initPos = pos
|
||||
}
|
||||
|
|
|
|||
|
|
@ -388,12 +388,12 @@ func max(x, y int) int {
|
|||
}
|
||||
|
||||
// karatsubaLen computes an approximation to the maximum k <= n such that
|
||||
// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
|
||||
// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
|
||||
// result is the largest number that can be divided repeatedly by 2 before
|
||||
// becoming about the value of karatsubaThreshold.
|
||||
func karatsubaLen(n int) int {
|
||||
// becoming about the value of threshold.
|
||||
func karatsubaLen(n, threshold int) int {
|
||||
i := uint(0)
|
||||
for n > karatsubaThreshold {
|
||||
for n > threshold {
|
||||
n >>= 1
|
||||
i++
|
||||
}
|
||||
|
|
@ -433,7 +433,7 @@ func (z nat) mul(x, y nat) nat {
|
|||
// y = yh*b + y0 (0 <= y0 < b)
|
||||
// b = 1<<(_W*k) ("base" of digits xi, yi)
|
||||
//
|
||||
k := karatsubaLen(n)
|
||||
k := karatsubaLen(n, karatsubaThreshold)
|
||||
// k <= n
|
||||
|
||||
// multiply x0 and y0 via Karatsuba
|
||||
|
|
@ -486,8 +486,8 @@ func (z nat) mul(x, y nat) nat {
|
|||
|
||||
// basicSqr sets z = x*x and is asymptotically faster than basicMul
|
||||
// by about a factor of 2, but slower for small arguments due to overhead.
|
||||
// Requirements: len(x) > 0, len(z) >= 2*len(x)
|
||||
// The (non-normalized) result is placed in z[0 : 2 * len(x)].
|
||||
// Requirements: len(x) > 0, len(z) == 2*len(x)
|
||||
// The (non-normalized) result is placed in z.
|
||||
func basicSqr(z, x nat) {
|
||||
n := len(x)
|
||||
t := make(nat, 2*n) // temporary variable to hold the products
|
||||
|
|
@ -503,11 +503,48 @@ func basicSqr(z, x nat) {
|
|||
addVV(z, z, t) // combine the result
|
||||
}
|
||||
|
||||
// karatsubaSqr squares x and leaves the result in z.
|
||||
// len(x) must be a power of 2 and len(z) >= 6*len(x).
|
||||
// The (non-normalized) result is placed in z[0 : 2*len(x)].
|
||||
//
|
||||
// The algorithm and the layout of z are the same as for karatsuba.
|
||||
func karatsubaSqr(z, x nat) {
|
||||
n := len(x)
|
||||
|
||||
if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
|
||||
z = z[:2*n]
|
||||
basicSqr(z, x)
|
||||
return
|
||||
}
|
||||
|
||||
n2 := n >> 1
|
||||
x1, x0 := x[n2:], x[0:n2]
|
||||
|
||||
karatsubaSqr(z, x0)
|
||||
karatsubaSqr(z[n:], x1)
|
||||
|
||||
// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
|
||||
xd := z[2*n : 2*n+n2]
|
||||
if subVV(xd, x1, x0) != 0 {
|
||||
subVV(xd, x0, x1)
|
||||
}
|
||||
|
||||
p := z[n*3:]
|
||||
karatsubaSqr(p, xd)
|
||||
|
||||
r := z[n*4:]
|
||||
copy(r, z[:n*2])
|
||||
|
||||
karatsubaAdd(z[n2:], r, n)
|
||||
karatsubaAdd(z[n2:], r[n:], n)
|
||||
karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
|
||||
}
|
||||
|
||||
// Operands that are shorter than basicSqrThreshold are squared using
|
||||
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
|
||||
// the Karatsuba algorithm is used.
|
||||
// we use the Karatsuba algorithm optimized for x == y.
|
||||
var basicSqrThreshold = 20 // computed by calibrate_test.go
|
||||
var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
|
||||
var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
|
||||
|
||||
// z = x*x
|
||||
func (z nat) sqr(x nat) nat {
|
||||
|
|
@ -536,7 +573,31 @@ func (z nat) sqr(x nat) nat {
|
|||
return z.norm()
|
||||
}
|
||||
|
||||
return z.mul(x, x)
|
||||
// Use Karatsuba multiplication optimized for x == y.
|
||||
// The algorithm and layout of z are the same as for mul.
|
||||
|
||||
// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
|
||||
|
||||
k := karatsubaLen(n, karatsubaSqrThreshold)
|
||||
|
||||
x0 := x[0:k]
|
||||
z = z.make(max(6*k, 2*n))
|
||||
karatsubaSqr(z, x0) // z = x0^2
|
||||
z = z[0 : 2*n]
|
||||
z[2*k:].clear()
|
||||
|
||||
if k < n {
|
||||
var t nat
|
||||
x0 := x0.norm()
|
||||
x1 := x[k:]
|
||||
t = t.mul(x0, x1)
|
||||
addAt(z, t, k)
|
||||
addAt(z, t, k) // z = 2*x1*x0*b + x0^2
|
||||
t = t.sqr(x1)
|
||||
addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
|
||||
}
|
||||
|
||||
return z.norm()
|
||||
}
|
||||
|
||||
// mulRange computes the product of all the unsigned integers in the
|
||||
|
|
|
|||
|
|
@ -648,26 +648,26 @@ func TestSticky(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testBasicSqr(t *testing.T, x nat) {
|
||||
func testSqr(t *testing.T, x nat) {
|
||||
got := make(nat, 2*len(x))
|
||||
want := make(nat, 2*len(x))
|
||||
basicSqr(got, x)
|
||||
basicMul(want, x, x)
|
||||
got = got.sqr(x)
|
||||
want = want.mul(x, x)
|
||||
if got.cmp(want) != 0 {
|
||||
t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicSqr(t *testing.T) {
|
||||
func TestSqr(t *testing.T) {
|
||||
for _, a := range prodNN {
|
||||
if a.x != nil {
|
||||
testBasicSqr(t, a.x)
|
||||
testSqr(t, a.x)
|
||||
}
|
||||
if a.y != nil {
|
||||
testBasicSqr(t, a.y)
|
||||
testSqr(t, a.y)
|
||||
}
|
||||
if a.z != nil {
|
||||
testBasicSqr(t, a.z)
|
||||
testSqr(t, a.z)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue