Skip to content

MultiDistribution #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
7 changes: 3 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
//! - [`Beta`] distribution
//! - [`Triangular`] distribution
//! - Multivariate probability distributions
//! - [`Dirichlet`] distribution
//! - [`multi::Dirichlet`] distribution
//! - [`UnitSphere`] distribution
//! - [`UnitBall`] distribution
//! - [`UnitCircle`] distribution
Expand Down Expand Up @@ -100,8 +100,6 @@ pub use self::beta::{Beta, Error as BetaError};
pub use self::binomial::{Binomial, Error as BinomialError};
pub use self::cauchy::{Cauchy, Error as CauchyError};
pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError};
#[cfg(feature = "alloc")]
pub use self::dirichlet::{Dirichlet, Error as DirichletError};
pub use self::exponential::{Error as ExpError, Exp, Exp1};
pub use self::fisher_f::{Error as FisherFError, FisherF};
pub use self::frechet::{Error as FrechetError, Frechet};
Expand Down Expand Up @@ -130,6 +128,8 @@ pub use student_t::StudentT;

pub use num_traits;

#[cfg(feature = "alloc")]
pub mod multi;
#[cfg(feature = "alloc")]
pub mod weighted;

Expand Down Expand Up @@ -188,7 +188,6 @@ mod beta;
mod binomial;
mod cauchy;
mod chi_squared;
mod dirichlet;
mod exponential;
mod fisher_f;
mod frechet;
Expand Down
46 changes: 26 additions & 20 deletions src/dirichlet.rs → src/multi/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`.

#![cfg(feature = "alloc")]
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use crate::{multi::MultiDistribution, Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use core::fmt;
use num_traits::{Float, NumCast};
use rand::Rng;
Expand Down Expand Up @@ -68,26 +68,27 @@ where
}
}

impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N>
impl<F, const N: usize> MultiDistribution<F> for DirichletFromGamma<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
let mut samples = [F::zero(); N];
fn sample_len(&self) -> usize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be a const fn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, this is not supported by traits.

N
}
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
let mut sum = F::zero();

for (s, g) in samples.iter_mut().zip(self.samplers.iter()) {
for (s, g) in output.iter_mut().zip(self.samplers.iter()) {
*s = g.sample(rng);
sum = sum + *s;
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
for s in output.iter_mut() {
*s = *s * invacc;
}
samples
}
}

Expand Down Expand Up @@ -149,24 +150,25 @@ where
}
}

impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N>
impl<F, const N: usize> MultiDistribution<F> for DirichletFromBeta<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
let mut samples = [F::zero(); N];
fn sample_len(&self) -> usize {
N
}
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
let mut acc = F::one();

for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) {
for (s, beta) in output.iter_mut().zip(self.samplers.iter()) {
let beta_sample = beta.sample(rng);
*s = acc * beta_sample;
acc = acc * (F::one() - beta_sample);
}
samples[N - 1] = acc;
samples
output[N - 1] = acc;
}
}

Expand Down Expand Up @@ -208,7 +210,8 @@ where
///
/// ```
/// use rand::prelude::*;
/// use rand_distr::Dirichlet;
/// use rand_distr::multi::Dirichlet;
/// use rand_distr::multi::MultiDistribution;
///
/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
/// let samples = dirichlet.sample(&mut rand::rng());
Expand Down Expand Up @@ -259,7 +262,7 @@ impl fmt::Display for Error {
"failed to create required Gamma distribution for Dirichlet distribution"
}
Error::FailedToCreateBeta => {
"failed to create required Beta distribition for Dirichlet distribution"
"failed to create required Beta distribution for Dirichlet distribution"
}
})
}
Expand Down Expand Up @@ -315,17 +318,20 @@ where
}
}

impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
impl<F, const N: usize> MultiDistribution<F> for Dirichlet<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
fn sample_len(&self) -> usize {
N
}
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [F]) {
match &self.repr {
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample_to_buf(rng, output),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample_to_buf(rng, output),
}
}
}
Expand Down Expand Up @@ -403,7 +409,7 @@ mod test {
let alpha_sum: f64 = alpha.iter().sum();
let expected_mean = alpha.map(|x| x / alpha_sum);
for i in 0..N {
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
average::assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
}
}

Expand Down
30 changes: 30 additions & 0 deletions src/multi/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//! Contains Multi-dimensional distributions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop the word 'contains' and remove the full-stop since this isn't a sentence.

//!
//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations.
//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait.
Comment on lines +3 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally wrap comments at 80 chars width (sometimes up to 100 if the line already has a large indent).

The wording could be a little better, e.g.

The MultiDistribution trait allows sampling a multi-dimensional distribution to a pre-allocated buffer or to a new [Vec].


use alloc::vec::Vec;
use rand::Rng;

/// This trait allows to sample from a multi-dimensional distribution without extra allocations.
/// For convenience it also provides a `sample` method which returns the result as a `Vec`.
pub trait MultiDistribution<T> {
Comment on lines +9 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Items have a short one-line description, with additional details in new paragraphs.

/// returns the length of one sample (dimension of the distribution)
fn sample_len(&self) -> usize;
Comment on lines +12 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalise the first letter of 'returns'

/// samples from the distribution and writes the result to `buf`
fn sample_to_buf<R: Rng + ?Sized>(&self, rng: &mut R, buf: &mut [T]);
/// samples from the distribution and returns the result as a `Vec`, to avoid extra allocations use `sample_to_buf`
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T>
where
T: Default,
{
let mut buf = Vec::new();
buf.resize_with(self.sample_len(), || T::default());
self.sample_to_buf(rng, &mut buf);
buf
}
}

pub use dirichlet::Dirichlet;

mod dirichlet;
12 changes: 9 additions & 3 deletions tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,17 @@ fn weibull_stability() {
#[cfg(feature = "alloc")]
#[test]
fn dirichlet_stability() {
use rand_distr::multi::MultiDistribution;

let mut rng = get_rng(223);
assert_eq!(
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
multi::Dirichlet::new([1.0, 2.0, 3.0])
.unwrap()
.sample(&mut rng),
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(
rng.sample(Dirichlet::new([8.0; 5]).unwrap()),
multi::Dirichlet::new([8.0; 5]).unwrap().sample(&mut rng),
[
0.17684200044809556,
0.29915953935953055,
Expand All @@ -517,7 +521,9 @@ fn dirichlet_stability() {
);
// Test stability for the case where all alphas are less than 0.1.
assert_eq!(
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
multi::Dirichlet::new([0.05, 0.025, 0.075, 0.05])
.unwrap()
.sample(&mut rng),
[
0.00027580456855692104,
2.296135759821706e-20,
Expand Down