Skip to content

Commit 863729e

Browse files
tim-onearhadthedev
andauthored
bpo-46218: Change long_pow() to sliding window algorithm (pythonGH-30319)
* bpo-46218: Change long_pow() to sliding window algorithm The primary motivation is to eliminate long_pow's reliance on that the number of bits in a long "digit" is a multiple of 5. Now it no longer cares how many bits are in a digit. But the sliding window approach also allows cutting the precomputed table of small powers in half, which reduces initialization overhead enough that the approach pays off for smaller exponents too. Depending on exponent bit patterns, a sliding window may also be able to save some bigint multiplies (sometimes when at least 5 consecutive exponent bits are 0, regardless of their starting bit position modulo 5). Note: boosting the window width to 6 didn't work well overall. It give marginal speed improvements for huge exponents, but the increased overhead (the small-power table needs twice as many entries) made it a loss for smaller exponents. Co-authored-by: Oleg Iarygin <[email protected]>
1 parent ce4d25f commit 863729e

File tree

3 files changed

+106
-30
lines changed

3 files changed

+106
-30
lines changed

Include/cpython/longintrepr.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ extern "C" {
2121
PyLong_SHIFT. The majority of the code doesn't care about the precise
2222
value of PyLong_SHIFT, but there are some notable exceptions:
2323
24-
- long_pow() requires that PyLong_SHIFT be divisible by 5
25-
2624
- PyLong_{As,From}ByteArray require that PyLong_SHIFT be at least 8
2725
2826
- long_hash() requires that PyLong_SHIFT is *strictly* less than the number
@@ -63,10 +61,6 @@ typedef long stwodigits; /* signed variant of twodigits */
6361
#define PyLong_BASE ((digit)1 << PyLong_SHIFT)
6462
#define PyLong_MASK ((digit)(PyLong_BASE - 1))
6563

66-
#if PyLong_SHIFT % 5 != 0
67-
#error "longobject.c requires that PyLong_SHIFT be divisible by 5"
68-
#endif
69-
7064
/* Long integer representation.
7165
The absolute value of a number is equal to
7266
SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)

Lib/test/test_pow.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ def test_other(self):
9393
pow(int(i),j,k)
9494
)
9595

96+
def test_big_exp(self):
97+
import random
98+
self.assertEqual(pow(2, 50000), 1 << 50000)
99+
# Randomized modular tests, checking the identities
100+
# a**(b1 + b2) == a**b1 * a**b2
101+
# a**(b1 * b2) == (a**b1)**b2
102+
prime = 1000000000039 # for speed, relatively small prime modulus
103+
for i in range(10):
104+
a = random.randrange(1000, 1000000)
105+
bpower = random.randrange(1000, 50000)
106+
b = random.randrange(1 << (bpower - 1), 1 << bpower)
107+
b1 = random.randrange(1, b)
108+
b2 = b - b1
109+
got1 = pow(a, b, prime)
110+
got2 = pow(a, b1, prime) * pow(a, b2, prime) % prime
111+
if got1 != got2:
112+
self.fail(f"{a=:x} {b1=:x} {b2=:x} {got1=:x} {got2=:x}")
113+
got3 = pow(a, b1 * b2, prime)
114+
got4 = pow(pow(a, b1, prime), b2, prime)
115+
if got3 != got4:
116+
self.fail(f"{a=:x} {b1=:x} {b2=:x} {got3=:x} {got4=:x}")
117+
96118
def test_bug643260(self):
97119
class TestRpow:
98120
def __rpow__(self, other):

Objects/longobject.c

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,34 @@ maybe_small_long(PyLongObject *v)
7474
#define KARATSUBA_CUTOFF 70
7575
#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
7676

77-
/* For exponentiation, use the binary left-to-right algorithm
78-
* unless the exponent contains more than FIVEARY_CUTOFF digits.
79-
* In that case, do 5 bits at a time. The potential drawback is that
80-
* a table of 2**5 intermediate results is computed.
77+
/* For exponentiation, use the binary left-to-right algorithm unless the
78+
^ exponent contains more than HUGE_EXP_CUTOFF bits. In that case, do
79+
* (no more than) EXP_WINDOW_SIZE bits at a time. The potential drawback is
80+
* that a table of 2**(EXP_WINDOW_SIZE - 1) intermediate results is
81+
* precomputed.
8182
*/
82-
#define FIVEARY_CUTOFF 8
83+
#define EXP_WINDOW_SIZE 5
84+
#define EXP_TABLE_LEN (1 << (EXP_WINDOW_SIZE - 1))
85+
/* Suppose the exponent has bit length e. All ways of doing this
86+
* need e squarings. The binary method also needs a multiply for
87+
* each bit set. In a k-ary method with window width w, a multiply
88+
* for each non-zero window, so at worst (and likely!)
89+
* ceiling(e/w). The k-ary sliding window method has the same
90+
* worst case, but the window slides so it can sometimes skip
91+
* over an all-zero window that the fixed-window method can't
92+
* exploit. In addition, the windowing methods need multiplies
93+
* to precompute a table of small powers.
94+
*
95+
* For the sliding window method with width 5, 16 precomputation
96+
* multiplies are needed. Assuming about half the exponent bits
97+
* are set, then, the binary method needs about e/2 extra mults
98+
* and the window method about 16 + e/5.
99+
*
100+
* The latter is smaller for e > 53 1/3. We don't have direct
101+
* access to the bit length, though, so call it 60, which is a
102+
* multiple of a long digit's max bit length (15 or 30 so far).
103+
*/
104+
#define HUGE_EXP_CUTOFF 60
83105

84106
#define SIGCHECK(PyTryBlock) \
85107
do { \
@@ -4172,14 +4194,15 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
41724194
int negativeOutput = 0; /* if x<0 return negative output */
41734195

41744196
PyLongObject *z = NULL; /* accumulated result */
4175-
Py_ssize_t i, j, k; /* counters */
4197+
Py_ssize_t i, j; /* counters */
41764198
PyLongObject *temp = NULL;
4199+
PyLongObject *a2 = NULL; /* may temporarily hold a**2 % c */
41774200

4178-
/* 5-ary values. If the exponent is large enough, table is
4179-
* precomputed so that table[i] == a**i % c for i in range(32).
4201+
/* k-ary values. If the exponent is large enough, table is
4202+
* precomputed so that table[i] == a**(2*i+1) % c for i in
4203+
* range(EXP_TABLE_LEN).
41804204
*/
4181-
PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
4182-
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
4205+
PyLongObject *table[EXP_TABLE_LEN] = {0};
41834206

41844207
/* a, b, c = v, w, x */
41854208
CHECK_BINOP(v, w);
@@ -4332,7 +4355,7 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
43324355
}
43334356
/* else bi is 0, and z==1 is correct */
43344357
}
4335-
else if (i <= FIVEARY_CUTOFF) {
4358+
else if (i <= HUGE_EXP_CUTOFF / PyLong_SHIFT ) {
43364359
/* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
43374360
/* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf */
43384361

@@ -4366,23 +4389,59 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
43664389
}
43674390
}
43684391
else {
4369-
/* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
4370-
Py_INCREF(z); /* still holds 1L */
4371-
table[0] = z;
4372-
for (i = 1; i < 32; ++i)
4373-
MULT(table[i-1], a, table[i]);
4392+
/* Left-to-right k-ary sliding window exponentiation
4393+
* (Handbook of Applied Cryptography (HAC) Algorithm 14.85)
4394+
*/
4395+
Py_INCREF(a);
4396+
table[0] = a;
4397+
MULT(a, a, a2);
4398+
/* table[i] == a**(2*i + 1) % c */
4399+
for (i = 1; i < EXP_TABLE_LEN; ++i)
4400+
MULT(table[i-1], a2, table[i]);
4401+
Py_CLEAR(a2);
4402+
4403+
/* Repeatedly extract the next (no more than) EXP_WINDOW_SIZE bits
4404+
* into `pending`, starting with the next 1 bit. The current bit
4405+
* length of `pending` is `blen`.
4406+
*/
4407+
int pending = 0, blen = 0;
4408+
#define ABSORB_PENDING do { \
4409+
int ntz = 0; /* number of trailing zeroes in `pending` */ \
4410+
assert(pending && blen); \
4411+
assert(pending >> (blen - 1)); \
4412+
assert(pending >> blen == 0); \
4413+
while ((pending & 1) == 0) { \
4414+
++ntz; \
4415+
pending >>= 1; \
4416+
} \
4417+
assert(ntz < blen); \
4418+
blen -= ntz; \
4419+
do { \
4420+
MULT(z, z, z); \
4421+
} while (--blen); \
4422+
MULT(z, table[pending >> 1], z); \
4423+
while (ntz-- > 0) \
4424+
MULT(z, z, z); \
4425+
assert(blen == 0); \
4426+
pending = 0; \
4427+
} while(0)
43744428

