commit 7b0840047035241b7d2596c216111f84c47fac1e
parent 0a2da1e9f35a1bda6ac0514223247a721018e1d5
Author: Robert Russell <robert@rr3.xyz>
Date: Thu, 2 Jan 2025 18:08:43 -0800
Clean up
Diffstat:
| M | bigmul.c | | | 189 | +++++++++++++++++++++++++++++++++---------------------------------------------- |
1 file changed, 78 insertions(+), 111 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -6,7 +6,10 @@
// TODO: Add (slow) fallback for __builtin_{add,sub}cl when not using clang or
// GCC 14. Maybe that should go in rcx.
-#define KARATSUBA_THRESH 32 // Best power of 2 determined via benchmarking
+// This power of 2 results in the lowest run-time (on the hardware on which the
+// benchmarks were run). This must be at least 4 (see the comments in
+// karatsuba).
+#define KARATSUBA_THRESH 32
struct nat {
usize cap;
@@ -105,30 +108,25 @@ fma_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
}
}
-// Precondition: r does not intersect x nor y
-void
-mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
- memset(r, 0, (m + n) * sizeof *r);
- fma_quadratic(r, x, m, y, n);
-}
-
-// Precondition: r does not intersect x nor y
-// TODO: Document precondition regarding size of scratch memory.
+// Precondition: capacity(r) >= 2 * n, capacity(x) = capactiy(y) = n
+// Precondition: r is disjoint with x and y
+// Precondition: capacity(tt) >= 2 * kk, where kk := 2 * (n - n / 2 + 1)
void
karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) {
- /* TODO: Update
- * We seek to multiply x and y, which have m and n "words" (digits of
- * base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2)
- * and split x and y as
- * x = xh * b^k + xl and y = yh * b^k + yl.
+ /* We seek to multiply x and y, which each have n "words" (digits of base
+ * b := 2^64), obtaining the full 2*n word product. For this, we let
+ * l := floor(n / 2) and h := ceil(n / 2)
+ * and partition x and y into low parts xl and yl with l words and high
+ * parts xh and yh with h words, such that
+ * x = xh * b^l + xl and y = yh * b^l + yl.
* Then
- * x * y = s * b^(2k) + t * b^k + u,
+ * x * y = s * b^(2*l) + t * b^l + u,
* where
* s := xh * yh, t := xh * yl + xl * yh, and u := xl * yl.
* Thus, we could multiply x and y by recursively evaluating the four
* products in the definition of s, t, and u (the products involving b^i
- * are just bit shifts). However, this would result in quadratic
- * performance, the same asymptotic performance as the naive "doubly-nested
+ * are just bit shifts). However, this would result in a time complexity
+ * of O(n^2), the same asymptotic performance as the naive "doubly-nested
* for-loop" algorithm. Instead, we exploit the following identity, whose
* significance for multiplication algorithms was first noticed by Anatoly
* Karatsuba in 1960:
@@ -136,68 +134,82 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) {
* where
* p := xl + xh and q := yl + yh.
* Computing t in its latter form saves one multiplication at the expense
- * of a few additions/subtractions, which reduces the asymptotic run-time
- * to O(max(m, n)^(lg 3)).
+ * of a few additions/subtractions (which have O(n) time complexity),
+ * thereby reducing the time complexity to O(n^(lg 3)).
*
* Let |z| denote the number of words in a number z. For well-founded
- * recusion, we need the pair (m, n) to lexicographically decrease in each
- * of the three recursive calls. That is, for each recursive call, we need
- * m' < m && n' <= n or m' <= m && n' < n.
- * When we compute xl * yl and xh * yh, this is true as long as m >= 2 (or
- * n >= 2), for then |xl|,|xh| < |x| = m (resp., |yl|,|yh| < |y| = n). On
- * the other hand, if m >= n, then
- * |p| = |xl + xh| (definition of p)
- * <= max(|xl|, |xh|) + 1 (addition adds at most 1 word)
- * = |xl| + 1 (|xl| >= |xh|)
- * = min(k, m) + 1 (definition of xl)
- * = min(ceil(max(m, n) / 2), m) + 1 (definition of k)
- * = min(ceil(m / 2), m) + 1, (m >= n)
- * and as long as m >= 4, this quantity is strictly less than m. Since we
- * always have |q| = |yl + yh| <= |y| = n, this ensures the termination of
- * the evaluation of p * q. Similarly, if n >= m (instead of m >= n) and
- * n >= 4 (instead of m >= 4), then |q| < n and |p| <= m. It therefore
- * suffices to separate the case m < 4 || n < 4 for the recursion basis. */
-
- if (n < KARATSUBA_THRESH) { // TODO: Calibrate, and ensure necessary bounds
+ * recusion, we need n to strictly decrease in each of the three recursive
+ * calls. When we compute u = xl * yl and s = xh * yh, this is true as long
+ * as n >= 2, for then |xl|,|yl| <= l < n and |yh|,|yh| <= h < n. For the
+ * computation of p * q in t, on the other hand, we have
+ * |p|,|q| = |xl + xh|,|yl + yh| (definition of p and q)
+ * <= max(l, h) + 1 (addition adds at most 1 word)
+ * = h + 1 (h >= l)
+ * = ceil(n / 2) + 1 (definition of h)
+ * and as long as n >= 4, this quantity is strictly less than n. It
+ * therefore suffices to separate the case n < 4 for the recursion
+ * basis. */
+
+ // 1. Basis
+ if (n < KARATSUBA_THRESH) {
fma_quadratic(r, x, n, y, n);
return;
}
- // 1. Compute l, h, k, and their doubles
- usize l = n / 2, ll = l * 2;
- usize h = n - l, hh = h * 2;
- usize k = h + 1, kk = k * 2;
-
- // 2. Split x and y
+ // 2. Compute l, h, and k, and their doubles ll, hh, and kk
+ // The significance of these quantities is as follows:
+ // - l is the max width of xl and yl
+ // - h is the max width of xh and yh
+ // - k is the max width of p and q
+ // - ll is the max width of u
+ // - hh is the max width of s
+ // - kk is the max width of t
+ usize l = n / 2, ll = 2 * l;
+ usize h = n - l, hh = 2 * h;
+ usize k = h + 1, kk = 2 * k;
+
+ // 3. Split x and y
u64 *xh = x + l, *xl = x;
u64 *yh = y + l, *yl = y;
- // 3. Assign blocks of memory for intermediate results.
+ // 4. Assign blocks of memory for intermediate results
// Note that we use the output buffer r as temporary storage for p and q.
- // We also store u and s directly in r at the appropriate offsets, such
- // that p and q overlap with u and s, but that's ok, because we're done
- // with p and q by the time we calculate u and s.
- // TODO: Justify all these lengths
+ // We also store u and s directly in r at the appropriate offsets (free bit
+ // shifts!), such that p and q overlap with u and s, but that's ok, because
+ // we're done with p and q by the time we calculate u and s.
u64 *p = r;
u64 *q = r + k;
- u64 *t0 = tt;
+ u64 *t0 = tt; // Storage for t in this invocation
+ u64 *t1 = tt + kk; // Storage for t in future (recursive) invocations
u64 *u = r;
u64 *s = r + ll;
- u64 *t1 = tt + kk;
- // 4. Arithmetic
+ // 5. Arithmetic
p[k] = add(p, xh, h, xl, l); // p = xh + xl
q[k] = add(q, yh, h, yl, l); // q = yh + yl
- karatsuba(t0, p, q, k, t1); // t0 = p * q
+ karatsuba(t0, p, q, k, t1); // t = p * q
karatsuba(u, xl, yl, l, t1); // u = xl * yl
karatsuba(s, xh, yh, h, t1); // s = xh * yh
- sub(t0, t0, kk, u, ll); // t0 -= u (borrow out must be 0) TODO: explain
- sub(t0, t0, kk, s, hh); // t0 -= s (borrow out must be 0) TODO: explain
- add(r + l, r + l, hh + l, t0, kk); // r[l..] += t0 (carry out must be 0) TODO: explain
+ sub(t0, t0, kk, u, ll); // t -= u [1]
+ sub(t0, t0, kk, s, hh); // t -= s [1]
+ add(r + l, r + l, hh + l, t0, kk); // r[l..] += t [2] [3]
+ // [1]: The borrow outs are guaranteed to be 0, because t0 - u - s must
+ // be positive.
+ // [2]: The carry out is guaranteed to be 0, because the full product
+ // x * y must fit in 2 * n words.
+ // [3]: The add precondition hh + l >= kk is satisfied here as long as
+ // n >= 4, and it is, because n < 4 is the recursion basis.
}
+// Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n
+// Precondition: r is disjoint with x and y
+// Precondition: m >= n
void
-fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
+fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+ // TODO: Accept capacity of r as an argument, and use excess memory for
+ // scratch, if it's big enough, instead of allocating.
+ usize firstkk = 2 * (n - n / 2 + 1);
+ u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch);
u64 *prod = scratch;
u64 *tt = scratch + 2 * n;
@@ -207,19 +219,22 @@ fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
add(r, prod, 2 * n, r, n);
}
- if (m == 0) break; // TODO: Remove this if we special case m == n.
+ if (m == 0) break;
- if (m < KARATSUBA_THRESH) { // TODO: Calibrate.
- fma_quadratic(prod, x, m, y, n);
- add(r, prod, m + n, r, n);
+ if (m < KARATSUBA_THRESH) {
+ fma_quadratic(r, x, m, y, n);
break;
}
u64 *t0 = x; x = y; y = t0; // Swap x and y
usize t1 = m; m = n; n = t1; // Swap m and n
}
+
+ free(scratch);
}
+// Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n
+// Precondition: r is disjoint with x and y
void
mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
if (m < n) {
@@ -227,22 +242,8 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
usize t1 = m; m = n; n = t1; // Swap m and n
}
- if (n < KARATSUBA_THRESH) { // TODO: Calibrate.
- mul_quadratic(r, x, m, y, n);
- return;
- }
-
- // TODO: Special-case m == n?
-
- // TODO: Accept capacity of r as an argument, and use excess memory for
- // scratch, if it's big enough, instead of allocating.
- usize firstkk = 2 * (n - n / 2 + 1);
- u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch);
-
memset(r, 0, (m + n) * sizeof *r);
- fma_karatsuba(r, x, m, y, n, scratch);
-
- free(scratch);
+ (n < KARATSUBA_THRESH ? fma_quadratic : fma_karatsuba)(r, x, m, y, n);
}
u64 x[4096];
@@ -250,23 +251,6 @@ u64 y[4096];
u64 r[8192];
NOINLINE void
-bench_quadratic(u64 l, u64 n) {
- r_bench_start();
- for (u64 i = 0; i < n; i++) mul_quadratic(r, x, l, y, l);
- r_bench_stop();
-}
-
-void bench_quadratic16(u64 n) { bench_quadratic(16, n); }
-void bench_quadratic32(u64 n) { bench_quadratic(32, n); }
-void bench_quadratic64(u64 n) { bench_quadratic(64, n); }
-void bench_quadratic128(u64 n) { bench_quadratic(128, n); }
-void bench_quadratic256(u64 n) { bench_quadratic(256, n); }
-void bench_quadratic512(u64 n) { bench_quadratic(512, n); }
-void bench_quadratic1024(u64 n) { bench_quadratic(1024, n); }
-void bench_quadratic2048(u64 n) { bench_quadratic(2048, n); }
-void bench_quadratic4096(u64 n) { bench_quadratic(4096, n); }
-
-NOINLINE void
bench_karatsuba(u64 l, u64 n) {
r_bench_start();
for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l);
@@ -285,26 +269,9 @@ void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); }
int
main(void) {
- u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef };
- u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc };
- u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y));
- u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y));
- printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]);
- printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]);
-
for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64();
for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();
- r_bench(bench_quadratic16, 1000);
- r_bench(bench_quadratic32, 1000);
- r_bench(bench_quadratic64, 1000);
- r_bench(bench_quadratic128, 1000);
- r_bench(bench_quadratic256, 1000);
- r_bench(bench_quadratic512, 1000);
- r_bench(bench_quadratic1024, 1000);
- r_bench(bench_quadratic2048, 1000);
- r_bench(bench_quadratic4096, 1000);
-
r_bench(bench_karatsuba16, 1000);
r_bench(bench_karatsuba32, 1000);
r_bench(bench_karatsuba64, 1000);