commit 0714d05592f0527d3572fb31b1b345e89e1facfc
parent a58e698cfc79a9511cf576aee2288e07b6c54e7c
Author: Robert Russell <robert@rr3.xyz>
Date: Wed, 1 Jan 2025 19:31:37 -0800
Only karatsuba numbers of same width
This results in a performance boost, and simplifies the karatsuba
function and analysis.
This is still WIP.
Diffstat:
| M | bigmul.c | | | 133 | ++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------- |
1 file changed, 87 insertions(+), 46 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -3,6 +3,8 @@
#include <stdio.h>
#include <unistd.h>
+#define KARATSUBA_THRESH 32 // Best power of 2 determined via benchmarking
+
struct nat {
usize cap;
usize len;
@@ -55,27 +57,32 @@ fmaa64(u64 *rh, u64 *rl, u64 w, u64 x, u64 y, u64 z) {
}
// Precondition: m >= n
-void
-add(u64 *r, u64 *x, usize m, u64 *y, usize n) {
- u64 c = 0;
+u64
+add(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 c) {
for (usize i = 0; i < n; i++)
add64(&c, &r[i], x[i], y[i], c);
for (usize i = n; i < m; i++)
add64(&c, &r[i], x[i], c, 0);
- r[m] = c;
+ return c;
+}
+
+u64
+addw(u64 *r, u64 *x, usize m, u64 y) {
+ for (usize i = 0; i < m; i++)
+ add64(&y, &r[i], x[i], y, 0);
+ return y;
}
// Precondition: m >= n
// TODO: sub is not commutative like add, so we need a "bus" operation
// ("sub" backwards) for when m < n.
-void
-sub(u64 *r, u64 *x, usize m, u64 *y, usize n) {
- u64 b = 0;
+u64
+sub(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 b) {
for (usize i = 0; i < n; i++)
sub64(&b, &r[i], x[i], y[i], b);
for (usize i = n; i < m; i++)
sub64(&b, &r[i], x[i], b, 0);
- r[m] = -b; // TODO: I don't think this makes sense for nats.
+ return b;
}
// Precondition: r does not intersect x nor y
@@ -93,8 +100,9 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// Precondition: r does not intersect x nor y
// TODO: Document precondition regarding size of scratch memory.
void
-karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
- /* We seek to multiply x and y, which have m and n "words" (digits of
+karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *scratch0) {
+ /* TODO: Update
+ * We seek to multiply x and y, which have m and n "words" (digits of
* base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2)
* and split x and y as
* x = xh * b^k + xl and y = yh * b^k + yl.
@@ -135,54 +143,85 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
* n >= 4 (instead of m >= 4), then |q| < n and |p| <= m. It therefore
* suffices to separate the case m < 4 || n < 4 for the recursion basis. */
- usize maxmn = MAX(m, n);
- if (m < 4 || n < 4 || maxmn < 32) { // 32 was determined via benchmarks.
- mul_quadratic(r, x, m, y, n);
+ if (n < KARATSUBA_THRESH) { // TODO: Calibrate, and ensure necessary bounds
+ mul_quadratic(r, x, n, y, n);
return;
}
- usize k = (maxmn + 1) / 2;
-
- // 1. Split x
- usize mh = m > k ? m - k : 0, ml = MIN(k, m);
- u64 *xh = x + ml, *xl = x;
+ // 1. Compute l, h, k, and their doubles
+ usize l = n / 2, ll = l * 2;
+ usize h = n - l, hh = h * 2;
+ usize k = h + 1, kk = k * 2;
- // 2. Split y
- usize nh = n > k ? n - k : 0, nl = MIN(k, n);
- u64 *yh = y + nl, *yl = y;
+ // 2. Split x and y
+ u64 *xh = x + l, *xl = x;
+ u64 *yh = y + l, *yl = y;
// 3. Assign blocks of memory for intermediate results.
// Note that we use the output buffer r as temporary storage for p and q.
// We also store u and s directly in r at the appropriate offsets, such
// that p and q overlap with u and s, but that's ok, because we're done
// with p and q by the time we calculate u and s.
- usize pw = MIN(ml + 1, m); u64 *p = r;
- usize qw = MIN(nl + 1, n); u64 *q = r + pw;
- usize tw = pw + qw; u64 *t = scratch;
- usize uw = ml + nl; u64 *u = r;
- usize sw = mh + nh; u64 *s = r + 2 * k;
+ // TODO: Justify all these lengths
+ u64 *p = r;
+ u64 *q = r + k;
+ u64 *t = scratch0;
+ u64 *u = r;
+ u64 *s = r + ll;
+ u64 *scratch1 = scratch0 + kk;
// 4. Arithmetic
- add(p, xl, ml, xh, mh); // p = xl + xh
- add(q, yl, nl, yh, nh); // q = yl + yh
- karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q
- karatsuba(u, xl, ml, yl, nl, scratch + tw); // u = xl * yl
- for (usize i = uw; i < 2 * k; i++) r[i] = 0; // r[uw..2*k] = 0
- karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh
- sub(t, t, tw, u, uw); // t -= u
- sub(t, t, tw, s, sw); // t -= s
- add(r + k, t, tw, r + k, k + sw); // r[k..] += t
- for (usize i = tw + k + sw; i < m + n; i++) r[i] = 0; // TODO
- // TODO: Prove that tw + sw <= m + n - k.
+ p[k] = add(p, xh, h, xl, l, 0); // p = xh + xl
+ q[k] = add(q, yh, h, yl, l, 0); // q = yh + yl
+ karatsuba(t, p, q, k, scratch1); // t = p * q
+ karatsuba(u, xl, yl, l, scratch1); // u = xl * yl
+ karatsuba(s, xh, yh, h, scratch1); // s = xh * yh
+ sub(t, t, kk, u, ll, 0); // t -= u (borrow out must be 0) TODO: explain
+ sub(t, t, kk, s, hh, 0); // t -= s (borrow out must be 0) TODO: explain
+ add(r + l, r + l, hh + l, t, kk, 0); // r[l..] += t (carry out must be 0) TODO: explain
}
void
mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+ if (m < n) {
+ u64 *t0 = x; x = y; y = t0; // Swap x and y
+ usize t1 = m; m = n; n = t1; // Swap m and n
+ }
+
+ if (n < KARATSUBA_THRESH) { // TODO: Calibrate.
+ mul_quadratic(r, x, m, y, n);
+ return;
+ }
+
+ // TODO: Special-case m == n?
+
// TODO: Accept capacity of r as an argument, and use excess memory for
// scratch, if it's big enough, instead of allocating.
- u64 *scratch = r_eallocn((m + n) * 2, sizeof *scratch);
- karatsuba(r, x, m, y, n, scratch);
- free(scratch);
+ usize firstkk = 2 * (n - n / 2 + 1);
+ u64 *mem = r_eallocn(2 * n + 2 * firstkk, sizeof *mem);
+ u64 *prod = mem;
+ u64 *scratch = mem + 2 * n;
+
+ // TODO: The control flow here kinda sucks.
+ // TODO: There are unnecessary copies between prod and r. Try to do it in
+ // "one pass", without first initializing r to 0.
+ memset(r, 0, (m + n) * sizeof *r);
+ for (;;) {
+ for (; m >= n; r += n, x += n, m -= n) {
+ karatsuba(prod, x, y, n, scratch);
+ add(r, prod, 2 * n, r, n, 0);
+ }
+ if (m == 0) break;
+ if (m < KARATSUBA_THRESH) { // TODO: Calibrate.
+ mul_quadratic(prod, x, m, y, n);
+ add(r, prod, m + n, r, n, 0);
+ break;
+ }
+ u64 *t0 = x; x = y; y = t0;
+ usize t1 = m; m = n; n = t1;
+ }
+
+ free(mem);
}
u64 x[4096];
@@ -225,12 +264,14 @@ void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); }
int
main(void) {
- // u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef };
- // u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc };
- // u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y));
- // u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y));
- // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]);
- // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]);
+/*
+ u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef };
+ u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc };
+ u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y));
+ u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y));
+ printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]);
+ printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]);
+*/
for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64();
for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();