Skip to content

Commit 05a7ab3

Browse files
authored
Migrate rand_distr to num-traits for no_std support (#987)
* replace custom Float trait with num-traits::Float * enable no_std support via num-traits math functions * remove Distribution<u64> impl for poisson * move stability tests * add copyright notice * tweak dirichlet and alias_method to use boxed slice instead of vec
1 parent dda1780 commit 05a7ab3

22 files changed

+897
-1040
lines changed

rand_distr/Cargo.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,17 @@ travis-ci = { repository = "rust-random/rand" }
2020
appveyor = { repository = "rust-random/rand" }
2121

2222
[dependencies]
23-
rand = { path = "..", version = "0.7" }
23+
rand = { path = "..", version = "0.7", default-features = false }
24+
num-traits = { version = "0.2", default-features = false, features = ["libm"] }
25+
26+
[features]
27+
default = ["std"]
28+
std = ["alloc"]
29+
alloc = []
2430

2531
[dev-dependencies]
2632
rand_pcg = { version = "0.2", path = "../rand_pcg" }
33+
# For inline examples
34+
rand = { path = "..", version = "0.7", default-features = false, features = ["std_rng", "std"] }
2735
# Histogram implementation for testing uniformity
2836
average = "0.10.3"

rand_distr/src/binomial.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
1212
use crate::{Distribution, Uniform};
1313
use rand::Rng;
14-
use std::{error, fmt};
14+
use core::fmt;
1515

1616
/// The binomial distribution `Binomial(n, p)`.
1717
///
@@ -53,7 +53,8 @@ impl fmt::Display for Error {
5353
}
5454
}
5555

56-
impl error::Error for Error {}
56+
#[cfg(feature = "std")]
57+
impl std::error::Error for Error {}
5758

