From 4d7387d45290c19f8b9a931b66cb29d9c7a67eaa Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 10:19:15 +0100 Subject: [PATCH 01/16] Multidistribution trait --- src/multi/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 src/multi/mod.rs diff --git a/src/multi/mod.rs b/src/multi/mod.rs new file mode 100644 index 0000000..3d2f69d --- /dev/null +++ b/src/multi/mod.rs @@ -0,0 +1,8 @@ +//! Contains Multi-dimensional distributions. +//! +//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. +//! All multi-dimensional distributions should implement this trait addidionally to the `Distribution` trait returning a `Vec` of samples. + +pub trait MultiDistribution { + fn sample(&self, rng: &mut R, output: &mut S); +} \ No newline at end of file From 0048240152c74aa2c2e723135b3dabc350baf838 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 10:22:17 +0100 Subject: [PATCH 02/16] documentation for MultiDistr --- src/multi/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/multi/mod.rs b/src/multi/mod.rs index 3d2f69d..4020785 100644 --- a/src/multi/mod.rs +++ b/src/multi/mod.rs @@ -3,6 +3,8 @@ //! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. //! All multi-dimensional distributions should implement this trait addidionally to the `Distribution` trait returning a `Vec` of samples. +/// This trait allows to sample from a multi-dimensional distribution without extra allocations. +/// Typically distributions will implement `MultiDistribution<[F]>` where `F` is the type of the samples. pub trait MultiDistribution { fn sample(&self, rng: &mut R, output: &mut S); } \ No newline at end of file From 5c9661781fb0c14e22bd3a1d6bb22ebac07f415d Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 10:25:33 +0100 Subject: [PATCH 03/16] add to lib.rs and more doc --- src/lib.rs | 4 ++++ src/multi/mod.rs | 12 ++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ef1109b..8ece5a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,6 +96,8 @@ pub use rand::distr::{ StandardUniform, Uniform, }; +pub use multi::MultiDistribution; + pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; @@ -133,6 +135,8 @@ pub use num_traits; #[cfg(feature = "alloc")] pub mod weighted; +pub mod multi; + #[cfg(test)] #[macro_use] mod test { diff --git a/src/multi/mod.rs b/src/multi/mod.rs index 4020785..44e5f87 100644 --- a/src/multi/mod.rs +++ b/src/multi/mod.rs @@ -1,10 +1,14 @@ //! Contains Multi-dimensional distributions. -//! +//! //! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. //! All multi-dimensional distributions should implement this trait addidionally to the `Distribution` trait returning a `Vec` of samples. +use rand::Rng; + /// This trait allows to sample from a multi-dimensional distribution without extra allocations. /// Typically distributions will implement `MultiDistribution<[F]>` where `F` is the type of the samples. -pub trait MultiDistribution { - fn sample(&self, rng: &mut R, output: &mut S); -} \ No newline at end of file +pub trait MultiDistribution { + /// Sample from the distribution using the given random number generator and write the result to `output`. + /// The method panics if the buffer is too small to hold the samples. + fn sample(&self, rng: &mut R, output: &mut S); +} From 6a5bd806d8a3e227127dee3d92025e733e85f1d7 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 10:36:28 +0100 Subject: [PATCH 04/16] better doc --- src/multi/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multi/mod.rs b/src/multi/mod.rs index 44e5f87..f979ce6 100644 --- a/src/multi/mod.rs +++ b/src/multi/mod.rs @@ -9,6 +9,6 @@ use rand::Rng; /// Typically distributions will implement `MultiDistribution<[F]>` where `F` is the type of the samples. pub trait MultiDistribution { /// Sample from the distribution using the given random number generator and write the result to `output`. - /// The method panics if the buffer is too small to hold the samples. + /// The method should panic if the buffer is too small to hold the samples. fn sample(&self, rng: &mut R, output: &mut S); } From c9df79cbc8e2ac5dbfeb67a1cd40b1a5672f57da Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 16:46:56 +0100 Subject: [PATCH 05/16] remove pub use of MultiDistribution --- src/dirichlet.rs | 13 +++++++++++++ src/lib.rs | 2 -- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/dirichlet.rs b/src/dirichlet.rs index ac17fa2..54f50e6 100644 --- a/src/dirichlet.rs +++ b/src/dirichlet.rs @@ -330,6 +330,19 @@ where } } +impl crate::multi::MultiDistribution<[F]> for Dirichlet +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R, output: &mut [F]) { + let samples = Distribution::sample(self, rng); + output.copy_from_slice(&samples); + } +} + #[cfg(test)] mod test { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 8ece5a4..a62044f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,8 +96,6 @@ pub use rand::distr::{ StandardUniform, Uniform, }; -pub use multi::MultiDistribution; - pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; From e61da6972711a085badeadb1ce737f3fdb50951e Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 16:48:32 +0100 Subject: [PATCH 06/16] remove test impl of MultiDistribution again --- src/dirichlet.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/dirichlet.rs b/src/dirichlet.rs index 54f50e6..cd4ffe4 100644 --- a/src/dirichlet.rs +++ b/src/dirichlet.rs @@ -330,18 +330,6 @@ where } } -impl crate::multi::MultiDistribution<[F]> for Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R, output: &mut [F]) { - let samples = Distribution::sample(self, rng); - output.copy_from_slice(&samples); - } -} #[cfg(test)] mod test { From 4a78dbc277ed554108d60936402ec892a36bcf65 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Sun, 2 Mar 2025 16:51:08 +0100 Subject: [PATCH 07/16] fmt --- src/dirichlet.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dirichlet.rs b/src/dirichlet.rs index cd4ffe4..ac17fa2 100644 --- a/src/dirichlet.rs +++ b/src/dirichlet.rs @@ -330,7 +330,6 @@ where } } - #[cfg(test)] mod test { use super::*; From cb4c824d5b5bb49e196bcc0545ee5a7d7c9627fb Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 3 Mar 2025 22:25:02 +0100 Subject: [PATCH 08/16] move dirichlet --- src/lib.rs | 3 --- src/{ => multi}/dirichlet.rs | 2 +- src/multi/mod.rs | 4 ++++ tests/value_stability.rs | 6 +++--- 4 files changed, 8 insertions(+), 7 deletions(-) rename src/{ => multi}/dirichlet.rs (99%) diff --git a/src/lib.rs b/src/lib.rs index a62044f..f3f4e3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,8 +100,6 @@ pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; -#[cfg(feature = "alloc")] -pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; pub use self::fisher_f::{Error as FisherFError, FisherF}; pub use self::frechet::{Error as FrechetError, Frechet}; @@ -190,7 +188,6 @@ mod beta; mod binomial; mod cauchy; mod chi_squared; -mod dirichlet; mod exponential; mod fisher_f; mod frechet; diff --git a/src/dirichlet.rs b/src/multi/dirichlet.rs similarity index 99% rename from src/dirichlet.rs rename to src/multi/dirichlet.rs index ac17fa2..60999db 100644 --- a/src/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -403,7 +403,7 @@ mod test { let alpha_sum: f64 = alpha.iter().sum(); let expected_mean = alpha.map(|x| x / alpha_sum); for i in 0..N { - assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); + average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); } } diff --git a/src/multi/mod.rs b/src/multi/mod.rs index f979ce6..531c510 100644 --- a/src/multi/mod.rs +++ b/src/multi/mod.rs @@ -12,3 +12,7 @@ pub trait MultiDistribution { /// The method should panic if the buffer is too small to hold the samples. fn sample(&self, rng: &mut R, output: &mut S); } + +pub use dirichlet::Dirichlet; + +mod dirichlet; \ No newline at end of file diff --git a/tests/value_stability.rs b/tests/value_stability.rs index 2eb263e..9cabd83 100644 --- a/tests/value_stability.rs +++ b/tests/value_stability.rs @@ -502,11 +502,11 @@ fn weibull_stability() { fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + rng.sample(multi::Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); assert_eq!( - rng.sample(Dirichlet::new([8.0; 5]).unwrap()), + rng.sample(multi::Dirichlet::new([8.0; 5]).unwrap()), [ 0.17684200044809556, 0.29915953935953055, @@ -517,7 +517,7 @@ fn dirichlet_stability() { ); // Test stability for the case where all alphas are less than 0.1. assert_eq!( - rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + rng.sample(multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), [ 0.00027580456855692104, 2.296135759821706e-20, From 7572ce3cb63e5a233264a4374c4cab904b50fa69 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 3 Mar 2025 22:31:56 +0100 Subject: [PATCH 09/16] MultiDistribution in Dirichlet, still const generics --- src/multi/dirichlet.rs | 49 ++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/multi/dirichlet.rs b/src/multi/dirichlet.rs index 60999db..3c55404 100644 --- a/src/multi/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -10,7 +10,7 @@ //! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. #![cfg(feature = "alloc")] -use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; +use crate::{multi::MultiDistribution, Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; use core::fmt; use num_traits::{Float, NumCast}; use rand::Rng; @@ -68,26 +68,24 @@ where } } -impl Distribution<[F; N]> for DirichletFromGamma +impl MultiDistribution<[F; N]> for DirichletFromGamma where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample(&self, rng: &mut R, output: &mut [F; N]) { let mut sum = F::zero(); - for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { + for (s, g) in output.iter_mut().zip(self.samplers.iter()) { *s = g.sample(rng); sum = sum + *s; } let invacc = F::one() / sum; - for s in samples.iter_mut() { + for s in output.iter_mut() { *s = *s * invacc; } - samples } } @@ -149,24 +147,22 @@ where } } -impl Distribution<[F; N]> for DirichletFromBeta +impl MultiDistribution<[F; N]> for DirichletFromBeta where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; + fn sample(&self, rng: &mut R, output: &mut [F; N]) { let mut acc = F::one(); - for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { + for (s, beta) in output.iter_mut().zip(self.samplers.iter()) { let beta_sample = beta.sample(rng); *s = acc * beta_sample; acc = acc * (F::one() - beta_sample); } - samples[N - 1] = acc; - samples + output[N - 1] = acc; } } @@ -315,21 +311,36 @@ where } } -impl Distribution<[F; N]> for Dirichlet +impl MultiDistribution<[F; N]> for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> [F; N] { + fn sample(&self, rng: &mut R, output: &mut [F; N]) { match &self.repr { - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng, output), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng, output), } } } +impl Distribution<[F; N]> for Dirichlet +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, + Dirichlet: MultiDistribution<[F; N]>, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut output = [F::zero(); N]; + MultiDistribution::sample(self, rng, &mut output); + output + } +} + #[cfg(test)] mod test { use super::*; @@ -338,7 +349,7 @@ mod test { fn test_dirichlet() { let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); + let samples = Distribution::sample(&d, &mut rng); assert!(samples.into_iter().all(|x: f64| x > 0.0)); } @@ -394,7 +405,7 @@ mod test { let mut rng = crate::test::rng(seed); let mut sums = [0.0; N]; for _ in 0..n { - let samples = d.sample(&mut rng); + let samples = Distribution::sample(&d, &mut rng); for i in 0..N { sums[i] += samples[i]; } From b1e663d30908e0c90910d530c1d0eeb930ecbfdb Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:18:46 +0100 Subject: [PATCH 10/16] new MultiDistribution still const gen Dirichlet --- src/lib.rs | 2 +- src/multi/dirichlet.rs | 44 ++++++++++++++++++------------------------ src/multi/mod.rs | 26 ++++++++++++++++++------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f3f4e3b..522aff3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,7 +130,7 @@ pub use num_traits; #[cfg(feature = "alloc")] pub mod weighted; - +#[cfg(feature = "alloc")] pub mod multi; #[cfg(test)] diff --git a/src/multi/dirichlet.rs b/src/multi/dirichlet.rs index 3c55404..38e5b09 100644 --- a/src/multi/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -68,14 +68,17 @@ where } } -impl MultiDistribution<[F; N]> for DirichletFromGamma +impl MultiDistribution for DirichletFromGamma where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R, output: &mut [F; N]) { + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { let mut sum = F::zero(); for (s, g) in output.iter_mut().zip(self.samplers.iter()) { @@ -147,14 +150,17 @@ where } } -impl MultiDistribution<[F; N]> for DirichletFromBeta +impl MultiDistribution for DirichletFromBeta where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R, output: &mut [F; N]) { + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { let mut acc = F::one(); for (s, beta) in output.iter_mut().zip(self.samplers.iter()) { @@ -311,36 +317,24 @@ where } } -impl MultiDistribution<[F; N]> for Dirichlet +impl MultiDistribution for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R, output: &mut [F; N]) { + fn sample_len(&self) -> usize { + N + } + fn sample_to_buf(&self, rng: &mut R, output: &mut [F]) { match &self.repr { - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng, output), - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng, output), + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_buf(rng, output), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_buf(rng, output), } } } -impl Distribution<[F; N]> for Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, - Dirichlet: MultiDistribution<[F; N]>, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - let mut output = [F::zero(); N]; - MultiDistribution::sample(self, rng, &mut output); - output - } -} - #[cfg(test)] mod test { use super::*; @@ -349,7 +343,7 @@ mod test { fn test_dirichlet() { let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); - let samples = Distribution::sample(&d, &mut rng); + let samples = d.sample(&mut rng); assert!(samples.into_iter().all(|x: f64| x > 0.0)); } @@ -405,7 +399,7 @@ mod test { let mut rng = crate::test::rng(seed); let mut sums = [0.0; N]; for _ in 0..n { - let samples = Distribution::sample(&d, &mut rng); + let samples = d.sample(&mut rng); for i in 0..N { sums[i] += samples[i]; } diff --git a/src/multi/mod.rs b/src/multi/mod.rs index 531c510..b191440 100644 --- a/src/multi/mod.rs +++ b/src/multi/mod.rs @@ -1,18 +1,30 @@ //! Contains Multi-dimensional distributions. //! //! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations. -//! All multi-dimensional distributions should implement this trait addidionally to the `Distribution` trait returning a `Vec` of samples. +//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait. +use alloc::vec::Vec; use rand::Rng; /// This trait allows to sample from a multi-dimensional distribution without extra allocations. -/// Typically distributions will implement `MultiDistribution<[F]>` where `F` is the type of the samples. -pub trait MultiDistribution { - /// Sample from the distribution using the given random number generator and write the result to `output`. - /// The method should panic if the buffer is too small to hold the samples. - fn sample(&self, rng: &mut R, output: &mut S); +/// For convenience it also provides a `sample` method which returns the result as a `Vec`. +pub trait MultiDistribution { + /// returns the length of one sample (dimension of the distribution) + fn sample_len(&self) -> usize; + /// samples from the distribution and writes the result to `buf` + fn sample_to_buf(&self, rng: &mut R, buf: &mut [T]); + /// samples from the distribution and returns the result as a `Vec`, to avoid extra allocations use `sample_to_buf` + fn sample(&self, rng: &mut R) -> Vec + where + T: Default, + { + let mut buf = Vec::new(); + buf.resize_with(self.sample_len(), || T::default()); + self.sample_to_buf(rng, &mut buf); + buf + } } pub use dirichlet::Dirichlet; -mod dirichlet; \ No newline at end of file +mod dirichlet; From b53aeda7fc90967f4c077594064ba4d429c3f714 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:20:07 +0100 Subject: [PATCH 11/16] fmt --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 522aff3..7227c1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,10 +128,10 @@ pub use student_t::StudentT; pub use num_traits; -#[cfg(feature = "alloc")] -pub mod weighted; #[cfg(feature = "alloc")] pub mod multi; +#[cfg(feature = "alloc")] +pub mod weighted; #[cfg(test)] #[macro_use] From 2435d3f189583966a4aa5a4d3d515cf71670d3cf Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:20:44 +0100 Subject: [PATCH 12/16] doc --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 7227c1a..057d995 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,7 @@ //! - [`Beta`] distribution //! - [`Triangular`] distribution //! - Multivariate probability distributions -//! - [`Dirichlet`] distribution +//! - [`multi::Dirichlet`] distribution //! - [`UnitSphere`] distribution //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution From c4acb6d865c5d4cd416f5ca4206647de3455ba6e Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:22:51 +0100 Subject: [PATCH 13/16] dirichlet usage --- tests/value_stability.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/value_stability.rs b/tests/value_stability.rs index 9cabd83..8563e54 100644 --- a/tests/value_stability.rs +++ b/tests/value_stability.rs @@ -500,13 +500,15 @@ fn weibull_stability() { #[cfg(feature = "alloc")] #[test] fn dirichlet_stability() { + use rand_distr::multi::MultiDistribution; + let mut rng = get_rng(223); assert_eq!( - rng.sample(multi::Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + multi::Dirichlet::new([1.0, 2.0, 3.0]).unwrap().sample(&mut rng), [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); assert_eq!( - rng.sample(multi::Dirichlet::new([8.0; 5]).unwrap()), + multi::Dirichlet::new([8.0; 5]).unwrap().sample(&mut rng), [ 0.17684200044809556, 0.29915953935953055, @@ -517,7 +519,7 @@ fn dirichlet_stability() { ); // Test stability for the case where all alphas are less than 0.1. assert_eq!( - rng.sample(multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap().sample(&mut rng), [ 0.00027580456855692104, 2.296135759821706e-20, From 860897cf81d66ff28a704e8d1d6106a81d3ddb5e Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:23:07 +0100 Subject: [PATCH 14/16] fmt --- tests/value_stability.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/value_stability.rs b/tests/value_stability.rs index 8563e54..002b8b4 100644 --- a/tests/value_stability.rs +++ b/tests/value_stability.rs @@ -504,7 +504,9 @@ fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - multi::Dirichlet::new([1.0, 2.0, 3.0]).unwrap().sample(&mut rng), + multi::Dirichlet::new([1.0, 2.0, 3.0]) + .unwrap() + .sample(&mut rng), [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); assert_eq!( @@ -519,7 +521,9 @@ fn dirichlet_stability() { ); // Test stability for the case where all alphas are less than 0.1. assert_eq!( - multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap().sample(&mut rng), + multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05]) + .unwrap() + .sample(&mut rng), [ 0.00027580456855692104, 2.296135759821706e-20, From 3219cd6366a23254db1b6a9ba93ddceac155da56 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:26:56 +0100 Subject: [PATCH 15/16] doctest --- src/multi/dirichlet.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/multi/dirichlet.rs b/src/multi/dirichlet.rs index 38e5b09..f3fa0e6 100644 --- a/src/multi/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -210,7 +210,8 @@ where /// /// ``` /// use rand::prelude::*; -/// use rand_distr::Dirichlet; +/// use rand_distr::multi::Dirichlet; +/// use rand_distr::multi::MultiDistribution; /// /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); /// let samples = dirichlet.sample(&mut rand::rng()); From e74b4b4e2c08ee632c69a152aa7c4a515ec1accd Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 6 Mar 2025 17:42:58 +0100 Subject: [PATCH 16/16] typo --- src/multi/dirichlet.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multi/dirichlet.rs b/src/multi/dirichlet.rs index f3fa0e6..558f64e 100644 --- a/src/multi/dirichlet.rs +++ b/src/multi/dirichlet.rs @@ -262,7 +262,7 @@ impl fmt::Display for Error { "failed to create required Gamma distribution for Dirichlet distribution" } Error::FailedToCreateBeta => { - "failed to create required Beta distribition for Dirichlet distribution" + "failed to create required Beta distribution for Dirichlet distribution" } }) }