Skip to content

Commit 1bec79e

Browse files
committed
lift comba limit for s_mp_mul_comba
this is how it is done in tfm
1 parent ee42591 commit 1bec79e

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

etc/tune.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ static int s_offset = 1;
6060

6161
static mp_err s_mul_full(const mp_int *a, const mp_int *b, mp_int *c)
6262
{
63-
if (MP_HAS(S_MP_MUL_HIGH_COMBA)
64-
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
63+
if (MP_HAS(S_MP_MUL_COMBA)) {
6564
return s_mp_mul_comba(a, b, c, a->used + b->used + 1);
6665
}
6766
return s_mp_mul(a, b, c, a->used + b->used + 1);

mp_mul.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
3131
} else if (MP_HAS(S_MP_MUL_KARATSUBA) &&
3232
(min >= MP_MUL_KARATSUBA_CUTOFF)) {
3333
err = s_mp_mul_karatsuba(a, b, c);
34-
} else if (MP_HAS(S_MP_MUL_COMBA) && /* can we use the fast multiplier? */
35-
(min <= MP_MAX_COMBA)) {
34+
} else if (MP_HAS(S_MP_MUL_COMBA)) {
3635
err = s_mp_mul_comba(a, b, c, digs);
3736
} else if (MP_HAS(S_MP_MUL)) {
3837
err = s_mp_mul(a, b, c, digs);

mp_reduce.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
4040
if ((err = mp_mod_2d(x, MP_DIGIT_BIT * (um + 1), x)) != MP_OKAY) goto LBL_ERR;
4141

4242
/* q = q * m mod b**(k+1), quick (no division) */
43-
if (MP_HAS(S_MP_MUL_COMBA)
44-
&& (MP_MIN(q.used, m->used) < MP_MAX_COMBA)) {
43+
if (MP_HAS(S_MP_MUL_COMBA)) {
4544
if ((err = s_mp_mul_comba(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR;
4645
} else {
4746
if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) goto LBL_ERR;

s_mp_mul_comba.c

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
2323
{
2424
int oldused, pa, ix;
2525
mp_err err;
26-
mp_word W;
26+
mp_digit c0, c1, c2;
2727
mp_int tmp, *c_;
2828

2929
/* prepare the destination */
@@ -38,7 +38,7 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
3838
pa = MP_MIN(digs, a->used + b->used);
3939

4040
/* clear the carry */
41-
W = 0;
41+
c0 = c1 = c2 = 0;
4242
for (ix = 0; ix < pa; ix++) {
4343
int tx, ty, iy, iz;
4444

@@ -51,16 +51,59 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
5151
*/
5252
iy = MP_MIN(a->used-tx, ty+1);
5353

54-
/* execute loop */
55-
for (iz = 0; iz < iy; ++iz) {
56-
W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
54+
/* execute loop
55+
*
56+
* Give the autovectorizer a hint! this might not be necessary.
57+
* I don't think the generated code will be particularily good here,
58+
* if we will use full width digits the masks will go away.
59+
*/
60+
for (iz = 0; iz + 3 < iy;) {
61+
mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
62+
c0 = (mp_digit)(w & MP_MASK);
63+
w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
64+
c1 = (mp_digit)(w & MP_MASK);
65+
c2 += (mp_digit)(w >> MP_DIGIT_BIT);
66+
++iz;
67+
68+
w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
69+
c0 = (mp_digit)(w & MP_MASK);
70+
w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
71+
c1 = (mp_digit)(w & MP_MASK);
72+
c2 += (mp_digit)(w >> MP_DIGIT_BIT);
73+
++iz;
74+
75+
w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
76+
c0 = (mp_digit)(w & MP_MASK);
77+
w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
78+
c1 = (mp_digit)(w & MP_MASK);
79+
c2 += (mp_digit)(w >> MP_DIGIT_BIT);
80+
++iz;
81+
82+
w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
83+
c0 = (mp_digit)(w & MP_MASK);
84+
w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
85+
c1 = (mp_digit)(w & MP_MASK);
86+
c2 += (mp_digit)(w >> MP_DIGIT_BIT);
87+
++iz;
88+
}
89+
90+
/* execute rest of loop */
91+
for (; iz < iy;) {
92+
mp_word w = (mp_word)c0 + ((mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz]);
93+
c0 = (mp_digit)(w & MP_MASK);
94+
w = (mp_word)c1 + (w >> MP_DIGIT_BIT);
95+
c1 = (mp_digit)(w & MP_MASK);
96+
c2 += (mp_digit)(w >> MP_DIGIT_BIT);
97+
++iz;
5798
}
5899

59100
/* store term */
60-
c_->dp[ix] = (mp_digit)W & MP_MASK;
101+
c_->dp[ix] = c0;
61102

62103
/* make next carry */
63-
W = W >> (mp_word)MP_DIGIT_BIT;
104+
c0 = c1;
105+
c1 = c2;
106+
c2 = 0;
64107
}
65108

66109
/* setup dest */

0 commit comments

Comments
 (0)