diff --git a/demo/test.c b/demo/test.c
index 3e432cf0e..18e67eeeb 100644
--- a/demo/test.c
+++ b/demo/test.c
@@ -1895,7 +1895,7 @@ static int test_s_mp_mul_balance(void)
return EXIT_FAILURE;
}
-#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1)
+#define s_mp_mul_full(a, b, c) s_mp_mul_comba(a, b, c, (a)->used + (b)->used + 1)
static int test_s_mp_mul_karatsuba(void)
{
mp_int a, b, c, d;
@@ -1929,7 +1929,7 @@ static int test_s_mp_sqr_karatsuba(void)
for (size = MP_SQR_KARATSUBA_CUTOFF; size < (MP_SQR_KARATSUBA_CUTOFF + 20); size++) {
DO(mp_rand(&a, size));
DO(s_mp_sqr_karatsuba(&a, &b));
- DO(s_mp_sqr(&a, &c));
+ DO(s_mp_sqr_comba(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Karatsuba squaring failed at size %d\n", size);
goto LBL_ERR;
@@ -2002,7 +2002,7 @@ static int test_s_mp_sqr_toom(void)
for (size = MP_SQR_TOOM_CUTOFF; size < (MP_SQR_TOOM_CUTOFF + 20); size++) {
DO(mp_rand(&a, size));
DO(s_mp_sqr_toom(&a, &b));
- DO(s_mp_sqr(&a, &c));
+ DO(s_mp_sqr_comba(&a, &c));
if (mp_cmp(&b, &c) != MP_EQ) {
fprintf(stderr, "Toom-Cook 3-way squaring failed at size %d\n", size);
goto LBL_ERR;
diff --git a/etc/tune.c b/etc/tune.c
index 0b7373448..334aec080 100644
--- a/etc/tune.c
+++ b/etc/tune.c
@@ -58,7 +58,11 @@ static int s_number_of_test_loops;
static int s_stabilization_extra;
static int s_offset = 1;
-#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1)
+static mp_err s_mul_full(const mp_int *a, const mp_int *b, mp_int *c)
+{
+ return s_mp_mul_comba(a, b, c, a->used + b->used + 1);
+}
+
static uint64_t s_time_mul(int size)
{
int x;
@@ -87,7 +91,7 @@ static uint64_t s_time_mul(int size)
goto LBL_ERR;
}
if (s_check_result == 1) {
- if ((e = s_mp_mul_full(&a,&b,&d)) != MP_OKAY) {
+ if ((e = s_mul_full(&a,&b,&d)) != MP_OKAY) {
t1 = UINT64_MAX;
goto LBL_ERR;
}
@@ -129,7 +133,7 @@ static uint64_t s_time_sqr(int size)
goto LBL_ERR;
}
if (s_check_result == 1) {
- if ((e = s_mp_sqr(&a,&c)) != MP_OKAY) {
+ if ((e = s_mp_sqr_comba(&a,&c)) != MP_OKAY) {
t1 = UINT64_MAX;
goto LBL_ERR;
}
diff --git a/libtommath_VS2008.vcproj b/libtommath_VS2008.vcproj
index 7e16199e8..929191f07 100644
--- a/libtommath_VS2008.vcproj
+++ b/libtommath_VS2008.vcproj
@@ -868,10 +868,6 @@
RelativePath="s_mp_montgomery_reduce_comba.c"
>
-
-
@@ -916,10 +912,6 @@
RelativePath="s_mp_rand_platform.c"
>
-
-
diff --git a/makefile b/makefile
index 88eff7921..8cdade17e 100644
--- a/makefile
+++ b/makefile
@@ -46,10 +46,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r
mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \
s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \
s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \
-s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
+s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \
-s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \
-s_mp_zero_buf.o s_mp_zero_digs.o
+s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \
+s_mp_zero_digs.o
#END_INS
diff --git a/makefile.mingw b/makefile.mingw
index 3a3bc631f..52f470a18 100644
--- a/makefile.mingw
+++ b/makefile.mingw
@@ -48,10 +48,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r
mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \
s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \
s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \
-s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
+s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \
-s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \
-s_mp_zero_buf.o s_mp_zero_digs.o
+s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \
+s_mp_zero_digs.o
HEADERS_PUB=tommath.h
HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB)
diff --git a/makefile.msvc b/makefile.msvc
index a22267c4a..2255faf4b 100644
--- a/makefile.msvc
+++ b/makefile.msvc
@@ -41,10 +41,10 @@ mp_sqrmod.obj mp_sqrt.obj mp_sqrtmod_prime.obj mp_sub.obj mp_sub_d.obj mp_submod
mp_to_ubin.obj mp_ubin_size.obj mp_unpack.obj mp_xor.obj mp_zero.obj s_mp_add.obj s_mp_copy_digs.obj s_mp_div_3.obj \
s_mp_div_recursive.obj s_mp_div_school.obj s_mp_div_small.obj s_mp_exptmod.obj s_mp_exptmod_fast.obj s_mp_get_bit.obj \
s_mp_invmod.obj s_mp_invmod_odd.obj s_mp_log.obj s_mp_log_d.obj s_mp_log_pow2.obj s_mp_montgomery_reduce_comba.obj \
-s_mp_mul.obj s_mp_mul_balance.obj s_mp_mul_comba.obj s_mp_mul_high.obj s_mp_mul_high_comba.obj s_mp_mul_karatsuba.obj \
+s_mp_mul_balance.obj s_mp_mul_comba.obj s_mp_mul_high.obj s_mp_mul_high_comba.obj s_mp_mul_karatsuba.obj \
s_mp_mul_toom.obj s_mp_prime_is_divisible.obj s_mp_prime_tab.obj s_mp_radix_map.obj s_mp_rand_jenkins.obj \
-s_mp_rand_platform.obj s_mp_sqr.obj s_mp_sqr_comba.obj s_mp_sqr_karatsuba.obj s_mp_sqr_toom.obj s_mp_sub.obj \
-s_mp_zero_buf.obj s_mp_zero_digs.obj
+s_mp_rand_platform.obj s_mp_sqr_comba.obj s_mp_sqr_karatsuba.obj s_mp_sqr_toom.obj s_mp_sub.obj s_mp_zero_buf.obj \
+s_mp_zero_digs.obj
HEADERS_PUB=tommath.h
HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB)
diff --git a/makefile.shared b/makefile.shared
index ad58e61fe..9c4fd4dec 100644
--- a/makefile.shared
+++ b/makefile.shared
@@ -43,10 +43,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r
mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \
s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \
s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \
-s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
+s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \
-s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \
-s_mp_zero_buf.o s_mp_zero_digs.o
+s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \
+s_mp_zero_digs.o
#END_INS
diff --git a/makefile.unix b/makefile.unix
index 1e0da7393..6365b7acf 100644
--- a/makefile.unix
+++ b/makefile.unix
@@ -49,10 +49,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r
mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \
s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \
s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \
-s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
+s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \
s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \
-s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \
-s_mp_zero_buf.o s_mp_zero_digs.o
+s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \
+s_mp_zero_digs.o
HEADERS_PUB=tommath.h
HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB)
diff --git a/mp_mul.c b/mp_mul.c
index b2dbf7d72..f77ff457d 100644
--- a/mp_mul.c
+++ b/mp_mul.c
@@ -21,13 +21,8 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
(a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
err = s_mp_sqr_karatsuba(a, c);
} else if ((a == b) &&
- MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
- (((a->used * 2) + 1) < MP_WARRAY) &&
- (a->used < (MP_MAX_COMBA / 2))) {
+ MP_HAS(S_MP_SQR_COMBA)) {
err = s_mp_sqr_comba(a, c);
- } else if ((a == b) &&
- MP_HAS(S_MP_SQR)) {
- err = s_mp_sqr(a, c);
} else if (MP_HAS(S_MP_MUL_BALANCE) &&
/* Check sizes. The smaller one needs to be larger than the Karatsuba cut-off.
* The bigger one needs to be at least about one MP_MUL_KARATSUBA_CUTOFF bigger
@@ -47,18 +42,8 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
} else if (MP_HAS(S_MP_MUL_KARATSUBA) &&
(min >= MP_MUL_KARATSUBA_CUTOFF)) {
err = s_mp_mul_karatsuba(a, b, c);
- } else if (MP_HAS(S_MP_MUL_COMBA) &&
- /* can we use the fast multiplier?
- *
- * The fast multiplier can be used if the output will
- * have less than MP_WARRAY digits and the number of
- * digits won't affect carry propagation
- */
- (digs < MP_WARRAY) &&
- (min <= MP_MAX_COMBA)) {
+ } else if (MP_HAS(S_MP_MUL_COMBA)) {
err = s_mp_mul_comba(a, b, c, digs);
- } else if (MP_HAS(S_MP_MUL)) {
- err = s_mp_mul(a, b, c, digs);
} else {
err = MP_VAL;
}
diff --git a/mp_reduce.c b/mp_reduce.c
index b6fae55cc..f08bcf1c9 100644
--- a/mp_reduce.c
+++ b/mp_reduce.c
@@ -23,17 +23,11 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
/* according to HAC this optimization is ok */
if ((mp_digit)um > ((mp_digit)1 << (MP_DIGIT_BIT - 1))) {
- if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) {
- goto LBL_ERR;
- }
- } else if (MP_HAS(S_MP_MUL_HIGH)) {
- if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) goto LBL_ERR;
} else if (MP_HAS(S_MP_MUL_HIGH_COMBA)) {
- if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR;
+ } else if (MP_HAS(S_MP_MUL_HIGH)) {
+ if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR;
} else {
err = MP_VAL;
goto LBL_ERR;
@@ -43,41 +37,28 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
mp_rshd(&q, um + 1);
/* x = x mod b**(k+1), quick (no division) */
- if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) goto LBL_ERR;
/* q = q * m mod b**(k+1), quick (no division) */
- if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = s_mp_mul_comba(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR;
/* x = x - q */
- if ((err = mp_sub(x, &q, x)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = mp_sub(x, &q, x)) != MP_OKAY) goto LBL_ERR;
/* If x < 0, add b**(k+1) to it */
if (mp_cmp_d(x, 0uL) == MP_LT) {
mp_set(&q, 1uL);
- if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) {
- goto LBL_ERR;
- }
- if ((err = mp_add(x, &q, x)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) goto LBL_ERR;
+ if ((err = mp_add(x, &q, x)) != MP_OKAY) goto LBL_ERR;
}
/* Back off if it's too big */
while (mp_cmp(x, m) != MP_LT) {
- if ((err = s_mp_sub(x, m, x)) != MP_OKAY) {
- goto LBL_ERR;
- }
+ if ((err = s_mp_sub(x, m, x)) != MP_OKAY) goto LBL_ERR;
}
LBL_ERR:
mp_clear(&q);
-
return err;
}
#endif
diff --git a/s_mp_mul.c b/s_mp_mul.c
deleted file mode 100644
index fb99d8054..000000000
--- a/s_mp_mul.c
+++ /dev/null
@@ -1,61 +0,0 @@
-#include "tommath_private.h"
-#ifdef S_MP_MUL_C
-/* LibTomMath, multiple-precision integer library -- Tom St Denis */
-/* SPDX-License-Identifier: Unlicense */
-
-/* multiplies |a| * |b| and only computes upto digs digits of result
- * HAC pp. 595, Algorithm 14.12 Modified so you can control how
- * many digits of output are created.
- */
-mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
-{
- mp_int t;
- mp_err err;
- int pa, ix;
-
- /* can we use the fast multiplier? */
- if ((digs < MP_WARRAY) &&
- (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
- return s_mp_mul_comba(a, b, c, digs);
- }
-
- if ((err = mp_init_size(&t, digs)) != MP_OKAY) {
- return err;
- }
- t.used = digs;
-
- /* compute the digits of the product directly */
- pa = a->used;
- for (ix = 0; ix < pa; ix++) {
- int iy, pb;
- mp_digit u = 0;
-
- /* limit ourselves to making digs digits of output */
- pb = MP_MIN(b->used, digs - ix);
-
- /* compute the columns of the output and propagate the carry */
- for (iy = 0; iy < pb; iy++) {
- /* compute the column as a mp_word */
- mp_word r = (mp_word)t.dp[ix + iy] +
- ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
- (mp_word)u;
-
- /* the new column is the lower part of the result */
- t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
-
- /* get the carry word from the result */
- u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
- }
- /* set carry if it is placed below digs */
- if ((ix + iy) < digs) {
- t.dp[ix + pb] = u;
- }
- }
-
- mp_clamp(&t);
- mp_exch(&t, c);
-
- mp_clear(&t);
- return MP_OKAY;
-}
-#endif
diff --git a/s_mp_mul_comba.c b/s_mp_mul_comba.c
index 07dd7913d..cca228c19 100644
--- a/s_mp_mul_comba.c
+++ b/s_mp_mul_comba.c
@@ -23,11 +23,14 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{
int oldused, pa, ix;
mp_err err;
- mp_digit W[MP_WARRAY];
- mp_word _W;
+ mp_digit c0, c1, c2;
+ mp_int tmp, *c_;
- /* grow the destination as required */
- if ((err = mp_grow(c, digs)) != MP_OKAY) {
+ /* prepare the destination */
+ err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
+ ? mp_init_size((c_ = &tmp), digs)
+ : mp_grow((c_ = c), digs);
+ if (err != MP_OKAY) {
return err;
}
@@ -35,7 +38,7 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
pa = MP_MIN(digs, a->used + b->used);
/* clear the carry */
- _W = 0;
+ c0 = c1 = c2 = 0;
for (ix = 0; ix < pa; ix++) {
int tx, ty, iy, iz;
@@ -48,31 +51,75 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
*/
iy = MP_MIN(a->used-tx, ty+1);
- /* execute loop */
- for (iz = 0; iz < iy; ++iz) {
- _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
+ /* execute loop
+ *
+ * Give the autovectorizer a hint! this might not be necessary.
+ * I don't think the generated code will be particularily good here,
+ * if we will use full width digits the masks will go away.
+ */
+ for (iz = 0; iz + 3 < iy;) {
+ mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ ++iz;
+
+ w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ ++iz;
+
+ w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ ++iz;
+
+ w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ ++iz;
+ }
+
+ /* execute rest of loop */
+ for (; iz < iy;) {
+ mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ ++iz;
}
/* store term */
- W[ix] = (mp_digit)_W & MP_MASK;
+ c_->dp[ix] = c0;
/* make next carry */
- _W = _W >> (mp_word)MP_DIGIT_BIT;
+ c0 = c1;
+ c1 = c2;
+ c2 = 0;
}
/* setup dest */
- oldused = c->used;
- c->used = pa;
-
- for (ix = 0; ix < pa; ix++) {
- /* now extract the previous digit [below the carry] */
- c->dp[ix] = W[ix];
- }
+ oldused = c_->used;
+ c_->used = pa;
/* clear unused digits [that existed in the old copy of c] */
- s_mp_zero_digs(c->dp + c->used, oldused - c->used);
+ s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used);
+
+ mp_clamp(c_);
+
+ if (c_ == &tmp) {
+ mp_clear(c);
+ *c = *c_;
+ }
- mp_clamp(c);
return MP_OKAY;
}
#endif
diff --git a/s_mp_mul_high.c b/s_mp_mul_high.c
index 1bde00aa9..5820730f4 100644
--- a/s_mp_mul_high.c
+++ b/s_mp_mul_high.c
@@ -8,21 +8,20 @@
*/
mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs)
{
- mp_int t;
+ mp_int tmp, *c_;
int pa, pb, ix;
mp_err err;
- /* can we use the fast multiplier? */
- if (MP_HAS(S_MP_MUL_HIGH_COMBA)
- && ((a->used + b->used + 1) < MP_WARRAY)
- && (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
- return s_mp_mul_high_comba(a, b, c, digs);
- }
-
- if ((err = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) {
+ /* prepare the destination */
+ err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
+ ? mp_init_size((c_ = &tmp), a->used + b->used + 1)
+ : mp_grow((c_ = c), a->used + b->used + 1);
+ if (err != MP_OKAY) {
return err;
}
- t.used = a->used + b->used + 1;
+
+ s_mp_zero_digs(c_->dp, c_->used);
+ c_->used = a->used + b->used + 1;
pa = a->used;
pb = b->used;
@@ -32,21 +31,26 @@ mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs)
for (iy = digs - ix; iy < pb; iy++) {
/* calculate the double precision result */
- mp_word r = (mp_word)t.dp[ix + iy] +
+ mp_word r = (mp_word)c_->dp[ix + iy] +
((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
(mp_word)u;
/* get the lower part */
- t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
+ c_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
/* carry the carry */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
}
- t.dp[ix + pb] = u;
+ c_->dp[ix + pb] = u;
}
- mp_clamp(&t);
- mp_exch(&t, c);
- mp_clear(&t);
+
+ mp_clamp(c_);
+
+ if (c_ == &tmp) {
+ mp_clear(c);
+ *c = *c_;
+ }
+
return MP_OKAY;
}
#endif
diff --git a/s_mp_mul_high_comba.c b/s_mp_mul_high_comba.c
index 317346dfa..951f992ff 100644
--- a/s_mp_mul_high_comba.c
+++ b/s_mp_mul_high_comba.c
@@ -16,18 +16,21 @@ mp_err s_mp_mul_high_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs
{
int oldused, pa, ix;
mp_err err;
- mp_digit W[MP_WARRAY];
- mp_word _W;
+ mp_word W;
+ mp_int tmp, *c_;
- /* grow the destination as required */
+ /* prepare the destination */
pa = a->used + b->used;
- if ((err = mp_grow(c, pa)) != MP_OKAY) {
+ err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
+ ? mp_init_size((c_ = &tmp), pa)
+ : mp_grow((c_ = c), pa);
+ if (err != MP_OKAY) {
return err;
}
/* number of output digits to produce */
pa = a->used + b->used;
- _W = 0;
+ W = 0;
for (ix = digs; ix < pa; ix++) {
int tx, ty, iy, iz;
@@ -42,29 +45,29 @@ mp_err s_mp_mul_high_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs
/* execute loop */
for (iz = 0; iz < iy; iz++) {
- _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
+ W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
}
/* store term */
- W[ix] = (mp_digit)_W & MP_MASK;
+ c_->dp[ix] = (mp_digit)W & MP_MASK;
/* make next carry */
- _W = _W >> (mp_word)MP_DIGIT_BIT;
+ W = W >> (mp_word)MP_DIGIT_BIT;
}
/* setup dest */
- oldused = c->used;
- c->used = pa;
-
- for (ix = digs; ix < pa; ix++) {
- /* now extract the previous digit [below the carry] */
- c->dp[ix] = W[ix];
- }
+ oldused = c_->used;
+ c_->used = pa;
/* clear unused digits [that existed in the old copy of c] */
- s_mp_zero_digs(c->dp + c->used, oldused - c->used);
+ s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used);
+ mp_clamp(c_);
+
+ if (c_ == &tmp) {
+ mp_clear(c);
+ *c = *c_;
+ }
- mp_clamp(c);
return MP_OKAY;
}
#endif
diff --git a/s_mp_sqr.c b/s_mp_sqr.c
deleted file mode 100644
index 4a2030638..000000000
--- a/s_mp_sqr.c
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "tommath_private.h"
-#ifdef S_MP_SQR_C
-/* LibTomMath, multiple-precision integer library -- Tom St Denis */
-/* SPDX-License-Identifier: Unlicense */
-
-/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
-mp_err s_mp_sqr(const mp_int *a, mp_int *b)
-{
- mp_int t;
- int ix, pa;
- mp_err err;
-
- pa = a->used;
- if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
- return err;
- }
-
- /* default used is maximum possible size */
- t.used = (2 * pa) + 1;
-
- for (ix = 0; ix < pa; ix++) {
- mp_digit u;
- int iy;
-
- /* first calculate the digit at 2*ix */
- /* calculate double precision result */
- mp_word r = (mp_word)t.dp[2*ix] +
- ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
-
- /* store lower part in result */
- t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
-
- /* get the carry */
- u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
-
- for (iy = ix + 1; iy < pa; iy++) {
- /* first calculate the product */
- r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];
-
- /* now calculate the double precision result, note we use
- * addition instead of *2 since it's easier to optimize
- */
- r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
-
- /* store lower part */
- t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
-
- /* get carry */
- u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
- }
- /* propagate upwards */
- while (u != 0uL) {
- r = (mp_word)t.dp[ix + iy] + (mp_word)u;
- t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
- u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
- ++iy;
- }
- }
-
- mp_clamp(&t);
- mp_exch(&t, b);
- mp_clear(&t);
- return MP_OKAY;
-}
-#endif
diff --git a/s_mp_sqr_comba.c b/s_mp_sqr_comba.c
index cb88dcc9e..47f3dbddd 100644
--- a/s_mp_sqr_comba.c
+++ b/s_mp_sqr_comba.c
@@ -16,24 +16,24 @@ After that loop you do the squares and add them in.
mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b)
{
int oldused, pa, ix;
- mp_digit W[MP_WARRAY];
- mp_word W1;
+ mp_digit c0, c1, c2;
mp_err err;
+ mp_int tmp, *b_;
- /* grow the destination as required */
- pa = a->used + a->used;
- if ((err = mp_grow(b, pa)) != MP_OKAY) {
+ pa = 2 * a->used;
+
+ /* prepare the destination */
+ err = MP_ALIAS(a, b)
+ ? mp_init_size((b_ = &tmp), pa)
+ : mp_grow((b_ = b), pa);
+ if (err != MP_OKAY) {
return err;
}
/* number of output digits to produce */
- W1 = 0;
+ c0 = c1 = c2 = 0;
for (ix = 0; ix < pa; ix++) {
int tx, ty, iy, iz;
- mp_word _W;
-
- /* clear counter */
- _W = 0;
/* get offsets into the two bignums */
ty = MP_MIN(a->used-1, ix);
@@ -52,36 +52,49 @@ mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b)
/* execute loop */
for (iz = 0; iz < iy; iz++) {
- _W += (mp_word)a->dp[tx + iz] * (mp_word)a->dp[ty - iz];
+ mp_word t = (mp_word)a->dp[tx + iz] * (mp_word)a->dp[ty - iz];
+ int j;
+ for (j = 0; j < 2; ++j) {
+ mp_word w = (mp_word)c0 + t;
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
+ }
}
- /* double the inner product and add carry */
- _W = _W + _W + W1;
-
/* even columns have the square term in them */
if (((unsigned)ix & 1u) == 0u) {
- _W += (mp_word)a->dp[ix>>1] * (mp_word)a->dp[ix>>1];
+ mp_word w = (mp_word)c0 + ((mp_word)a->dp[ix / 2] * (mp_word)a->dp[ix / 2]);
+ c0 = (mp_digit)(w & MP_MASK);
+ w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
+ c1 = (mp_digit)(w & MP_MASK);
+ c2 += (mp_digit)(w >> MP_DIGIT_BIT);
}
- /* store it */
- W[ix] = (mp_digit)_W & MP_MASK;
+ /* store term */
+ b_->dp[ix] = c0;
/* make next carry */
- W1 = _W >> (mp_word)MP_DIGIT_BIT;
+ c0 = c1;
+ c1 = c2;
+ c2 = 0;
}
/* setup dest */
- oldused = b->used;
- b->used = a->used+a->used;
-
- for (ix = 0; ix < pa; ix++) {
- b->dp[ix] = W[ix] & MP_MASK;
- }
+ oldused = b_->used;
+ b_->used = 2 * a->used;
/* clear unused digits [that existed in the old copy of c] */
- s_mp_zero_digs(b->dp + b->used, oldused - b->used);
+ s_mp_zero_digs(b_->dp + b_->used, oldused - b_->used);
+
+ mp_clamp(b_);
+
+ if (b_ == &tmp) {
+ mp_clear(b);
+ *b = *b_;
+ }
- mp_clamp(b);
return MP_OKAY;
}
#endif
diff --git a/tommath_class.h b/tommath_class.h
index f5f99074c..0e61f643d 100644
--- a/tommath_class.h
+++ b/tommath_class.h
@@ -150,7 +150,6 @@
# define S_MP_LOG_D_C
# define S_MP_LOG_POW2_C
# define S_MP_MONTGOMERY_REDUCE_COMBA_C
-# define S_MP_MUL_C
# define S_MP_MUL_BALANCE_C
# define S_MP_MUL_COMBA_C
# define S_MP_MUL_HIGH_C
@@ -162,7 +161,6 @@
# define S_MP_RADIX_MAP_C
# define S_MP_RAND_JENKINS_C
# define S_MP_RAND_PLATFORM_C
-# define S_MP_SQR_C
# define S_MP_SQR_COMBA_C
# define S_MP_SQR_KARATSUBA_C
# define S_MP_SQR_TOOM_C
@@ -542,11 +540,9 @@
#if defined(MP_MUL_C)
# define S_MP_MUL_BALANCE_C
-# define S_MP_MUL_C
# define S_MP_MUL_COMBA_C
# define S_MP_MUL_KARATSUBA_C
# define S_MP_MUL_TOOM_C
-# define S_MP_SQR_C
# define S_MP_SQR_COMBA_C
# define S_MP_SQR_KARATSUBA_C
# define S_MP_SQR_TOOM_C
@@ -737,7 +733,7 @@
# define MP_RSHD_C
# define MP_SET_C
# define MP_SUB_C
-# define S_MP_MUL_C
+# define S_MP_MUL_COMBA_C
# define S_MP_MUL_HIGH_C
# define S_MP_MUL_HIGH_COMBA_C
# define S_MP_SUB_C
@@ -1124,14 +1120,6 @@
# define S_MP_ZERO_DIGS_C
#endif
-#if defined(S_MP_MUL_C)
-# define MP_CLAMP_C
-# define MP_CLEAR_C
-# define MP_EXCH_C
-# define MP_INIT_SIZE_C
-# define S_MP_MUL_COMBA_C
-#endif
-
#if defined(S_MP_MUL_BALANCE_C)
# define MP_ADD_C
# define MP_CLAMP_C
@@ -1147,21 +1135,25 @@
#if defined(S_MP_MUL_COMBA_C)
# define MP_CLAMP_C
+# define MP_CLEAR_C
# define MP_GROW_C
+# define MP_INIT_SIZE_C
# define S_MP_ZERO_DIGS_C
#endif
#if defined(S_MP_MUL_HIGH_C)
# define MP_CLAMP_C
# define MP_CLEAR_C
-# define MP_EXCH_C
+# define MP_GROW_C
# define MP_INIT_SIZE_C
-# define S_MP_MUL_HIGH_COMBA_C
+# define S_MP_ZERO_DIGS_C
#endif
#if defined(S_MP_MUL_HIGH_COMBA_C)
# define MP_CLAMP_C
+# define MP_CLEAR_C
# define MP_GROW_C
+# define MP_INIT_SIZE_C
# define S_MP_ZERO_DIGS_C
#endif
@@ -1210,16 +1202,11 @@
#if defined(S_MP_RAND_PLATFORM_C)
#endif
-#if defined(S_MP_SQR_C)
-# define MP_CLAMP_C
-# define MP_CLEAR_C
-# define MP_EXCH_C
-# define MP_INIT_SIZE_C
-#endif
-
#if defined(S_MP_SQR_COMBA_C)
# define MP_CLAMP_C
+# define MP_CLEAR_C
# define MP_GROW_C
+# define MP_INIT_SIZE_C
# define S_MP_ZERO_DIGS_C
#endif
diff --git a/tommath_private.h b/tommath_private.h
index f6295020f..350ca1ffc 100644
--- a/tommath_private.h
+++ b/tommath_private.h
@@ -116,6 +116,8 @@ extern void MP_FREE(void *mem, size_t size);
#define MP_MIN(x, y) (((x) < (y)) ? (x) : (y))
#define MP_MAX(x, y) (((x) > (y)) ? (x) : (y))
+#define MP_ALIAS(a, b) ((a) == (b) || (a)->dp == (b)->dp)
+
#define MP_TOUPPER(c) ((((c) >= 'a') && ((c) <= 'z')) ? (((c) + 'A') - 'a') : (c))
#define MP_EXCH(t, a, b) do { t _c = a; a = b; b = _c; } while (0)
@@ -127,6 +129,7 @@ extern void MP_FREE(void *mem, size_t size);
#define MP_SIZEOF_BITS(type) ((size_t)CHAR_BIT * sizeof(type))
+/* TODO: Remove MP_MAX_COMBA and MP_WARRAY */
#define MP_MAX_COMBA (int)(1uL << (MP_SIZEOF_BITS(mp_word) - (2u * (size_t)MP_DIGIT_BIT)))
#define MP_WARRAY (int)(1uL << ((MP_SIZEOF_BITS(mp_word) - (2u * (size_t)MP_DIGIT_BIT)) + 1u))
@@ -173,7 +176,6 @@ MP_PRIVATE mp_err s_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) MP_WU
MP_PRIVATE mp_err s_mp_invmod_odd(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c) MP_WUR;
MP_PRIVATE mp_err s_mp_montgomery_reduce_comba(mp_int *x, const mp_int *n, mp_digit rho) MP_WUR;
-MP_PRIVATE mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR;
MP_PRIVATE mp_err s_mp_mul_balance(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR;
MP_PRIVATE mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR;
@@ -182,7 +184,6 @@ MP_PRIVATE mp_err s_mp_mul_karatsuba(const mp_int *a, const mp_int *b, mp_int *c
MP_PRIVATE mp_err s_mp_mul_toom(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR;
MP_PRIVATE mp_err s_mp_prime_is_divisible(const mp_int *a, bool *result) MP_WUR;
MP_PRIVATE mp_err s_mp_rand_platform(void *p, size_t n) MP_WUR;
-MP_PRIVATE mp_err s_mp_sqr(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sqr_karatsuba(const mp_int *a, mp_int *b) MP_WUR;
MP_PRIVATE mp_err s_mp_sqr_toom(const mp_int *a, mp_int *b) MP_WUR;