Skip to content

Commit dbbc1bf

Browse files
authored
Merge pull request #1218 from Will-Low/master
Making distributions comparable by deriving PartialEq
2 parents a407bdf + 9f20df0 commit dbbc1bf

22 files changed

+182
-39
lines changed

rand_distr/src/binomial.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use num_traits::Float;
3030
/// let v = bin.sample(&mut rand::thread_rng());
3131
/// println!("{} is from a binomial distribution", v);
3232
/// ```
33-
#[derive(Clone, Copy, Debug)]
33+
#[derive(Clone, Copy, Debug, PartialEq)]
3434
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3535
pub struct Binomial {
3636
/// Number of trials.
@@ -347,4 +347,9 @@ mod test {
347347
fn test_binomial_invalid_lambda_neg() {
348348
Binomial::new(20, -10.0).unwrap();
349349
}
350+
351+
#[test]
352+
fn binomial_distributions_can_be_compared() {
353+
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
354+
}
350355
}

rand_distr/src/cauchy.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use core::fmt;
3131
/// let v = cau.sample(&mut rand::thread_rng());
3232
/// println!("{} is from a Cauchy(2, 5) distribution", v);
3333
/// ```
34-
#[derive(Clone, Copy, Debug)]
34+
#[derive(Clone, Copy, Debug, PartialEq)]
3535
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3636
pub struct Cauchy<F>
3737
where F: Float + FloatConst, Standard: Distribution<F>
@@ -164,4 +164,9 @@ mod test {
164164
assert_almost_eq!(*a, *b, 1e-5);
165165
}
166166
}
167+
168+
#[test]
169+
fn cauchy_distributions_can_be_compared() {
170+
assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0));
171+
}
167172
}

rand_distr/src/dirichlet.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use alloc::{boxed::Box, vec, vec::Vec};
3232
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
3333
/// ```
3434
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
35-
#[derive(Clone, Debug)]
35+
#[derive(Clone, Debug, PartialEq)]
3636
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3737
pub struct Dirichlet<F>
3838
where
@@ -183,4 +183,9 @@ mod test {
183183
fn test_dirichlet_invalid_alpha() {
184184
Dirichlet::new_with_size(0.0f64, 2).unwrap();
185185
}
186+
187+
#[test]
188+
fn dirichlet_distributions_can_be_compared() {
189+
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
190+
}
186191
}

rand_distr/src/exponential.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl Distribution<f64> for Exp1 {
9191
/// let v = exp.sample(&mut rand::thread_rng());
9292
/// println!("{} is from a Exp(2) distribution", v);
9393
/// ```
94-
#[derive(Clone, Copy, Debug)]
94+
#[derive(Clone, Copy, Debug, PartialEq)]
9595
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
9696
pub struct Exp<F>
9797
where F: Float, Exp1: Distribution<F>
@@ -178,4 +178,9 @@ mod test {
178178
fn test_exp_invalid_lambda_nan() {
179179
Exp::new(f64::nan()).unwrap();
180180
}
181+
182+
#[test]
183+
fn exponential_distributions_can_be_compared() {
184+
assert_eq!(Exp::new(1.0), Exp::new(1.0));
185+
}
181186
}

rand_distr/src/frechet.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use rand::Rng;
2727
/// let val: f64 = thread_rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap());
2828
/// println!("{}", val);
2929
/// ```
30-
#[derive(Clone, Copy, Debug)]
30+
#[derive(Clone, Copy, Debug, PartialEq)]
3131
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3232
pub struct Frechet<F>
3333
where
@@ -182,4 +182,9 @@ mod tests {
182182
.zip(&probabilities)
183183
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
184184
}
185+
186+
#[test]
187+
fn frechet_distributions_can_be_compared() {
188+
assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0));
189+
}
185190
}

