diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 99bfc66d7fc..bd85462c2de 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +- New `Beta` sampling algorithm for improved performance and accuracy (#1000) + ## [0.3.0] - 2020-08-25 - Move alias method for `WeightedIndex` from `rand` (#945) - Rename `WeightedIndex` to `WeightedAliasIndex` (#1008) diff --git a/rand_distr/benches/distributions.rs b/rand_distr/benches/distributions.rs index b0cd96dc6ab..6776901d224 100644 --- a/rand_distr/benches/distributions.rs +++ b/rand_distr/benches/distributions.rs @@ -20,7 +20,7 @@ use std::mem::size_of; use test::Bencher; use rand::prelude::*; -use rand_distr::{weighted::WeightedIndex, *}; +use rand_distr::*; // At this time, distributions are optimised for 64-bit platforms. use rand_pcg::Pcg64Mcg; @@ -112,11 +112,15 @@ distr_float!(distr_normal, f64, Normal::new(-1.23, 4.56).unwrap()); distr_float!(distr_log_normal, f64, LogNormal::new(-1.23, 4.56).unwrap()); distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0).unwrap()); distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0).unwrap()); +distr_float!(distr_beta_small_param, f64, Beta::new(0.1, 0.1).unwrap()); +distr_float!(distr_beta_large_param_similar, f64, Beta::new(101., 95.).unwrap()); +distr_float!(distr_beta_large_param_different, f64, Beta::new(10., 1000.).unwrap()); +distr_float!(distr_beta_mixed_param, f64, Beta::new(0.5, 100.).unwrap()); distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9).unwrap()); distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap()); distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap()); distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30).unwrap()); -distr_int!(distr_poisson, u64, Poisson::new(4.0).unwrap()); +distr_float!(distr_poisson, f64, Poisson::new(4.0).unwrap()); distr!(distr_bernoulli, bool, Bernoulli::new(0.18).unwrap()); distr_arr!(distr_circle, [f64; 2], UnitCircle); distr_arr!(distr_sphere, [f64; 3], UnitSphere); @@ -127,10 +131,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); -distr_int!(distr_weighted_alias_method_i8, usize, weighted::alias_method::WeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_u32, usize, weighted::alias_method::WeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_f64, usize, weighted::alias_method::WeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); -distr_int!(distr_weighted_alias_method_large_set, usize, weighted::alias_method::WeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); +distr_int!(distr_weighted_alias_method_i8, usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_u32, usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_f64, usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_alias_method_large_set, usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); #[bench] diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 34cb45dfb36..5e98dbdfcfc 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -495,6 +495,38 @@ where } } +/// The algorithm used for sampling the Beta distribution. +/// +/// Reference: +/// +/// R. C. H. Cheng (1978). +/// Generating beta variates with nonintegral shape parameters. +/// Communications of the ACM 21, 317-322. +/// https://doi.org/10.1145/359460.359482 +#[derive(Clone, Copy, Debug)] +enum BetaAlgorithm { + BB(BB), + BC(BC), +} + +/// Algorithm BB for `min(alpha, beta) > 1`. +#[derive(Clone, Copy, Debug)] +struct BB { + alpha: N, + beta: N, + gamma: N, +} + +/// Algorithm BC for `min(alpha, beta) <= 1`. +#[derive(Clone, Copy, Debug)] +struct BC { + alpha: N, + beta: N, + delta: N, + kappa1: N, + kappa2: N, +} + /// The Beta distribution with shape parameters `alpha` and `beta`. /// /// # Example @@ -510,12 +542,10 @@ where pub struct Beta where F: Float, - StandardNormal: Distribution, - Exp1: Distribution, Open01: Distribution, { - gamma_a: Gamma, - gamma_b: Gamma, + a: F, b: F, switched_params: bool, + algorithm: BetaAlgorithm, } /// Error type returned from `Beta::new`. @@ -542,31 +572,142 @@ impl std::error::Error for BetaError {} impl Beta where F: Float, - StandardNormal: Distribution, - Exp1: Distribution, Open01: Distribution, { /// Construct an object representing the `Beta(alpha, beta)` /// distribution. pub fn new(alpha: F, beta: F) -> Result, BetaError> { - Ok(Beta { - gamma_a: Gamma::new(alpha, F::one()).map_err(|_| BetaError::AlphaTooSmall)?, - gamma_b: Gamma::new(beta, F::one()).map_err(|_| BetaError::BetaTooSmall)?, - }) + if !(alpha > F::zero()) { + return Err(BetaError::AlphaTooSmall); + } + if !(beta > F::zero()) { + return Err(BetaError::BetaTooSmall); + } + // From now on, we use the notation from the reference, + // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. + let (a0, b0) = (alpha, beta); + let (a, b, switched_params) = if a0 < b0 { + (a0, b0, false) + } else { + (b0, a0, true) + }; + if a > F::one() { + // Algorithm BB + let alpha = a + b; + let beta = ((alpha - F::from(2.).unwrap()) + / (F::from(2.).unwrap()*a*b - alpha)).sqrt(); + let gamma = a + F::one() / beta; + + Ok(Beta { + a, b, switched_params, + algorithm: BetaAlgorithm::BB(BB { + alpha, beta, gamma, + }) + }) + } else { + // Algorithm BC + // + // Here `a` is the maximum instead of the minimum. + let (a, b, switched_params) = (b, a, !switched_params); + let alpha = a + b; + let beta = F::one() / b; + let delta = F::one() + a - b; + let kappa1 = delta + * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) + / (a*beta - F::from(14. / 18.).unwrap()); + let kappa2 = F::from(0.25).unwrap() + + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; + + Ok(Beta { + a, b, switched_params, + algorithm: BetaAlgorithm::BC(BC { + alpha, beta, delta, kappa1, kappa2, + }) + }) + } } } impl Distribution for Beta where F: Float, - StandardNormal: Distribution, - Exp1: Distribution, Open01: Distribution, { fn sample(&self, rng: &mut R) -> F { - let x = self.gamma_a.sample(rng); - let y = self.gamma_b.sample(rng); - x / (x + y) + let mut w; + match self.algorithm { + BetaAlgorithm::BB(algo) => { + loop { + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + let z = u1*u1 * u2; + let r = algo.gamma * v - F::from(4.).unwrap().ln(); + let s = self.a + r - w; + // 2. + if s + F::one() + F::from(5.).unwrap().ln() + >= F::from(5.).unwrap() * z { + break; + } + // 3. + let t = z.ln(); + if s >= t { + break; + } + // 4. + if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { + break; + } + } + }, + BetaAlgorithm::BC(algo) => { + loop { + let z; + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + if u1 < F::from(0.5).unwrap() { + // 2. + let y = u1 * u2; + z = u1 * y; + if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { + continue; + } + } else { + // 3. + z = u1 * u1 * u2; + if z <= F::from(0.25).unwrap() { + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + break; + } + // 4. + if z >= algo.kappa2 { + continue; + } + } + // 5. + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) + - F::from(4.).unwrap().ln() < z.ln()) { + break; + }; + } + }, + }; + // 5. for BB, 6. for BC + if !self.switched_params { + if w == F::infinity() { + // Assuming `b` is finite, for large `w`: + return F::one(); + } + w / (self.b + w) + } else { + self.b / (self.b + w) + } } } @@ -636,4 +777,13 @@ mod test { fn test_beta_invalid_dof() { Beta::new(0., 0.).unwrap(); } + + #[test] + fn test_beta_small_param() { + let beta = Beta::::new(1e-3, 1e-3).unwrap(); + let mut rng = crate::test::rng(206); + for i in 0..1000 { + assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); + } + } } diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 192ba748b7f..7d7316096ca 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -121,11 +121,11 @@ fn normal_inverse_gaussian_stability() { fn pert_stability() { // mean = 4, var = 12/7 test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ - 4.631484136029422f64, - 3.307201472321789f64, - 3.29995019556348f64, - 3.66835483991721f64, - 3.514246139933899f64, + 4.908681667460367, + 4.014196196158352, + 2.6489397149197234, + 3.4569780580044727, + 4.242864311947118, ]); } @@ -200,15 +200,21 @@ fn gamma_stability() { -2.377641221169782, ]); - // Beta has same special cases as Gamma on each param + // Beta has two special cases: + // + // 1. min(alpha, beta) <= 1 + // 2. min(alpha, beta) > 1 test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ - 0.6444564f32, 0.357635, 0.4110078, 0.7347192, - ]); - test_samples(223, Beta::new(0.7, 1.2).unwrap(), &[ - 0.6433129944095513f64, - 0.5373371199711573, - 0.10313293199269491, - 0.002472280249144378, + 0.8300703726659456, + 0.8134131062097899, + 0.47912589330631555, + 0.25323238071138526, + ]); + test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[ + 0.49563509121756827, + 0.9551305482256759, + 0.5151181353461637, + 0.7551732971235077, ]); } diff --git a/utils/ci/script.sh b/utils/ci/script.sh index efefc2adc5e..caef0767ca9 100644 --- a/utils/ci/script.sh +++ b/utils/ci/script.sh @@ -29,6 +29,7 @@ main() { if [ "0$NIGHTLY" -ge 1 ]; then $CARGO test $TARGET --all-features $CARGO test $TARGET --benches --features=nightly + $CARGO test $TARGET --manifest-path rand_distr/Cargo.toml --benches else # all stable features: $CARGO test $TARGET --features=serde1,log,small_rng