bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

commit be1311cb7ebfa8ce641861617a4f0c5f739badd5
parent 0714d05592f0527d3572fb31b1b345e89e1facfc
Author: Robert Russell <robert@rr3.xyz>
Date:   Wed,  1 Jan 2025 21:53:15 -0800

Clean up wide u64 math

This is a performance regression.

Diffstat:
Mbigmul.c | 153++++++++++++++++++++++++++++++++-----------------------------------------------
1 file changed, 62 insertions(+), 91 deletions(-)

diff --git a/bigmul.c b/bigmul.c @@ -13,75 +13,67 @@ struct nat { typedef struct nat Nat[1]; -// TODO: Make sure that GCC inlines and compiles {add,sub}64 well. -// Try __builtin_{add,sub}c{,l,ll}. -inline void -add64(u64 *co, u64 *r, u64 x, u64 y, u64 ci) { - u128 cr = (u128)x + (u128)y + (u128)ci; - *co = cr >> 64; - *r = cr; -} +/* ----- Wide u64 math ----- */ -inline void -sub64(u64 *bo, u64 *r, u64 x, u64 y, u64 bi) { - u128 br = (u128)x - (u128)y - (u128)bi; - *bo = -(br >> 64); - *r = br; -} +// TODO: Wide sub doesn't make much sense. +#define WIDE(rh, rl, x, op, y) do { \ + u128 r = (u128)(x) op (u128)(y); \ + *(rh) = r >> 64; \ + *(rl) = r; \ + } while (0) +inline void add64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, +, y); } +inline void sub64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, -, y); } +inline void mul64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, *, y); } +#undef WIDE inline void -mul64(u64 *rh, u64 *rl, u64 x, u64 y) { - u128 r = (u128)x * (u128)y; - *rh = r >> 64; - *rl = r; +fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) { + u64 h0, h1, h2, l; + mul64(&h0, &l, w, x); // h0:l = w * x + add64(&h1, &l, l, y); // h1:l = l + y + add64(&h2, &l, l, z); // h2:l = l + z + *rh = h0 + h1 + h2; + *rl = l; } -/* -inline void -fma64(u64 *rh, u64 *rl, u64 x, u64 y, u64 z) { - u64 h, l, c; - mul64(&h, &l, x, y); - add64(&c, rl, l, z, 0); - *rh = h + c; -} -*/ -inline void -fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) { - u64 h, l, c, d; - mul64(&h, &l, w, x); // [h, l] = w * x - add64(&c, &l, l, y, 0); // [c, l] = l + y - add64(&d, rl, l, z, 0); // [d, rl] = l + z - *rh = h + c + d; // rh = h + c + d -} +/* ----- Big nat math ----- */ // Precondition: m >= n u64 -add(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 c) { - for (usize i = 0; i < n; i++) - add64(&c, &r[i], x[i], y[i], c); +add(u64 *r, u64 *x, usize m, u64 *y, usize n) { + u64 c = 0; + for (usize i = 0; i < n; i++) { + u64 h0, h1, l; + add64(&h0, &l, x[i], y[i]); // h0:l = x[i] + y[i] + add64(&h1, &l, l, c); // h1:l = l + c + c = h0 + h1; + r[i] = l; + } for (usize i = n; i < m; i++) - add64(&c, &r[i], x[i], c, 0); + add64(&c, &r[i], x[i], c); // c:r[i] = x[i] + c return c; } -u64 -addw(u64 *r, u64 *x, usize m, u64 y) { - for (usize i = 0; i < m; i++) - add64(&y, &r[i], x[i], y, 0); - return y; -} - // Precondition: m >= n // TODO: sub is not commutative like add, so we need a "bus" operation // ("sub" backwards) for when m < n. u64 -sub(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 b) { - for (usize i = 0; i < n; i++) - sub64(&b, &r[i], x[i], y[i], b); - for (usize i = n; i < m; i++) - sub64(&b, &r[i], x[i], b, 0); +sub(u64 *r, u64 *x, usize m, u64 *y, usize n) { + u64 b = 0; + for (usize i = 0; i < n; i++) { + u64 h0, h1, l; + sub64(&h0, &l, x[i], y[i]); // h0:l = x[i] - y[i] + sub64(&h1, &l, l, b); // h1:l = l - b + b = -h0 + -h1; + r[i] = l; + } + for (usize i = n; i < m; i++) { + u64 h; + sub64(&h, &r[i], x[i], b); + b = -h; + } return b; } @@ -171,14 +163,14 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) { u64 *scratch1 = scratch0 + kk; // 4. Arithmetic - p[k] = add(p, xh, h, xl, l, 0); // p = xh + xl - q[k] = add(q, yh, h, yl, l, 0); // q = yh + yl - karatsuba(t, p, q, k, scratch1); // t = p * q - karatsuba(u, xl, yl, l, scratch1); // u = xl * yl - karatsuba(s, xh, yh, h, scratch1); // s = xh * yh - sub(t, t, kk, u, ll, 0); // t -= u (borrow out must be 0) TODO: explain - sub(t, t, kk, s, hh, 0); // t -= s (borrow out must be 0) TODO: explain - add(r + l, r + l, hh + l, t, kk, 0); // r[l..] += t (carry out must be 0) TODO: explain + p[k] = add(p, xh, h, xl, l); // p = xh + xl + q[k] = add(q, yh, h, yl, l); // q = yh + yl + karatsuba(t, p, q, k, scratch1); // t = p * q + karatsuba(u, xl, yl, l, scratch1); // u = xl * yl + karatsuba(s, xh, yh, h, scratch1); // s = xh * yh + sub(t, t, kk, u, ll); // t -= u (borrow out must be 0) TODO: explain + sub(t, t, kk, s, hh); // t -= s (borrow out must be 0) TODO: explain + add(r + l, r + l, hh + l, t, kk); // r[l..] += t (carry out must be 0) TODO: explain } void @@ -209,12 +201,12 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { for (;;) { for (; m >= n; r += n, x += n, m -= n) { karatsuba(prod, x, y, n, scratch); - add(r, prod, 2 * n, r, n, 0); + add(r, prod, 2 * n, r, n); } if (m == 0) break; if (m < KARATSUBA_THRESH) { // TODO: Calibrate. mul_quadratic(prod, x, m, y, n); - add(r, prod, m + n, r, n, 0); + add(r, prod, m + n, r, n); break; } u64 *t0 = x; x = y; y = t0; @@ -276,15 +268,15 @@ main(void) { for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64(); for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64(); - r_bench(bench_quadratic16, 3000); - r_bench(bench_quadratic32, 3000); - r_bench(bench_quadratic64, 3000); - r_bench(bench_quadratic128, 3000); - r_bench(bench_quadratic256, 3000); - r_bench(bench_quadratic512, 3000); - r_bench(bench_quadratic1024, 3000); - r_bench(bench_quadratic2048, 3000); - r_bench(bench_quadratic4096, 3000); + // r_bench(bench_quadratic16, 3000); + // r_bench(bench_quadratic32, 3000); + // r_bench(bench_quadratic64, 3000); + // r_bench(bench_quadratic128, 3000); + // r_bench(bench_quadratic256, 3000); + // r_bench(bench_quadratic512, 3000); + // r_bench(bench_quadratic1024, 3000); + // r_bench(bench_quadratic2048, 3000); + // r_bench(bench_quadratic4096, 3000); r_bench(bench_karatsuba16, 3000); r_bench(bench_karatsuba32, 3000); @@ -296,24 +288,3 @@ main(void) { r_bench(bench_karatsuba2048, 3000); r_bench(bench_karatsuba4096, 3000); } - -/* -benchmark: bench_quadratic16 8174101 iters 417 ns/op -benchmark: bench_quadratic32 2135396 iters 1673 ns/op -benchmark: bench_quadratic64 519625 iters 6898 ns/op -benchmark: bench_quadratic128 136077 iters 26540 ns/op -benchmark: bench_quadratic256 34462 iters 104515 ns/op -benchmark: bench_quadratic512 8522 iters 412899 ns/op -benchmark: bench_quadratic1024 2196 iters 1645027 ns/op -benchmark: bench_quadratic2048 540 iters 6601480 ns/op -benchmark: bench_quadratic4096 136 iters 26184686 ns/op -benchmark: bench_karatsuba16 8448774 iters 426 ns/op -benchmark: bench_karatsuba32 2520357 iters 1430 ns/op -benchmark: bench_karatsuba64 769735 iters 4695 ns/op -benchmark: bench_karatsuba128 242052 iters 14550 ns/op -benchmark: bench_karatsuba256 79255 iters 45003 ns/op -benchmark: bench_karatsuba512 26338 iters 136375 ns/op -benchmark: bench_karatsuba1024 8701 iters 412653 ns/op -benchmark: bench_karatsuba2048 2900 iters 1245014 ns/op -benchmark: bench_karatsuba4096 955 iters 3748654 ns/op -*/