diff --git a/demo/test.c b/demo/test.c index 3e432cf0e..18e67eeeb 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1895,7 +1895,7 @@ static int test_s_mp_mul_balance(void) return EXIT_FAILURE; } -#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1) +#define s_mp_mul_full(a, b, c) s_mp_mul_comba(a, b, c, (a)->used + (b)->used + 1) static int test_s_mp_mul_karatsuba(void) { mp_int a, b, c, d; @@ -1929,7 +1929,7 @@ static int test_s_mp_sqr_karatsuba(void) for (size = MP_SQR_KARATSUBA_CUTOFF; size < (MP_SQR_KARATSUBA_CUTOFF + 20); size++) { DO(mp_rand(&a, size)); DO(s_mp_sqr_karatsuba(&a, &b)); - DO(s_mp_sqr(&a, &c)); + DO(s_mp_sqr_comba(&a, &c)); if (mp_cmp(&b, &c) != MP_EQ) { fprintf(stderr, "Karatsuba squaring failed at size %d\n", size); goto LBL_ERR; @@ -2002,7 +2002,7 @@ static int test_s_mp_sqr_toom(void) for (size = MP_SQR_TOOM_CUTOFF; size < (MP_SQR_TOOM_CUTOFF + 20); size++) { DO(mp_rand(&a, size)); DO(s_mp_sqr_toom(&a, &b)); - DO(s_mp_sqr(&a, &c)); + DO(s_mp_sqr_comba(&a, &c)); if (mp_cmp(&b, &c) != MP_EQ) { fprintf(stderr, "Toom-Cook 3-way squaring failed at size %d\n", size); goto LBL_ERR; diff --git a/etc/tune.c b/etc/tune.c index 0b7373448..334aec080 100644 --- a/etc/tune.c +++ b/etc/tune.c @@ -58,7 +58,11 @@ static int s_number_of_test_loops; static int s_stabilization_extra; static int s_offset = 1; -#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1) +static mp_err s_mul_full(const mp_int *a, const mp_int *b, mp_int *c) +{ + return s_mp_mul_comba(a, b, c, a->used + b->used + 1); +} + static uint64_t s_time_mul(int size) { int x; @@ -87,7 +91,7 @@ static uint64_t s_time_mul(int size) goto LBL_ERR; } if (s_check_result == 1) { - if ((e = s_mp_mul_full(&a,&b,&d)) != MP_OKAY) { + if ((e = s_mul_full(&a,&b,&d)) != MP_OKAY) { t1 = UINT64_MAX; goto LBL_ERR; } @@ -129,7 +133,7 @@ static uint64_t s_time_sqr(int size) goto LBL_ERR; } if (s_check_result == 1) { - if ((e = s_mp_sqr(&a,&c)) != MP_OKAY) { + if ((e = s_mp_sqr_comba(&a,&c)) != MP_OKAY) { t1 = UINT64_MAX; goto LBL_ERR; } diff --git a/libtommath_VS2008.vcproj b/libtommath_VS2008.vcproj index 7e16199e8..929191f07 100644 --- a/libtommath_VS2008.vcproj +++ b/libtommath_VS2008.vcproj @@ -868,10 +868,6 @@ RelativePath="s_mp_montgomery_reduce_comba.c" > - - @@ -916,10 +912,6 @@ RelativePath="s_mp_rand_platform.c" > - - diff --git a/makefile b/makefile index 88eff7921..8cdade17e 100644 --- a/makefile +++ b/makefile @@ -46,10 +46,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \ s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \ s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \ -s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ +s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \ -s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \ -s_mp_zero_buf.o s_mp_zero_digs.o +s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \ +s_mp_zero_digs.o #END_INS diff --git a/makefile.mingw b/makefile.mingw index 3a3bc631f..52f470a18 100644 --- a/makefile.mingw +++ b/makefile.mingw @@ -48,10 +48,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \ s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \ s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \ -s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ +s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \ -s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \ -s_mp_zero_buf.o s_mp_zero_digs.o +s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \ +s_mp_zero_digs.o HEADERS_PUB=tommath.h HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB) diff --git a/makefile.msvc b/makefile.msvc index a22267c4a..2255faf4b 100644 --- a/makefile.msvc +++ b/makefile.msvc @@ -41,10 +41,10 @@ mp_sqrmod.obj mp_sqrt.obj mp_sqrtmod_prime.obj mp_sub.obj mp_sub_d.obj mp_submod mp_to_ubin.obj mp_ubin_size.obj mp_unpack.obj mp_xor.obj mp_zero.obj s_mp_add.obj s_mp_copy_digs.obj s_mp_div_3.obj \ s_mp_div_recursive.obj s_mp_div_school.obj s_mp_div_small.obj s_mp_exptmod.obj s_mp_exptmod_fast.obj s_mp_get_bit.obj \ s_mp_invmod.obj s_mp_invmod_odd.obj s_mp_log.obj s_mp_log_d.obj s_mp_log_pow2.obj s_mp_montgomery_reduce_comba.obj \ -s_mp_mul.obj s_mp_mul_balance.obj s_mp_mul_comba.obj s_mp_mul_high.obj s_mp_mul_high_comba.obj s_mp_mul_karatsuba.obj \ +s_mp_mul_balance.obj s_mp_mul_comba.obj s_mp_mul_high.obj s_mp_mul_high_comba.obj s_mp_mul_karatsuba.obj \ s_mp_mul_toom.obj s_mp_prime_is_divisible.obj s_mp_prime_tab.obj s_mp_radix_map.obj s_mp_rand_jenkins.obj \ -s_mp_rand_platform.obj s_mp_sqr.obj s_mp_sqr_comba.obj s_mp_sqr_karatsuba.obj s_mp_sqr_toom.obj s_mp_sub.obj \ -s_mp_zero_buf.obj s_mp_zero_digs.obj +s_mp_rand_platform.obj s_mp_sqr_comba.obj s_mp_sqr_karatsuba.obj s_mp_sqr_toom.obj s_mp_sub.obj s_mp_zero_buf.obj \ +s_mp_zero_digs.obj HEADERS_PUB=tommath.h HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB) diff --git a/makefile.shared b/makefile.shared index ad58e61fe..9c4fd4dec 100644 --- a/makefile.shared +++ b/makefile.shared @@ -43,10 +43,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \ s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \ s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \ -s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ +s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \ -s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \ -s_mp_zero_buf.o s_mp_zero_digs.o +s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \ +s_mp_zero_digs.o #END_INS diff --git a/makefile.unix b/makefile.unix index 1e0da7393..6365b7acf 100644 --- a/makefile.unix +++ b/makefile.unix @@ -49,10 +49,10 @@ mp_sqrmod.o mp_sqrt.o mp_sqrtmod_prime.o mp_sub.o mp_sub_d.o mp_submod.o mp_to_r mp_to_ubin.o mp_ubin_size.o mp_unpack.o mp_xor.o mp_zero.o s_mp_add.o s_mp_copy_digs.o s_mp_div_3.o \ s_mp_div_recursive.o s_mp_div_school.o s_mp_div_small.o s_mp_exptmod.o s_mp_exptmod_fast.o s_mp_get_bit.o \ s_mp_invmod.o s_mp_invmod_odd.o s_mp_log.o s_mp_log_d.o s_mp_log_pow2.o s_mp_montgomery_reduce_comba.o \ -s_mp_mul.o s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ +s_mp_mul_balance.o s_mp_mul_comba.o s_mp_mul_high.o s_mp_mul_high_comba.o s_mp_mul_karatsuba.o \ s_mp_mul_toom.o s_mp_prime_is_divisible.o s_mp_prime_tab.o s_mp_radix_map.o s_mp_rand_jenkins.o \ -s_mp_rand_platform.o s_mp_sqr.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o \ -s_mp_zero_buf.o s_mp_zero_digs.o +s_mp_rand_platform.o s_mp_sqr_comba.o s_mp_sqr_karatsuba.o s_mp_sqr_toom.o s_mp_sub.o s_mp_zero_buf.o \ +s_mp_zero_digs.o HEADERS_PUB=tommath.h HEADERS=tommath_private.h tommath_class.h tommath_superclass.h tommath_cutoffs.h $(HEADERS_PUB) diff --git a/mp_mul.c b/mp_mul.c index b2dbf7d72..f77ff457d 100644 --- a/mp_mul.c +++ b/mp_mul.c @@ -21,13 +21,8 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c) (a->used >= MP_SQR_KARATSUBA_CUTOFF)) { err = s_mp_sqr_karatsuba(a, c); } else if ((a == b) && - MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */ - (((a->used * 2) + 1) < MP_WARRAY) && - (a->used < (MP_MAX_COMBA / 2))) { + MP_HAS(S_MP_SQR_COMBA)) { err = s_mp_sqr_comba(a, c); - } else if ((a == b) && - MP_HAS(S_MP_SQR)) { - err = s_mp_sqr(a, c); } else if (MP_HAS(S_MP_MUL_BALANCE) && /* Check sizes. The smaller one needs to be larger than the Karatsuba cut-off. * The bigger one needs to be at least about one MP_MUL_KARATSUBA_CUTOFF bigger @@ -47,18 +42,8 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c) } else if (MP_HAS(S_MP_MUL_KARATSUBA) && (min >= MP_MUL_KARATSUBA_CUTOFF)) { err = s_mp_mul_karatsuba(a, b, c); - } else if (MP_HAS(S_MP_MUL_COMBA) && - /* can we use the fast multiplier? - * - * The fast multiplier can be used if the output will - * have less than MP_WARRAY digits and the number of - * digits won't affect carry propagation - */ - (digs < MP_WARRAY) && - (min <= MP_MAX_COMBA)) { + } else if (MP_HAS(S_MP_MUL_COMBA)) { err = s_mp_mul_comba(a, b, c, digs); - } else if (MP_HAS(S_MP_MUL)) { - err = s_mp_mul(a, b, c, digs); } else { err = MP_VAL; } diff --git a/mp_reduce.c b/mp_reduce.c index b6fae55cc..f08bcf1c9 100644 --- a/mp_reduce.c +++ b/mp_reduce.c @@ -23,17 +23,11 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu) /* according to HAC this optimization is ok */ if ((mp_digit)um > ((mp_digit)1 << (MP_DIGIT_BIT - 1))) { - if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) { - goto LBL_ERR; - } - } else if (MP_HAS(S_MP_MUL_HIGH)) { - if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) goto LBL_ERR; } else if (MP_HAS(S_MP_MUL_HIGH_COMBA)) { - if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR; + } else if (MP_HAS(S_MP_MUL_HIGH)) { + if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) goto LBL_ERR; } else { err = MP_VAL; goto LBL_ERR; @@ -43,41 +37,28 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu) mp_rshd(&q, um + 1); /* x = x mod b**(k+1), quick (no division) */ - if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) goto LBL_ERR; /* q = q * m mod b**(k+1), quick (no division) */ - if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = s_mp_mul_comba(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR; /* x = x - q */ - if ((err = mp_sub(x, &q, x)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = mp_sub(x, &q, x)) != MP_OKAY) goto LBL_ERR; /* If x < 0, add b**(k+1) to it */ if (mp_cmp_d(x, 0uL) == MP_LT) { mp_set(&q, 1uL); - if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) { - goto LBL_ERR; - } - if ((err = mp_add(x, &q, x)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = mp_lshd(&q, um + 1)) != MP_OKAY) goto LBL_ERR; + if ((err = mp_add(x, &q, x)) != MP_OKAY) goto LBL_ERR; } /* Back off if it's too big */ while (mp_cmp(x, m) != MP_LT) { - if ((err = s_mp_sub(x, m, x)) != MP_OKAY) { - goto LBL_ERR; - } + if ((err = s_mp_sub(x, m, x)) != MP_OKAY) goto LBL_ERR; } LBL_ERR: mp_clear(&q); - return err; } #endif diff --git a/s_mp_mul.c b/s_mp_mul.c deleted file mode 100644 index fb99d8054..000000000 --- a/s_mp_mul.c +++ /dev/null @@ -1,61 +0,0 @@ -#include "tommath_private.h" -#ifdef S_MP_MUL_C -/* LibTomMath, multiple-precision integer library -- Tom St Denis */ -/* SPDX-License-Identifier: Unlicense */ - -/* multiplies |a| * |b| and only computes upto digs digits of result - * HAC pp. 595, Algorithm 14.12 Modified so you can control how - * many digits of output are created. - */ -mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs) -{ - mp_int t; - mp_err err; - int pa, ix; - - /* can we use the fast multiplier? */ - if ((digs < MP_WARRAY) && - (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) { - return s_mp_mul_comba(a, b, c, digs); - } - - if ((err = mp_init_size(&t, digs)) != MP_OKAY) { - return err; - } - t.used = digs; - - /* compute the digits of the product directly */ - pa = a->used; - for (ix = 0; ix < pa; ix++) { - int iy, pb; - mp_digit u = 0; - - /* limit ourselves to making digs digits of output */ - pb = MP_MIN(b->used, digs - ix); - - /* compute the columns of the output and propagate the carry */ - for (iy = 0; iy < pb; iy++) { - /* compute the column as a mp_word */ - mp_word r = (mp_word)t.dp[ix + iy] + - ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) + - (mp_word)u; - - /* the new column is the lower part of the result */ - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); - - /* get the carry word from the result */ - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); - } - /* set carry if it is placed below digs */ - if ((ix + iy) < digs) { - t.dp[ix + pb] = u; - } - } - - mp_clamp(&t); - mp_exch(&t, c); - - mp_clear(&t); - return MP_OKAY; -} -#endif diff --git a/s_mp_mul_comba.c b/s_mp_mul_comba.c index 07dd7913d..cca228c19 100644 --- a/s_mp_mul_comba.c +++ b/s_mp_mul_comba.c @@ -23,11 +23,14 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs) { int oldused, pa, ix; mp_err err; - mp_digit W[MP_WARRAY]; - mp_word _W; + mp_digit c0, c1, c2; + mp_int tmp, *c_; - /* grow the destination as required */ - if ((err = mp_grow(c, digs)) != MP_OKAY) { + /* prepare the destination */ + err = (MP_ALIAS(a, c) || MP_ALIAS(b, c)) + ? mp_init_size((c_ = &tmp), digs) + : mp_grow((c_ = c), digs); + if (err != MP_OKAY) { return err; } @@ -35,7 +38,7 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs) pa = MP_MIN(digs, a->used + b->used); /* clear the carry */ - _W = 0; + c0 = c1 = c2 = 0; for (ix = 0; ix < pa; ix++) { int tx, ty, iy, iz; @@ -48,31 +51,75 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs) */ iy = MP_MIN(a->used-tx, ty+1); - /* execute loop */ - for (iz = 0; iz < iy; ++iz) { - _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]; + /* execute loop + * + * Give the autovectorizer a hint! this might not be necessary. + * I don't think the generated code will be particularily good here, + * if we will use full width digits the masks will go away. + */ + for (iz = 0; iz + 3 < iy;) { + mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + ++iz; + + w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + ++iz; + + w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + ++iz; + + w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + ++iz; + } + + /* execute rest of loop */ + for (; iz < iy;) { + mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + ++iz; } /* store term */ - W[ix] = (mp_digit)_W & MP_MASK; + c_->dp[ix] = c0; /* make next carry */ - _W = _W >> (mp_word)MP_DIGIT_BIT; + c0 = c1; + c1 = c2; + c2 = 0; } /* setup dest */ - oldused = c->used; - c->used = pa; - - for (ix = 0; ix < pa; ix++) { - /* now extract the previous digit [below the carry] */ - c->dp[ix] = W[ix]; - } + oldused = c_->used; + c_->used = pa; /* clear unused digits [that existed in the old copy of c] */ - s_mp_zero_digs(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used); + + mp_clamp(c_); + + if (c_ == &tmp) { + mp_clear(c); + *c = *c_; + } - mp_clamp(c); return MP_OKAY; } #endif diff --git a/s_mp_mul_high.c b/s_mp_mul_high.c index 1bde00aa9..5820730f4 100644 --- a/s_mp_mul_high.c +++ b/s_mp_mul_high.c @@ -8,21 +8,20 @@ */ mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs) { - mp_int t; + mp_int tmp, *c_; int pa, pb, ix; mp_err err; - /* can we use the fast multiplier? */ - if (MP_HAS(S_MP_MUL_HIGH_COMBA) - && ((a->used + b->used + 1) < MP_WARRAY) - && (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) { - return s_mp_mul_high_comba(a, b, c, digs); - } - - if ((err = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) { + /* prepare the destination */ + err = (MP_ALIAS(a, c) || MP_ALIAS(b, c)) + ? mp_init_size((c_ = &tmp), a->used + b->used + 1) + : mp_grow((c_ = c), a->used + b->used + 1); + if (err != MP_OKAY) { return err; } - t.used = a->used + b->used + 1; + + s_mp_zero_digs(c_->dp, c_->used); + c_->used = a->used + b->used + 1; pa = a->used; pb = b->used; @@ -32,21 +31,26 @@ mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs) for (iy = digs - ix; iy < pb; iy++) { /* calculate the double precision result */ - mp_word r = (mp_word)t.dp[ix + iy] + + mp_word r = (mp_word)c_->dp[ix + iy] + ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) + (mp_word)u; /* get the lower part */ - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); + c_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); /* carry the carry */ u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); } - t.dp[ix + pb] = u; + c_->dp[ix + pb] = u; } - mp_clamp(&t); - mp_exch(&t, c); - mp_clear(&t); + + mp_clamp(c_); + + if (c_ == &tmp) { + mp_clear(c); + *c = *c_; + } + return MP_OKAY; } #endif diff --git a/s_mp_mul_high_comba.c b/s_mp_mul_high_comba.c index 317346dfa..951f992ff 100644 --- a/s_mp_mul_high_comba.c +++ b/s_mp_mul_high_comba.c @@ -16,18 +16,21 @@ mp_err s_mp_mul_high_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs { int oldused, pa, ix; mp_err err; - mp_digit W[MP_WARRAY]; - mp_word _W; + mp_word W; + mp_int tmp, *c_; - /* grow the destination as required */ + /* prepare the destination */ pa = a->used + b->used; - if ((err = mp_grow(c, pa)) != MP_OKAY) { + err = (MP_ALIAS(a, c) || MP_ALIAS(b, c)) + ? mp_init_size((c_ = &tmp), pa) + : mp_grow((c_ = c), pa); + if (err != MP_OKAY) { return err; } /* number of output digits to produce */ pa = a->used + b->used; - _W = 0; + W = 0; for (ix = digs; ix < pa; ix++) { int tx, ty, iy, iz; @@ -42,29 +45,29 @@ mp_err s_mp_mul_high_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs /* execute loop */ for (iz = 0; iz < iy; iz++) { - _W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]; + W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]; } /* store term */ - W[ix] = (mp_digit)_W & MP_MASK; + c_->dp[ix] = (mp_digit)W & MP_MASK; /* make next carry */ - _W = _W >> (mp_word)MP_DIGIT_BIT; + W = W >> (mp_word)MP_DIGIT_BIT; } /* setup dest */ - oldused = c->used; - c->used = pa; - - for (ix = digs; ix < pa; ix++) { - /* now extract the previous digit [below the carry] */ - c->dp[ix] = W[ix]; - } + oldused = c_->used; + c_->used = pa; /* clear unused digits [that existed in the old copy of c] */ - s_mp_zero_digs(c->dp + c->used, oldused - c->used); + s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used); + mp_clamp(c_); + + if (c_ == &tmp) { + mp_clear(c); + *c = *c_; + } - mp_clamp(c); return MP_OKAY; } #endif diff --git a/s_mp_sqr.c b/s_mp_sqr.c deleted file mode 100644 index 4a2030638..000000000 --- a/s_mp_sqr.c +++ /dev/null @@ -1,65 +0,0 @@ -#include "tommath_private.h" -#ifdef S_MP_SQR_C -/* LibTomMath, multiple-precision integer library -- Tom St Denis */ -/* SPDX-License-Identifier: Unlicense */ - -/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */ -mp_err s_mp_sqr(const mp_int *a, mp_int *b) -{ - mp_int t; - int ix, pa; - mp_err err; - - pa = a->used; - if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) { - return err; - } - - /* default used is maximum possible size */ - t.used = (2 * pa) + 1; - - for (ix = 0; ix < pa; ix++) { - mp_digit u; - int iy; - - /* first calculate the digit at 2*ix */ - /* calculate double precision result */ - mp_word r = (mp_word)t.dp[2*ix] + - ((mp_word)a->dp[ix] * (mp_word)a->dp[ix]); - - /* store lower part in result */ - t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK); - - /* get the carry */ - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); - - for (iy = ix + 1; iy < pa; iy++) { - /* first calculate the product */ - r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy]; - - /* now calculate the double precision result, note we use - * addition instead of *2 since it's easier to optimize - */ - r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u; - - /* store lower part */ - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); - - /* get carry */ - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); - } - /* propagate upwards */ - while (u != 0uL) { - r = (mp_word)t.dp[ix + iy] + (mp_word)u; - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); - ++iy; - } - } - - mp_clamp(&t); - mp_exch(&t, b); - mp_clear(&t); - return MP_OKAY; -} -#endif diff --git a/s_mp_sqr_comba.c b/s_mp_sqr_comba.c index cb88dcc9e..47f3dbddd 100644 --- a/s_mp_sqr_comba.c +++ b/s_mp_sqr_comba.c @@ -16,24 +16,24 @@ After that loop you do the squares and add them in. mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) { int oldused, pa, ix; - mp_digit W[MP_WARRAY]; - mp_word W1; + mp_digit c0, c1, c2; mp_err err; + mp_int tmp, *b_; - /* grow the destination as required */ - pa = a->used + a->used; - if ((err = mp_grow(b, pa)) != MP_OKAY) { + pa = 2 * a->used; + + /* prepare the destination */ + err = MP_ALIAS(a, b) + ? mp_init_size((b_ = &tmp), pa) + : mp_grow((b_ = b), pa); + if (err != MP_OKAY) { return err; } /* number of output digits to produce */ - W1 = 0; + c0 = c1 = c2 = 0; for (ix = 0; ix < pa; ix++) { int tx, ty, iy, iz; - mp_word _W; - - /* clear counter */ - _W = 0; /* get offsets into the two bignums */ ty = MP_MIN(a->used-1, ix); @@ -52,36 +52,49 @@ mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) /* execute loop */ for (iz = 0; iz < iy; iz++) { - _W += (mp_word)a->dp[tx + iz] * (mp_word)a->dp[ty - iz]; + mp_word t = (mp_word)a->dp[tx + iz] * (mp_word)a->dp[ty - iz]; + int j; + for (j = 0; j < 2; ++j) { + mp_word w = (mp_word)c0 + t; + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); + } } - /* double the inner product and add carry */ - _W = _W + _W + W1; - /* even columns have the square term in them */ if (((unsigned)ix & 1u) == 0u) { - _W += (mp_word)a->dp[ix>>1] * (mp_word)a->dp[ix>>1]; + mp_word w = (mp_word)c0 + ((mp_word)a->dp[ix / 2] * (mp_word)a->dp[ix / 2]); + c0 = (mp_digit)(w & MP_MASK); + w = (mp_word)c1 + (w >> MP_DIGIT_BIT); + c1 = (mp_digit)(w & MP_MASK); + c2 += (mp_digit)(w >> MP_DIGIT_BIT); } - /* store it */ - W[ix] = (mp_digit)_W & MP_MASK; + /* store term */ + b_->dp[ix] = c0; /* make next carry */ - W1 = _W >> (mp_word)MP_DIGIT_BIT; + c0 = c1; + c1 = c2; + c2 = 0; } /* setup dest */ - oldused = b->used; - b->used = a->used+a->used; - - for (ix = 0; ix < pa; ix++) { - b->dp[ix] = W[ix] & MP_MASK; - } + oldused = b_->used; + b_->used = 2 * a->used; /* clear unused digits [that existed in the old copy of c] */ - s_mp_zero_digs(b->dp + b->used, oldused - b->used); + s_mp_zero_digs(b_->dp + b_->used, oldused - b_->used); + + mp_clamp(b_); + + if (b_ == &tmp) { + mp_clear(b); + *b = *b_; + } - mp_clamp(b); return MP_OKAY; } #endif diff --git a/tommath_class.h b/tommath_class.h index f5f99074c..0e61f643d 100644 --- a/tommath_class.h +++ b/tommath_class.h @@ -150,7 +150,6 @@ # define S_MP_LOG_D_C # define S_MP_LOG_POW2_C # define S_MP_MONTGOMERY_REDUCE_COMBA_C -# define S_MP_MUL_C # define S_MP_MUL_BALANCE_C # define S_MP_MUL_COMBA_C # define S_MP_MUL_HIGH_C @@ -162,7 +161,6 @@ # define S_MP_RADIX_MAP_C # define S_MP_RAND_JENKINS_C # define S_MP_RAND_PLATFORM_C -# define S_MP_SQR_C # define S_MP_SQR_COMBA_C # define S_MP_SQR_KARATSUBA_C # define S_MP_SQR_TOOM_C @@ -542,11 +540,9 @@ #if defined(MP_MUL_C) # define S_MP_MUL_BALANCE_C -# define S_MP_MUL_C # define S_MP_MUL_COMBA_C # define S_MP_MUL_KARATSUBA_C # define S_MP_MUL_TOOM_C -# define S_MP_SQR_C # define S_MP_SQR_COMBA_C # define S_MP_SQR_KARATSUBA_C # define S_MP_SQR_TOOM_C @@ -737,7 +733,7 @@ # define MP_RSHD_C # define MP_SET_C # define MP_SUB_C -# define S_MP_MUL_C +# define S_MP_MUL_COMBA_C # define S_MP_MUL_HIGH_C # define S_MP_MUL_HIGH_COMBA_C # define S_MP_SUB_C @@ -1124,14 +1120,6 @@ # define S_MP_ZERO_DIGS_C #endif -#if defined(S_MP_MUL_C) -# define MP_CLAMP_C -# define MP_CLEAR_C -# define MP_EXCH_C -# define MP_INIT_SIZE_C -# define S_MP_MUL_COMBA_C -#endif - #if defined(S_MP_MUL_BALANCE_C) # define MP_ADD_C # define MP_CLAMP_C @@ -1147,21 +1135,25 @@ #if defined(S_MP_MUL_COMBA_C) # define MP_CLAMP_C +# define MP_CLEAR_C # define MP_GROW_C +# define MP_INIT_SIZE_C # define S_MP_ZERO_DIGS_C #endif #if defined(S_MP_MUL_HIGH_C) # define MP_CLAMP_C # define MP_CLEAR_C -# define MP_EXCH_C +# define MP_GROW_C # define MP_INIT_SIZE_C -# define S_MP_MUL_HIGH_COMBA_C +# define S_MP_ZERO_DIGS_C #endif #if defined(S_MP_MUL_HIGH_COMBA_C) # define MP_CLAMP_C +# define MP_CLEAR_C # define MP_GROW_C +# define MP_INIT_SIZE_C # define S_MP_ZERO_DIGS_C #endif @@ -1210,16 +1202,11 @@ #if defined(S_MP_RAND_PLATFORM_C) #endif -#if defined(S_MP_SQR_C) -# define MP_CLAMP_C -# define MP_CLEAR_C -# define MP_EXCH_C -# define MP_INIT_SIZE_C -#endif - #if defined(S_MP_SQR_COMBA_C) # define MP_CLAMP_C +# define MP_CLEAR_C # define MP_GROW_C +# define MP_INIT_SIZE_C # define S_MP_ZERO_DIGS_C #endif diff --git a/tommath_private.h b/tommath_private.h index f6295020f..350ca1ffc 100644 --- a/tommath_private.h +++ b/tommath_private.h @@ -116,6 +116,8 @@ extern void MP_FREE(void *mem, size_t size); #define MP_MIN(x, y) (((x) < (y)) ? (x) : (y)) #define MP_MAX(x, y) (((x) > (y)) ? (x) : (y)) +#define MP_ALIAS(a, b) ((a) == (b) || (a)->dp == (b)->dp) + #define MP_TOUPPER(c) ((((c) >= 'a') && ((c) <= 'z')) ? (((c) + 'A') - 'a') : (c)) #define MP_EXCH(t, a, b) do { t _c = a; a = b; b = _c; } while (0) @@ -127,6 +129,7 @@ extern void MP_FREE(void *mem, size_t size); #define MP_SIZEOF_BITS(type) ((size_t)CHAR_BIT * sizeof(type)) +/* TODO: Remove MP_MAX_COMBA and MP_WARRAY */ #define MP_MAX_COMBA (int)(1uL << (MP_SIZEOF_BITS(mp_word) - (2u * (size_t)MP_DIGIT_BIT))) #define MP_WARRAY (int)(1uL << ((MP_SIZEOF_BITS(mp_word) - (2u * (size_t)MP_DIGIT_BIT)) + 1u)) @@ -173,7 +176,6 @@ MP_PRIVATE mp_err s_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) MP_WU MP_PRIVATE mp_err s_mp_invmod_odd(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_log(const mp_int *a, uint32_t base, uint32_t *c) MP_WUR; MP_PRIVATE mp_err s_mp_montgomery_reduce_comba(mp_int *x, const mp_int *n, mp_digit rho) MP_WUR; -MP_PRIVATE mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR; MP_PRIVATE mp_err s_mp_mul_balance(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR; MP_PRIVATE mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs) MP_WUR; @@ -182,7 +184,6 @@ MP_PRIVATE mp_err s_mp_mul_karatsuba(const mp_int *a, const mp_int *b, mp_int *c MP_PRIVATE mp_err s_mp_mul_toom(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; MP_PRIVATE mp_err s_mp_prime_is_divisible(const mp_int *a, bool *result) MP_WUR; MP_PRIVATE mp_err s_mp_rand_platform(void *p, size_t n) MP_WUR; -MP_PRIVATE mp_err s_mp_sqr(const mp_int *a, mp_int *b) MP_WUR; MP_PRIVATE mp_err s_mp_sqr_comba(const mp_int *a, mp_int *b) MP_WUR; MP_PRIVATE mp_err s_mp_sqr_karatsuba(const mp_int *a, mp_int *b) MP_WUR; MP_PRIVATE mp_err s_mp_sqr_toom(const mp_int *a, mp_int *b) MP_WUR;