Skip to content

Commit b7d9057

Browse files
czurniedensjaeckel
authored andcommitted
optimized s_mp_sqr
1 parent e4b789b commit b7d9057

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

demo/test.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,6 +1846,33 @@ static mp_err s_fill_with_ones(mp_int *a, int size)
18461846
return err;
18471847
}
18481848

1849+
static int test_s_mp_sqr(void)
1850+
{
1851+
mp_int a, b, c;
1852+
int i;
1853+
1854+
DOR(mp_init_multi(&a, &b, &c, NULL));
1855+
1856+
/* s_mp_mul() has a hardcoded branch to s_mul_comba if s_mul_comba is available,
1857+
so test another 10 just in case. */
1858+
for (i = 1; i < MP_MAX_COMBA + 10; i++) {
1859+
DO(s_fill_with_ones(&a, i));
1860+
DO(s_mp_sqr(&a, &b));
1861+
DO(s_mp_mul(&a, &a, &c, 2*i + 1));
1862+
EXPECT(mp_cmp(&b, &c) == MP_EQ);
1863+
DO(mp_rand(&a, i));
1864+
DO(s_mp_sqr(&a, &b));
1865+
DO(s_mp_mul(&a, &a, &c, 2*i + 1));
1866+
EXPECT(mp_cmp(&b, &c) == MP_EQ);
1867+
}
1868+
1869+
mp_clear_multi(&a, &b, &c, NULL);
1870+
return EXIT_SUCCESS;
1871+
LBL_ERR:
1872+
mp_clear_multi(&a, &b, &c, NULL);
1873+
return EXIT_FAILURE;
1874+
}
1875+
18491876
static int test_s_mp_sqr_comba(void)
18501877
{
18511878
mp_int a, r1, r2;
@@ -2373,6 +2400,7 @@ static int unit_tests(int argc, char **argv)
23732400
T1(mp_xor, MP_XOR),
23742401
T3(s_mp_div_recursive, ONLY_PUBLIC_API, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL),
23752402
T3(s_mp_div_small, ONLY_PUBLIC_API, S_MP_DIV_SMALL, S_MP_DIV_SCHOOL),
2403+
T2(s_mp_sqr, ONLY_PUBLIC_API, S_MP_SQR),
23762404
/* s_mp_mul_comba not (yet) testable because s_mp_mul branches to s_mp_mul_comba automatically */
23772405
T2(s_mp_sqr_comba, ONLY_PUBLIC_API, S_MP_SQR_COMBA),
23782406
T2(s_mp_mul_balance, ONLY_PUBLIC_API, S_MP_MUL_BALANCE),

s_mp_sqr.c

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
3838
r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];
3939

4040
/* now calculate the double precision result, note we use
41-
* addition instead of *2 since it's easier to optimize
41+
* addition instead of *2 since it's easier to optimize.
4242
*/
43-
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
43+
/* Some architectures and/or compilers seem to prefer a bit-shift nowadays */
44+
r = (mp_word)t.dp[ix + iy] + (r<<1) + (mp_word)u;
4445

4546
/* store lower part */
4647
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
@@ -50,9 +51,21 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
5051
}
5152
/* propagate upwards */
5253
while (u != 0uL) {
53-
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
54-
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
55-
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
54+
mp_digit tmp;
55+
/*
56+
"u" can get bigger than MP_DIGIT_MAX and would need a bigger type
57+
for the sum (mp_word). That is costly if mp_word is not a native
58+
integer but a bigint from the compiler library. We do a manual
59+
multiword addition instead.
60+
*/
61+
/* t.dp[ix + iy] has been masked off by MP_MASK and is hence of the correct size
62+
and we can just add the lower part of "u". Carry is guaranteed to fit into
63+
the type used for mp_digit, too, so we can extract it later. */
64+
tmp = t.dp[ix + iy] + (u & MP_MASK);
65+
/* t.dp[ix + iy] is set to the result minus the carry, carry is still in "tmp" */
66+
t.dp[ix + iy] = tmp & MP_MASK;
67+
/* Add high part of "u" and the carry from "tmp" to get the next "u" */
68+
u = (u >> MP_DIGIT_BIT) + (tmp >> MP_DIGIT_BIT);
5669
++iy;
5770
}
5871
}

0 commit comments

Comments
 (0)