Skip to content

Commit c504fa8

Browse files
Merge #56
56: Implement Roots for BigInt and BigUint r=cuviper a=mancabizjak Supersedes #51 . Since there is now a `Roots` trait with `sqrt`, `cbrt` and `nth_root` methods in the `num-integer` crate, this PR implements it for `BigInt` and `BigUint` types. I also added inherent methods on both types to allow the users access to all these functions without having to import `Roots`. PS: `nth_root` currently uses `num_traits::pow`. Should we perhaps wait for #54 to get merged, and then replace the call to use the new `pow::Pow` implementation on `BigUint`? Co-authored-by: Manca Bizjak <[email protected]>
2 parents 86e019b + 1d45ca9 commit c504fa8

File tree

5 files changed

+276
-4
lines changed

5 files changed

+276
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ name = "shootout-pidigits"
3131
[dependencies]
3232

3333
[dependencies.num-integer]
34-
version = "0.1.38"
34+
version = "0.1.39"
3535
default-features = false
3636

3737
[dependencies.num-traits]

benches/bigint.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
extern crate test;
55
extern crate num_bigint;
66
extern crate num_traits;
7+
extern crate num_integer;
78
extern crate rand;
89

910
use std::mem::replace;
@@ -342,3 +343,27 @@ fn modpow_even(b: &mut Bencher) {
342343

343344
b.iter(|| base.modpow(&e, &m));
344345
}
346+
347+
#[bench]
348+
fn roots_sqrt(b: &mut Bencher) {
349+
let mut rng = get_rng();
350+
let x = rng.gen_biguint(2048);
351+
352+
b.iter(|| x.sqrt());
353+
}
354+
355+
#[bench]
356+
fn roots_cbrt(b: &mut Bencher) {
357+
let mut rng = get_rng();
358+
let x = rng.gen_biguint(2048);
359+
360+
b.iter(|| x.cbrt());
361+
}
362+
363+
#[bench]
364+
fn roots_nth_100(b: &mut Bencher) {
365+
let mut rng = get_rng();
366+
let x = rng.gen_biguint(2048);
367+
368+
b.iter(|| x.nth_root(100));
369+
}

src/bigint.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::iter::{Product, Sum};
1616
#[cfg(feature = "serde")]
1717
use serde;
1818

19-
use integer::Integer;
19+
use integer::{Integer, Roots};
2020
use traits::{ToPrimitive, FromPrimitive, Num, CheckedAdd, CheckedSub,
2121
CheckedMul, CheckedDiv, Signed, Zero, One};
2222

@@ -1802,6 +1802,25 @@ impl Integer for BigInt {
18021802
}
18031803
}
18041804

1805+
impl Roots for BigInt {
1806+
fn nth_root(&self, n: u32) -> Self {
1807+
assert!(!(self.is_negative() && n.is_even()),
1808+
"root of degree {} is imaginary", n);
1809+
1810+
BigInt::from_biguint(self.sign, self.data.nth_root(n))
1811+
}
1812+
1813+
fn sqrt(&self) -> Self {
1814+
assert!(!self.is_negative(), "square root is imaginary");
1815+
1816+
BigInt::from_biguint(self.sign, self.data.sqrt())
1817+
}
1818+
1819+
fn cbrt(&self) -> Self {
1820+
BigInt::from_biguint(self.sign, self.data.cbrt())
1821+
}
1822+
}
1823+
18051824
impl ToPrimitive for BigInt {
18061825
#[inline]
18071826
fn to_i64(&self) -> Option<i64> {
@@ -2538,6 +2557,24 @@ impl BigInt {
25382557
};
25392558
BigInt::from_biguint(sign, mag)
25402559
}
2560+
2561+
/// Returns the truncated principal square root of `self` --
2562+
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt).
2563+
pub fn sqrt(&self) -> Self {
2564+
Roots::sqrt(self)
2565+
}
2566+
2567+
/// Returns the truncated principal cube root of `self` --
2568+
/// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt).
2569+
pub fn cbrt(&self) -> Self {
2570+
Roots::cbrt(self)
2571+
}
2572+
2573+
/// Returns the truncated principal `n`th root of `self` --
2574+
/// See [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root).
2575+
pub fn nth_root(&self, n: u32) -> Self {
2576+
Roots::nth_root(self, n)
2577+
}
25412578
}
25422579

25432580
impl_sum_iter_type!(BigInt);

src/biguint.rs

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ use std::ascii::AsciiExt;
1717
#[cfg(feature = "serde")]
1818
use serde;
1919

20-
use integer::Integer;
20+
use integer::{Integer, Roots};
2121
use traits::{ToPrimitive, FromPrimitive, Float, Num, Unsigned, CheckedAdd, CheckedSub, CheckedMul,
22-
CheckedDiv, Zero, One};
22+
CheckedDiv, Zero, One, pow};
2323

2424
use big_digit::{self, BigDigit, DoubleBigDigit};
2525

@@ -1026,6 +1026,94 @@ impl Integer for BigUint {
10261026
}
10271027
}
10281028

