Skip to content

Commit cc2e7dd

Browse files
wschellavks
authored andcommitted
Handle NaN in WeightedIndex with error instead of panic
1 parent 39a37f0 commit cc2e7dd

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

src/distributions/weighted_index.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
102102
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
103103

104104
let zero = <X as Default>::default();
105-
if total_weight < zero {
105+
if !(total_weight >= zero) {
106106
return Err(WeightedError::InvalidWeight);
107107
}
108108

109109
let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
110110
for w in iter {
111-
if *w.borrow() < zero {
111+
// Note that `!(w >= x)` is not equivalent to `w < x` for partially
112+
// ordered types due to NaNs which are equal to nothing.
113+
if !(w.borrow() >= &zero) {
112114
return Err(WeightedError::InvalidWeight);
113115
}
114116
weights.push(total_weight.clone());
@@ -158,7 +160,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
158160
return Err(WeightedError::InvalidWeight);
159161
}
160162
}
161-
if *w < zero {
163+
if !(*w >= zero) {
162164
return Err(WeightedError::InvalidWeight);
163165
}
164166
if i >= self.cumulative_weights.len() + 1 {
@@ -256,6 +258,30 @@ mod test {
256258
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
257259
}
258260

261+
#[test]
262+
fn test_accepting_nan(){
263+
assert_eq!(
264+
WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
265+
WeightedError::InvalidWeight,
266+
);
267+
assert_eq!(
268+
WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
269+
WeightedError::InvalidWeight,
270+
);
271+
assert_eq!(
272+
WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
273+
WeightedError::InvalidWeight,
274+
);
275+
276+
assert_eq!(
277+
WeightedIndex::new(&[0.5, 7.0])
278+
.unwrap()
279+
.update_weights(&[(0, &core::f32::NAN)])
280+
.unwrap_err(),
281+
WeightedError::InvalidWeight,
282+
)
283+
}
284+
259285

260286
#[test]
261287
#[cfg_attr(miri, ignore)] // Miri is too slow
@@ -399,8 +425,8 @@ pub enum WeightedError {
399425
/// The provided weight collection contains no items.
400426
NoItem,
401427

402-
/// A weight is either less than zero, greater than the supported maximum or
403-
/// otherwise invalid.
428+
/// A weight is either less than zero, greater than the supported maximum,
429+
/// NaN, or otherwise invalid.
404430
InvalidWeight,
405431

406432
/// All items in the provided weight collection are zero.

0 commit comments

Comments
 (0)