Skip to content

Commit 7a35f3d

Browse files
committed
Implement extended GCD and modular inverse
1 parent 4d166cb commit 7a35f3d

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

src/lib.rs

+71-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ extern crate num_traits as traits;
2323

2424
use core::mem;
2525
use core::ops::Add;
26+
use core::cmp::Ordering;
2627

27-
use traits::{Num, Signed, Zero};
28+
use traits::{Num, NumRef, RefNum, Signed, Zero};
2829

2930
mod roots;
3031
pub use roots::Roots;
@@ -1013,6 +1014,57 @@ impl_integer_for_usize!(usize, test_integer_usize);
10131014
#[cfg(has_i128)]
10141015
impl_integer_for_usize!(u128, test_integer_u128);
10151016

1017+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1018+
pub struct GcdResult<T> {
1019+
/// Greatest common divisor.
1020+
pub gcd: T,
1021+
/// Coefficients such that: gcd(a, b) = c1*a + c2*b
1022+
pub c1: T, pub c2: T,
1023+
}
1024+
1025+
/// Calculate greatest common divisor and the corresponding coefficients.
1026+
pub fn extended_gcd<T: Integer + NumRef>(a: T, b: T) -> GcdResult<T>
1027+
where for<'a> &'a T: RefNum<T>
1028+
{
1029+
// Euclid's extended algorithm
1030+
let (mut s, mut old_s) = (T::zero(), T::one());
1031+
let (mut t, mut old_t) = (T::one(), T::zero());
1032+
let (mut r, mut old_r) = (b, a);
1033+
1034+
while r != T::zero() {
1035+
let quotient = &old_r / &r;
1036+
old_r = old_r - &quotient * &r; std::mem::swap(&mut old_r, &mut r);
1037+
old_s = old_s - &quotient * &s; std::mem::swap(&mut old_s, &mut s);
1038+
old_t = old_t - quotient * &t; std::mem::swap(&mut old_t, &mut t);
1039+
}
1040+
1041+
let _quotients = (t, s); // == (a, b) / gcd
1042+
1043+
GcdResult { gcd: old_r, c1: old_s, c2: old_t }
1044+
}
1045+
1046+
/// Find the standard representation of a (mod n).
1047+
pub fn normalize<T: Integer + NumRef>(a: T, n: &T) -> T {
1048+
let a = a % n;
1049+
match a.cmp(&T::zero()) {
1050+
Ordering::Less => a + n,
1051+
_ => a,
1052+
}
1053+
}
1054+
1055+
/// Calculate the inverse of a (mod n).
1056+
pub fn inverse<T: Integer + NumRef + Clone>(a: T, n: &T) -> Option<T>
1057+
where for<'a> &'a T: RefNum<T>
1058+
{
1059+
let GcdResult { gcd, c1: c, .. } = extended_gcd(a, n.clone());
1060+
if gcd == T::one() {
1061+
Some(normalize(c, n))
1062+
} else {
1063+
None
1064+
}
1065+
}
1066+
1067+
10161068
/// An iterator over binomial coefficients.
10171069
pub struct IterBinomial<T> {
10181070
a: T,
@@ -1169,6 +1221,24 @@ fn test_lcm_overflow() {
11691221
check!(u64, 0x8000_0000_0000_0000, 0x02, 0x8000_0000_0000_0000);
11701222
}
11711223

1224+
#[test]
1225+
fn test_extended_gcd() {
1226+
assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47});
1227+
}
1228+
1229+
#[test]
1230+
fn test_normalize() {
1231+
assert_eq!(normalize(10, &7), 3);
1232+
assert_eq!(normalize(7, &7), 0);
1233+
assert_eq!(normalize(5, &7), 5);
1234+
assert_eq!(normalize(-3, &7), 4);
1235+
}
1236+
1237+
#[test]
1238+
fn test_inverse() {
1239+
assert_eq!(inverse(5, &7).unwrap(), 3);
1240+
}
1241+
11721242
#[test]
11731243
fn test_iter_binomial() {
11741244
macro_rules! check_simple {

0 commit comments

Comments
 (0)