@@ -102,13 +102,15 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
102
102
let mut total_weight: X = iter. next ( ) . ok_or ( WeightedError :: NoItem ) ?. borrow ( ) . clone ( ) ;
103
103
104
104
let zero = <X as Default >:: default ( ) ;
105
- if total_weight < zero {
105
+ if ! ( total_weight >= zero) {
106
106
return Err ( WeightedError :: InvalidWeight ) ;
107
107
}
108
108
109
109
let mut weights = Vec :: < X > :: with_capacity ( iter. size_hint ( ) . 0 ) ;
110
110
for w in iter {
111
- if * w. borrow ( ) < zero {
111
+ // Note that `!(w >= x)` is not equivalent to `w < x` for partially
112
+ // ordered types due to NaNs which are equal to nothing.
113
+ if !( w. borrow ( ) >= & zero) {
112
114
return Err ( WeightedError :: InvalidWeight ) ;
113
115
}
114
116
weights. push ( total_weight. clone ( ) ) ;
@@ -158,7 +160,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
158
160
return Err ( WeightedError :: InvalidWeight ) ;
159
161
}
160
162
}
161
- if * w < zero {
163
+ if ! ( * w >= zero) {
162
164
return Err ( WeightedError :: InvalidWeight ) ;
163
165
}
164
166
if i >= self . cumulative_weights . len ( ) + 1 {
@@ -256,6 +258,30 @@ mod test {
256
258
assert_eq ! ( de_weighted_index. total_weight, weighted_index. total_weight) ;
257
259
}
258
260
261
+ #[ test]
262
+ fn test_accepting_nan ( ) {
263
+ assert_eq ! (
264
+ WeightedIndex :: new( & [ core:: f32 :: NAN , 0.5 ] ) . unwrap_err( ) ,
265
+ WeightedError :: InvalidWeight ,
266
+ ) ;
267
+ assert_eq ! (
268
+ WeightedIndex :: new( & [ core:: f32 :: NAN ] ) . unwrap_err( ) ,
269
+ WeightedError :: InvalidWeight ,
270
+ ) ;
271
+ assert_eq ! (
272
+ WeightedIndex :: new( & [ 0.5 , core:: f32 :: NAN ] ) . unwrap_err( ) ,
273
+ WeightedError :: InvalidWeight ,
274
+ ) ;
275
+
276
+ assert_eq ! (
277
+ WeightedIndex :: new( & [ 0.5 , 7.0 ] )
278
+ . unwrap( )
279
+ . update_weights( & [ ( 0 , & core:: f32 :: NAN ) ] )
280
+ . unwrap_err( ) ,
281
+ WeightedError :: InvalidWeight ,
282
+ )
283
+ }
284
+
259
285
260
286
#[ test]
261
287
#[ cfg_attr( miri, ignore) ] // Miri is too slow
@@ -399,8 +425,8 @@ pub enum WeightedError {
399
425
/// The provided weight collection contains no items.
400
426
NoItem ,
401
427
402
- /// A weight is either less than zero, greater than the supported maximum or
403
- /// otherwise invalid.
428
+ /// A weight is either less than zero, greater than the supported maximum,
429
+ /// NaN, or otherwise invalid.
404
430
InvalidWeight ,
405
431
406
432
/// All items in the provided weight collection are zero.
0 commit comments