bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

commit 7b0840047035241b7d2596c216111f84c47fac1e
parent 0a2da1e9f35a1bda6ac0514223247a721018e1d5
Author: Robert Russell <robert@rr3.xyz>
Date:   Thu,  2 Jan 2025 18:08:43 -0800

Clean up

Diffstat:
Mbigmul.c | 189+++++++++++++++++++++++++++++++++----------------------------------------------
1 file changed, 78 insertions(+), 111 deletions(-)

diff --git a/bigmul.c b/bigmul.c @@ -6,7 +6,10 @@ // TODO: Add (slow) fallback for __builtin_{add,sub}cl when not using clang or // GCC 14. Maybe that should go in rcx. -#define KARATSUBA_THRESH 32 // Best power of 2 determined via benchmarking +// This power of 2 results in the lowest run-time (on the hardware on which the +// benchmarks were run). This must be at least 4 (see the comments in +// karatsuba). +#define KARATSUBA_THRESH 32 struct nat { usize cap; @@ -105,30 +108,25 @@ fma_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) { } } -// Precondition: r does not intersect x nor y -void -mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) { - memset(r, 0, (m + n) * sizeof *r); - fma_quadratic(r, x, m, y, n); -} - -// Precondition: r does not intersect x nor y -// TODO: Document precondition regarding size of scratch memory. +// Precondition: capacity(r) >= 2 * n, capacity(x) = capactiy(y) = n +// Precondition: r is disjoint with x and y +// Precondition: capacity(tt) >= 2 * kk, where kk := 2 * (n - n / 2 + 1) void karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) { - /* TODO: Update - * We seek to multiply x and y, which have m and n "words" (digits of - * base b := 2^64), respectively. For this, we let k := ceil(max(m, n) / 2) - * and split x and y as - * x = xh * b^k + xl and y = yh * b^k + yl. + /* We seek to multiply x and y, which each have n "words" (digits of base + * b := 2^64), obtaining the full 2*n word product. For this, we let + * l := floor(n / 2) and h := ceil(n / 2) + * and partition x and y into low parts xl and yl with l words and high + * parts xh and yh with h words, such that + * x = xh * b^l + xl and y = yh * b^l + yl. * Then - * x * y = s * b^(2k) + t * b^k + u, + * x * y = s * b^(2*l) + t * b^l + u, * where * s := xh * yh, t := xh * yl + xl * yh, and u := xl * yl. * Thus, we could multiply x and y by recursively evaluating the four * products in the definition of s, t, and u (the products involving b^i - * are just bit shifts). However, this would result in quadratic - * performance, the same asymptotic performance as the naive "doubly-nested + * are just bit shifts). However, this would result in a time complexity + * of O(n^2), the same asymptotic performance as the naive "doubly-nested * for-loop" algorithm. Instead, we exploit the following identity, whose * significance for multiplication algorithms was first noticed by Anatoly * Karatsuba in 1960: @@ -136,68 +134,82 @@ karatsuba(u64 *r, u64 *x, u64 *y, usize n, u64 *tt) { * where * p := xl + xh and q := yl + yh. * Computing t in its latter form saves one multiplication at the expense - * of a few additions/subtractions, which reduces the asymptotic run-time - * to O(max(m, n)^(lg 3)). + * of a few additions/subtractions (which have O(n) time complexity), + * thereby reducing the time complexity to O(n^(lg 3)). * * Let |z| denote the number of words in a number z. For well-founded - * recusion, we need the pair (m, n) to lexicographically decrease in each - * of the three recursive calls. That is, for each recursive call, we need - * m' < m && n' <= n or m' <= m && n' < n. - * When we compute xl * yl and xh * yh, this is true as long as m >= 2 (or - * n >= 2), for then |xl|,|xh| < |x| = m (resp., |yl|,|yh| < |y| = n). On - * the other hand, if m >= n, then - * |p| = |xl + xh| (definition of p) - * <= max(|xl|, |xh|) + 1 (addition adds at most 1 word) - * = |xl| + 1 (|xl| >= |xh|) - * = min(k, m) + 1 (definition of xl) - * = min(ceil(max(m, n) / 2), m) + 1 (definition of k) - * = min(ceil(m / 2), m) + 1, (m >= n) - * and as long as m >= 4, this quantity is strictly less than m. Since we - * always have |q| = |yl + yh| <= |y| = n, this ensures the termination of - * the evaluation of p * q. Similarly, if n >= m (instead of m >= n) and - * n >= 4 (instead of m >= 4), then |q| < n and |p| <= m. It therefore - * suffices to separate the case m < 4 || n < 4 for the recursion basis. */ - - if (n < KARATSUBA_THRESH) { // TODO: Calibrate, and ensure necessary bounds + * recusion, we need n to strictly decrease in each of the three recursive + * calls. When we compute u = xl * yl and s = xh * yh, this is true as long + * as n >= 2, for then |xl|,|yl| <= l < n and |yh|,|yh| <= h < n. For the + * computation of p * q in t, on the other hand, we have + * |p|,|q| = |xl + xh|,|yl + yh| (definition of p and q) + * <= max(l, h) + 1 (addition adds at most 1 word) + * = h + 1 (h >= l) + * = ceil(n / 2) + 1 (definition of h) + * and as long as n >= 4, this quantity is strictly less than n. It + * therefore suffices to separate the case n < 4 for the recursion + * basis. */ + + // 1. Basis + if (n < KARATSUBA_THRESH) { fma_quadratic(r, x, n, y, n); return; } - // 1. Compute l, h, k, and their doubles - usize l = n / 2, ll = l * 2; - usize h = n - l, hh = h * 2; - usize k = h + 1, kk = k * 2; - - // 2. Split x and y + // 2. Compute l, h, and k, and their doubles ll, hh, and kk + // The significance of these quantities is as follows: + // - l is the max width of xl and yl + // - h is the max width of xh and yh + // - k is the max width of p and q + // - ll is the max width of u + // - hh is the max width of s + // - kk is the max width of t + usize l = n / 2, ll = 2 * l; + usize h = n - l, hh = 2 * h; + usize k = h + 1, kk = 2 * k; + + // 3. Split x and y u64 *xh = x + l, *xl = x; u64 *yh = y + l, *yl = y; - // 3. Assign blocks of memory for intermediate results. + // 4. Assign blocks of memory for intermediate results // Note that we use the output buffer r as temporary storage for p and q. - // We also store u and s directly in r at the appropriate offsets, such - // that p and q overlap with u and s, but that's ok, because we're done - // with p and q by the time we calculate u and s. - // TODO: Justify all these lengths + // We also store u and s directly in r at the appropriate offsets (free bit + // shifts!), such that p and q overlap with u and s, but that's ok, because + // we're done with p and q by the time we calculate u and s. u64 *p = r; u64 *q = r + k; - u64 *t0 = tt; + u64 *t0 = tt; // Storage for t in this invocation + u64 *t1 = tt + kk; // Storage for t in future (recursive) invocations u64 *u = r; u64 *s = r + ll; - u64 *t1 = tt + kk; - // 4. Arithmetic + // 5. Arithmetic p[k] = add(p, xh, h, xl, l); // p = xh + xl q[k] = add(q, yh, h, yl, l); // q = yh + yl - karatsuba(t0, p, q, k, t1); // t0 = p * q + karatsuba(t0, p, q, k, t1); // t = p * q karatsuba(u, xl, yl, l, t1); // u = xl * yl karatsuba(s, xh, yh, h, t1); // s = xh * yh - sub(t0, t0, kk, u, ll); // t0 -= u (borrow out must be 0) TODO: explain - sub(t0, t0, kk, s, hh); // t0 -= s (borrow out must be 0) TODO: explain - add(r + l, r + l, hh + l, t0, kk); // r[l..] += t0 (carry out must be 0) TODO: explain + sub(t0, t0, kk, u, ll); // t -= u [1] + sub(t0, t0, kk, s, hh); // t -= s [1] + add(r + l, r + l, hh + l, t0, kk); // r[l..] += t [2] [3] + // [1]: The borrow outs are guaranteed to be 0, because t0 - u - s must + // be positive. + // [2]: The carry out is guaranteed to be 0, because the full product + // x * y must fit in 2 * n words. + // [3]: The add precondition hh + l >= kk is satisfied here as long as + // n >= 4, and it is, because n < 4 is the recursion basis. } +// Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n +// Precondition: r is disjoint with x and y +// Precondition: m >= n void -fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { +fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { + // TODO: Accept capacity of r as an argument, and use excess memory for + // scratch, if it's big enough, instead of allocating. + usize firstkk = 2 * (n - n / 2 + 1); + u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch); u64 *prod = scratch; u64 *tt = scratch + 2 * n; @@ -207,19 +219,22 @@ fma_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { add(r, prod, 2 * n, r, n); } - if (m == 0) break; // TODO: Remove this if we special case m == n. + if (m == 0) break; - if (m < KARATSUBA_THRESH) { // TODO: Calibrate. - fma_quadratic(prod, x, m, y, n); - add(r, prod, m + n, r, n); + if (m < KARATSUBA_THRESH) { + fma_quadratic(r, x, m, y, n); break; } u64 *t0 = x; x = y; y = t0; // Swap x and y usize t1 = m; m = n; n = t1; // Swap m and n } + + free(scratch); } +// Precondition: capacity(r) >= m + n, capacity(x) = m, capactiy(y) = n +// Precondition: r is disjoint with x and y void mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { if (m < n) { @@ -227,22 +242,8 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { usize t1 = m; m = n; n = t1; // Swap m and n } - if (n < KARATSUBA_THRESH) { // TODO: Calibrate. - mul_quadratic(r, x, m, y, n); - return; - } - - // TODO: Special-case m == n? - - // TODO: Accept capacity of r as an argument, and use excess memory for - // scratch, if it's big enough, instead of allocating. - usize firstkk = 2 * (n - n / 2 + 1); - u64 *scratch = r_eallocn(2 * n + 2 * firstkk, sizeof *scratch); - memset(r, 0, (m + n) * sizeof *r); - fma_karatsuba(r, x, m, y, n, scratch); - - free(scratch); + (n < KARATSUBA_THRESH ? fma_quadratic : fma_karatsuba)(r, x, m, y, n); } u64 x[4096]; @@ -250,23 +251,6 @@ u64 y[4096]; u64 r[8192]; NOINLINE void -bench_quadratic(u64 l, u64 n) { - r_bench_start(); - for (u64 i = 0; i < n; i++) mul_quadratic(r, x, l, y, l); - r_bench_stop(); -} - -void bench_quadratic16(u64 n) { bench_quadratic(16, n); } -void bench_quadratic32(u64 n) { bench_quadratic(32, n); } -void bench_quadratic64(u64 n) { bench_quadratic(64, n); } -void bench_quadratic128(u64 n) { bench_quadratic(128, n); } -void bench_quadratic256(u64 n) { bench_quadratic(256, n); } -void bench_quadratic512(u64 n) { bench_quadratic(512, n); } -void bench_quadratic1024(u64 n) { bench_quadratic(1024, n); } -void bench_quadratic2048(u64 n) { bench_quadratic(2048, n); } -void bench_quadratic4096(u64 n) { bench_quadratic(4096, n); } - -NOINLINE void bench_karatsuba(u64 l, u64 n) { r_bench_start(); for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l); @@ -285,26 +269,9 @@ void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); } int main(void) { - u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef }; - u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc }; - u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y)); - u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y)); - printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]); - printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]); - for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64(); for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64(); - r_bench(bench_quadratic16, 1000); - r_bench(bench_quadratic32, 1000); - r_bench(bench_quadratic64, 1000); - r_bench(bench_quadratic128, 1000); - r_bench(bench_quadratic256, 1000); - r_bench(bench_quadratic512, 1000); - r_bench(bench_quadratic1024, 1000); - r_bench(bench_quadratic2048, 1000); - r_bench(bench_quadratic4096, 1000); - r_bench(bench_karatsuba16, 1000); r_bench(bench_karatsuba32, 1000); r_bench(bench_karatsuba64, 1000);