rand_distr/src/gamma.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ use serde::{Serialize, Deserialize};
5454
/// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
5555
/// (September 2000), 363-372.
5656
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
57-
#[derive(Clone, Copy, Debug)]
57+
#[derive(Clone, Copy, Debug, PartialEq)]
5858
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
5959
pub struct Gamma<F>
6060
where
@@ -91,7 +91,7 @@ impl fmt::Display for Error {
9191
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
9292
impl std::error::Error for Error {}
9393

94-
#[derive(Clone, Copy, Debug)]
94+
#[derive(Clone, Copy, Debug, PartialEq)]
9595
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
9696
enum GammaRepr<F>
9797
where
@@ -119,7 +119,7 @@ where
119119
///
120120
/// See `Gamma` for sampling from a Gamma distribution with general
121121
/// shape parameters.
122-
#[derive(Clone, Copy, Debug)]
122+
#[derive(Clone, Copy, Debug, PartialEq)]
123123
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
124124
struct GammaSmallShape<F>
125125
where
@@ -135,7 +135,7 @@ where
135135
///
136136
/// See `Gamma` for sampling from a Gamma distribution with general
137137
/// shape parameters.
138-
#[derive(Clone, Copy, Debug)]
138+
#[derive(Clone, Copy, Debug, PartialEq)]
139139
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
140140
struct GammaLargeShape<F>
141141
where
@@ -280,7 +280,7 @@ where
280280
/// let v = chi.sample(&mut rand::thread_rng());
281281
/// println!("{} is from a χ²(11) distribution", v)
282282
/// ```
283-
#[derive(Clone, Copy, Debug)]
283+
#[derive(Clone, Copy, Debug, PartialEq)]
284284
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
285285
pub struct ChiSquared<F>
286286
where
@@ -314,7 +314,7 @@ impl fmt::Display for ChiSquaredError {
314314
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
315315
impl std::error::Error for ChiSquaredError {}
316316

317-
#[derive(Clone, Copy, Debug)]
317+
#[derive(Clone, Copy, Debug, PartialEq)]
318318
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
319319
enum ChiSquaredRepr<F>
320320
where
@@ -385,7 +385,7 @@ where
385385
/// let v = f.sample(&mut rand::thread_rng());
386386
/// println!("{} is from an F(2, 32) distribution", v)
387387
/// ```
388-
#[derive(Clone, Copy, Debug)]
388+
#[derive(Clone, Copy, Debug, PartialEq)]
389389
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
390390
pub struct FisherF<F>
391391
where
@@ -472,7 +472,7 @@ where
472472
/// let v = t.sample(&mut rand::thread_rng());
473473
/// println!("{} is from a t(11) distribution", v)
474474
/// ```
475-
#[derive(Clone, Copy, Debug)]
475+
#[derive(Clone, Copy, Debug, PartialEq)]
476476
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
477477
pub struct StudentT<F>
478478
where
@@ -522,15 +522,15 @@ where
522522
/// Generating beta variates with nonintegral shape parameters.
523523
/// Communications of the ACM 21, 317-322.
524524
/// https://doi.org/10.1145/359460.359482
525-
#[derive(Clone, Copy, Debug)]
525+
#[derive(Clone, Copy, Debug, PartialEq)]
526526
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
527527
enum BetaAlgorithm<N> {
528528
BB(BB<N>),
529529
BC(BC<N>),
530530
}
531531

532532
/// Algorithm BB for `min(alpha, beta) > 1`.
533-
#[derive(Clone, Copy, Debug)]
533+
#[derive(Clone, Copy, Debug, PartialEq)]
534534
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
535535
struct BB<N> {
536536
alpha: N,
@@ -539,7 +539,7 @@ struct BB<N> {
539539
}
540540

541541
/// Algorithm BC for `min(alpha, beta) <= 1`.
542-
#[derive(Clone, Copy, Debug)]
542+
#[derive(Clone, Copy, Debug, PartialEq)]
543543
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
544544
struct BC<N> {
545545
alpha: N,
@@ -560,7 +560,7 @@ struct BC<N> {
560560
/// let v = beta.sample(&mut rand::thread_rng());
561561
/// println!("{} is from a Beta(2, 5) distribution", v);
562562
/// ```
563-
#[derive(Clone, Copy, Debug)]
563+
#[derive(Clone, Copy, Debug, PartialEq)]
564564
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
565565
pub struct Beta<F>
566566
where
@@ -811,4 +811,29 @@ mod test {
811811
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
812812
}
813813
}
814+
815+
#[test]
816+
fn gamma_distributions_can_be_compared() {
817+
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
818+
}
819+
820+
#[test]
821+
fn beta_distributions_can_be_compared() {
822+
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
823+
}
824+
825+
#[test]
826+
fn chi_squared_distributions_can_be_compared() {
827+
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
828+
}
829+
830+
#[test]
831+
fn fisher_f_distributions_can_be_compared() {
832+
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
833+
}
834+
835+
#[test]
836+
fn student_t_distributions_can_be_compared() {
837+
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
838+
}
814839
}

rand_distr/src/geometric.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use num_traits::Float;
2727
/// let v = geo.sample(&mut rand::thread_rng());
2828
/// println!("{} is from a Geometric(0.25) distribution", v);
2929
/// ```
30-
#[derive(Copy, Clone, Debug)]
30+
#[derive(Copy, Clone, Debug, PartialEq)]
3131
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3232
pub struct Geometric
3333
{
@@ -235,4 +235,9 @@ mod test {
235235
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
236236
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
237237
}
238+
239+
#[test]
240+
fn geometric_distributions_can_be_compared() {
241+
assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
242+
}
238243
}

rand_distr/src/gumbel.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use rand::Rng;
2727
/// let val: f64 = thread_rng().sample(Gumbel::new(0.0, 1.0).unwrap());
2828
/// println!("{}", val);
2929
/// ```
30-
#[derive(Clone, Copy, Debug)]
30+
#[derive(Clone, Copy, Debug, PartialEq)]
3131
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3232
pub struct Gumbel<F>
3333
where
@@ -152,4 +152,9 @@ mod tests {
152152
.zip(&probabilities)
153153
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
154154
}
155+
156+
#[test]
157+
fn gumbel_distributions_can_be_compared() {
158+
assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0));
159+
}
155160
}

