@@ -59,28 +59,26 @@ impl Barrett {
59
59
///
60
60
/// * `a` `0 <= a < m`
61
61
/// * `b` `0 <= b < m`
62
- /// * `m` `1 <= m <= 2^31 `
63
- /// * `im` = ceil(2^64 / `m`)
62
+ /// * `m` `1 <= m < 2^32 `
63
+ /// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
64
64
#[ allow( clippy:: many_single_char_names) ]
65
65
pub ( crate ) fn mul_mod ( a : u32 , b : u32 , m : u32 , im : u64 ) -> u32 {
66
66
// [1] m = 1
67
67
// a = b = im = 0, so okay
68
68
69
69
// [2] m >= 2
70
- // im = ceil(2^64 / m)
70
+ // im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1
71
71
// -> im * m = 2^64 + r (0 <= r < m)
72
72
// let z = a*b = c*m + d (0 <= c, d < m)
73
73
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
74
74
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
75
75
// ((ab * im) >> 64) == c or c + 1
76
- let mut z = a as u64 ;
77
- z *= b as u64 ;
76
+ let z = ( a as u64 ) * ( b as u64 ) ;
78
77
let x = ( ( ( z as u128 ) * ( im as u128 ) ) >> 64 ) as u64 ;
79
- let mut v = z . wrapping_sub ( x. wrapping_mul ( m as u64 ) ) as u32 ;
80
- if m <= v {
81
- v = v . wrapping_add ( m ) ;
78
+ match z . overflowing_sub ( x. wrapping_mul ( m as u64 ) ) {
79
+ ( v , true ) => ( v as u32 ) . wrapping_add ( m ) ,
80
+ ( v , false ) => v as u32 ,
82
81
}
83
- v
84
82
}
85
83
86
84
/// # Parameters
@@ -320,6 +318,17 @@ mod tests {
320
318
let b = Barrett :: new ( 2147483647 ) ;
321
319
assert_eq ! ( b. umod( ) , 2147483647 ) ;
322
320
assert_eq ! ( b. mul( 1073741824 , 2147483645 ) , 2147483646 ) ;
321
+
322
+ // test `2^31 < self._m < 2^32` case.
323
+ // https://github.com/rust-lang-ja/ac-library-rs/pull/112
324
+ // https://github.com/atcoder/ac-library/issues/149
325
+ // https://github.com/atcoder/ac-library/pull/163
326
+ let b = Barrett :: new ( 3221225471 ) ;
327
+ assert_eq ! ( b. umod( ) , 3221225471 ) ;
328
+ assert_eq ! ( b. mul( 3188445886 , 2844002853 ) , 1840468257 ) ;
329
+ assert_eq ! ( b. mul( 2834869488 , 2779159607 ) , 2084027561 ) ;
330
+ assert_eq ! ( b. mul( 3032263594 , 3039996727 ) , 2130247251 ) ;
331
+ assert_eq ! ( b. mul( 3029175553 , 3140869278 ) , 1892378237 ) ;
323
332
}
324
333
325
334
#[ test]
0 commit comments