bigmul

big multiplication in C
git clone git://git.rr3.xyz/bigmul
Log | Files | Refs | README | LICENSE

commit 7c154e5299815ed6d84c2fa17e52826228cfdd37
parent 1e05d117b80ca748c86bc6c2db9df359acbd3e26
Author: Robert Russell <robert@rr3.xyz>
Date:   Wed,  1 Jan 2025 02:11:30 -0800

Add benchmarks and use quadratic method in karatsuba for small m,n

Diffstat:
Mbigmul.c | 86++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------
1 file changed, 74 insertions(+), 12 deletions(-)

diff --git a/bigmul.c b/bigmul.c @@ -1,4 +1,5 @@ #include <rcx/all.h> +#include <rcx/bench.h> #include <stdio.h> #include <unistd.h> @@ -74,7 +75,7 @@ sub(u64 *r, u64 *x, usize m, u64 *y, usize n) { sub64(&b, &r[i], x[i], y[i], b); for (usize i = n; i < m; i++) sub64(&b, &r[i], x[i], b, 0); - r[m] = -b; + r[m] = -b; // TODO: I don't think this makes sense for nats. } // Precondition: r does not intersect x nor y @@ -134,13 +135,13 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { * n >= 4, then |q| < n and |p| <= m. It therefore suffices to separate the * case m < 4 || n < 4 for the recursion basis. */ - // TODO: Determine best threshold for switching to the quadratic method. - if (m < 4 || n < 4) { + usize maxmn = MAX(m, n); + if (m < 4 || n < 4 || maxmn < 256) { // 256 was determined via benchmarking. mul_quadratic(r, x, m, y, n); return; } - usize k = (MAX(m, n) + 1) / 2; + usize k = (maxmn + 1) / 2; // 1. Split x usize mh = m > k ? m - k : 0, ml = MIN(k, m); @@ -155,7 +156,6 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { // We also store u and s directly in r at the appropriate offsets, such // that p and q overlap with u and s, but that's ok, because we're done // with p and q by the time we calculate u and s. - usize rw = m + n; usize pw = MIN(ml + 1, m); u64 *p = r; usize qw = MIN(nl + 1, n); u64 *q = r + pw; usize tw = pw + qw; u64 *t = scratch; @@ -170,7 +170,8 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) { karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh sub(t, t, tw, u, uw); // t -= u sub(t, t, tw, s, sw); // t -= s - add(r + k, r + k, rw - k, t, tw); // r[k..] += t + add(r + k, r + k, sw, t, tw); // r[k..] += t + // TODO: Prove that tw + sw <= m + n - k. } void @@ -182,12 +183,73 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) { free(scratch); } +u64 x[4096]; +u64 y[4096]; +u64 r[8192]; + +NOINLINE void +bench_quadratic(u64 l, u64 n) { + r_bench_start(); + for (u64 i = 0; i < n; i++) mul_quadratic(r, x, l, y, l); + r_bench_stop(); +} + +void bench_quadratic16(u64 n) { bench_quadratic(16, n); } +void bench_quadratic32(u64 n) { bench_quadratic(32, n); } +void bench_quadratic64(u64 n) { bench_quadratic(64, n); } +void bench_quadratic128(u64 n) { bench_quadratic(128, n); } +void bench_quadratic256(u64 n) { bench_quadratic(256, n); } +void bench_quadratic512(u64 n) { bench_quadratic(512, n); } +void bench_quadratic1024(u64 n) { bench_quadratic(1024, n); } +void bench_quadratic2048(u64 n) { bench_quadratic(2048, n); } +void bench_quadratic4096(u64 n) { bench_quadratic(4096, n); } + +NOINLINE void +bench_karatsuba(u64 l, u64 n) { + r_bench_start(); + for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l); + r_bench_stop(); +} + +void bench_karatsuba16(u64 n) { bench_karatsuba(16, n); } +void bench_karatsuba32(u64 n) { bench_karatsuba(32, n); } +void bench_karatsuba64(u64 n) { bench_karatsuba(64, n); } +void bench_karatsuba128(u64 n) { bench_karatsuba(128, n); } +void bench_karatsuba256(u64 n) { bench_karatsuba(256, n); } +void bench_karatsuba512(u64 n) { bench_karatsuba(512, n); } +void bench_karatsuba1024(u64 n) { bench_karatsuba(1024, n); } +void bench_karatsuba2048(u64 n) { bench_karatsuba(2048, n); } +void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); } + int main(void) { - u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef }; - u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc }; - u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y)); - u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y)); - printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]); - printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]); + // u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef }; + // u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc }; + // u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y)); + // u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y)); + // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]); + // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]); + + for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64(); + for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64(); + + r_bench(bench_quadratic16, 3000); + r_bench(bench_quadratic32, 3000); + r_bench(bench_quadratic64, 3000); + r_bench(bench_quadratic128, 3000); + r_bench(bench_quadratic256, 3000); + r_bench(bench_quadratic512, 3000); + r_bench(bench_quadratic1024, 3000); + r_bench(bench_quadratic2048, 3000); + r_bench(bench_quadratic4096, 3000); + + r_bench(bench_karatsuba16, 3000); + r_bench(bench_karatsuba32, 3000); + r_bench(bench_karatsuba64, 3000); + r_bench(bench_karatsuba128, 3000); + r_bench(bench_karatsuba256, 3000); + r_bench(bench_karatsuba512, 3000); + r_bench(bench_karatsuba1024, 3000); + r_bench(bench_karatsuba2048, 3000); + r_bench(bench_karatsuba4096, 3000); }