commit 7c154e5299815ed6d84c2fa17e52826228cfdd37
parent 1e05d117b80ca748c86bc6c2db9df359acbd3e26
Author: Robert Russell <robert@rr3.xyz>
Date: Wed, 1 Jan 2025 02:11:30 -0800
Add benchmarks and use quadratic method in karatsuba for small m,n
Diffstat:
| M | bigmul.c | | | 86 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- |
1 file changed, 74 insertions(+), 12 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -1,4 +1,5 @@
#include <rcx/all.h>
+#include <rcx/bench.h>
#include <stdio.h>
#include <unistd.h>
@@ -74,7 +75,7 @@ sub(u64 *r, u64 *x, usize m, u64 *y, usize n) {
sub64(&b, &r[i], x[i], y[i], b);
for (usize i = n; i < m; i++)
sub64(&b, &r[i], x[i], b, 0);
- r[m] = -b;
+ r[m] = -b; // TODO: I don't think this makes sense for nats.
}
// Precondition: r does not intersect x nor y
@@ -134,13 +135,13 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
* n >= 4, then |q| < n and |p| <= m. It therefore suffices to separate the
* case m < 4 || n < 4 for the recursion basis. */
- // TODO: Determine best threshold for switching to the quadratic method.
- if (m < 4 || n < 4) {
+ usize maxmn = MAX(m, n);
+ if (m < 4 || n < 4 || maxmn < 256) { // 256 was determined via benchmarking.
mul_quadratic(r, x, m, y, n);
return;
}
- usize k = (MAX(m, n) + 1) / 2;
+ usize k = (maxmn + 1) / 2;
// 1. Split x
usize mh = m > k ? m - k : 0, ml = MIN(k, m);
@@ -155,7 +156,6 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
// We also store u and s directly in r at the appropriate offsets, such
// that p and q overlap with u and s, but that's ok, because we're done
// with p and q by the time we calculate u and s.
- usize rw = m + n;
usize pw = MIN(ml + 1, m); u64 *p = r;
usize qw = MIN(nl + 1, n); u64 *q = r + pw;
usize tw = pw + qw; u64 *t = scratch;
@@ -170,7 +170,8 @@ karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh
sub(t, t, tw, u, uw); // t -= u
sub(t, t, tw, s, sw); // t -= s
- add(r + k, r + k, rw - k, t, tw); // r[k..] += t
+ add(r + k, r + k, sw, t, tw); // r[k..] += t
+ // TODO: Prove that tw + sw <= m + n - k.
}
void
@@ -182,12 +183,73 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
free(scratch);
}
+u64 x[4096];
+u64 y[4096];
+u64 r[8192];
+
+NOINLINE void
+bench_quadratic(u64 l, u64 n) {
+ r_bench_start();
+ for (u64 i = 0; i < n; i++) mul_quadratic(r, x, l, y, l);
+ r_bench_stop();
+}
+
+void bench_quadratic16(u64 n) { bench_quadratic(16, n); }
+void bench_quadratic32(u64 n) { bench_quadratic(32, n); }
+void bench_quadratic64(u64 n) { bench_quadratic(64, n); }
+void bench_quadratic128(u64 n) { bench_quadratic(128, n); }
+void bench_quadratic256(u64 n) { bench_quadratic(256, n); }
+void bench_quadratic512(u64 n) { bench_quadratic(512, n); }
+void bench_quadratic1024(u64 n) { bench_quadratic(1024, n); }
+void bench_quadratic2048(u64 n) { bench_quadratic(2048, n); }
+void bench_quadratic4096(u64 n) { bench_quadratic(4096, n); }
+
+NOINLINE void
+bench_karatsuba(u64 l, u64 n) {
+ r_bench_start();
+ for (u64 i = 0; i < n; i++) mul_karatsuba(r, x, l, y, l);
+ r_bench_stop();
+}
+
+void bench_karatsuba16(u64 n) { bench_karatsuba(16, n); }
+void bench_karatsuba32(u64 n) { bench_karatsuba(32, n); }
+void bench_karatsuba64(u64 n) { bench_karatsuba(64, n); }
+void bench_karatsuba128(u64 n) { bench_karatsuba(128, n); }
+void bench_karatsuba256(u64 n) { bench_karatsuba(256, n); }
+void bench_karatsuba512(u64 n) { bench_karatsuba(512, n); }
+void bench_karatsuba1024(u64 n) { bench_karatsuba(1024, n); }
+void bench_karatsuba2048(u64 n) { bench_karatsuba(2048, n); }
+void bench_karatsuba4096(u64 n) { bench_karatsuba(4096, n); }
+
int
main(void) {
- u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef };
- u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc };
- u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y));
- u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y));
- printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]);
- printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]);
+ // u64 x[] = { 0x1234123412341234, 0x5678567856785678, 0x89ab89ab89ab89ab, 0xcdefcdefcdefcdef };
+ // u64 y[] = { 0x4321432143214321, 0x8765876587658765, 0xba98ba98ba98ba98, 0xfedcfedcfedcfedc };
+ // u64 r0[LEN(x) + LEN(y)]; mul_quadratic(r0, x, LEN(x), y, LEN(y));
+ // u64 r1[LEN(x) + LEN(y)]; mul_karatsuba(r1, x, LEN(x), y, LEN(y));
+ // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r0[7], r0[6], r0[5], r0[4], r0[3], r0[2], r0[1], r0[0]);
+ // printf("0x%016lx%016lx%016lx%016lx%016lx%016lx%016lx%016lx\n", r1[7], r1[6], r1[5], r1[4], r1[3], r1[2], r1[1], r1[0]);
+
+ for (usize i = 0; i < LEN(x); i++) x[i] = r_prand64();
+ for (usize i = 0; i < LEN(y); i++) y[i] = r_prand64();
+
+ r_bench(bench_quadratic16, 3000);
+ r_bench(bench_quadratic32, 3000);
+ r_bench(bench_quadratic64, 3000);
+ r_bench(bench_quadratic128, 3000);
+ r_bench(bench_quadratic256, 3000);
+ r_bench(bench_quadratic512, 3000);
+ r_bench(bench_quadratic1024, 3000);
+ r_bench(bench_quadratic2048, 3000);
+ r_bench(bench_quadratic4096, 3000);
+
+ r_bench(bench_karatsuba16, 3000);
+ r_bench(bench_karatsuba32, 3000);
+ r_bench(bench_karatsuba64, 3000);
+ r_bench(bench_karatsuba128, 3000);
+ r_bench(bench_karatsuba256, 3000);
+ r_bench(bench_karatsuba512, 3000);
+ r_bench(bench_karatsuba1024, 3000);
+ r_bench(bench_karatsuba2048, 3000);
+ r_bench(bench_karatsuba4096, 3000);
}