diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index c9ab729be8c..6dbb892f0e2 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Move some of the computations in Binomial from `sample` to `new` (#1484) - Add Kolmogorov Smirnov test for sampling of `Normal` and `Binomial` (#1494) - Add Kolmogorov Smirnov test for more distributions (#1504) +- Add `Distribution` support for `Zeta`, `Zipf` (#1516) ### Added - Add plots for `rand_distr` distributions to documentation (#1434) diff --git a/rand_distr/src/zeta.rs b/rand_distr/src/zeta.rs index 3c2f55546e5..ee1d420092a 100644 --- a/rand_distr/src/zeta.rs +++ b/rand_distr/src/zeta.rs @@ -132,6 +132,14 @@ where } } +impl Distribution for Zeta { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + // `as` from float to int saturates + as Distribution>::sample(self, rng) as u64 + } +} + #[cfg(test)] mod tests { use super::*; @@ -163,7 +171,7 @@ mod tests { let d = Zeta::new(a).unwrap(); let mut rng = crate::test::rng(1); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } } @@ -174,7 +182,7 @@ mod tests { let d = Zeta::new(a).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } } diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs index 0a56fdf8182..0ee2c8c3863 100644 --- a/rand_distr/src/zipf.rs +++ b/rand_distr/src/zipf.rs @@ -150,6 +150,14 @@ where } } +impl Distribution for Zipf { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + // `as` from float to int saturates + as Distribution>::sample(self, rng) as u64 + } +} + #[cfg(test)] mod tests { use super::*; @@ -186,7 +194,7 @@ mod tests { let d = Zipf::new(10, 0.5).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } } @@ -196,7 +204,7 @@ mod tests { let d = Zipf::new(10, 1.).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } } @@ -206,7 +214,7 @@ mod tests { let d = Zipf::new(10, 0.).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } // TODO: verify that this is a uniform distribution @@ -217,7 +225,7 @@ mod tests { let d = Zipf::new(u64::MAX, 1.5).unwrap(); let mut rng = crate::test::rng(2); for _ in 0..1000 { - let r = d.sample(&mut rng); + let r: f64 = d.sample(&mut rng); assert!(r >= 1.); } // TODO: verify that this is a zeta distribution diff --git a/rand_distr/tests/cdf.rs b/rand_distr/tests/cdf.rs index 8eb22740e2b..585625202d3 100644 --- a/rand_distr/tests/cdf.rs +++ b/rand_distr/tests/cdf.rs @@ -366,7 +366,8 @@ fn zeta() { for (seed, s) in parameters.into_iter().enumerate() { let dist = rand_distr::Zeta::new(s).unwrap(); - test_discrete(seed as u64, dist, |k| cdf(k, s)); + test_discrete::(seed as u64, dist, |k| cdf(k, s)); + test_discrete::(seed as u64, dist, |k| cdf(k, s)); } } @@ -386,7 +387,8 @@ fn zipf() { for (seed, (n, x)) in parameters.into_iter().enumerate() { let dist = rand_distr::Zipf::new(n, x).unwrap(); - test_discrete(seed as u64, dist, |k| cdf(k, n, x)); + test_discrete::(seed as u64, dist, |k| cdf(k, n, x)); + test_discrete::(seed as u64, dist, |k| cdf(k, n, x)); } }