5859
impl Binomial {
5960
/// Construct a new `Binomial` with the given shape parameters `n` (number
@@ -72,7 +73,7 @@ impl Binomial {
7273
/// Convert a `f64` to an `i64`, panicing on overflow.
7374
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
7475
fn f64_to_i64(x: f64) -> i64 {
75-
assert!(x < (::std::i64::MAX as f64));
76+
assert!(x < (core::i64::MAX as f64));
7677
x as i64
7778
}
7879

@@ -106,7 +107,7 @@ impl Distribution<u64> for Binomial {
106107
// Ranlib uses 30, and GSL uses 14.
107108
const BINV_THRESHOLD: f64 = 10.;
108109

109-
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
110+
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) {
110111
// Use the BINV algorithm.
111112
let s = p / q;
112113
let a = ((self.n + 1) as f64) * s;
@@ -338,22 +339,4 @@ mod test {
338339
fn test_binomial_invalid_lambda_neg() {
339340
Binomial::new(20, -10.0).unwrap();
340341
}
341-
342-
#[test]
343-
fn value_stability() {
344-
fn test_samples(n: u64, p: f64, expected: &[u64]) {
345-
let distr = Binomial::new(n, p).unwrap();
346-
let mut rng = crate::test::rng(353);
347-
let mut buf = [0; 4];
348-
for x in &mut buf {
349-
*x = rng.sample(&distr);
350-
}
351-
assert_eq!(buf, expected);
352-
}
353-
354-
// We have multiple code paths: np < 10, p > 0.5
355-
test_samples(2, 0.7, &[1, 1, 2, 1]);
356-
test_samples(20, 0.3, &[7, 7, 5, 7]);
357-
test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]);
358-
}
359342
}

rand_distr/src/cauchy.rs

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
//! The Cauchy distribution.
1111
12-
use crate::utils::Float;
12+
use num_traits::{Float, FloatConst};
1313
use crate::{Distribution, Standard};
1414
use rand::Rng;
15-
use std::{error, fmt};
15+
use core::fmt;
1616

1717
/// The Cauchy distribution `Cauchy(median, scale)`.
1818
///
@@ -32,9 +32,11 @@ use std::{error, fmt};
3232
/// println!("{} is from a Cauchy(2, 5) distribution", v);
3333
/// ```
3434
#[derive(Clone, Copy, Debug)]
35-
pub struct Cauchy<N> {
36-
median: N,
37-
scale: N,
35+
pub struct Cauchy<F>
36+
where F: Float + FloatConst, Standard: Distribution<F>
37+
{
38+
median: F,
39+
scale: F,
3840
}
3941

4042
/// Error type returned from `Cauchy::new`.
@@ -52,30 +54,31 @@ impl fmt::Display for Error {
5254
}
5355
}
5456

55-
impl error::Error for Error {}
57+
#[cfg(feature = "std")]
58+
impl std::error::Error for Error {}
5659

57-
impl<N: Float> Cauchy<N>
58-
where Standard: Distribution<N>
60+
impl<F> Cauchy<F>
61+
where F: Float + FloatConst, Standard: Distribution<F>
5962
{
6063
/// Construct a new `Cauchy` with the given shape parameters
6164
/// `median` the peak location and `scale` the scale factor.
62-
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
63-
if !(scale > N::from(0.0)) {
65+
pub fn new(median: F, scale: F) -> Result<Cauchy<F>, Error> {
66+
if !(scale > F::zero()) {
6467
return Err(Error::ScaleTooSmall);
6568
}
6669
Ok(Cauchy { median, scale })
6770
}
6871
}
6972

70-
impl<N: Float> Distribution<N> for Cauchy<N>
71-
where Standard: Distribution<N>
73+
impl<F> Distribution<F> for Cauchy<F>
74+
where F: Float + FloatConst, Standard: Distribution<F>
7275
{
73-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
76+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
7477
// sample from [0, 1)
7578
let x = Standard.sample(rng);
7679
// get standard cauchy random number
7780
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
78-
let comp_dev = (N::pi() * x).tan();
81+
let comp_dev = (F::PI() * x).tan();
7982
// shift and scale according to parameters
8083
self.median + self.scale * comp_dev
8184
}
@@ -108,10 +111,12 @@ mod test {
108111
sum += numbers[i];
109112
}
110113
let median = median(&mut numbers);
111-
println!("Cauchy median: {}", median);
114+
#[cfg(feature = "std")]
115+
std::println!("Cauchy median: {}", median);
112116
assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough
113117
let mean = sum / 1000.0;
114-
println!("Cauchy mean: {}", mean);
118+
#[cfg(feature = "std")]
119+
std::println!("Cauchy mean: {}", mean);
115120
// for a Cauchy distribution the mean should not converge
116121
assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough
117122
}
@@ -130,8 +135,8 @@ mod test {
130135

131136
#[test]
132137
fn value_stability() {
133-
fn gen_samples<N: Float + core::fmt::Debug>(m: N, s: N, buf: &mut [N])
134-
where Standard: Distribution<N> {
138+
fn gen_samples<F: Float + FloatConst + core::fmt::Debug>(m: F, s: F, buf: &mut [F])
139+
where Standard: Distribution<F> {
135140
let distr = Cauchy::new(m, s).unwrap();
136141
let mut rng = crate::test::rng(353);
137142
for x in buf {

rand_distr/src/dirichlet.rs

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
// except according to those terms.
99

1010
//! The dirichlet distribution.
11-
12-
use crate::utils::Float;
11+
#![cfg(feature = "alloc")]
12+
use num_traits::Float;
1313
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
1414
use rand::Rng;
15-
use std::{error, fmt};
15+
use core::fmt;
16+
use alloc::{boxed::Box, vec, vec::Vec};
1617

1718
/// The Dirichlet distribution `Dirichlet(alpha)`.
1819
///
@@ -26,14 +27,20 @@ use std::{error, fmt};
2627
/// use rand::prelude::*;
2728
/// use rand_distr::Dirichlet;
2829
///
29-
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
30+
/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
3031
/// let samples = dirichlet.sample(&mut rand::thread_rng());
3132
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
3233
/// ```
3334
#[derive(Clone, Debug)]
34-
pub struct Dirichlet<N> {
35+
pub struct Dirichlet<F>
36+
where
37+
F: Float,
38+
StandardNormal: Distribution<F>,
39+
Exp1: Distribution<F>,
40+
Open01: Distribution<F>,
41+
{
3542
/// Concentration parameters (alpha)
36-
alpha: Vec<N>,
43+
alpha: Box<[F]>,
3744
}
3845

3946
/// Error type returned from `Dirchlet::new`.
@@ -58,68 +65,70 @@ impl fmt::Display for Error {
5865
}
5966
}
6067

61-
impl error::Error for Error {}
68+
#[cfg(feature = "std")]
69+
impl std::error::Error for Error {}
6270

63-
impl<N: Float> Dirichlet<N>
71+
impl<F> Dirichlet<F>
6472
where
65-
StandardNormal: Distribution<N>,
66-
Exp1: Distribution<N>,
67-
Open01: Distribution<N>,
73+
F: Float,
74+
StandardNormal: Distribution<F>,
75+
Exp1: Distribution<F>,
76+
Open01: Distribution<F>,
6877
{
6978
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
7079
///
7180
/// Requires `alpha.len() >= 2`.
7281
#[inline]
73-
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
74-
let a = alpha.into();
75-
if a.len() < 2 {
82+
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
83+
if alpha.len() < 2 {
7684
return Err(Error::AlphaTooShort);
7785
}
78-
for &ai in &a {
79-
if !(ai > N::from(0.0)) {
86+
for &ai in alpha.iter() {
87+
if !(ai > F::zero()) {
8088
return Err(Error::AlphaTooSmall);
8189
}
8290
}
8391

84-
Ok(Dirichlet { alpha: a })
92+
Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() })
8593
}
8694

8795
/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
8896
///
8997
/// Requires `size >= 2`.
9098
#[inline]
91-
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
92-
if !(alpha > N::from(0.0)) {
99+
pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
100+
if !(alpha > F::zero()) {
93101
return Err(Error::AlphaTooSmall);
94102
}
95103
if size < 2 {
96104
return Err(Error::SizeTooSmall);
97105
}
98106
Ok(Dirichlet {
99-
alpha: vec![alpha; size],
107+
alpha: vec![alpha; size].into_boxed_slice(),
100108
})
101109
}
102110
}
103111

104-
impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
112+
impl<F> Distribution<Vec<F>> for Dirichlet<F>
105113
where
106-
StandardNormal: Distribution<N>,
107-
Exp1: Distribution<N>,
108-
Open01: Distribution<N>,
114+
F: Float,
115+
StandardNormal: Distribution<F>,
116+
Exp1: Distribution<F>,
117+
Open01: Distribution<F>,
109118
{
110-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
119+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
111120
let n = self.alpha.len();
112-
let mut samples = vec![N::from(0.0); n];
113-
let mut sum = N::from(0.0);
121+
let mut samples = vec![F::zero(); n];
122+
let mut sum = F::zero();
114123

115124
for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
116-
let g = Gamma::new(a, N::from(1.0)).unwrap();
125+
let g = Gamma::new(a, F::one()).unwrap();
117126
*s = g.sample(rng);
118-
sum += *s;
127+
sum = sum + (*s);
119128
}
120-
let invacc = N::from(1.0) / sum;
129+
let invacc = F::one() / sum;
121130
for s in samples.iter_mut() {
122-
*s *= invacc;
131+
*s = (*s)*invacc;
123132
}
124133
samples
125134
}
@@ -131,7 +140,7 @@ mod test {
131140

132141
#[test]
133142
fn test_dirichlet() {
134-
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
143+
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
135144
let mut rng = crate::test::rng(221);
136145
let samples = d.sample(&mut rng);
137146
let _: Vec<f64> = samples
@@ -170,20 +179,4 @@ mod test {
170179
fn test_dirichlet_invalid_alpha() {
171180
Dirichlet::new_with_size(0.0f64, 2).unwrap();
172181
}
173-
174-
#[test]
175-
fn value_stability() {
176-
let mut rng = crate::test::rng(223);
177-
assert_eq!(
178-
rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()),
179-
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
180-
);
181-
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![
182-
0.17684200044809556,
183-
0.29915953935953055,
184-
0.1832858056608014,
185-
0.1425623503573967,
186-
0.19815030417417595
187-
]);
188-
}
189182
}

0 commit comments

Comments
 (0)