Skip to content

Commit f7873d0

Browse files
rand_distr: Fix dirichlet sample method for small alpha.
Generating Dirichlet samples uing the method based on samples from the gamma distribution can result in samples being nan if all the values in alpha are sufficiently small. The fix is to instead use the method based on the marginal distributions being the beta distribution (i.e. the "stick breaking" method) when all values in alpha are small.
1 parent 19404d6 commit f7873d0

File tree

1 file changed

+120
-11
lines changed

1 file changed

+120
-11
lines changed

rand_distr/src/dirichlet.rs

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
//! The dirichlet distribution.
1111
#![cfg(feature = "alloc")]
12-
use num_traits::Float;
13-
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
12+
use num_traits::{Float, NumCast};
13+
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
1414
use rand::Rng;
1515
use core::fmt;
1616
use alloc::{boxed::Box, vec, vec::Vec};
@@ -123,16 +123,56 @@ where
123123
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
124124
let n = self.alpha.len();
125125
let mut samples = vec![F::zero(); n];
126-
let mut sum = F::zero();
127126

128-
for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
129-
let g = Gamma::new(a, F::one()).unwrap();
130-
*s = g.sample(rng);
131-
sum = sum + (*s);
132-
}
133-
let invacc = F::one() / sum;
134-
for s in samples.iter_mut() {
135-
*s = (*s)*invacc;
127+
if self.alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) {
128+
// All the values in alpha are less than 0.1.
129+
//
130+
// When all the alpha parameters are sufficiently small, there
131+
// is a nontrivial probability that the samples from the gamma
132+
// distributions used in the other method will all be 0, which
133+
// results in the dirichlet samples being nan. So instead of
134+
// use that method, use the "stick breaking" method based on the
135+
// marginal beta distributions.
136+
//
137+
// Form the right-to-left cumulative sum of alpha, exluding the
138+
// first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
139+
// after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
140+
// will hold [a1+a2+a3, a2+a3, a3].
141+
let mut alpha_sum_rl: Vec<F> = self
142+
.alpha
143+
.iter()
144+
.skip(1)
145+
.rev()
146+
// scan does the cumulative sum
147+
.scan(F::zero(), |sum, x| {
148+
*sum = *sum + *x;
149+
Some(*sum)
150+
})
151+
.collect();
152+
alpha_sum_rl.reverse();
153+
let mut acc = F::one();
154+
for ((s, &a), &b) in samples
155+
.iter_mut()
156+
.zip(self.alpha.iter())
157+
.zip(alpha_sum_rl.iter())
158+
{
159+
let beta = Beta::new(a, b).unwrap();
160+
let beta_sample = beta.sample(rng);
161+
*s = acc * beta_sample;
162+
acc = acc * (F::one() - beta_sample);
163+
}
164+
samples[n - 1] = acc;
165+
} else {
166+
let mut sum = F::zero();
167+
for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
168+
let g = Gamma::new(a, F::one()).unwrap();
169+
*s = g.sample(rng);
170+
sum = sum + (*s);
171+
}
172+
let invacc = F::one() / sum;
173+
for s in samples.iter_mut() {
174+
*s = (*s) * invacc;
175+
}
136176
}
137177
samples
138178
}
@@ -142,6 +182,33 @@ where
142182
mod test {
143183
use super::*;
144184

185+
//
186+
// Check that the means of the components of n samples from
187+
// the Dirichlet distribution agree with the expected means
188+
// with a relative tolerance of rtol.
189+
//
190+
// This is a crude statistical test, but it will catch egregious
191+
// mistakes. It will also also fail if any samples contain nan.
192+
//
193+
fn check_dirichlet_means(alpha: &Vec<f64>, n: i32, rtol: f64, seed: u64) {
194+
let d = Dirichlet::new(&alpha).unwrap();
195+
let alpha_len = d.alpha.len();
196+
let mut rng = crate::test::rng(seed);
197+
let mut sums = vec![0.0; alpha_len];
198+
for _ in 0..n {
199+
let samples = d.sample(&mut rng);
200+
for i in 0..alpha_len {
201+
sums[i] += samples[i];
202+
}
203+
}
204+
let sample_mean: Vec<f64> = sums.iter().map(|x| x / n as f64).collect();
205+
let alpha_sum: f64 = d.alpha.iter().sum();
206+
let expected_mean: Vec<f64> = d.alpha.iter().map(|x| x / alpha_sum).collect();
207+
for i in 0..alpha_len {
208+
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
209+
}
210+
}
211+
145212
#[test]
146213
fn test_dirichlet() {
147214
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
@@ -172,6 +239,48 @@ mod test {
172239
.collect();
173240
}
174241

242+
#[test]
243+
fn test_dirichlet_means() {
244+
// Check the means of 20000 samples for several different alphas.
245+
let alpha_set = vec![
246+
vec![0.5, 0.25],
247+
vec![123.0, 75.0],
248+
vec![2.0, 2.5, 5.0, 7.0],
249+
vec![0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5],
250+
];
251+
let n = 20000;
252+
let rtol = 2e-2;
253+
let seed = 1317624576693539401;
254+
for alpha in alpha_set {
255+
check_dirichlet_means(&alpha, n, rtol, seed);
256+
}
257+
}
258+
259+
#[test]
260+
fn test_dirichlet_means_very_small_alpha() {
261+
// With values of alpha that are all 0.001, check that the means of the
262+
// components of 10000 samples are within 1% of the expected means.
263+
// With the sampling method based on gamma variates, this test would
264+
// fail, with about 10% of the samples containing nan.
265+
let alpha = vec![0.001, 0.001, 0.001];
266+
let n = 10000;
267+
let rtol = 1e-2;
268+
let seed = 1317624576693539401;
269+
check_dirichlet_means(&alpha, n, rtol, seed);
270+
}
271+
272+
#[test]
273+
fn test_dirichlet_means_small_alpha() {
274+
// With values of alpha that are all less than 0.1, check that the
275+
// means of the components of 150000 samples are within 0.1% of the
276+
// expected means.
277+
let alpha = vec![0.05, 0.025, 0.075, 0.05];
278+
let n = 150000;
279+
let rtol = 1e-3;
280+
let seed = 1317624576693539401;
281+
check_dirichlet_means(&alpha, n, rtol, seed);
282+
}
283+
175284
#[test]
176285
#[should_panic]
177286
fn test_dirichlet_invalid_length() {

0 commit comments

Comments
 (0)