diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 570474bee44..5ff10d934e2 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add plots for `rand_distr` distributions to documentation (#1434) - Add `PertBuilder`, fix case where mode ≅ mean (#1452) +- Add `Multinomail` distribution ## [0.5.0-alpha.1] - 2024-03-18 - Target `rand` version `0.9.0-alpha.1` diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 90a534ff8cb..eb527c716a3 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -48,6 +48,7 @@ //! - [`Cauchy`] distribution //! - Related to Bernoulli trials (yes/no events, with a given probability): //! - [`Binomial`] distribution +//! - [`MultinomialConst`] and [`MultinomialDyn`] distribution //! - [`Geometric`] distribution //! - [`Hypergeometric`] distribution //! - Related to positive real-valued quantities that grow exponentially @@ -112,6 +113,9 @@ pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; pub use self::gumbel::{Error as GumbelError, Gumbel}; pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian}; +#[cfg(feature = "alloc")] +pub use self::multinomial::MultinomialDyn; +pub use self::multinomial::{Error as MultinomialError, MultinomialConst}; pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; pub use self::normal_inverse_gaussian::{ Error as NormalInverseGaussianError, NormalInverseGaussian, @@ -207,6 +211,7 @@ mod geometric; mod gumbel; mod hypergeometric; mod inverse_gaussian; +mod multinomial; mod normal; mod normal_inverse_gaussian; mod pareto; diff --git a/rand_distr/src/multinomial.rs b/rand_distr/src/multinomial.rs new file mode 100644 index 00000000000..9986603fb2d --- /dev/null +++ b/rand_distr/src/multinomial.rs @@ -0,0 +1,253 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The multinomial distribution. + +use core::borrow::Borrow; + +use crate::{Binomial, Distribution}; +use num_traits::AsPrimitive; +use rand::Rng; + +/// Error type returned from `Multinomial::new`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// There is a negative weight or Nan + ProbabilityNegative, + /// Sum overflows to inf + SumOverflow, + /// Sum is zero + SumZero, +} + +impl core::fmt::Display for Error { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(match self { + Error::ProbabilityNegative => "One of the weights is negative or Nan", + Error::SumOverflow => "Sum of weights overflows to inf", + Error::SumZero => "Sum of weights is zero", + }) + } +} + +/// The [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) distribution `Multinomial(n, w)` with compile time number of categories. +#[derive(Debug, Clone, PartialEq)] +pub struct MultinomialConst { + /// number of draws + n: I, + /// weights for the multinomial distribution + weights: [f64; K], + /// sum of the weights + sum: f64, +} + +impl MultinomialConst { + /// Constructs a new `MultinomialConst` distribution which samples from `K` categories. + /// + /// `n` is the number of draws. + /// + /// `weights` have to be non negative and will be normalized to 1. + /// + /// `K` has to be known at compile time + pub fn new(n: I, weights: [f64; K]) -> Result, Error> + where + I: num_traits::PrimInt, + u64: num_traits::AsPrimitive, + I: num_traits::AsPrimitive, + { + let all_pos = weights.iter().all(|&x| x >= 0.0); + + if !all_pos { + return Err(Error::ProbabilityNegative); + } + + let sum: f64 = weights.iter().sum(); + + if !sum.is_finite() { + return Err(Error::SumOverflow); + } + + if sum == 0.0 { + return Err(Error::SumZero); + } + + Ok(MultinomialConst:: { n, weights, sum }) + } +} + +#[cfg(feature = "alloc")] +/// The [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) distribution `Multinomial(n, w)` with runtime determined number of categories. +#[derive(Debug, Clone, PartialEq)] +pub struct MultinomialDyn { + /// number of draws + n: I, + /// weights for the multinomial distribution + weights: alloc::boxed::Box<[f64]>, + /// sum of the weights + sum: f64, +} + +#[cfg(feature = "alloc")] +impl MultinomialDyn { + /// Constructs a new `MultinomialDyn` distribution which samples from `K` different categories. + /// + /// `n` is the number of draws. + /// + /// `weights` have to be not negative and will be normalized to 1. + /// + /// `K` can be specified at runtime + pub fn new>( + n: I, + weights: impl IntoIterator, + ) -> Result, Error> { + let weights: alloc::boxed::Box<[f64]> = weights.into_iter().map(|x| *x.borrow()).collect(); + + let all_pos = weights.iter().all(|&x| x >= 0.0); + + if !all_pos { + return Err(Error::ProbabilityNegative); + } + + let sum: f64 = weights.iter().sum(); + + if !sum.is_finite() { + return Err(Error::SumOverflow); + } + + if sum == 0.0 { + return Err(Error::SumZero); + } + + Ok(MultinomialDyn:: { n, weights, sum }) + } +} + +/// sum has to be the sum of the weights, this is a performance optimization +fn sample(rng: &mut R, n: I, weights: &[f64], sum: f64, result: &mut [I]) +where + I: num_traits::PrimInt, + u64: num_traits::AsPrimitive, + I: num_traits::AsPrimitive, +{ + // This follows the binomial approach in "The computer generation of multinomial random variates" by Charles S. Davis + + let mut sum_p = 0.0; + let mut sum_n: I = 0.as_(); + + for k in 0..weights.len() { + if sum - sum_p <= 0.0 { + result[k] = 0.as_(); + continue; + } + + let prob = (weights[k] / (sum - sum_p)).min(1.0); + let binomial = Binomial::new((n - sum_n).as_(), prob) + .expect("We know that prob is between 0.0 and 1.0"); + result[k] = binomial.sample(rng).as_(); + sum_n = sum_n + result[k]; + sum_p += weights[k]; + } +} + +impl Distribution<[I; K]> for MultinomialConst +where + I: num_traits::PrimInt, + u64: num_traits::AsPrimitive, + I: num_traits::AsPrimitive, +{ + fn sample(&self, rng: &mut R) -> [I; K] { + let mut result = [0.as_(); K]; + sample(rng, self.n, &self.weights, self.sum, &mut result); + result + } +} + +#[cfg(feature = "alloc")] +impl Distribution> for MultinomialDyn +where + I: num_traits::PrimInt, + u64: num_traits::AsPrimitive, + I: num_traits::AsPrimitive, +{ + fn sample(&self, rng: &mut R) -> alloc::vec::Vec { + let mut result = alloc::vec![0.as_(); self.weights.len()]; + sample(rng, self.n, &self.weights, self.sum, &mut result); + result + } +} + +#[cfg(test)] +mod test { + + #[test] + fn test_multinomial_const() { + use super::*; + + let n: i32 = 1000; + let weights = [0.1, 0.2, 0.3, 0.4]; + let mut rng = crate::test::rng(123); + let multinomial = MultinomialConst::new(n, weights).unwrap(); + let sample = multinomial.sample(&mut rng); + assert_eq!(sample.iter().sum::(), n); + } + + #[test] + fn test_almost_zero_dist() { + use super::*; + + let n: i32 = 1000; + let weights = [0.0, 0.0, 0.0, 0.000000001]; + let multinomial = MultinomialConst::new(n, weights).unwrap(); + let sample = multinomial.sample(&mut crate::test::rng(123)); + assert!(sample[3] == n); + } + + #[test] + fn test_zero_dist() { + use super::*; + + let n: i32 = 1000; + let weights = [0.0, 0.0, 0.0, 0.0]; + let multinomial = MultinomialConst::new(n, weights); + assert_eq!(multinomial, Err(Error::SumZero)); + } + + #[test] + fn test_negative_dist() { + use super::*; + + let n: i32 = 1000; + let weights = [0.1, 0.2, 0.3, -0.6]; + let multinomial = MultinomialConst::new(n, weights); + assert_eq!(multinomial, Err(Error::ProbabilityNegative)); + } + + #[test] + fn test_overflow() { + use super::*; + + let n: i32 = 1000; + let weights = [f64::MAX, f64::MAX, f64::MAX, f64::MAX]; + let multinomial = MultinomialConst::new(n, weights); + assert_eq!(multinomial, Err(Error::SumOverflow)); + } + + #[cfg(feature = "alloc")] + #[test] + fn test_multinomial_dyn() { + use super::*; + + let n = 1000; + let weights = alloc::vec![0.1, 0.2, 0.3, 0.4]; + let mut rng = crate::test::rng(123); + let multinomial = MultinomialDyn::new(n, weights).unwrap(); + let sample = multinomial.sample(&mut rng); + assert_eq!(sample.iter().sum::(), n); + } +}