commit 0d4646cfb7375fe8fc16d000f1340a6a773eede0
parent 6929edbb9449c0929418f06fc6137aafc4b342cf
Author: Robert Russell <robert@rr3.xyz>
Date: Tue, 31 Dec 2024 20:01:00 -0800
Move allocation out of hot karatsuba code
Diffstat:
| M | bigmul.c | | | 45 | +++++++++++++++++++++++++-------------------- |
1 file changed, 25 insertions(+), 20 deletions(-)
diff --git a/bigmul.c b/bigmul.c
@@ -90,8 +90,9 @@ mul_quadratic(u64 *r, u64 *x, usize m, u64 *y, usize n) {
}
// Precondition: r does not intersect x nor y
+// TODO: Document precondition regarding size of scratch memory.
void
-mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n, u64 *scratch) {
// If we allow m < 4 and n < 4, then the recursion is not well-founded.
// TODO: Determine best threshold for quadratic.
if (m < 4 && n < 4) {
@@ -99,7 +100,7 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
return;
}
- usize k = (MAX(m, n) + 1) / 2;
+ usize k = (MAX(m, n) + 1) / 2; // k = ceil(max(m, n) / 2)
// 1. Split x
usize mh = m > k ? m - k : 0, ml = MIN(k, m);
@@ -112,28 +113,32 @@ mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
// 3. Assign/allocate memory
// Note that we use the output buffer r as temporary storage for
// intermediate results.
- // TODO: Accept capacity of r as an argument, and use excess memory for t
- // if it's big enough instead of allocating. Or at least allocate once in
- // the first recursive instances of the function, and pass the memory down.
usize rw = m + n;
- usize pw = MAX(ml, mh) + 1; u64 *p = r;
- usize qw = MAX(nl, nh) + 1; u64 *q = r + pw;
- usize tw = pw + qw; u64 *t = r_eallocn(tw, sizeof *t);
- usize uw = ml + nl; u64 *u = r;
- usize sw = mh + mh; u64 *s = r + 2 * k;
+ usize pw = ml + 1; u64 *p = r; // Note: ml = MAX(ml, mh)
+ usize qw = nl + 1; u64 *q = r + pw; // Note: nl = MAX(nl, nh)
+ usize tw = pw + qw; u64 *t = scratch;
+ usize uw = ml + nl; u64 *u = r;
+ usize sw = mh + mh; u64 *s = r + 2 * k;
// 4. Arithmetic
// TODO: Explain algorithm.
- add(p, xl, ml, xh, mh); // p = xl + xh
- add(q, yl, nl, yh, nh); // q = yl + yh
- mul_karatsuba(t, p, pw, q, qw); // t = p * q
- mul_karatsuba(u, xl, ml, yl, nl); // u = xl * yl
- mul_karatsuba(s, xh, mh, yh, nh); // s = xh * yh
- sub(t, t, tw, u, uw); // t -= u
- sub(t, t, tw, s, sw); // t -= s
- add(r + k, r + k, rw - k, t, tw); // r[k..] += t
-
- free(t);
+ add(p, xl, ml, xh, mh); // p = xl + xh
+ add(q, yl, nl, yh, nh); // q = yl + yh
+ karatsuba(t, p, pw, q, qw, scratch + tw); // t = p * q
+ karatsuba(u, xl, ml, yl, nl, scratch + tw); // u = xl * yl
+ karatsuba(s, xh, mh, yh, nh, scratch + tw); // s = xh * yh
+ sub(t, t, tw, u, uw); // t -= u
+ sub(t, t, tw, s, sw); // t -= s
+ add(r + k, r + k, rw - k, t, tw); // r[k..] += t
+}
+
+void
+mul_karatsuba(u64 *r, u64 *x, usize m, u64 *y, usize n) {
+ // TODO: Accept capacity of r as an argument, and use excess memory for
+ // scratch if it's big enough instead of allocating.
+ u64 *scratch = r_eallocn((m + n + 2) * 2, sizeof *scratch);
+ karatsuba(r, x, m, y, n, scratch);
+ free(scratch);
}
int