rand_distr/src/hypergeometric.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use core::fmt;
77
#[allow(unused_imports)]
88
use num_traits::Float;
99

10-
#[derive(Clone, Copy, Debug)]
10+
#[derive(Clone, Copy, Debug, PartialEq)]
1111
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
1212
enum SamplingMethod {
1313
InverseTransform{ initial_p: f64, initial_x: i64 },
@@ -45,7 +45,7 @@ enum SamplingMethod {
4545
/// let v = hypergeo.sample(&mut rand::thread_rng());
4646
/// println!("{} is from a hypergeometric distribution", v);
4747
/// ```
48-
#[derive(Copy, Clone, Debug)]
48+
#[derive(Copy, Clone, Debug, PartialEq)]
4949
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
5050
pub struct Hypergeometric {
5151
n1: u64,
@@ -419,4 +419,9 @@ mod test {
419419
test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
420420
test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
421421
}
422+
423+
#[test]
424+
fn hypergeometric_distributions_can_be_compared() {
425+
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
426+
}
422427
}

rand_distr/src/inverse_gaussian.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl fmt::Display for Error {
2626
impl std::error::Error for Error {}
2727

2828
/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution)
29-
#[derive(Debug, Clone, Copy)]
29+
#[derive(Debug, Clone, Copy, PartialEq)]
3030
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3131
pub struct InverseGaussian<F>
3232
where
@@ -109,4 +109,9 @@ mod tests {
109109
assert!(InverseGaussian::new(1.0, -1.0).is_err());
110110
assert!(InverseGaussian::new(1.0, 1.0).is_ok());
111111
}
112+
113+
#[test]
114+
fn inverse_gaussian_distributions_can_be_compared() {
115+
assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0));
116+
}
112117
}

rand_distr/src/normal.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl Distribution<f64> for StandardNormal {
112112
/// ```
113113
///
114114
/// [`StandardNormal`]: crate::StandardNormal
115-
#[derive(Clone, Copy, Debug)]
115+
#[derive(Clone, Copy, Debug, PartialEq)]
116116
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
117117
pub struct Normal<F>
118118
where F: Float, StandardNormal: Distribution<F>
@@ -227,7 +227,7 @@ where F: Float, StandardNormal: Distribution<F>
227227
/// let v = log_normal.sample(&mut rand::thread_rng());
228228
/// println!("{} is from an ln N(2, 9) distribution", v)
229229
/// ```
230-
#[derive(Clone, Copy, Debug)]
230+
#[derive(Clone, Copy, Debug, PartialEq)]
231231
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
232232
pub struct LogNormal<F>
233233
where F: Float, StandardNormal: Distribution<F>
@@ -368,4 +368,14 @@ mod tests {
368368
assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
369369
assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
370370
}
371+
372+
#[test]
373+
fn normal_distributions_can_be_compared() {
374+
assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0));
375+
}
376+
377+
#[test]
378+
fn log_normal_distributions_can_be_compared() {
379+
assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
380+
}
371381
}

rand_distr/src/normal_inverse_gaussian.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl fmt::Display for Error {
2626
impl std::error::Error for Error {}
2727

2828
/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution)
29-
#[derive(Debug, Clone, Copy)]
29+
#[derive(Debug, Clone, Copy, PartialEq)]
3030
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
3131
pub struct NormalInverseGaussian<F>
3232
where
@@ -104,4 +104,9 @@ mod tests {
104104
assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
105105
assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
106106
}
107+
108+
#[test]
109+
fn normal_inverse_gaussian_distributions_can_be_compared() {
110+
assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0));
111+
}
107112
}

rand_distr/src/pareto.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use core::fmt;
2323
/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap());
2424
/// println!("{}", val);
2525
/// ```
26-
#[derive(Clone, Copy, Debug)]
26+
#[derive(Clone, Copy, Debug, PartialEq)]
2727
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
2828
pub struct Pareto<F>
2929
where F: Float, OpenClosed01: Distribution<F>
@@ -131,4 +131,9 @@ mod tests {
131131
105.8826669383772,
132132
]);
133133
}
134+
135+
#[test]
136+
fn pareto_distributions_can_be_compared() {
137+
assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0));
138+
}
134139
}

0 commit comments

Comments
 (0)