Skip to content

Commit b664e64

Browse files
authored
Merge pull request #785 from dhardy/distr
Make distributions generic / impl for f32
2 parents 01343e1 + 19829a4 commit b664e64

File tree

4 files changed

+122
-73
lines changed

4 files changed

+122
-73
lines changed

rand_distr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ appveyor = { repository = "rust-random/rand" }
2020

2121
[dependencies]
2222
rand = { path = "..", version = ">=0.5, <=0.7" }
23+
num-traits = "0.2"

rand_distr/src/exponential.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use rand::Rng;
1313
use crate::{ziggurat_tables, Distribution};
1414
use crate::utils::ziggurat;
15+
use num_traits::Float;
1516

1617
/// Samples floating-point numbers according to the exponential distribution,
1718
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
@@ -39,6 +40,15 @@ use crate::utils::ziggurat;
3940
#[derive(Clone, Copy, Debug)]
4041
pub struct Exp1;
4142

43+
impl Distribution<f32> for Exp1 {
44+
#[inline]
45+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
46+
// TODO: use optimal 32-bit implementation
47+
let x: f64 = self.sample(rng);
48+
x as f32
49+
}
50+
}
51+
4252
// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
4353
impl Distribution<f64> for Exp1 {
4454
#[inline]
@@ -76,9 +86,9 @@ impl Distribution<f64> for Exp1 {
7686
/// println!("{} is from a Exp(2) distribution", v);
7787
/// ```
7888
#[derive(Clone, Copy, Debug)]
79-
pub struct Exp {
89+
pub struct Exp<N> {
8090
/// `lambda` stored as `1/lambda`, since this is what we scale by.
81-
lambda_inverse: f64
91+
lambda_inverse: N
8292
}
8393

8494
/// Error type returned from `Exp::new`.
@@ -88,22 +98,25 @@ pub enum Error {
8898
LambdaTooSmall,
8999
}
90100

91-
impl Exp {
101+
impl<N: Float> Exp<N>
102+
where Exp1: Distribution<N>
103+
{
92104
/// Construct a new `Exp` with the given shape parameter
93105
/// `lambda`.
94106
#[inline]
95-
pub fn new(lambda: f64) -> Result<Exp, Error> {
96-
if !(lambda > 0.0) {
107+
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
108+
if !(lambda > N::zero()) {
97109
return Err(Error::LambdaTooSmall);
98110
}
99-
Ok(Exp { lambda_inverse: 1.0 / lambda })
111+
Ok(Exp { lambda_inverse: N::one() / lambda })
100112
}
101113
}
102114

103-
impl Distribution<f64> for Exp {
104-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
105-
let n: f64 = rng.sample(Exp1);
106-
n * self.lambda_inverse
115+
impl<N: Float> Distribution<N> for Exp<N>
116+
where Exp1: Distribution<N>
117+
{
118+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
119+
rng.sample(Exp1) * self.lambda_inverse
107120
}
108121
}
109122

rand_distr/src/gamma.rs

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;
1414

1515
use rand::Rng;
1616
use crate::normal::StandardNormal;
17-
use crate::{Distribution, Exp, Open01};
17+
use crate::{Distribution, Exp1, Exp, Open01};
18+
use num_traits::Float;
1819

1920
/// The Gamma distribution `Gamma(shape, scale)` distribution.
2021
///
@@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
4748
/// (September 2000), 363-372.
4849
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
4950
#[derive(Clone, Copy, Debug)]
50-
pub struct Gamma {
51-
repr: GammaRepr,
51+
pub struct Gamma<N> {
52+
repr: GammaRepr<N>,
5253
}
5354

5455
/// Error type returned from `Gamma::new`.
@@ -63,10 +64,10 @@ pub enum Error {
6364
}
6465

6566
#[derive(Clone, Copy, Debug)]
66-
enum GammaRepr {
67-
Large(GammaLargeShape),
68-
One(Exp),
69-
Small(GammaSmallShape)
67+
enum GammaRepr<N> {
68+
Large(GammaLargeShape<N>),
69+
One(Exp<N>),
70+
Small(GammaSmallShape<N>)
7071
}
7172

7273
// These two helpers could be made public, but saving the
@@ -84,37 +85,39 @@ enum GammaRepr {
8485
/// See `Gamma` for sampling from a Gamma distribution with general
8586
/// shape parameters.
8687
#[derive(Clone, Copy, Debug)]
87-
struct GammaSmallShape {
88-
inv_shape: f64,
89-
large_shape: GammaLargeShape
88+
struct GammaSmallShape<N> {
89+
inv_shape: N,
90+
large_shape: GammaLargeShape<N>
9091
}
9192

9293
/// Gamma distribution where the shape parameter is larger than 1.
9394
///
9495
/// See `Gamma` for sampling from a Gamma distribution with general
9596
/// shape parameters.
9697
#[derive(Clone, Copy, Debug)]
97-
struct GammaLargeShape {
98-
scale: f64,
99-
c: f64,
100-
d: f64
98+
struct GammaLargeShape<N> {
99+
scale: N,
100+
c: N,
101+
d: N
101102
}
102103

103-
impl Gamma {
104+
impl<N: Float> Gamma<N>
105+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
106+
{
104107
/// Construct an object representing the `Gamma(shape, scale)`
105108
/// distribution.
106109
#[inline]
107-
pub fn new(shape: f64, scale: f64) -> Result<Gamma, Error> {
108-
if !(shape > 0.0) {
110+
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
111+
if !(shape > N::zero()) {
109112
return Err(Error::ShapeTooSmall);
110113
}
111-
if !(scale > 0.0) {
114+
if !(scale > N::zero()) {
112115
return Err(Error::ScaleTooSmall);
113116
}
114117

115-
let repr = if shape == 1.0 {
116-
One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?)
117-
} else if shape < 1.0 {
118+
let repr = if shape == N::one() {
119+
One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
120+
} else if shape < N::one() {
118121
Small(GammaSmallShape::new_raw(shape, scale))
119122
} else {
120123
Large(GammaLargeShape::new_raw(shape, scale))
@@ -123,57 +126,69 @@ impl Gamma {
123126
}
124127
}
125128

126-
impl GammaSmallShape {
127-
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
129+
impl<N: Float> GammaSmallShape<N>
130+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
131+
{
132+
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
128133
GammaSmallShape {
129-
inv_shape: 1. / shape,
130-
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
134+
inv_shape: N::one() / shape,
135+
large_shape: GammaLargeShape::new_raw(shape + N::one(), scale)
131136
}
132137
}
133138
}
134139

135-
impl GammaLargeShape {
136-
fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
137-
let d = shape - 1. / 3.;
140+
impl<N: Float> GammaLargeShape<N>
141+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
142+
{
143+
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
144+
let d = shape - N::from(1. / 3.).unwrap();
138145
GammaLargeShape {
139146
scale,
140-
c: 1. / (9. * d).sqrt(),
147+
c: N::one() / (N::from(9.).unwrap() * d).sqrt(),
141148
d
142149
}
143150
}
144151
}
145152

146-
impl Distribution<f64> for Gamma {
147-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
153+
impl<N: Float> Distribution<N> for Gamma<N>
154+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
155+
{
156+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
148157
match self.repr {
149158
Small(ref g) => g.sample(rng),
150159
One(ref g) => g.sample(rng),
151160
Large(ref g) => g.sample(rng),
152161
}
153162
}
154163
}
155-
impl Distribution<f64> for GammaSmallShape {
156-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
157-
let u: f64 = rng.sample(Open01);
164+
impl<N: Float> Distribution<N> for GammaSmallShape<N>
165+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
166+
{
167+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
168+
let u: N = rng.sample(Open01);
158169

159170
self.large_shape.sample(rng) * u.powf(self.inv_shape)
160171
}
161172
}
162-
impl Distribution<f64> for GammaLargeShape {
163-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
173+
impl<N: Float> Distribution<N> for GammaLargeShape<N>
174+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
175+
{
176+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
177+
// Marsaglia & Tsang method, 2000
164178
loop {
165-
let x = rng.sample(StandardNormal);
166-
let v_cbrt = 1.0 + self.c * x;
167-
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
179+
let x: N = rng.sample(StandardNormal);
180+
let v_cbrt = N::one() + self.c * x;
181+
if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0
168182
continue
169183
}
170184

171185
let v = v_cbrt * v_cbrt * v_cbrt;
172-
let u: f64 = rng.sample(Open01);
186+
let u: N = rng.sample(Open01);
173187

174188
let x_sqr = x * x;
175-
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
176-
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
189+
if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr ||
190+
u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln())
191+
{
177192
return self.d * v * self.scale
178193
}
179194
}
@@ -215,7 +230,7 @@ enum ChiSquaredRepr {
215230
// e.g. when alpha = 1/2 as it would be for this case, so special-
216231
// casing and using the definition of N(0,1)^2 is faster.
217232
DoFExactlyOne,
218-
DoFAnythingElse(Gamma),
233+
DoFAnythingElse(Gamma<f64>),
219234
}
220235

221236
impl ChiSquared {
@@ -238,7 +253,7 @@ impl Distribution<f64> for ChiSquared {
238253
match self.repr {
239254
DoFExactlyOne => {
240255
// k == 1 => N(0,1)^2
241-
let norm = rng.sample(StandardNormal);
256+
let norm: f64 = rng.sample(StandardNormal);
242257
norm * norm
243258
}
244259
DoFAnythingElse(ref g) => g.sample(rng)
@@ -332,7 +347,7 @@ impl StudentT {
332347
}
333348
impl Distribution<f64> for StudentT {
334349
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
335-
let norm = rng.sample(StandardNormal);
350+
let norm: f64 = rng.sample(StandardNormal);
336351
norm * (self.dof / self.chi.sample(rng)).sqrt()
337352
}
338353
}
@@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
350365
/// ```
351366
#[derive(Clone, Copy, Debug)]
352367
pub struct Beta {
353-
gamma_a: Gamma,
354-
gamma_b: Gamma,
368+
gamma_a: Gamma<f64>,
369+
gamma_b: Gamma<f64>,
355370
}
356371

357372
/// Error type returned from `Beta::new`.

0 commit comments

Comments
 (0)