bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

commit 0a2da1e9f35a1bda6ac0514223247a721018e1d5
parent a00d44757c3d3d113a44359dc5520132067c5a41
Author: Robert Russell <robert@rr3.xyz>
Date:   Thu,  2 Jan 2025 16:50:02 -0800

Clean up karatsuba

Diffstat:
Mbigmul.c | 66++++++++++++++++++++++++++++++++++++------------------------------
1 file changed, 36 insertions(+), 30 deletions(-)

diff --git a/bigmul.c b/bigmul.c @@ -115,7 +115,7 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) { // Precondition: r does not intersect x nor y // TODO: Document precondition regarding size of scratch memory. void -karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) { +karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) { /* TODO: Update * We seek to multiply x and y, which have m and n "words" (digits of * base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2) @@ -180,20 +180,44 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) { // TODO: Justify all these lengths u64 *p = r; u64 *q = r + k; - u64 *t = scratch0; + u64 *t0 = tt; u64 *u = r; u64 *s = r + ll; - u64 *scratch1 = scratch0 + kk; + u64 *t1 = tt + kk; // 4. Arithmetic p[k] = add(p, xh, h, xl, l); // p = xh + xl q[k] = add(q, yh, h, yl, l); // q = yh + yl - karatsuba(t, p, q, k, scratch1); // t = p * q - karatsuba(u, xl, yl, l, scratch1); // u = xl * yl - karatsuba(s, xh, yh, h, scratch1); // s = xh * yh - sub(t, t, kk, u, ll); // t -= u (borrow out must be 0) TODO: explain - sub(t, t, kk, s, hh); // t -= s (borrow out must be 0) TODO: explain - add(r + l, r + l, hh + l, t, kk); // r[l..] += t (carry out must be 0) TODO: explain + karatsuba(t0, p, q, k, t1); // t0 = p * q + karatsuba(u, xl, yl, l, t1); // u = xl * yl + karatsuba(s, xh, yh, h, t1); // s = xh * yh + sub(t0, t0, kk, u, ll); // t0 -= u (borrow out must be 0) TODO: explain + sub(t0, t0, kk, s, hh); // t0 -= s (borrow out must be 0) TODO: explain + add(r + l, r + l, hh + l, t0, kk); // r[l..] += t0 (carry out must be 0) TODO: explain +} + +void +fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { + u64 *prod = scratch; + u64 *tt = scratch + 2 * n; + + for (;;) { + for (; m >= n; r += n, x += n, m -= n) { + karatsuba(prod, x, y, n, tt); + add(r, prod, 2 * n, r, n); + } + + if (m == 0) break; // TODO: Remove this if we special case m == n. + + if (m < KARATSUBA_THRESH) { // TODO: Calibrate. + fma_quadratic(prod, x, m, y, n); + add(r, prod, m + n, r, n); + break; + } + + u64 *t0 = x; x = y; y = t0; // Swap x and y + usize t1 = m; m = n; n = t1; // Swap m and n + } } void @@ -213,30 +237,12 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { // TODO: Accept capacity of r as an argument, and use excess memory for // scratch, if it's big enough, instead of allocating. usize firstkk = 2 * (n - n / 2 + 1); - u64 *mem = r_eallocn(2 * n + 2 * firstkk, sizeof *mem); - u64 *prod = mem; - u64 *scratch = mem + 2 * n; + u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch); - // TODO: The control flow here kinda sucks. - // TODO: There are unnecessary copies between prod and r. Try to do it in - // "one pass", without first initializing r to 0. memset(r, 0, (m + n) * sizeof *r); - for (;;) { - for (; m >= n; r += n, x += n, m -= n) { - karatsuba(prod, x, y, n, scratch); - add(r, prod, 2 * n, r, n); - } - if (m == 0) break; - if (m < KARATSUBA_THRESH) { // TODO: Calibrate. - fma_quadratic(prod, x, m, y, n); - add(r, prod, m + n, r, n); - break; - } - u64 *t0 = x; x = y; y = t0; - usize t1 = m; m = n; n = t1; - } + fma_karatsuba(r, x, m, y, n, scratch); - free(mem); + free(scratch); } u64 x[4096];