43754429
for (i = Py_SIZE(b) - 1; i >= 0; --i) {
43764430
const digit bi = b->ob_digit[i];
4377-
4378-
for (j = PyLong_SHIFT - 5; j >= 0; j -= 5) {
4379-
const int index = (bi >> j) & 0x1f;
4380-
for (k = 0; k < 5; ++k)
4431+
for (j = PyLong_SHIFT - 1; j >= 0; --j) {
4432+
const int bit = (bi >> j) & 1;
4433+
pending = (pending << 1) | bit;
4434+
if (pending) {
4435+
++blen;
4436+
if (blen == EXP_WINDOW_SIZE)
4437+
ABSORB_PENDING;
4438+
}
4439+
else /* absorb strings of 0 bits */
43814440
MULT(z, z, z);
4382-
if (index)
4383-
MULT(z, table[index], z);
43844441
}
43854442
}
4443+
if (pending)
4444+
ABSORB_PENDING;
43864445
}
43874446

43884447
if (negativeOutput && (Py_SIZE(z) != 0)) {
@@ -4399,13 +4458,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
43994458
Py_CLEAR(z);
44004459
/* fall through */
44014460
Done:
4402-
if (Py_SIZE(b) > FIVEARY_CUTOFF) {
4403-
for (i = 0; i < 32; ++i)
4461+
if (Py_SIZE(b) > HUGE_EXP_CUTOFF / PyLong_SHIFT) {
4462+
for (i = 0; i < EXP_TABLE_LEN; ++i)
44044463
Py_XDECREF(table[i]);
44054464
}
44064465
Py_DECREF(a);
44074466
Py_DECREF(b);
44084467
Py_XDECREF(c);
4468+
Py_XDECREF(a2);
44094469
Py_XDECREF(temp);
44104470
return (PyObject *)z;
44114471
}

0 commit comments

Comments
 (0)