Skip to content

Commit 0faff20

Browse files
authored
Merge pull request #518 from sicking/weighted
Implement weighted sampling API
2 parents cd16da4 + bbb037a commit 0faff20

File tree

5 files changed

+334
-26
lines changed

5 files changed

+334
-26
lines changed

benches/distributions.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
115115
distr_int!(distr_poisson, u64, Poisson::new(4.0));
116116
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));
117117

118+
// Weighted
119+
distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
120+
distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
121+
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
122+
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
118123

119124
// construct and sample from a range
120125
macro_rules! gen_range_int {

src/distributions/mod.rs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
//! numbers of the `char` type; in contrast [`Standard`] may sample any valid
7474
//! `char`.
7575
//!
76+
//! [`WeightedIndex`] can be used to do weighted sampling from a set of items,
77+
//! such as from an array.
7678
//!
7779
//! # Non-uniform probability distributions
7880
//!
@@ -167,12 +169,15 @@
167169
//! [`Uniform`]: struct.Uniform.html
168170
//! [`Uniform::new`]: struct.Uniform.html#method.new
169171
//! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive
172+
//! [`WeightedIndex`]: struct.WeightedIndex.html
170173
171174
use Rng;
172175

173176
#[doc(inline)] pub use self::other::Alphanumeric;
174177
#[doc(inline)] pub use self::uniform::Uniform;
175178
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
179+
#[cfg(feature="alloc")]
180+
#[doc(inline)] pub use self::weighted::WeightedIndex;
176181
#[cfg(feature="std")]
177182
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
178183
#[cfg(feature="std")]
@@ -192,6 +197,8 @@ use Rng;
192197
#[doc(inline)] pub use self::dirichlet::Dirichlet;
193198

194199
pub mod uniform;
200+
#[cfg(feature="alloc")]
201+
#[doc(hidden)] pub mod weighted;
195202
#[cfg(feature="std")]
196203
#[doc(hidden)] pub mod gamma;
197204
#[cfg(feature="std")]
@@ -373,6 +380,8 @@ pub struct Standard;
373380

374381

375382
/// A value with a particular weight for use with `WeightedChoice`.
383+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
384+
#[allow(deprecated)]
376385
#[derive(Copy, Clone, Debug)]
377386
pub struct Weighted<T> {
378387
/// The numerical weight of this item
@@ -383,34 +392,19 @@ pub struct Weighted<T> {
383392

384393
/// A distribution that selects from a finite collection of weighted items.
385394
///
386-
/// Each item has an associated weight that influences how likely it
387-
/// is to be chosen: higher weight is more likely.
395+
/// Deprecated: use [`WeightedIndex`] instead.
388396
///
389-
/// The `Clone` restriction is a limitation of the `Distribution` trait.
390-
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
391-
/// store references or indices into another vector.
392-
///
393-
/// # Example
394-
///
395-
/// ```
396-
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
397-
///
398-
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
399-
/// Weighted { weight: 4, item: 'b' },
400-
/// Weighted { weight: 1, item: 'c' });
401-
/// let wc = WeightedChoice::new(&mut items);
402-
/// let mut rng = rand::thread_rng();
403-
/// for _ in 0..16 {
404-
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
405-
/// println!("{}", wc.sample(&mut rng));
406-
/// }
407-
/// ```
397+
/// [`WeightedIndex`]: struct.WeightedIndex.html
398+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
399+
#[allow(deprecated)]
408400
#[derive(Debug)]
409401
pub struct WeightedChoice<'a, T:'a> {
410402
items: &'a mut [Weighted<T>],
411403
weight_range: Uniform<u32>,
412404
}
413405

406+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
407+
#[allow(deprecated)]
414408
impl<'a, T: Clone> WeightedChoice<'a, T> {
415409
/// Create a new `WeightedChoice`.
416410
///
@@ -448,6 +442,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
448442
}
449443
}
450444

445+
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
446+
#[allow(deprecated)]
451447
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
452448
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
453449
// we want to find the first element that has cumulative
@@ -557,9 +553,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
557553
#[cfg(test)]
558554
mod tests {
559555
use rngs::mock::StepRng;
556+
#[allow(deprecated)]
560557
use super::{WeightedChoice, Weighted, Distribution};
561558

562559
#[test]
560+
#[allow(deprecated)]
563561
fn test_weighted_choice() {
564562
// this makes assumptions about the internal implementation of
565563
// WeightedChoice. It may fail when the implementation in
@@ -619,6 +617,7 @@ mod tests {
619617
}
620618

621619
#[test]
620+
#[allow(deprecated)]
622621
fn test_weighted_clone_initialization() {
623622
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
624623
let clone = initial.clone();
@@ -627,6 +626,7 @@ mod tests {
627626
}
628627

629628
#[test] #[should_panic]
629+
#[allow(deprecated)]
630630
fn test_weighted_clone_change_weight() {
631631
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
632632
let mut clone = initial.clone();
@@ -635,6 +635,7 @@ mod tests {
635635
}
636636

637637
#[test] #[should_panic]
638+
#[allow(deprecated)]
638639
fn test_weighted_clone_change_item() {
639640
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
640641
let mut clone = initial.clone();
@@ -644,15 +645,18 @@ mod tests {
644645
}
645646

646647
#[test] #[should_panic]
648+
#[allow(deprecated)]
647649
fn test_weighted_choice_no_items() {
648650
WeightedChoice::<isize>::new(&mut []);
649651
}
650652
#[test] #[should_panic]
653+
#[allow(deprecated)]
651654
fn test_weighted_choice_zero_weight() {
652655
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
653656
Weighted { weight: 0, item: 1}]);
654657
}
655658
#[test] #[should_panic]
659+
#[allow(deprecated)]
656660
fn test_weighted_choice_weight_overflows() {
657661
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
658662
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },

src/distributions/weighted.rs

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
use Rng;
12+
use distributions::Distribution;
13+
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
14+
use ::core::cmp::PartialOrd;
15+
use ::{Error, ErrorKind};
16+
17+
// Note that this whole module is only imported if feature="alloc" is enabled.
18+
#[cfg(not(feature="std"))] use alloc::Vec;
19+
20+
/// A distribution using weighted sampling to pick a discretely selected item.
21+
///
22+
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
23+
/// selected element from the iterator used when the `WeightedIndex` was
24+
/// created. The chance of a given element being picked is proportional to the
25+
/// value of the element. The weights can use any type `X` for which an
26+
/// implementation of [`Uniform<X>`] exists.
27+
///
28+
/// # Example
29+
///
30+
/// ```
31+
/// use rand::prelude::*;
32+
/// use rand::distributions::WeightedIndex;
33+
///
34+
/// let choices = ['a', 'b', 'c'];
35+
/// let weights = [2, 1, 1];
36+
/// let dist = WeightedIndex::new(&weights).unwrap();
37+
/// let mut rng = thread_rng();
38+
/// for _ in 0..100 {
39+
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
40+
/// println!("{}", choices[dist.sample(&mut rng)]);
41+
/// }
42+
///
43+
/// let items = [('a', 0), ('b', 3), ('c', 7)];
44+
/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
45+
/// for _ in 0..100 {
46+
/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
47+
/// println!("{}", items[dist2.sample(&mut rng)].0);
48+
/// }
49+
/// ```
50+
#[derive(Debug, Clone)]
51+
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
52+
cumulative_weights: Vec<X>,
53+
weight_distribution: X::Sampler,
54+
}
55+
56+
impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
57+
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
58+
/// in `weights`. The weights can use any type `X` for which an
59+
/// implementation of [`Uniform<X>`] exists.
60+
///
61+
/// Returns an error if the iterator is empty, if any weight is `< 0`, or
62+
/// if its total value is 0.
63+
///
64+
/// [`Distribution`]: trait.Distribution.html
65+
/// [`Uniform<X>`]: struct.Uniform.html
66+
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
67+
where I: IntoIterator,
68+
I::Item: SampleBorrow<X>,
69+
X: for<'a> ::core::ops::AddAssign<&'a X> +
70+
Clone +
71+
Default {
72+
let mut iter = weights.into_iter();
73+
let mut total_weight: X = iter.next()
74+
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
75+
.borrow()
76+
.clone();
77+
78+
let zero = <X as Default>::default();
79+
if total_weight < zero {
80+
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
81+
}
82+
83+
let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
84+
for w in iter {
85+
if *w.borrow() < zero {
86+
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
87+
}
88+
weights.push(total_weight.clone());
89+
total_weight += w.borrow();
90+
}
91+
92+
if total_weight == zero {
93+
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
94+
}
95+
let distr = X::Sampler::new(zero, total_weight);
96+
97+
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
98+
}
99+
}
100+
101+
impl<X> Distribution<usize> for WeightedIndex<X> where
102+
X: SampleUniform + PartialOrd {
103+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
104+
use ::core::cmp::Ordering;
105+
let chosen_weight = self.weight_distribution.sample(rng);
106+
// Find the first item which has a weight *higher* than the chosen weight.
107+
self.cumulative_weights.binary_search_by(
108+
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
109+
}
110+
}
111+
112+
#[cfg(test)]
113+
mod test {
114+
use super::*;
115+
116+
#[test]
117+
fn test_weightedindex() {
118+
let mut r = ::test::rng(700);
119+
const N_REPS: u32 = 5000;
120+
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
121+
let total_weight = weights.iter().sum::<u32>() as f32;
122+
123+
let verify = |result: [i32; 14]| {
124+
for (i, count) in result.iter().enumerate() {
125+
let exp = (weights[i] * N_REPS) as f32 / total_weight;
126+
let mut err = (*count as f32 - exp).abs();
127+
if err != 0.0 {
128+
err /= exp;
129+
}
130+
assert!(err <= 0.25);
131+
}
132+
};
133+
134+
// WeightedIndex from vec
135+
let mut chosen = [0i32; 14];
136+
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
137+
for _ in 0..N_REPS {
138+
chosen[distr.sample(&mut r)] += 1;
139+
}
140+
verify(chosen);
141+
142+
// WeightedIndex from slice
143+
chosen = [0i32; 14];
144+
let distr = WeightedIndex::new(&weights[..]).unwrap();
145+
for _ in 0..N_REPS {
146+
chosen[distr.sample(&mut r)] += 1;
147+
}
148+
verify(chosen);
149+
150+
// WeightedIndex from iterator
151+
chosen = [0i32; 14];
152+
let distr = WeightedIndex::new(weights.iter()).unwrap();
153+
for _ in 0..N_REPS {
154+
chosen[distr.sample(&mut r)] += 1;
155+
}
156+
verify(chosen);
157+
158+
for _ in 0..5 {
159+
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
160+
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
161+
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
162+
}
163+
164+
assert!(WeightedIndex::new(&[10][0..0]).is_err());
165+
assert!(WeightedIndex::new(&[0]).is_err());
166+
assert!(WeightedIndex::new(&[10, 20, -1, 30]).is_err());
167+
assert!(WeightedIndex::new(&[-10, 20, 1, 30]).is_err());
168+
assert!(WeightedIndex::new(&[-10]).is_err());
169+
}
170+
}

src/lib.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@
134134
//!
135135
//! For more slice/sequence related functionality, look in the [`seq` module].
136136
//!
137-
//! There is also [`distributions::WeightedChoice`], which can be used to pick
138-
//! elements at random with some probability. But it does not work well at the
139-
//! moment and is going through a redesign.
140-
//!
141137
//!
142138
//! # Error handling
143139
//!
@@ -187,7 +183,6 @@
187183
//!
188184
//!
189185
//! [`distributions` module]: distributions/index.html
190-
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
191186
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
192187
//! [`Error`]: struct.Error.html
193188
//! [`gen_range`]: trait.Rng.html#method.gen_range

0 commit comments

Comments
 (0)