9
9
10
10
//! The dirichlet distribution.
11
11
#![ cfg( feature = "alloc" ) ]
12
- use num_traits:: Float ;
13
- use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
12
+ use num_traits:: { Float , NumCast } ;
13
+ use crate :: { Beta , Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
14
14
use rand:: Rng ;
15
15
use core:: fmt;
16
16
use alloc:: { boxed:: Box , vec, vec:: Vec } ;
@@ -123,16 +123,56 @@ where
123
123
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
124
124
let n = self . alpha . len ( ) ;
125
125
let mut samples = vec ! [ F :: zero( ) ; n] ;
126
- let mut sum = F :: zero ( ) ;
127
126
128
- for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
129
- let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
130
- * s = g. sample ( rng) ;
131
- sum = sum + ( * s) ;
132
- }
133
- let invacc = F :: one ( ) / sum;
134
- for s in samples. iter_mut ( ) {
135
- * s = ( * s) * invacc;
127
+ if self . alpha . iter ( ) . all ( |x| * x <= NumCast :: from ( 0.1 ) . unwrap ( ) ) {
128
+ // All the values in alpha are less than 0.1.
129
+ //
130
+ // When all the alpha parameters are sufficiently small, there
131
+ // is a nontrivial probability that the samples from the gamma
132
+ // distributions used in the other method will all be 0, which
133
+ // results in the dirichlet samples being nan. So instead of
134
+ // use that method, use the "stick breaking" method based on the
135
+ // marginal beta distributions.
136
+ //
137
+ // Form the right-to-left cumulative sum of alpha, exluding the
138
+ // first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
139
+ // after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
140
+ // will hold [a1+a2+a3, a2+a3, a3].
141
+ let mut alpha_sum_rl: Vec < F > = self
142
+ . alpha
143
+ . iter ( )
144
+ . skip ( 1 )
145
+ . rev ( )
146
+ // scan does the cumulative sum
147
+ . scan ( F :: zero ( ) , |sum, x| {
148
+ * sum = * sum + * x;
149
+ Some ( * sum)
150
+ } )
151
+ . collect ( ) ;
152
+ alpha_sum_rl. reverse ( ) ;
153
+ let mut acc = F :: one ( ) ;
154
+ for ( ( s, & a) , & b) in samples
155
+ . iter_mut ( )
156
+ . zip ( self . alpha . iter ( ) )
157
+ . zip ( alpha_sum_rl. iter ( ) )
158
+ {
159
+ let beta = Beta :: new ( a, b) . unwrap ( ) ;
160
+ let beta_sample = beta. sample ( rng) ;
161
+ * s = acc * beta_sample;
162
+ acc = acc * ( F :: one ( ) - beta_sample) ;
163
+ }
164
+ samples[ n - 1 ] = acc;
165
+ } else {
166
+ let mut sum = F :: zero ( ) ;
167
+ for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
168
+ let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
169
+ * s = g. sample ( rng) ;
170
+ sum = sum + ( * s) ;
171
+ }
172
+ let invacc = F :: one ( ) / sum;
173
+ for s in samples. iter_mut ( ) {
174
+ * s = ( * s) * invacc;
175
+ }
136
176
}
137
177
samples
138
178
}
@@ -142,6 +182,33 @@ where
142
182
mod test {
143
183
use super :: * ;
144
184
185
+ //
186
+ // Check that the means of the components of n samples from
187
+ // the Dirichlet distribution agree with the expected means
188
+ // with a relative tolerance of rtol.
189
+ //
190
+ // This is a crude statistical test, but it will catch egregious
191
+ // mistakes. It will also also fail if any samples contain nan.
192
+ //
193
+ fn check_dirichlet_means ( alpha : & Vec < f64 > , n : i32 , rtol : f64 , seed : u64 ) {
194
+ let d = Dirichlet :: new ( & alpha) . unwrap ( ) ;
195
+ let alpha_len = d. alpha . len ( ) ;
196
+ let mut rng = crate :: test:: rng ( seed) ;
197
+ let mut sums = vec ! [ 0.0 ; alpha_len] ;
198
+ for _ in 0 ..n {
199
+ let samples = d. sample ( & mut rng) ;
200
+ for i in 0 ..alpha_len {
201
+ sums[ i] += samples[ i] ;
202
+ }
203
+ }
204
+ let sample_mean: Vec < f64 > = sums. iter ( ) . map ( |x| x / n as f64 ) . collect ( ) ;
205
+ let alpha_sum: f64 = d. alpha . iter ( ) . sum ( ) ;
206
+ let expected_mean: Vec < f64 > = d. alpha . iter ( ) . map ( |x| x / alpha_sum) . collect ( ) ;
207
+ for i in 0 ..alpha_len {
208
+ assert_almost_eq ! ( sample_mean[ i] , expected_mean[ i] , rtol) ;
209
+ }
210
+ }
211
+
145
212
#[ test]
146
213
fn test_dirichlet ( ) {
147
214
let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
@@ -172,6 +239,48 @@ mod test {
172
239
. collect ( ) ;
173
240
}
174
241
242
+ #[ test]
243
+ fn test_dirichlet_means ( ) {
244
+ // Check the means of 20000 samples for several different alphas.
245
+ let alpha_set = vec ! [
246
+ vec![ 0.5 , 0.25 ] ,
247
+ vec![ 123.0 , 75.0 ] ,
248
+ vec![ 2.0 , 2.5 , 5.0 , 7.0 ] ,
249
+ vec![ 0.1 , 8.0 , 1.0 , 2.0 , 2.0 , 0.85 , 0.05 , 12.5 ] ,
250
+ ] ;
251
+ let n = 20000 ;
252
+ let rtol = 2e-2 ;
253
+ let seed = 1317624576693539401 ;
254
+ for alpha in alpha_set {
255
+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
256
+ }
257
+ }
258
+
259
+ #[ test]
260
+ fn test_dirichlet_means_very_small_alpha ( ) {
261
+ // With values of alpha that are all 0.001, check that the means of the
262
+ // components of 10000 samples are within 1% of the expected means.
263
+ // With the sampling method based on gamma variates, this test would
264
+ // fail, with about 10% of the samples containing nan.
265
+ let alpha = vec ! [ 0.001 , 0.001 , 0.001 ] ;
266
+ let n = 10000 ;
267
+ let rtol = 1e-2 ;
268
+ let seed = 1317624576693539401 ;
269
+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
270
+ }
271
+
272
+ #[ test]
273
+ fn test_dirichlet_means_small_alpha ( ) {
274
+ // With values of alpha that are all less than 0.1, check that the
275
+ // means of the components of 150000 samples are within 0.1% of the
276
+ // expected means.
277
+ let alpha = vec ! [ 0.05 , 0.025 , 0.075 , 0.05 ] ;
278
+ let n = 150000 ;
279
+ let rtol = 1e-3 ;
280
+ let seed = 1317624576693539401 ;
281
+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
282
+ }
283
+
175
284
#[ test]
176
285
#[ should_panic]
177
286
fn test_dirichlet_invalid_length ( ) {
0 commit comments