diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 526480411df..413c00476ab 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -9,12 +9,182 @@ //! The dirichlet distribution. #![cfg(feature = "alloc")] -use num_traits::Float; -use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; -use rand::Rng; +use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; use core::fmt; -#[cfg(feature = "serde_with")] -use serde_with::serde_as; +use num_traits::{Float, NumCast}; +use rand::Rng; +#[cfg(feature = "serde_with")] use serde_with::serde_as; + +use alloc::{boxed::Box, vec, vec::Vec}; + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde_with", serde_as)] +struct DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: [Gamma; N], +} + +/// Error type returned from `DirchletFromGamma::new`. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromGammaError { + /// Gamma::new(a, 1) failed. + GammmaNewFailed, + + /// gamma_dists.try_into() failed (in theory, this should not happen). + GammaArrayCreationFailed, +} + +impl DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromGamma` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { + let mut gamma_dists = Vec::new(); + for a in alpha { + let dist = + Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; + gamma_dists.push(dist); + } + Ok(DirichletFromGamma { + samplers: gamma_dists + .try_into() + .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, + }) + } +} + +impl Distribution<[F; N]> for DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut sum = F::zero(); + + for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { + *s = g.sample(rng); + sum = sum + *s; + } + let invacc = F::one() / sum; + for s in samples.iter_mut() { + *s = *s * invacc; + } + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +struct DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: Box<[Beta]>, +} + +/// Error type returned from `DirchletFromBeta::new`. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromBetaError { + /// Beta::new(a, b) failed. + BetaNewFailed, +} + +impl DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromBeta` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { + // `alpha_rev_csum` is the reverse of the cumulative sum of the + // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then + // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. + // Note that instances of DirichletFromBeta will always have N >= 2, + // so the subtractions of 1, 2 and 3 from N in the following are safe. + let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; + for k in 0..(N - 2) { + alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; + } + + // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example + // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples + // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`. + // Then pass each tuple to `Beta::new()` to create the `Beta` + // instances. + let mut beta_dists = Vec::new(); + for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { + let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; + beta_dists.push(dist); + } + Ok(DirichletFromBeta { + samplers: beta_dists.into_boxed_slice(), + }) + } +} + +impl Distribution<[F; N]> for DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut acc = F::one(); + + for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { + let beta_sample = beta.sample(rng); + *s = acc * beta_sample; + acc = acc * (F::one() - beta_sample); + } + samples[N - 1] = acc; + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde_with", serde_as)] +enum DirichletRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Dirichlet distribution that generates samples using the Gamma distribution. + FromGamma(DirichletFromGamma), + + /// Dirichlet distribution that generates samples using the Beta distribution. + FromBeta(DirichletFromBeta), +} /// The Dirichlet distribution `Dirichlet(alpha)`. /// @@ -42,9 +212,7 @@ where Exp1: Distribution, Open01: Distribution, { - /// Concentration parameters (alpha) - #[cfg_attr(feature = "serde_with", serde_as(as = "[_; N]"))] - alpha: [F; N], + repr: DirichletRepr, } /// Error type returned from `Dirchlet::new`. @@ -55,6 +223,15 @@ pub enum Error { AlphaTooShort, /// `alpha <= 0.0` or `nan`. AlphaTooSmall, + /// `alpha` is subnormal. + /// Variate generation methods are not reliable with subnormal inputs. + AlphaSubnormal, + /// `alpha` is infinite. + AlphaInfinite, + /// Failed to create required Gamma distribution(s). + FailedToCreateGamma, + /// Failed to create required Beta distribition(s). + FailedToCreateBeta, /// `size < 2`. SizeTooSmall, } @@ -66,6 +243,14 @@ impl fmt::Display for Error { "less than 2 dimensions in Dirichlet distribution" } Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution", + Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution", + Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution", + Error::FailedToCreateGamma => { + "failed to create required Gamma distribution for Dirichlet distribution" + } + Error::FailedToCreateBeta => { + "failed to create required Beta distribition for Dirichlet distribution" + } }) } } @@ -83,7 +268,8 @@ where { /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// - /// Requires `alpha.len() >= 2`. + /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, + /// finite and not subnormal. #[inline] pub fn new(alpha: [F; N]) -> Result, Error> { if N < 2 { @@ -91,11 +277,32 @@ where } for &ai in alpha.iter() { if !(ai > F::zero()) { + // This also catches nan. return Err(Error::AlphaTooSmall); } + if ai.is_infinite() { + return Err(Error::AlphaInfinite); + } + if !ai.is_normal() { + return Err(Error::AlphaSubnormal); + } } - Ok(Dirichlet { alpha }) + if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) { + // Use the Beta method when all the alphas are less than 0.1 This + // threshold provides a reasonable compromise between using the faster + // Gamma method for as wide a range as possible while ensuring that + // the probability of generating nans is negligibly small. + let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?; + Ok(Dirichlet { + repr: DirichletRepr::FromBeta(dist), + }) + } else { + let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?; + Ok(Dirichlet { + repr: DirichletRepr::FromGamma(dist), + }) + } } } @@ -107,26 +314,17 @@ where Open01: Distribution, { fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; - let mut sum = F::zero(); - - for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { - let g = Gamma::new(a, F::one()).unwrap(); - *s = g.sample(rng); - sum = sum + (*s); + match &self.repr { + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), } - let invacc = F::one() / sum; - for s in samples.iter_mut() { - *s = (*s)*invacc; - } - samples } } #[cfg(test)] mod test { - use alloc::vec::Vec; use super::*; + use alloc::vec::Vec; #[test] fn test_dirichlet() { @@ -150,12 +348,97 @@ mod test { #[test] #[should_panic] - fn test_dirichlet_invalid_alpha() { + fn test_dirichlet_alpha_zero() { Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); } + #[test] + #[should_panic] + fn test_dirichlet_alpha_negative() { + Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_nan() { + Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_subnormal() { + Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_inf() { + Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); + } + #[test] fn dirichlet_distributions_can_be_compared() { assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); } + + /// Check that the means of the components of n samples from + /// the Dirichlet distribution agree with the expected means + /// with a relative tolerance of rtol. + /// + /// This is a crude statistical test, but it will catch egregious + /// mistakes. It will also also fail if any samples contain nan. + fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { + let d = Dirichlet::new(alpha).unwrap(); + let mut rng = crate::test::rng(seed); + let mut sums = [0.0; N]; + for _ in 0..n { + let samples = d.sample(&mut rng); + for i in 0..N { + sums[i] += samples[i]; + } + } + let sample_mean = sums.map(|x| x / n as f64); + let alpha_sum: f64 = alpha.iter().sum(); + let expected_mean = alpha.map(|x| x / alpha_sum); + for i in 0..N { + assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); + } + } + + #[test] + fn test_dirichlet_means() { + // Check the means of 20000 samples for several different alphas. + let n = 20000; + let rtol = 2e-2; + let seed = 1317624576693539401; + check_dirichlet_means([0.5, 0.25], n, rtol, seed); + check_dirichlet_means([123.0, 75.0], n, rtol, seed); + check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed); + check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_very_small_alpha() { + // With values of alpha that are all 0.001, check that the means of the + // components of 10000 samples are within 1% of the expected means. + // With the sampling method based on gamma variates, this test would + // fail, with about 10% of the samples containing nan. + let alpha = [0.001; 3]; + let n = 10000; + let rtol = 1e-2; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_small_alpha() { + // With values of alpha that are all less than 0.1, check that the + // means of the components of 150000 samples are within 0.1% of the + // expected means. + let alpha = [0.05, 0.025, 0.075, 0.05]; + let n = 150000; + let rtol = 1e-3; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); + } } diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 4b9490a6581..88fe7d9ecab 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -358,6 +358,16 @@ fn dirichlet_stability() { 0.1425623503573967, 0.19815030417417595 ]); + // Test stability for the case where all alphas are less than 0.1. + assert_eq!( + rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + [ + 0.00027580456855692104, + 2.296135759821706e-20, + 3.004118281150937e-9, + 0.9997241924273248 + ] + ); } #[test]