commit 0a2da1e9f35a1bda6ac0514223247a721018e1d5
parent a00d44757c3d3d113a44359dc5520132067c5a41
Author: Robert Russell <robert@rr3.xyz>
Date: Thu, 2 Jan 2025 16:50:02 -0800
Clean up karatsuba
Diffstat:
| M | bigmul.c | | | 66 | ++++++++++++++++++++++++++++++++++++------------------------------ |
1 file changed, 36 insertions(+), 30 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -115,7 +115,7 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// Precondition: r does not intersect x nor y
// TODO: Document precondition regarding size of scratch memory.
void
-karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) {
+karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) {
/* TODO: Update
* We seek to multiply x and y, which have m and n "words" (digits of
* base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2)
@@ -180,20 +180,44 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) {
// TODO: Justify all these lengths
u64 *p = r;
u64 *q = r + k;
- u64 *t = scratch0;
+ u64 *t0 = tt;
u64 *u = r;
u64 *s = r + ll;
- u64 *scratch1 = scratch0 + kk;
+ u64 *t1 = tt + kk;
// 4. Arithmetic
p[k] = add(p, xh, h, xl, l); // p = xh + xl
q[k] = add(q, yh, h, yl, l); // q = yh + yl
- karatsuba(t, p, q, k, scratch1); // t = p * q
- karatsuba(u, xl, yl, l, scratch1); // u = xl * yl
- karatsuba(s, xh, yh, h, scratch1); // s = xh * yh
- sub(t, t, kk, u, ll); // t -= u (borrow out must be 0) TODO: explain
- sub(t, t, kk, s, hh); // t -= s (borrow out must be 0) TODO: explain
- add(r + l, r + l, hh + l, t, kk); // r[l..] += t (carry out must be 0) TODO: explain
+ karatsuba(t0, p, q, k, t1); // t0 = p * q
+ karatsuba(u, xl, yl, l, t1); // u = xl * yl
+ karatsuba(s, xh, yh, h, t1); // s = xh * yh
+ sub(t0, t0, kk, u, ll); // t0 -= u (borrow out must be 0) TODO: explain
+ sub(t0, t0, kk, s, hh); // t0 -= s (borrow out must be 0) TODO: explain
+ add(r + l, r + l, hh + l, t0, kk); // r[l..] += t0 (carry out must be 0) TODO: explain
+}
+
+void
+fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
+ u64 *prod = scratch;
+ u64 *tt = scratch + 2 * n;
+
+ for (;;) {
+ for (; m >= n; r += n, x += n, m -= n) {
+ karatsuba(prod, x, y, n, tt);
+ add(r, prod, 2 * n, r, n);
+ }
+
+ if (m == 0) break; // TODO: Remove this if we special case m == n.
+
+ if (m < KARATSUBA_THRESH) { // TODO: Calibrate.
+ fma_quadratic(prod, x, m, y, n);
+ add(r, prod, m + n, r, n);
+ break;
+ }
+
+ u64 *t0 = x; x = y; y = t0; // Swap x and y
+ usize t1 = m; m = n; n = t1; // Swap m and n
+ }
}
void
@@ -213,30 +237,12 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// TODO: Accept capacity of r as an argument, and use excess memory for
// scratch, if it's big enough, instead of allocating.
usize firstkk = 2 * (n - n / 2 + 1);
- u64 *mem = r_eallocn(2 * n + 2 * firstkk, sizeof *mem);
- u64 *prod = mem;
- u64 *scratch = mem + 2 * n;
+ u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch);
- // TODO: The control flow here kinda sucks.
- // TODO: There are unnecessary copies between prod and r. Try to do it in
- // "one pass", without first initializing r to 0.
memset(r, 0, (m + n) * sizeof *r);
- for (;;) {
- for (; m >= n; r += n, x += n, m -= n) {
- karatsuba(prod, x, y, n, scratch);
- add(r, prod, 2 * n, r, n);
- }
- if (m == 0) break;
- if (m < KARATSUBA_THRESH) { // TODO: Calibrate.
- fma_quadratic(prod, x, m, y, n);
- add(r, prod, m + n, r, n);
- break;
- }
- u64 *t0 = x; x = y; y = t0;
- usize t1 = m; m = n; n = t1;
- }
+ fma_karatsuba(r, x, m, y, n, scratch);
- free(mem);
+ free(scratch);
}
u64 x[4096];