@@ -13,7 +13,6 @@ use num_traits::Float;
13
13
use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
14
14
use rand:: Rng ;
15
15
use core:: fmt;
16
- use alloc:: { boxed:: Box , vec, vec:: Vec } ;
17
16
18
17
/// The Dirichlet distribution `Dirichlet(alpha)`.
19
18
///
@@ -27,22 +26,22 @@ use alloc::{boxed::Box, vec, vec::Vec};
27
26
/// use rand::prelude::*;
28
27
/// use rand_distr::Dirichlet;
29
28
///
30
- /// let dirichlet = Dirichlet::new(& [1.0, 2.0, 3.0]).unwrap();
29
+ /// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
31
30
/// let samples = dirichlet.sample(&mut rand::thread_rng());
32
31
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
33
32
/// ```
34
33
#[ cfg_attr( doc_cfg, doc( cfg( feature = "alloc" ) ) ) ]
35
34
#[ derive( Clone , Debug , PartialEq ) ]
36
35
#[ cfg_attr( feature = "serde1" , derive( serde:: Serialize , serde:: Deserialize ) ) ]
37
- pub struct Dirichlet < F >
36
+ pub struct Dirichlet < F , const N : usize >
38
37
where
39
38
F : Float ,
40
39
StandardNormal : Distribution < F > ,
41
40
Exp1 : Distribution < F > ,
42
41
Open01 : Distribution < F > ,
43
42
{
44
43
/// Concentration parameters (alpha)
45
- alpha : Box < [ F ] > ,
44
+ alpha : [ F ; N ] ,
46
45
}
47
46
48
47
/// Error type returned from `Dirchlet::new`.
@@ -72,7 +71,7 @@ impl fmt::Display for Error {
72
71
#[ cfg_attr( doc_cfg, doc( cfg( feature = "std" ) ) ) ]
73
72
impl std:: error:: Error for Error { }
74
73
75
- impl < F > Dirichlet < F >
74
+ impl < F , const N : usize > Dirichlet < F , N >
76
75
where
77
76
F : Float ,
78
77
StandardNormal : Distribution < F > ,
83
82
///
84
83
/// Requires `alpha.len() >= 2`.
85
84
#[ inline]
86
- pub fn new ( alpha : & [ F ] ) -> Result < Dirichlet < F > , Error > {
87
- if alpha . len ( ) < 2 {
85
+ pub fn new ( alpha : [ F ; N ] ) -> Result < Dirichlet < F , N > , Error > {
86
+ if N < 2 {
88
87
return Err ( Error :: AlphaTooShort ) ;
89
88
}
90
89
for & ai in alpha. iter ( ) {
@@ -93,36 +92,19 @@ where
93
92
}
94
93
}
95
94
96
- Ok ( Dirichlet { alpha : alpha. to_vec ( ) . into_boxed_slice ( ) } )
97
- }
98
-
99
- /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
100
- ///
101
- /// Requires `size >= 2`.
102
- #[ inline]
103
- pub fn new_with_size ( alpha : F , size : usize ) -> Result < Dirichlet < F > , Error > {
104
- if !( alpha > F :: zero ( ) ) {
105
- return Err ( Error :: AlphaTooSmall ) ;
106
- }
107
- if size < 2 {
108
- return Err ( Error :: SizeTooSmall ) ;
109
- }
110
- Ok ( Dirichlet {
111
- alpha : vec ! [ alpha; size] . into_boxed_slice ( ) ,
112
- } )
95
+ Ok ( Dirichlet { alpha } )
113
96
}
114
97
}
115
98
116
- impl < F > Distribution < Vec < F > > for Dirichlet < F >
99
+ impl < F , const N : usize > Distribution < [ F ; N ] > for Dirichlet < F , N >
117
100
where
118
101
F : Float ,
119
102
StandardNormal : Distribution < F > ,
120
103
Exp1 : Distribution < F > ,
121
104
Open01 : Distribution < F > ,
122
105
{
123
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
124
- let n = self . alpha . len ( ) ;
125
- let mut samples = vec ! [ F :: zero( ) ; n] ;
106
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> [ F ; N ] {
107
+ let mut samples = [ F :: zero ( ) ; N ] ;
126
108
let mut sum = F :: zero ( ) ;
127
109
128
110
for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
@@ -144,23 +126,7 @@ mod test {
144
126
145
127
#[ test]
146
128
fn test_dirichlet ( ) {
147
- let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
148
- let mut rng = crate :: test:: rng ( 221 ) ;
149
- let samples = d. sample ( & mut rng) ;
150
- let _: Vec < f64 > = samples
151
- . into_iter ( )
152
- . map ( |x| {
153
- assert ! ( x > 0.0 ) ;
154
- x
155
- } )
156
- . collect ( ) ;
157
- }
158
-
159
- #[ test]
160
- fn test_dirichlet_with_param ( ) {
161
- let alpha = 0.5f64 ;
162
- let size = 2 ;
163
- let d = Dirichlet :: new_with_size ( alpha, size) . unwrap ( ) ;
129
+ let d = Dirichlet :: new ( [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
164
130
let mut rng = crate :: test:: rng ( 221 ) ;
165
131
let samples = d. sample ( & mut rng) ;
166
132
let _: Vec < f64 > = samples
@@ -175,17 +141,17 @@ mod test {
175
141
#[ test]
176
142
#[ should_panic]
177
143
fn test_dirichlet_invalid_length ( ) {
178
- Dirichlet :: new_with_size ( 0.5f64 , 1 ) . unwrap ( ) ;
144
+ Dirichlet :: new ( [ 0.5 ] ) . unwrap ( ) ;
179
145
}
180
146
181
147
#[ test]
182
148
#[ should_panic]
183
149
fn test_dirichlet_invalid_alpha ( ) {
184
- Dirichlet :: new_with_size ( 0.0f64 , 2 ) . unwrap ( ) ;
150
+ Dirichlet :: new ( [ 0.1 , 0.0 , 0.3 ] ) . unwrap ( ) ;
185
151
}
186
152
187
153
#[ test]
188
154
fn dirichlet_distributions_can_be_compared ( ) {
189
- assert_eq ! ( Dirichlet :: new( & [ 1.0 , 2.0 ] ) , Dirichlet :: new( & [ 1.0 , 2.0 ] ) ) ;
155
+ assert_eq ! ( Dirichlet :: new( [ 1.0 , 2.0 ] ) , Dirichlet :: new( [ 1.0 , 2.0 ] ) ) ;
190
156
}
191
157
}
0 commit comments