commit 1e05d117b80ca748c86bc6c2db9df359acbd3e26
parent 0d4646cfb7375fe8fc16d000f1340a6a773eede0
Author: Robert Russell <robert@rr3.xyz>
Date: Wed, 1 Jan 2025 01:14:27 -0800
Document Karatsuba
Diffstat:
| M | bigmul.c | | | 75 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------- |
1 file changed, 58 insertions(+), 17 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -93,35 +93,76 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// TODO: Document precondition regarding size of scratch memory.
void
karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
- // If we allow m < 4 and n < 4, then the recursion is not well-founded.
- // TODO: Determine best threshold for quadratic.
- if (m < 4 && n < 4) {
+ /* We seek to multiply x and y, which have m and n "words" (digits of
+ * base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2)
+ * and split x and y as
+ * x = xh * b^k + xl and y = yh * b^k + yl.
+ * Then
+ * x * y = s * b^(2k) + t * b^k + u,
+ * where
+ * s := xh * yh, t := xh * yl + xl * yh, and u := xl * yl.
+ * Thus, we could multiply x and y by recursively evaluating the four
+ * products in the definition of s, t, and u (the products involving b^i
+ * are just bit shifts). However, this would result in quadratic
+ * performance, the same asymptotic performance as the naive "doubly-nested
+ * for-loop" algorithm. Instead, we exploit the following identity, whose
+ * significance for multiplication algorithms was first noticed by Anatoly
+ * Karatsuba in 1960:
+ * t = p * q - u - s
+ * where
+ * p := xl + xh and q := yl + yh.
+ * Computing t in its latter form saves one multiplication at the expense
+ * of a few additions/subtractions, which reduces the asymptotic run-time
+ * to O(max(m, n)^(lg 3)).
+ *
+ * Let |z| denote the number of words in a number z. For well-founded
+ * recusion, we need the pair (m, n) to lexicographically decrease in each
+ * of the three recursive calls. That is, for each recursive call, we need
+ * m' < m && n' <= n or m' <= m && n' < n.
+ * When we compute xl * yl and xh * yh, this is true as long as m >= 2 (or
+ * n >= 2), for then |xl|,|xh| < |x| = m (resp., |yl|,|yh| < |y| = n). On
+ * the other hand, if m >= n, then
+ * |p| = |xl + xh| (definition of p)
+ * <= max(|xl|, |xh|) + 1 (addition adds at most 1 word)
+ * = |xl| + 1 (|xl| >= |xh|)
+ * = min(k, m) + 1 (definition of xl)
+ * = min(ceil(max(m, n) / 2), m) + 1 (definition of k)
+ * = min(ceil(m / 2), m) + 1, (m >= n)
+ * and as long as m >= 4, this quantity is strictly less than m. Since we
+ * always have |q| = |yl + yh| <= |y| = n, this ensures the termination of
+ * the evaluation of p * q. Similarly, if n >= m (instead of m >= n) and
+ * n >= 4, then |q| < n and |p| <= m. It therefore suffices to separate the
+ * case m < 4 || n < 4 for the recursion basis. */
+
+ // TODO: Determine best threshold for switching to the quadratic method.
+ if (m < 4 || n < 4) {
mul_quadratic(r, x, m, y, n);
return;
}
- usize k = (MAX(m, n) + 1) / 2; // k = ceil(max(m, n) / 2)
+ usize k = (MAX(m, n) + 1) / 2;
// 1. Split x
usize mh = m > k ? m - k : 0, ml = MIN(k, m);
- u64 *xh = x + ml, *xl = x; // x = xh * b^k + xl
+ u64 *xh = x + ml, *xl = x;
// 2. Split y
usize nh = n > k ? n - k : 0, nl = MIN(k, n);
- u64 *yh = y + nl, *yl = y; // y = yh * b^k + yl
+ u64 *yh = y + nl, *yl = y;
- // 3. Assign/allocate memory
- // Note that we use the output buffer r as temporary storage for
- // intermediate results.
+ // 3. Assign blocks of memory for intermediate results.
+ // Note that we use the output buffer r as temporary storage for p and q.
+ // We also store u and s directly in r at the appropriate offsets, such
+ // that p and q overlap with u and s, but that's ok, because we're done
+ // with p and q by the time we calculate u and s.
usize rw = m + n;
- usize pw = ml + 1; u64 *p = r; // Note: ml = MAX(ml, mh)
- usize qw = nl + 1; u64 *q = r + pw; // Note: nl = MAX(nl, nh)
- usize tw = pw + qw; u64 *t = scratch;
- usize uw = ml + nl; u64 *u = r;
- usize sw = mh + mh; u64 *s = r + 2 * k;
+ usize pw = MIN(ml + 1, m); u64 *p = r;
+ usize qw = MIN(nl + 1, n); u64 *q = r + pw;
+ usize tw = pw + qw; u64 *t = scratch;
+ usize uw = ml + nl; u64 *u = r;
+ usize sw = mh + mh; u64 *s = r + 2 * k;
// 4. Arithmetic
- // TODO: Explain algorithm.
add(p, xl, ml, xh, mh); // p = xl + xh
add(q, yl, nl, yh, nh); // q = yl + yh
karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q
@@ -135,8 +176,8 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
void
mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// TODO: Accept capacity of r as an argument, and use excess memory for
- // scratch if it's big enough instead of allocating.
- u64 *scratch = r_eallocn((m + n + 2) * 2, sizeof *scratch);
+ // scratch, if it's big enough, instead of allocating.
+ u64 *scratch = r_eallocn((m + n) * 2, sizeof *scratch);
karatsuba(r, x, m, y, n, scratch);
free(scratch);
}