commit be1311cb7ebfa8ce641861617a4f0c5f739badd5
parent 0714d05592f0527d3572fb31b1b345e89e1facfc
Author: Robert Russell <robert@rr3.xyz>
Date: Wed, 1 Jan 2025 21:53:15 -0800
Clean up wide u64 math
This is a performance regression.
Diffstat:
| M | bigmul.c | | | 153 | ++++++++++++++++++++++++++++++++----------------------------------------------- |
1 file changed, 62 insertions(+), 91 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -13,75 +13,67 @@ struct nat {
typedef struct nat Nat[1];
-// TODO: Make sure that GCC inlines and compiles {add,sub}64 well.
-// Try __builtin_{add,sub}c{,l,ll}.
-inline void
-add64(u64 *co, u64 *r, u64 x, u64 y, u64 ci) {
- u128 cr = (u128)x + (u128)y + (u128)ci;
- *co = cr >> 64;
- *r = cr;
-}
+/* ----- Wide u64 math ----- */
-inline void
-sub64(u64 *bo, u64 *r, u64 x, u64 y, u64 bi) {
- u128 br = (u128)x - (u128)y - (u128)bi;
- *bo = -(br >> 64);
- *r = br;
-}
+// TODO: Wide sub doesn't make much sense.
+#define WIDE(rh, rl, x, op, y) do { \
+ u128 r = (u128)(x) op (u128)(y); \
+ *(rh) = r >> 64; \
+ *(rl) = r; \
+ } while (0)
+inline void add64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, +, y); }
+inline void sub64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, -, y); }
+inline void mul64(u64 *rh, u64 *rl, u64 x, u64 y) { WIDE(rh, rl, x, *, y); }
+#undef WIDE
inline void
-mul64(u64 *rh, u64 *rl, u64 x, u64 y) {
- u128 r = (u128)x * (u128)y;
- *rh = r >> 64;
- *rl = r;
+fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) {
+ u64 h0, h1, h2, l;
+ mul64(&h0, &l, w, x); // h0:l = w * x
+ add64(&h1, &l, l, y); // h1:l = l + y
+ add64(&h2, &l, l, z); // h2:l = l + z
+ *rh = h0 + h1 + h2;
+ *rl = l;
}
-/*
-inline void
-fma64(u64 *rh, u64 *rl, u64 x, u64 y, u64 z) {
- u64 h, l, c;
- mul64(&h, &l, x, y);
- add64(&c, rl, l, z, 0);
- *rh = h + c;
-}
-*/
-inline void
-fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) {
- u64 h, l, c, d;
- mul64(&h, &l, w, x); // [h, l] = w * x
- add64(&c, &l, l, y, 0); // [c, l] = l + y
- add64(&d, rl, l, z, 0); // [d, rl] = l + z
- *rh = h + c + d; // rh = h + c + d
-}
+/* ----- Big nat math ----- */
// Precondition: m >= n
u64
-add(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 c) {
- for (usize i = 0; i < n; i++)
- add64(&c, &r[i], x[i], y[i], c);
+add(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+ u64 c = 0;
+ for (usize i = 0; i < n; i++) {
+ u64 h0, h1, l;
+ add64(&h0, &l, x[i], y[i]); // h0:l = x[i] + y[i]
+ add64(&h1, &l, l, c); // h1:l = l + c
+ c = h0 + h1;
+ r[i] = l;
+ }
for (usize i = n; i < m; i++)
- add64(&c, &r[i], x[i], c, 0);
+ add64(&c, &r[i], x[i], c); // c:r[i] = x[i] + c
return c;
}
-u64
-addw(u64 *r, u64 *x, usize m, u64 y) {
- for (usize i = 0; i < m; i++)
- add64(&y, &r[i], x[i], y, 0);
- return y;
-}
-
// Precondition: m >= n
// TODO: sub is not commutative like add, so we need a "bus" operation
// ("sub" backwards) for when m < n.
u64
-sub(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 b) {
- for (usize i = 0; i < n; i++)
- sub64(&b, &r[i], x[i], y[i], b);
- for (usize i = n; i < m; i++)
- sub64(&b, &r[i], x[i], b, 0);
+sub(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+ u64 b = 0;
+ for (usize i = 0; i < n; i++) {
+ u64 h0, h1, l;
+ sub64(&h0, &l, x[i], y[i]); // h0:l = x[i] - y[i]
+ sub64(&h1, &l, l, b); // h1:l = l - b
+ b = -h0 + -h1;
+ r[i] = l;
+ }
+ for (usize i = n; i < m; i++) {
+ u64 h;
+ sub64(&h, &r[i], x[i], b);
+ b = -h;
+ }
return b;
}
@@ -171,14 +163,14 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) {
u64 *scratch1 = scratch0 + kk;
// 4. Arithmetic
- p[k] = add(p, xh, h, xl, l, 0); // p = xh + xl
- q[k] = add(q, yh, h, yl, l, 0); // q = yh + yl
- karatsuba(t, p, q, k, scratch1); // t = p * q
- karatsuba(u, xl, yl, l, scratch1); // u = xl * yl
- karatsuba(s, xh, yh, h, scratch1); // s = xh * yh
- sub(t, t, kk, u, ll, 0); // t -= u (borrow out must be 0) TODO: explain
- sub(t, t, kk, s, hh, 0); // t -= s (borrow out must be 0) TODO: explain
- add(r + l, r + l, hh + l, t, kk, 0); // r[l..] += t (carry out must be 0) TODO: explain
+ p[k] = add(p, xh, h, xl, l); // p = xh + xl
+ q[k] = add(q, yh, h, yl, l); // q = yh + yl
+ karatsuba(t, p, q, k, scratch1); // t = p * q
+ karatsuba(u, xl, yl, l, scratch1); // u = xl * yl
+ karatsuba(s, xh, yh, h, scratch1); // s = xh * yh
+ sub(t, t, kk, u, ll); // t -= u (borrow out must be 0) TODO: explain
+ sub(t, t, kk, s, hh); // t -= s (borrow out must be 0) TODO: explain
+ add(r + l, r + l, hh + l, t, kk); // r[l..] += t (carry out must be 0) TODO: explain
}
void
@@ -209,12 +201,12 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
for (;;) {
for (; m >= n; r += n, x += n, m -= n) {
karatsuba(prod, x, y, n, scratch);
- add(r, prod, 2 * n, r, n, 0);
+ add(r, prod, 2 * n, r, n);
}
if (m == 0) break;
if (m < KARATSUBA_THRESH) { // TODO: Calibrate.
mul_quadratic(prod, x, m, y, n);
- add(r, prod, m + n, r, n, 0);
+ add(r, prod, m + n, r, n);
break;
}
u64 *t0 = x; x = y; y = t0;
@@ -276,15 +268,15 @@ main(void) {
for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64();
for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();
- r_bench(bench_quadratic16, 3000);
- r_bench(bench_quadratic32, 3000);
- r_bench(bench_quadratic64, 3000);
- r_bench(bench_quadratic128, 3000);
- r_bench(bench_quadratic256, 3000);
- r_bench(bench_quadratic512, 3000);
- r_bench(bench_quadratic1024, 3000);
- r_bench(bench_quadratic2048, 3000);
- r_bench(bench_quadratic4096, 3000);
+ // r_bench(bench_quadratic16, 3000);
+ // r_bench(bench_quadratic32, 3000);
+ // r_bench(bench_quadratic64, 3000);
+ // r_bench(bench_quadratic128, 3000);
+ // r_bench(bench_quadratic256, 3000);
+ // r_bench(bench_quadratic512, 3000);
+ // r_bench(bench_quadratic1024, 3000);
+ // r_bench(bench_quadratic2048, 3000);
+ // r_bench(bench_quadratic4096, 3000);
r_bench(bench_karatsuba16, 3000);
r_bench(bench_karatsuba32, 3000);
@@ -296,24 +288,3 @@ main(void) {
r_bench(bench_karatsuba2048, 3000);
r_bench(bench_karatsuba4096, 3000);
}
-
-/*
-benchmark: bench_quadratic16 8174101 iters 417 ns/op
-benchmark: bench_quadratic32 2135396 iters 1673 ns/op
-benchmark: bench_quadratic64 519625 iters 6898 ns/op
-benchmark: bench_quadratic128 136077 iters 26540 ns/op
-benchmark: bench_quadratic256 34462 iters 104515 ns/op
-benchmark: bench_quadratic512 8522 iters 412899 ns/op
-benchmark: bench_quadratic1024 2196 iters 1645027 ns/op
-benchmark: bench_quadratic2048 540 iters 6601480 ns/op
-benchmark: bench_quadratic4096 136 iters 26184686 ns/op
-benchmark: bench_karatsuba16 8448774 iters 426 ns/op
-benchmark: bench_karatsuba32 2520357 iters 1430 ns/op
-benchmark: bench_karatsuba64 769735 iters 4695 ns/op
-benchmark: bench_karatsuba128 242052 iters 14550 ns/op
-benchmark: bench_karatsuba256 79255 iters 45003 ns/op
-benchmark: bench_karatsuba512 26338 iters 136375 ns/op
-benchmark: bench_karatsuba1024 8701 iters 412653 ns/op
-benchmark: bench_karatsuba2048 2900 iters 1245014 ns/op
-benchmark: bench_karatsuba4096 955 iters 3748654 ns/op
-*/