1029+
impl Roots for BigUint {
1030+
// nth_root, sqrt and cbrt use Newton's method to compute
1031+
// principal root of a given degree for a given integer.
1032+
1033+
// Reference:
1034+
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14
1035+
fn nth_root(&self, n: u32) -> Self {
1036+
assert!(n > 0, "root degree n must be at least 1");
1037+
1038+
if self.is_zero() || self.is_one() {
1039+
return self.clone()
1040+
}
1041+
1042+
match n { // Optimize for small n
1043+
1 => return self.clone(),
1044+
2 => return self.sqrt(),
1045+
3 => return self.cbrt(),
1046+
_ => (),
1047+
}
1048+
1049+
let n = n as usize;
1050+
let n_min_1 = n - 1;
1051+
1052+
let guess = BigUint::one() << (self.bits()/n + 1);
1053+
1054+
let mut u = guess;
1055+
let mut s: BigUint;
1056+
1057+
loop {
1058+
s = u;
1059+
let q = self / pow(s.clone(), n_min_1);
1060+
let t: BigUint = n_min_1 * &s + q;
1061+
1062+
u = t / n;
1063+
1064+
if u >= s { break; }
1065+
}
1066+
1067+
s
1068+
}
1069+
1070+
// Reference:
1071+
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
1072+
fn sqrt(&self) -> Self {
1073+
if self.is_zero() || self.is_one() {
1074+
return self.clone()
1075+
}
1076+
1077+
let guess = BigUint::one() << (self.bits()/2 + 1);
1078+
1079+
let mut u = guess;
1080+
let mut s: BigUint;
1081+
1082+
loop {
1083+
s = u;
1084+
let q = self / &s;
1085+
let t: BigUint = &s + q;
1086+
u = t >> 1;
1087+
1088+
if u >= s { break; }
1089+
}
1090+
1091+
s
1092+
}
1093+
1094+
fn cbrt(&self) -> Self {
1095+
if self.is_zero() || self.is_one() {
1096+
return self.clone()
1097+
}
1098+
1099+
let guess = BigUint::one() << (self.bits()/3 + 1);
1100+
1101+
let mut u = guess;
1102+
let mut s: BigUint;
1103+
1104+
loop {
1105+
s = u;
1106+
let q = self / (&s * &s);
1107+
let t: BigUint = (&s << 1) + q;
1108+
u = t / 3u32;
1109+
1110+
if u >= s { break; }
1111+
}
1112+
1113+
s
1114+
}
1115+
}
1116+
10291117
fn high_bits_to_u64(v: &BigUint) -> u64 {
10301118
match v.data.len() {
10311119
0 => 0,
@@ -1749,6 +1837,24 @@ impl BigUint {
17491837
}
17501838
acc
17511839
}
1840+
1841+
/// Returns the truncated principal square root of `self` --
1842+
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt)
1843+
pub fn sqrt(&self) -> Self {
1844+
Roots::sqrt(self)
1845+
}
1846+
1847+
/// Returns the truncated principal cube root of `self` --
1848+
/// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt).
1849+
pub fn cbrt(&self) -> Self {
1850+
Roots::cbrt(self)
1851+
}
1852+
1853+
/// Returns the truncated principal `n`th root of `self` --
1854+
/// see [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root).
1855+
pub fn nth_root(&self, n: u32) -> Self {
1856+
Roots::nth_root(self, n)
1857+
}
17521858
}
17531859

17541860
/// Returns the number of least-significant bits that are zero,

tests/roots.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
extern crate num_bigint;
2+
extern crate num_integer;
3+
extern crate num_traits;
4+
5+
mod biguint {
6+
use num_bigint::BigUint;
7+
use num_traits::pow;
8+
use std::str::FromStr;
9+
10+
fn check(x: u64, n: u32) {
11+
let big_x = BigUint::from(x);
12+
let res = big_x.nth_root(n);
13+
14+
if n == 2 {
15+
assert_eq!(&res, &big_x.sqrt())
16+
} else if n == 3 {
17+
assert_eq!(&res, &big_x.cbrt())
18+
}
19+
20+
assert!(pow(res.clone(), n as usize) <= big_x);
21+
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
22+
}
23+
24+
#[test]
25+
fn test_sqrt() {
26+
check(99, 2);
27+
check(100, 2);
28+
check(120, 2);
29+
}
30+
31+
#[test]
32+
fn test_cbrt() {
33+
check(8, 3);
34+
check(26, 3);
35+
}
36+
37+
#[test]
38+
fn test_nth_root() {
39+
check(0, 1);
40+
check(10, 1);
41+
check(100, 4);
42+
}
43+
44+
#[test]
45+
#[should_panic]
46+
fn test_nth_root_n_is_zero() {
47+
check(4, 0);
48+
}
49+
50+
#[test]
51+
fn test_nth_root_big() {
52+
let x = BigUint::from_str("123_456_789").unwrap();
53+
let expected = BigUint::from(6u32);
54+
55+
assert_eq!(x.nth_root(10), expected);
56+
}
57+
}
58+
59+
mod bigint {
60+
use num_bigint::BigInt;
61+
use num_traits::{Signed, pow};
62+
63+
fn check(x: i64, n: u32) {
64+
let big_x = BigInt::from(x);
65+
let res = big_x.nth_root(n);
66+
67+
if n == 2 {
68+
assert_eq!(&res, &big_x.sqrt())
69+
} else if n == 3 {
70+
assert_eq!(&res, &big_x.cbrt())
71+
}
72+
73+
if big_x.is_negative() {
74+
assert!(pow(res.clone() - 1u32, n as usize) < big_x);
75+
assert!(pow(res.clone(), n as usize) >= big_x);
76+
} else {
77+
assert!(pow(res.clone(), n as usize) <= big_x);
78+
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
79+
}
80+
}
81+
82+
#[test]
83+
fn test_nth_root() {
84+
check(-100, 3);
85+
}
86+
87+
#[test]
88+
#[should_panic]
89+
fn test_nth_root_x_neg_n_even() {
90+
check(-100, 4);
91+
}
92+
93+
#[test]
94+
#[should_panic]
95+
fn test_sqrt_x_neg() {
96+
check(-4, 2);
97+
}
98+
99+
#[test]
100+
fn test_cbrt() {
101+
check(8, 3);
102+
check(-8, 3);
103+
}
104+
}

0 commit comments

Comments
 (0)