Skip to content

Commit 9a3f2f4

Browse files
committed
sample_weighted: Make sure the correct IndexVec is generated
Also add some tests.
1 parent 29e78c7 commit 9a3f2f4

File tree

1 file changed

+50
-10
lines changed

1 file changed

+50
-10
lines changed

src/seq/index.rs

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ where
280280
F: Fn(usize) -> X,
281281
X: Into<f64>,
282282
{
283-
if length > (::core::u32::MAX as usize) {
283+
if length > (core::u32::MAX as usize) {
284284
sample_efraimidis_spirakis(rng, length, weight, amount)
285285
} else {
286+
assert!(amount <= core::u32::MAX as usize);
286287
let amount = amount as u32;
287288
let length = length as u32;
288289
sample_efraimidis_spirakis(rng, length, weight, amount)
@@ -310,6 +311,7 @@ where
310311
F: Fn(usize) -> X,
311312
X: Into<f64>,
312313
N: UInt,
314+
IndexVec: From<Vec<N>>,
313315
{
314316
if amount == N::zero() {
315317
return Ok(IndexVec::U32(Vec::new()));
@@ -345,14 +347,17 @@ where
345347
#[cfg(feature = "nightly")]
346348
{
347349
let mut candidates = Vec::with_capacity(length.as_usize());
348-
for index in 0..length.as_usize() {
349-
let weight = weight(index).into();
350+
let mut index = N::zero();
351+
while index < length {
352+
let weight = weight(index.as_usize()).into();
350353
if !(weight >= 0.) {
351354
return Err(WeightedError::InvalidWeight);
352355
}
353356

354357
let key = rng.gen::<f64>().powf(1.0 / weight);
355-
candidates.push(Element { index, key })
358+
candidates.push(Element { index, key });
359+
360+
index += N::one();
356361
}
357362

358363
// Partially sort the array to find the `amount` elements with the greatest
@@ -362,7 +367,7 @@ where
362367
let (_, mid, greater)
363368
= candidates.partition_at_index(length.as_usize() - amount.as_usize());
364369

365-
let mut result = Vec::with_capacity(amount.as_usize());
370+
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
366371
result.push(mid.index);
367372
for element in greater {
368373
result.push(element.index);
@@ -380,17 +385,20 @@ where
380385
// Partially sort the array such that the `amount` elements with the largest
381386
// keys are first using a binary max heap.
382387
let mut candidates = BinaryHeap::with_capacity(length.as_usize());
383-
for index in 0..length.as_usize() {
384-
let weight = weight(index).into();
385-
if weight < 0.0 || weight.is_nan() {
388+
let mut index = N::zero();
389+
while index < length {
390+
let weight = weight(index.as_usize()).into();
391+
if !(weight >= 0.) {
386392
return Err(WeightedError::InvalidWeight);
387393
}
388394

389395
let key = rng.gen::<f64>().powf(1.0 / weight);
390396
candidates.push(Element { index, key });
397+
398+
index += N::one();
391399
}
392400

393-
let mut result = Vec::with_capacity(amount.as_usize());
401+
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
394402
while result.len() < amount.as_usize() {
395403
result.push(candidates.pop().unwrap().index);
396404
}
@@ -462,8 +470,10 @@ where R: Rng + ?Sized {
462470
IndexVec::from(indices)
463471
}
464472

465-
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
473+
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
474+
+ core::hash::Hash + core::ops::AddAssign {
466475
fn zero() -> Self;
476+
fn one() -> Self;
467477
fn as_usize(self) -> usize;
468478
}
469479
impl UInt for u32 {
@@ -472,6 +482,11 @@ impl UInt for u32 {
472482
0
473483
}
474484

485+
#[inline]
486+
fn one() -> Self {
487+
1
488+
}
489+
475490
#[inline]
476491
fn as_usize(self) -> usize {
477492
self as usize
@@ -483,6 +498,11 @@ impl UInt for usize {
483498
0
484499
}
485500

501+
#[inline]
502+
fn one() -> Self {
503+
1
504+
}
505+
486506
#[inline]
487507
fn as_usize(self) -> usize {
488508
self
@@ -602,6 +622,26 @@ mod test {
602622
assert_eq!(v1, v2);
603623
}
604624

625+
#[test]
626+
fn test_sample_weighted() {
627+
let seed_rng = crate::test::rng;
628+
for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] {
629+
let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap();
630+
match v {
631+
IndexVec::U32(mut indices) => {
632+
assert_eq!(indices.len(), amount);
633+
indices.sort();
634+
indices.dedup();
635+
assert_eq!(indices.len(), amount);
636+
for &i in &indices {
637+
assert!((i as usize) < len);
638+
}
639+
},
640+
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
641+
}
642+
}
643+
}
644+
605645
#[test]
606646
fn value_stability_sample() {
607647
let do_test = |length, amount, values: &[u32]| {

0 commit comments

Comments
 (0)