Skip to content

Commit de0aa5f

Browse files
committed
remove W array from s_mp_mul_comba and s_mp_sqr_comba
remove calls to comba from s_mp_mul and s_mp_mul_high TODO: * Remove remaining W arrays * Replace mp_exch/mp_clear pairs by mp_clear/copy * Check if more mp_init* calls can be replaced by MP_ALIAS/mp_init_size/mp_grow optimization
1 parent cc77fad commit de0aa5f

11 files changed

+122
-92
lines changed

etc/tune.c

+10-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ static int s_number_of_test_loops;
5858
static int s_stabilization_extra;
5959
static int s_offset = 1;
6060

61-
#define s_mp_mul_full(a, b, c) s_mp_mul(a, b, c, (a)->used + (b)->used + 1)
61+
static mp_err s_mul_full(const mp_int *a, const mp_int *b, mp_int *c)
62+
{
63+
if (MP_HAS(S_MP_MUL_HIGH_COMBA)
64+
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
65+
return s_mp_mul_comba(a, b, c, a->used + b->used + 1);
66+
}
67+
return s_mp_mul(a, b, c, a->used + b->used + 1);
68+
}
69+
6270
static uint64_t s_time_mul(int size)
6371
{
6472
int x;
@@ -87,7 +95,7 @@ static uint64_t s_time_mul(int size)
8795
goto LBL_ERR;
8896
}
8997
if (s_check_result == 1) {
90-
if ((e = s_mp_mul_full(&a,&b,&d)) != MP_OKAY) {
98+
if ((e = s_mul_full(&a,&b,&d)) != MP_OKAY) {
9199
t1 = UINT64_MAX;
92100
goto LBL_ERR;
93101
}

mp_mul.c

+1-8
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,7 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
3131
} else if (MP_HAS(S_MP_MUL_KARATSUBA) &&
3232
(min >= MP_MUL_KARATSUBA_CUTOFF)) {
3333
err = s_mp_mul_karatsuba(a, b, c);
34-
} else if (MP_HAS(S_MP_MUL_COMBA) &&
35-
/* can we use the fast multiplier?
36-
*
37-
* The fast multiplier can be used if the output will
38-
* have less than MP_WARRAY digits and the number of
39-
* digits won't affect carry propagation
40-
*/
41-
(digs < MP_WARRAY) &&
34+
} else if (MP_HAS(S_MP_MUL_COMBA) && /* can we use the fast multiplier? */
4235
(min <= MP_MAX_COMBA)) {
4336
err = s_mp_mul_comba(a, b, c, digs);
4437
} else if (MP_HAS(S_MP_MUL)) {

mp_reduce.c

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
/* LibTomMath, multiple-precision integer library -- Tom St Denis */
44
/* SPDX-License-Identifier: Unlicense */
55

6+
static mp_err s_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
7+
{
8+
if (MP_HAS(S_MP_MUL_COMBA)
9+
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
10+
return s_mp_mul_comba(a, b, c, digs);
11+
}
12+
return s_mp_mul(a, b, c, digs);
13+
}
14+
615
/* reduces x mod m, assumes 0 < x < m**2, mu is
716
* precomputed via mp_reduce_setup.
817
* From HAC pp.604 Algorithm 14.42
@@ -26,14 +35,14 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
2635
if ((err = mp_mul(&q, mu, &q)) != MP_OKAY) {
2736
goto LBL_ERR;
2837
}
29-
} else if (MP_HAS(S_MP_MUL_HIGH)) {
30-
if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) {
31-
goto LBL_ERR;
32-
}
3338
} else if (MP_HAS(S_MP_MUL_HIGH_COMBA)) {
3439
if ((err = s_mp_mul_high_comba(&q, mu, &q, um)) != MP_OKAY) {
3540
goto LBL_ERR;
3641
}
42+
} else if (MP_HAS(S_MP_MUL_HIGH)) {
43+
if ((err = s_mp_mul_high(&q, mu, &q, um)) != MP_OKAY) {
44+
goto LBL_ERR;
45+
}
3746
} else {
3847
err = MP_VAL;
3948
goto LBL_ERR;
@@ -48,7 +57,7 @@ mp_err mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
4857
}
4958

5059
/* q = q * m mod b**(k+1), quick (no division) */
51-
if ((err = s_mp_mul(&q, m, &q, um + 1)) != MP_OKAY) {
60+
if ((err = s_mul(&q, m, &q, um + 1)) != MP_OKAY) {
5261
goto LBL_ERR;
5362
}
5463

mp_sqr.c

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ mp_err mp_sqr(const mp_int *a, mp_int *b)
1414
(a->used >= MP_SQR_KARATSUBA_CUTOFF)) {
1515
err = s_mp_sqr_karatsuba(a, b);
1616
} else if (MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */
17-
(((a->used * 2) + 1) < MP_WARRAY) &&
1817
(a->used < (MP_MAX_COMBA / 2))) {
1918
err = s_mp_sqr_comba(a, b);
2019
} else if (MP_HAS(S_MP_SQR)) {

s_mp_mul.c

+18-15
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
*/
1010
mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
1111
{
12-
mp_int t;
12+
mp_int tmp, *c_;
1313
mp_err err;
1414
int pa, ix;
1515

16-
/* can we use the fast multiplier? */
17-
if ((digs < MP_WARRAY) &&
18-
(MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
19-
return s_mp_mul_comba(a, b, c, digs);
20-
}
21-
22-
if ((err = mp_init_size(&t, digs)) != MP_OKAY) {
16+
/* prepare the destination */
17+
err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
18+
? mp_init_size((c_ = &tmp), digs)
19+
: mp_grow((c_ = c), digs);
20+
if (err != MP_OKAY) {
2321
return err;
2422
}
25-
t.used = digs;
23+
24+
s_mp_zero_digs(c_->dp, c_->used);
25+
c_->used = digs;
2626

2727
/* compute the digits of the product directly */
2828
pa = a->used;
@@ -36,26 +36,29 @@ mp_err s_mp_mul(const mp_int *a, const mp_int *b, mp_int *c, int digs)
3636
/* compute the columns of the output and propagate the carry */
3737
for (iy = 0; iy < pb; iy++) {
3838
/* compute the column as a mp_word */
39-
mp_word r = (mp_word)t.dp[ix + iy] +
39+
mp_word r = (mp_word)c_->dp[ix + iy] +
4040
((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
4141
(mp_word)u;
4242

4343
/* the new column is the lower part of the result */
44-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
44+
c_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
4545

4646
/* get the carry word from the result */
4747
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
4848
}
4949
/* set carry if it is placed below digs */
5050
if ((ix + iy) < digs) {
51-
t.dp[ix + pb] = u;
51+
c_->dp[ix + pb] = u;
5252
}
5353
}
5454

55-
mp_clamp(&t);
56-
mp_exch(&t, c);
55+
mp_clamp(c_);
56+
57+
if (c_ == &tmp) {
58+
mp_clear(c);
59+
*c = *c_;
60+
}
5761

58-
mp_clear(&t);
5962
return MP_OKAY;
6063
}
6164
#endif

s_mp_mul_comba.c

+21-17
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,22 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
2323
{
2424
int oldused, pa, ix;
2525
mp_err err;
26-
mp_digit W[MP_WARRAY];
27-
mp_word _W;
26+
mp_word W;
27+
mp_int tmp, *c_;
2828

29-
/* grow the destination as required */
30-
if ((err = mp_grow(c, digs)) != MP_OKAY) {
29+
/* prepare the destination */
30+
err = (MP_ALIAS(a, c) || MP_ALIAS(b, c))
31+
? mp_init_size((c_ = &tmp), digs)
32+
: mp_grow((c_ = c), digs);
33+
if (err != MP_OKAY) {
3134
return err;
3235
}
3336

3437
/* number of output digits to produce */
3538
pa = MP_MIN(digs, a->used + b->used);
3639

3740
/* clear the carry */
38-
_W = 0;
41+
W = 0;
3942
for (ix = 0; ix < pa; ix++) {
4043
int tx, ty, iy, iz;
4144

@@ -50,29 +53,30 @@ mp_err s_mp_mul_comba(const mp_int *a, const mp_int *b, mp_int *c, int digs)
5053

5154
/* execute loop */
5255
for (iz = 0; iz < iy; ++iz) {
53-
_W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
56+
W += (mp_word)a->dp[tx + iz] * (mp_word)b->dp[ty - iz];
5457
}
5558

5659
/* store term */
57-
W[ix] = (mp_digit)_W & MP_MASK;
60+
c_->dp[ix] = (mp_digit)W & MP_MASK;
5861

5962
/* make next carry */
60-
_W = _W >> (mp_word)MP_DIGIT_BIT;
63+
W = W >> (mp_word)MP_DIGIT_BIT;
6164
}
6265

6366
/* setup dest */
64-
oldused = c->used;
65-
c->used = pa;
66-
67-
for (ix = 0; ix < pa; ix++) {
68-
/* now extract the previous digit [below the carry] */
69-
c->dp[ix] = W[ix];
70-
}
67+
oldused = c_->used;
68+
c_->used = pa;
7169

7270
/* clear unused digits [that existed in the old copy of c] */
73-
s_mp_zero_digs(c->dp + c->used, oldused - c->used);
71+
s_mp_zero_digs(c_->dp + c_->used, oldused - c_->used);
72+
73+
mp_clamp(c_);
74+
75+
if (c_ == &tmp) {
76+
mp_clear(c);
77+
*c = *c_;
78+
}
7479

75-
mp_clamp(c);
7680
return MP_OKAY;
7781
}
7882
#endif

s_mp_mul_high.c

-7
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ mp_err s_mp_mul_high(const mp_int *a, const mp_int *b, mp_int *c, int digs)
1212
int pa, pb, ix;
1313
mp_err err;
1414

15-
/* can we use the fast multiplier? */
16-
if (MP_HAS(S_MP_MUL_HIGH_COMBA)
17-
&& ((a->used + b->used + 1) < MP_WARRAY)
18-
&& (MP_MIN(a->used, b->used) < MP_MAX_COMBA)) {
19-
return s_mp_mul_high_comba(a, b, c, digs);
20-
}
21-
2215
if ((err = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) {
2316
return err;
2417
}

s_mp_sqr.c

+23-12
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,36 @@
66
/* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
77
mp_err s_mp_sqr(const mp_int *a, mp_int *b)
88
{
9-
mp_int t;
9+
mp_int tmp, *b_;
1010
int ix, pa;
1111
mp_err err;
1212

1313
pa = a->used;
14-
if ((err = mp_init_size(&t, (2 * pa) + 1)) != MP_OKAY) {
14+
15+
/* prepare the destination */
16+
err = MP_ALIAS(a, b)
17+
? mp_init_size((b_ = &tmp), (2 * pa) + 1)
18+
: mp_grow((b_ = b), (2 * pa + 1));
19+
if (err != MP_OKAY) {
1520
return err;
1621
}
1722

23+
s_mp_zero_digs(b_->dp, b_->used);
24+
1825
/* default used is maximum possible size */
19-
t.used = (2 * pa) + 1;
26+
b_->used = (2 * pa) + 1;
2027

2128
for (ix = 0; ix < pa; ix++) {
2229
mp_digit u;
2330
int iy;
2431

2532
/* first calculate the digit at 2*ix */
2633
/* calculate double precision result */
27-
mp_word r = (mp_word)t.dp[2*ix] +
34+
mp_word r = (mp_word)b_->dp[2*ix] +
2835
((mp_word)a->dp[ix] * (mp_word)a->dp[ix]);
2936

3037
/* store lower part in result */
31-
t.dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
38+
b_->dp[ix+ix] = (mp_digit)(r & (mp_word)MP_MASK);
3239

3340
/* get the carry */
3441
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
@@ -40,26 +47,30 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
4047
/* now calculate the double precision result, note we use
4148
* addition instead of *2 since it's easier to optimize
4249
*/
43-
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
50+
r = (mp_word)b_->dp[ix + iy] + r + r + (mp_word)u;
4451

4552
/* store lower part */
46-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
53+
b_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
4754

4855
/* get carry */
4956
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
5057
}
5158
/* propagate upwards */
5259
while (u != 0uL) {
53-
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
54-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
60+
r = (mp_word)b_->dp[ix + iy] + (mp_word)u;
61+
b_->dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
5562
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
5663
++iy;
5764
}
5865
}
5966

60-
mp_clamp(&t);
61-
mp_exch(&t, b);
62-
mp_clear(&t);
67+
mp_clamp(b_);
68+
69+
if (b_ == &tmp) {
70+
mp_clear(b);
71+
*b = *b_;
72+
}
73+
6374
return MP_OKAY;
6475
}
6576
#endif

0 commit comments

Comments
 (0)