@@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;
14
14
15
15
use rand:: Rng ;
16
16
use crate :: normal:: StandardNormal ;
17
- use crate :: { Distribution , Exp , Open01 } ;
17
+ use crate :: { Distribution , Exp1 , Exp , Open01 } ;
18
+ use num_traits:: Float ;
18
19
19
20
/// The Gamma distribution `Gamma(shape, scale)` distribution.
20
21
///
@@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
47
48
/// (September 2000), 363-372.
48
49
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
49
50
#[ derive( Clone , Copy , Debug ) ]
50
- pub struct Gamma {
51
- repr : GammaRepr ,
51
+ pub struct Gamma < N > {
52
+ repr : GammaRepr < N > ,
52
53
}
53
54
54
55
/// Error type returned from `Gamma::new`.
@@ -63,10 +64,10 @@ pub enum Error {
63
64
}
64
65
65
66
#[ derive( Clone , Copy , Debug ) ]
66
- enum GammaRepr {
67
- Large ( GammaLargeShape ) ,
68
- One ( Exp ) ,
69
- Small ( GammaSmallShape )
67
+ enum GammaRepr < N > {
68
+ Large ( GammaLargeShape < N > ) ,
69
+ One ( Exp < N > ) ,
70
+ Small ( GammaSmallShape < N > )
70
71
}
71
72
72
73
// These two helpers could be made public, but saving the
@@ -84,37 +85,39 @@ enum GammaRepr {
84
85
/// See `Gamma` for sampling from a Gamma distribution with general
85
86
/// shape parameters.
86
87
#[ derive( Clone , Copy , Debug ) ]
87
- struct GammaSmallShape {
88
- inv_shape : f64 ,
89
- large_shape : GammaLargeShape
88
+ struct GammaSmallShape < N > {
89
+ inv_shape : N ,
90
+ large_shape : GammaLargeShape < N >
90
91
}
91
92
92
93
/// Gamma distribution where the shape parameter is larger than 1.
93
94
///
94
95
/// See `Gamma` for sampling from a Gamma distribution with general
95
96
/// shape parameters.
96
97
#[ derive( Clone , Copy , Debug ) ]
97
- struct GammaLargeShape {
98
- scale : f64 ,
99
- c : f64 ,
100
- d : f64
98
+ struct GammaLargeShape < N > {
99
+ scale : N ,
100
+ c : N ,
101
+ d : N
101
102
}
102
103
103
- impl Gamma {
104
+ impl < N : Float > Gamma < N >
105
+ where StandardNormal : Distribution < N > , Exp1 : Distribution < N > , Open01 : Distribution < N >
106
+ {
104
107
/// Construct an object representing the `Gamma(shape, scale)`
105
108
/// distribution.
106
109
#[ inline]
107
- pub fn new ( shape : f64 , scale : f64 ) -> Result < Gamma , Error > {
108
- if !( shape > 0.0 ) {
110
+ pub fn new ( shape : N , scale : N ) -> Result < Gamma < N > , Error > {
111
+ if !( shape > N :: zero ( ) ) {
109
112
return Err ( Error :: ShapeTooSmall ) ;
110
113
}
111
- if !( scale > 0.0 ) {
114
+ if !( scale > N :: zero ( ) ) {
112
115
return Err ( Error :: ScaleTooSmall ) ;
113
116
}
114
117
115
- let repr = if shape == 1.0 {
116
- One ( Exp :: new ( 1.0 / scale) . map_err ( |_| Error :: ScaleTooLarge ) ?)
117
- } else if shape < 1.0 {
118
+ let repr = if shape == N :: one ( ) {
119
+ One ( Exp :: new ( N :: one ( ) / scale) . map_err ( |_| Error :: ScaleTooLarge ) ?)
120
+ } else if shape < N :: one ( ) {
118
121
Small ( GammaSmallShape :: new_raw ( shape, scale) )
119
122
} else {
120
123
Large ( GammaLargeShape :: new_raw ( shape, scale) )
@@ -123,57 +126,69 @@ impl Gamma {
123
126
}
124
127
}
125
128
126
- impl GammaSmallShape {
127
- fn new_raw ( shape : f64 , scale : f64 ) -> GammaSmallShape {
129
+ impl < N : Float > GammaSmallShape < N >
130
+ where StandardNormal : Distribution < N > , Open01 : Distribution < N >
131
+ {
132
+ fn new_raw ( shape : N , scale : N ) -> GammaSmallShape < N > {
128
133
GammaSmallShape {
129
- inv_shape : 1. / shape,
130
- large_shape : GammaLargeShape :: new_raw ( shape + 1.0 , scale)
134
+ inv_shape : N :: one ( ) / shape,
135
+ large_shape : GammaLargeShape :: new_raw ( shape + N :: one ( ) , scale)
131
136
}
132
137
}
133
138
}
134
139
135
- impl GammaLargeShape {
136
- fn new_raw ( shape : f64 , scale : f64 ) -> GammaLargeShape {
137
- let d = shape - 1. / 3. ;
140
+ impl < N : Float > GammaLargeShape < N >
141
+ where StandardNormal : Distribution < N > , Open01 : Distribution < N >
142
+ {
143
+ fn new_raw ( shape : N , scale : N ) -> GammaLargeShape < N > {
144
+ let d = shape - N :: from ( 1. / 3. ) . unwrap ( ) ;
138
145
GammaLargeShape {
139
146
scale,
140
- c : 1. / ( 9. * d) . sqrt ( ) ,
147
+ c : N :: one ( ) / ( N :: from ( 9. ) . unwrap ( ) * d) . sqrt ( ) ,
141
148
d
142
149
}
143
150
}
144
151
}
145
152
146
- impl Distribution < f64 > for Gamma {
147
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
153
+ impl < N : Float > Distribution < N > for Gamma < N >
154
+ where StandardNormal : Distribution < N > , Exp1 : Distribution < N > , Open01 : Distribution < N >
155
+ {
156
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> N {
148
157
match self . repr {
149
158
Small ( ref g) => g. sample ( rng) ,
150
159
One ( ref g) => g. sample ( rng) ,
151
160
Large ( ref g) => g. sample ( rng) ,
152
161
}
153
162
}
154
163
}
155
- impl Distribution < f64 > for GammaSmallShape {
156
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
157
- let u: f64 = rng. sample ( Open01 ) ;
164
+ impl < N : Float > Distribution < N > for GammaSmallShape < N >
165
+ where StandardNormal : Distribution < N > , Open01 : Distribution < N >
166
+ {
167
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> N {
168
+ let u: N = rng. sample ( Open01 ) ;
158
169
159
170
self . large_shape . sample ( rng) * u. powf ( self . inv_shape )
160
171
}
161
172
}
162
- impl Distribution < f64 > for GammaLargeShape {
163
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
173
+ impl < N : Float > Distribution < N > for GammaLargeShape < N >
174
+ where StandardNormal : Distribution < N > , Open01 : Distribution < N >
175
+ {
176
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> N {
177
+ // Marsaglia & Tsang method, 2000
164
178
loop {
165
- let x = rng. sample ( StandardNormal ) ;
166
- let v_cbrt = 1.0 + self . c * x;
167
- if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
179
+ let x: N = rng. sample ( StandardNormal ) ;
180
+ let v_cbrt = N :: one ( ) + self . c * x;
181
+ if v_cbrt <= N :: zero ( ) { // a^3 <= 0 iff a <= 0
168
182
continue
169
183
}
170
184
171
185
let v = v_cbrt * v_cbrt * v_cbrt;
172
- let u: f64 = rng. sample ( Open01 ) ;
186
+ let u: N = rng. sample ( Open01 ) ;
173
187
174
188
let x_sqr = x * x;
175
- if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
176
- u. ln ( ) < 0.5 * x_sqr + self . d * ( 1.0 - v + v. ln ( ) ) {
189
+ if u < N :: one ( ) - N :: from ( 0.0331 ) . unwrap ( ) * x_sqr * x_sqr ||
190
+ u. ln ( ) < N :: from ( 0.5 ) . unwrap ( ) * x_sqr + self . d * ( N :: one ( ) - v + v. ln ( ) )
191
+ {
177
192
return self . d * v * self . scale
178
193
}
179
194
}
@@ -215,7 +230,7 @@ enum ChiSquaredRepr {
215
230
// e.g. when alpha = 1/2 as it would be for this case, so special-
216
231
// casing and using the definition of N(0,1)^2 is faster.
217
232
DoFExactlyOne ,
218
- DoFAnythingElse ( Gamma ) ,
233
+ DoFAnythingElse ( Gamma < f64 > ) ,
219
234
}
220
235
221
236
impl ChiSquared {
@@ -238,7 +253,7 @@ impl Distribution<f64> for ChiSquared {
238
253
match self . repr {
239
254
DoFExactlyOne => {
240
255
// k == 1 => N(0,1)^2
241
- let norm = rng. sample ( StandardNormal ) ;
256
+ let norm: f64 = rng. sample ( StandardNormal ) ;
242
257
norm * norm
243
258
}
244
259
DoFAnythingElse ( ref g) => g. sample ( rng)
@@ -332,7 +347,7 @@ impl StudentT {
332
347
}
333
348
impl Distribution < f64 > for StudentT {
334
349
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
335
- let norm = rng. sample ( StandardNormal ) ;
350
+ let norm: f64 = rng. sample ( StandardNormal ) ;
336
351
norm * ( self . dof / self . chi . sample ( rng) ) . sqrt ( )
337
352
}
338
353
}
@@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
350
365
/// ```
351
366
#[ derive( Clone , Copy , Debug ) ]
352
367
pub struct Beta {
353
- gamma_a : Gamma ,
354
- gamma_b : Gamma ,
368
+ gamma_a : Gamma < f64 > ,
369
+ gamma_b : Gamma < f64 > ,
355
370
}
356
371
357
372
/// Error type returned from `Beta::new`.
0 commit comments