commit a58e698cfc79a9511cf576aee2288e07b6c54e7c
parent 7c154e5299815ed6d84c2fa17e52826228cfdd37
Author: Robert Russell <robert@rr3.xyz>
Date: Wed, 1 Jan 2025 14:13:31 -0800
Fix some karatsuba bugs and improve quadratic threshold
Diffstat:
| M | bigmul.c | | | 47 | +++++++++++++++++++++++++++++++++++------------ |
1 file changed, 35 insertions(+), 12 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -132,11 +132,11 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
* and as long as m >= 4, this quantity is strictly less than m. Since we
* always have |q| = |yl + yh| <= |y| = n, this ensures the termination of
* the evaluation of p * q. Similarly, if n >= m (instead of m >= n) and
- * n >= 4, then |q| < n and |p| <= m. It therefore suffices to separate the
- * case m < 4 || n < 4 for the recursion basis. */
+ * n >= 4 (instead of m >= 4), then |q| < n and |p| <= m. It therefore
+ * suffices to separate the case m < 4 || n < 4 for the recursion basis. */
usize maxmn = MAX(m, n);
- if (m < 4 || n < 4 || maxmn < 256) { // 256 was determined via benchmarking.
+ if (m < 4 || n < 4 || maxmn < 32) { // 32 was determined via benchmarks.
mul_quadratic(r, x, m, y, n);
return;
}
@@ -160,17 +160,19 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
usize qw = MIN(nl + 1, n); u64 *q = r + pw;
usize tw = pw + qw; u64 *t = scratch;
usize uw = ml + nl; u64 *u = r;
- usize sw = mh + mh; u64 *s = r + 2 * k;
+ usize sw = mh + nh; u64 *s = r + 2 * k;
// 4. Arithmetic
- add(p, xl, ml, xh, mh); // p = xl + xh
- add(q, yl, nl, yh, nh); // q = yl + yh
- karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q
- karatsuba(u, xl, ml, yl, nl, scratch + tw); // u = xl * yl
- karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh
- sub(t, t, tw, u, uw); // t -= u
- sub(t, t, tw, s, sw); // t -= s
- add(r + k, r + k, sw, t, tw); // r[k..] += t
+ add(p, xl, ml, xh, mh); // p = xl + xh
+ add(q, yl, nl, yh, nh); // q = yl + yh
+ karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q
+ karatsuba(u, xl, ml, yl, nl, scratch + tw); // u = xl * yl
+ for (usize i = uw; i < 2 * k; i++) r[i] = 0; // r[uw..2*k] = 0
+ karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh
+ sub(t, t, tw, u, uw); // t -= u
+ sub(t, t, tw, s, sw); // t -= s
+ add(r + k, t, tw, r + k, k + sw); // r[k..] += t
+ for (usize i = tw + k + sw; i < m + n; i++) r[i] = 0; // TODO
// TODO: Prove that tw + sw <= m + n - k.
}
@@ -253,3 +255,24 @@ main(void) {
r_bench(bench_karatsuba2048, 3000);
r_bench(bench_karatsuba4096, 3000);
}
+
+/*
+benchmark: bench_quadratic16 8174101 iters 417 ns/op
+benchmark: bench_quadratic32 2135396 iters 1673 ns/op
+benchmark: bench_quadratic64 519625 iters 6898 ns/op
+benchmark: bench_quadratic128 136077 iters 26540 ns/op
+benchmark: bench_quadratic256 34462 iters 104515 ns/op
+benchmark: bench_quadratic512 8522 iters 412899 ns/op
+benchmark: bench_quadratic1024 2196 iters 1645027 ns/op
+benchmark: bench_quadratic2048 540 iters 6601480 ns/op
+benchmark: bench_quadratic4096 136 iters 26184686 ns/op
+benchmark: bench_karatsuba16 8448774 iters 426 ns/op
+benchmark: bench_karatsuba32 2520357 iters 1430 ns/op
+benchmark: bench_karatsuba64 769735 iters 4695 ns/op
+benchmark: bench_karatsuba128 242052 iters 14550 ns/op
+benchmark: bench_karatsuba256 79255 iters 45003 ns/op
+benchmark: bench_karatsuba512 26338 iters 136375 ns/op
+benchmark: bench_karatsuba1024 8701 iters 412653 ns/op
+benchmark: bench_karatsuba2048 2900 iters 1245014 ns/op
+benchmark: bench_karatsuba4096 955 iters 3748654 ns/op
+*/