Skip to content

Commit 52207f2

Browse files
bhansconnectandrewrk
authored andcommitted
Add karatsuba to big ints
1 parent 711520d commit 52207f2

File tree

1 file changed

+167
-15
lines changed

1 file changed

+167
-15
lines changed

lib/std/math/big/int.zig

Lines changed: 167 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -766,20 +766,19 @@ pub const Int = struct {
766766
r.deinit();
767767
};
768768

769-
try r.ensureCapacity(a.len() + b.len());
769+
try r.ensureCapacity(a.len() + b.len() + 1);
770770

771-
if (a.len() >= b.len()) {
772-
llmul(r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]);
773-
} else {
774-
llmul(r.limbs, b.limbs[0..b.len()], a.limbs[0..a.len()]);
775-
}
771+
mem.set(Limb, r.limbs[0 .. a.len() + b.len() + 1], 0);
772+
773+
try llmulacc(rma.allocator.?, r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]);
776774

777775
r.normalize(a.len() + b.len());
778776
r.setSign(a.isPositive() == b.isPositive());
779777
}
780778

781779
// a + b * c + *carry, sets carry to the overflow bits
782780
pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
781+
@setRuntimeSafety(false);
783782
var r1: Limb = undefined;
784783

785784
// r1 = a + *carry
@@ -800,25 +799,178 @@ pub const Int = struct {
800799
return r1;
801800
}
802801

