@@ -280,9 +280,10 @@ where
280
280
F : Fn ( usize ) -> X ,
281
281
X : Into < f64 > ,
282
282
{
283
- if length > ( :: core:: u32:: MAX as usize ) {
283
+ if length > ( core:: u32:: MAX as usize ) {
284
284
sample_efraimidis_spirakis ( rng, length, weight, amount)
285
285
} else {
286
+ assert ! ( amount <= core:: u32 :: MAX as usize ) ;
286
287
let amount = amount as u32 ;
287
288
let length = length as u32 ;
288
289
sample_efraimidis_spirakis ( rng, length, weight, amount)
@@ -310,6 +311,7 @@ where
310
311
F : Fn ( usize ) -> X ,
311
312
X : Into < f64 > ,
312
313
N : UInt ,
314
+ IndexVec : From < Vec < N > > ,
313
315
{
314
316
if amount == N :: zero ( ) {
315
317
return Ok ( IndexVec :: U32 ( Vec :: new ( ) ) ) ;
@@ -345,14 +347,17 @@ where
345
347
#[ cfg( feature = "nightly" ) ]
346
348
{
347
349
let mut candidates = Vec :: with_capacity ( length. as_usize ( ) ) ;
348
- for index in 0 ..length. as_usize ( ) {
349
- let weight = weight ( index) . into ( ) ;
350
+ let mut index = N :: zero ( ) ;
351
+ while index < length {
352
+ let weight = weight ( index. as_usize ( ) ) . into ( ) ;
350
353
if !( weight >= 0. ) {
351
354
return Err ( WeightedError :: InvalidWeight ) ;
352
355
}
353
356
354
357
let key = rng. gen :: < f64 > ( ) . powf ( 1.0 / weight) ;
355
- candidates. push ( Element { index, key } )
358
+ candidates. push ( Element { index, key } ) ;
359
+
360
+ index += N :: one ( ) ;
356
361
}
357
362
358
363
// Partially sort the array to find the `amount` elements with the greatest
@@ -362,7 +367,7 @@ where
362
367
let ( _, mid, greater)
363
368
= candidates. partition_at_index ( length. as_usize ( ) - amount. as_usize ( ) ) ;
364
369
365
- let mut result = Vec :: with_capacity ( amount. as_usize ( ) ) ;
370
+ let mut result: Vec < N > = Vec :: with_capacity ( amount. as_usize ( ) ) ;
366
371
result. push ( mid. index ) ;
367
372
for element in greater {
368
373
result. push ( element. index ) ;
@@ -380,17 +385,20 @@ where
380
385
// Partially sort the array such that the `amount` elements with the largest
381
386
// keys are first using a binary max heap.
382
387
let mut candidates = BinaryHeap :: with_capacity ( length. as_usize ( ) ) ;
383
- for index in 0 ..length. as_usize ( ) {
384
- let weight = weight ( index) . into ( ) ;
385
- if weight < 0.0 || weight. is_nan ( ) {
388
+ let mut index = N :: zero ( ) ;
389
+ while index < length {
390
+ let weight = weight ( index. as_usize ( ) ) . into ( ) ;
391
+ if !( weight >= 0. ) {
386
392
return Err ( WeightedError :: InvalidWeight ) ;
387
393
}
388
394
389
395
let key = rng. gen :: < f64 > ( ) . powf ( 1.0 / weight) ;
390
396
candidates. push ( Element { index, key } ) ;
397
+
398
+ index += N :: one ( ) ;
391
399
}
392
400
393
- let mut result = Vec :: with_capacity ( amount. as_usize ( ) ) ;
401
+ let mut result: Vec < N > = Vec :: with_capacity ( amount. as_usize ( ) ) ;
394
402
while result. len ( ) < amount. as_usize ( ) {
395
403
result. push ( candidates. pop ( ) . unwrap ( ) . index ) ;
396
404
}
@@ -462,8 +470,10 @@ where R: Rng + ?Sized {
462
470
IndexVec :: from ( indices)
463
471
}
464
472
465
- trait UInt : Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core:: hash:: Hash {
473
+ trait UInt : Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
474
+ + core:: hash:: Hash + core:: ops:: AddAssign {
466
475
fn zero ( ) -> Self ;
476
+ fn one ( ) -> Self ;
467
477
fn as_usize ( self ) -> usize ;
468
478
}
469
479
impl UInt for u32 {
@@ -472,6 +482,11 @@ impl UInt for u32 {
472
482
0
473
483
}
474
484
485
+ #[ inline]
486
+ fn one ( ) -> Self {
487
+ 1
488
+ }
489
+
475
490
#[ inline]
476
491
fn as_usize ( self ) -> usize {
477
492
self as usize
@@ -483,6 +498,11 @@ impl UInt for usize {
483
498
0
484
499
}
485
500
501
+ #[ inline]
502
+ fn one ( ) -> Self {
503
+ 1
504
+ }
505
+
486
506
#[ inline]
487
507
fn as_usize ( self ) -> usize {
488
508
self
@@ -602,6 +622,26 @@ mod test {
602
622
assert_eq ! ( v1, v2) ;
603
623
}
604
624
625
+ #[ test]
626
+ fn test_sample_weighted ( ) {
627
+ let seed_rng = crate :: test:: rng;
628
+ for & ( amount, len) in & [ ( 0 , 10 ) , ( 5 , 10 ) , ( 10 , 10 ) ] {
629
+ let v = sample_weighted ( & mut seed_rng ( 423 ) , len, |i| i as f64 , amount) . unwrap ( ) ;
630
+ match v {
631
+ IndexVec :: U32 ( mut indices) => {
632
+ assert_eq ! ( indices. len( ) , amount) ;
633
+ indices. sort ( ) ;
634
+ indices. dedup ( ) ;
635
+ assert_eq ! ( indices. len( ) , amount) ;
636
+ for & i in & indices {
637
+ assert ! ( ( i as usize ) < len) ;
638
+ }
639
+ } ,
640
+ IndexVec :: USize ( _) => panic ! ( "expected `IndexVec::U32`" ) ,
641
+ }
642
+ }
643
+ }
644
+
605
645
#[ test]
606
646
fn value_stability_sample ( ) {
607
647
let do_test = |length, amount, values : & [ u32 ] | {
0 commit comments