|
| 1 | +package math |
| 2 | + |
| 3 | +import java.math.BigInteger |
| 4 | + |
| 5 | +const val ERROR_MESSAGE_FOR_NEGATIVE_INDICES = "Binary Exponentiation cannot be used for negative indices" |
| 6 | + |
| 7 | +internal inline fun Long.toBigInteger(): BigInteger = BigInteger.valueOf(this) |
| 8 | + |
| 9 | +internal fun throwIllegalArgumentExceptionForNegativeIndices(): Nothing = |
| 10 | + throw IllegalArgumentException(ERROR_MESSAGE_FOR_NEGATIVE_INDICES) |
| 11 | + |
| 12 | +/** |
| 13 | + * Calculates a.pow(b) by using the algorithm of binary exponentiation, |
| 14 | + * @param that allows for arbitrarily large indices. |
| 15 | + * However, sufficiently large indices (as a rule of thumb : anything above 2 ^ 1000 as index) can cause a stack-overflow |
| 16 | + * To prevent stack overflow and use arbitrarily large indices, use deep recursive binary exponentiation |
| 17 | + * */ |
| 18 | +infix fun BigInteger.bpow(that: BigInteger): BigInteger { |
| 19 | + if (that < BigInteger.ZERO) { |
| 20 | + throwIllegalArgumentExceptionForNegativeIndices() |
| 21 | + } |
| 22 | + if (that == BigInteger.ZERO) { |
| 23 | + return BigInteger.ONE |
| 24 | + } |
| 25 | + val toSquare = this.bpow(that / 2.toBigInteger()) |
| 26 | + return if (that % 2.toBigInteger() == BigInteger.ZERO) { |
| 27 | + toSquare * toSquare |
| 28 | + } else { |
| 29 | + this * toSquare * toSquare |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +/** |
| 34 | + * Calculates a.pow(b) by using the algorithm of binary exponentiation, |
| 35 | + * where a is the base and b is the index. |
| 36 | + * If b is odd, a.pow(b) is written as a * (a.pow(b / 2)). |
| 37 | + * If b is even, a.pow(b) is written as (a.pow(b / 2)). |
| 38 | + * We compute (a.pow(b / 2)) recursively. |
| 39 | + * Time Complexity : O(log(n)). |
| 40 | + * Space Complexity : O(1). |
| 41 | + * @see Long.bpow |
| 42 | + * @see BigInteger.bpow |
| 43 | + * @receiver the base of exponentiation |
| 44 | + * @param that : the index of exponentiation |
| 45 | + */ |
| 46 | +infix fun Int.bpow(that: Int): Int { |
| 47 | + if (that < 0) { |
| 48 | + throwIllegalArgumentExceptionForNegativeIndices() |
| 49 | + } |
| 50 | + // a.pow(0) = 1 |
| 51 | + if (that == 0) { |
| 52 | + return 1 |
| 53 | + } |
| 54 | + |
| 55 | + val toSquare = this.bpow(that / 2) |
| 56 | + return if (that % 2 == 0) { |
| 57 | + toSquare * toSquare |
| 58 | + } else { |
| 59 | + this * toSquare * toSquare |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +/** |
| 64 | + * Calculates a.pow(b) by using the algorithm of binary exponentiation |
| 65 | + * Note that neither the [Int.bpow] nor [Long.bpow] are overflow-proof, |
| 66 | + * they use native ints (32-bit signed integers) and longs (64-bit signed integers). |
| 67 | + * To use overflow-proof exponentiation, use [BigInteger.bpow] |
| 68 | + * @see Int.bpow(that) |
| 69 | + * @see BigInteger.bpow |
| 70 | + * @receiver the base of exponentiation |
| 71 | + * @param that : the index of exponentiation |
| 72 | + * */ |
| 73 | +infix fun Long.bpow(that: Long): Long { |
| 74 | + if (that < 0L) { |
| 75 | + throwIllegalArgumentExceptionForNegativeIndices() |
| 76 | + } |
| 77 | + if (that == 0L) { |
| 78 | + return 1L |
| 79 | + } |
| 80 | + val toSquare = this.bpow(that / 2) |
| 81 | + return if (that % 2L == 0L) { |
| 82 | + toSquare * toSquare |
| 83 | + } else { |
| 84 | + this * toSquare * toSquare |
| 85 | + } |
| 86 | +} |
| 87 | + |
0 commit comments