802+
fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void {
803+
@setRuntimeSafety(false);
804+
if (xi == 0) {
805+
return;
806+
}
807+
808+
var carry: usize = 0;
809+
var a_lo = acc[0..y.len];
810+
var a_hi = acc[y.len..];
811+
812+
var j: usize = 0;
813+
while (j < a_lo.len) : (j += 1) {
814+
a_lo[j] = @inlineCall(addMulLimbWithCarry, a_lo[j], y[j], xi, &carry);
815+
}
816+
817+
j = 0;
818+
while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
819+
carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
820+
}
821+
}
822+
803823
// Knuth 4.3.1, Algorithm M.
804824
//
805825
// r MUST NOT alias any of a or b.
806-
fn llmul(r: []Limb, a: []const Limb, b: []const Limb) void {
826+
fn llmulacc(allocator: *Allocator, r: []Limb, a: []const Limb, b: []const Limb) error{OutOfMemory}!void {
807827
@setRuntimeSafety(false);
808-
debug.assert(a.len >= b.len);
809-
debug.assert(r.len >= a.len + b.len);
810828

811-
mem.set(Limb, r[0 .. a.len + b.len], 0);
829+
const a_norm = a[0..llnormalize(a)];
830+
const b_norm = b[0..llnormalize(b)];
831+
var x = a_norm;
832+
var y = b_norm;
833+
if (a_norm.len > b_norm.len) {
834+
x = b_norm;
835+
y = a_norm;
836+
}
837+
838+
debug.assert(r.len >= x.len + y.len + 1);
839+
840+
// 48 is a pretty abitrary size chosen based on performance of a factorial program.
841+
if (x.len <= 48) {
842+
// Basecase multiplication
843+
var i: usize = 0;
844+
while (i < x.len) : (i += 1) {
845+
llmulDigit(r[i..], y, x[i]);
846+
}
847+
} else {
848+
// Karatsuba multiplication
849+
const split = @divFloor(x.len, 2);
850+
var x0 = x[0..split];
851+
var x1 = x[split..x.len];
852+
var y0 = y[0..split];
853+
var y1 = y[split..y.len];
854+
855+
var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1);
856+
defer allocator.free(tmp);
857+
mem.set(Limb, tmp, 0);
858+
859+
try llmulacc(allocator, tmp, x1, y1);
860+
861+
var length = llnormalize(tmp);
862+
_ = llaccum(r[split..], tmp[0..length]);
863+
_ = llaccum(r[split * 2 ..], tmp[0..length]);
864+
865+
mem.set(Limb, tmp[0..length], 0);
866+
867+
try llmulacc(allocator, tmp, x0, y0);
868+
869+
length = llnormalize(tmp);
870+
_ = llaccum(r[0..], tmp[0..length]);
871+
_ = llaccum(r[split..], tmp[0..length]);
872+
873+
const x_cmp = llcmp(x1, x0);
874+
const y_cmp = llcmp(y1, y0);
875+
if (x_cmp * y_cmp == 0) {
876+
return;
877+
}
878+
const x0_len = llnormalize(x0);
879+
const x1_len = llnormalize(x1);
880+
var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len));
881+
defer allocator.free(j0);
882+
if (x_cmp == 1) {
883+
llsub(j0, x1[0..x1_len], x0[0..x0_len]);
884+
} else {
885+
llsub(j0, x0[0..x0_len], x1[0..x1_len]);
886+
}
887+
888+
const y0_len = llnormalize(y0);
889+
const y1_len = llnormalize(y1);
890+
var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len));
891+
defer allocator.free(j1);
892+
if (y_cmp == 1) {
893+
llsub(j1, y1[0..y1_len], y0[0..y0_len]);
894+
} else {
895+
llsub(j1, y0[0..y0_len], y1[0..y1_len]);
896+
}
897+
const j0_len = llnormalize(j0);
898+
const j1_len = llnormalize(j1);
899+
if (x_cmp == y_cmp) {
900+
mem.set(Limb, tmp[0..length], 0);
901+
try llmulacc(allocator, tmp, j0, j1);
902+
903+
length = Int.llnormalize(tmp);
904+
llsub(r[split..], r[split..], tmp[0..length]);
905+
} else {
906+
try llmulacc(allocator, r[split..], j0, j1);
907+
}
908+
}
909+
}
910+
911+
// r = r + a
912+
fn llaccum(r: []Limb, a: []const Limb) Limb {
913+
@setRuntimeSafety(false);
914+
debug.assert(r.len != 0 and a.len != 0);
915+
debug.assert(r.len >= a.len);
812916

813917
var i: usize = 0;
918+
var carry: Limb = 0;
919+
814920
while (i < a.len) : (i += 1) {
815-
var carry: Limb = 0;
816-
var j: usize = 0;
817-
while (j < b.len) : (j += 1) {
818-
r[i + j] = @inlineCall(addMulLimbWithCarry, r[i + j], a[i], b[j], &carry);
921+
var c: Limb = 0;
922+
c += @boolToInt(@addWithOverflow(Limb, r[i], a[i], &r[i]));
923+
c += @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
924+
carry = c;
925+
}
926+
927+
while ((carry != 0) and i < r.len) : (i += 1) {
928+
carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
929+
}
930+
931+
return carry;
932+
}
933+
934+
/// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs.
935+
pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
936+
@setRuntimeSafety(false);
937+
const a_len = llnormalize(a);
938+
const b_len = llnormalize(b);
939+
if (a_len < b_len) {
940+
return -1;
941+
}
942+
if (a_len > b_len) {
943+
return 1;
944+
}
945+
946+
var i: usize = a_len - 1;
947+
while (i != 0) : (i -= 1) {
948+
if (a[i] != b[i]) {
949+
break;
819950
}
820-
r[i + j] = carry;
821951
}
952+
953+
if (a[i] < b[i]) {
954+
return -1;
955+
} else if (a[i] > b[i]) {
956+
return 1;
957+
} else {
958+
return 0;
959+
}
960+
}
961+
962+
// returns the min length the limb could be.
963+
fn llnormalize(a: []const Limb) usize {
964+
@setRuntimeSafety(false);
965+
var j = a.len;
966+
while (j > 0) : (j -= 1) {
967+
if (a[j - 1] != 0) {
968+
break;
969+
}
970+
}
971+
972+
// Handle zero
973+
return if (j != 0) j else 1;
822974
}
823975

824976
/// q = a / b (rem r)

0 commit comments

Comments
 (0)