Skip to content

Commit bebfa85

Browse files
committed
Add NaN handling in dyn scalar comparison kernels
1 parent e2bf158 commit bebfa85

File tree

1 file changed

+149
-22
lines changed

1 file changed

+149
-22
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 149 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use crate::datatypes::{
3838
use crate::downcast_dictionary_array;
3939
use crate::error::{ArrowError, Result};
4040
use crate::util::bit_util;
41+
use num::ToPrimitive;
4142
use regex::Regex;
4243
use std::collections::HashMap;
4344

@@ -1328,7 +1329,11 @@ macro_rules! dyn_compare_utf8_scalar {
13281329
}
13291330

13301331
/// Perform `left == right` operation on an array and a numeric scalar
1331-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1332+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1333+
///
1334+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1335+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1336+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13321337
pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13331338
where
13341339
T: num::ToPrimitive + std::fmt::Debug,
@@ -1342,7 +1347,11 @@ where
13421347
}
13431348

13441349
/// Perform `left < right` operation on an array and a numeric scalar
1345-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1350+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1351+
///
1352+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1353+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1354+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13461355
pub fn lt_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13471356
where
13481357
T: num::ToPrimitive + std::fmt::Debug,
@@ -1356,7 +1365,11 @@ where
13561365
}
13571366

13581367
/// Perform `left <= right` operation on an array and a numeric scalar
1359-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1368+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1369+
///
1370+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1371+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1372+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13601373
pub fn lt_eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13611374
where
13621375
T: num::ToPrimitive + std::fmt::Debug,
@@ -1370,7 +1383,11 @@ where
13701383
}
13711384

13721385
/// Perform `left > right` operation on an array and a numeric scalar
1373-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1386+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1387+
///
1388+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1389+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1390+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13741391
pub fn gt_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13751392
where
13761393
T: num::ToPrimitive + std::fmt::Debug,
@@ -1384,7 +1401,11 @@ where
13841401
}
13851402

13861403
/// Perform `left >= right` operation on an array and a numeric scalar
1387-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1404+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1405+
///
1406+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1407+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1408+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13881409
pub fn gt_eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13891410
where
13901411
T: num::ToPrimitive + std::fmt::Debug,
@@ -1398,7 +1419,11 @@ where
13981419
}
13991420

14001421
/// Perform `left != right` operation on an array and a numeric scalar
1401-
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
1422+
/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values.
1423+
///
1424+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
1425+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
1426+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
14021427
pub fn neq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
14031428
where
14041429
T: num::ToPrimitive + std::fmt::Debug,
@@ -3019,14 +3044,31 @@ where
30193044
}
30203045

30213046
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
3047+
///
3048+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3049+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3050+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
30223051
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30233052
where
30243053
T: ArrowNumericType,
3054+
T::Native: num::ToPrimitive + std::fmt::Debug,
30253055
{
30263056
#[cfg(feature = "simd")]
30273057
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
30283058
#[cfg(not(feature = "simd"))]
3029-
return compare_op_scalar(left, |a| a == right);
3059+
match left.data_type() {
3060+
DataType::Float32 => {
3061+
let left = as_primitive_array::<Float32Type>(left);
3062+
let right = try_to_type!(right, to_f32)?;
3063+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_eq())
3064+
}
3065+
DataType::Float64 => {
3066+
let left = as_primitive_array::<Float64Type>(left);
3067+
let right = try_to_type!(right, to_f64)?;
3068+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_eq())
3069+
}
3070+
_ => compare_op_scalar(left, |a| a == right),
3071+
}
30303072
}
30313073

30323074
/// Applies an unary and infallible comparison function to a primitive array.
@@ -3050,14 +3092,31 @@ where
30503092
}
30513093

30523094
/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
3095+
///
3096+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3097+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3098+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
30533099
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30543100
where
30553101
T: ArrowNumericType,
3102+
T::Native: num::ToPrimitive + std::fmt::Debug,
30563103
{
30573104
#[cfg(feature = "simd")]
30583105
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
30593106
#[cfg(not(feature = "simd"))]
3060-
return compare_op_scalar(left, |a| a != right);
3107+
match left.data_type() {
3108+
DataType::Float32 => {
3109+
let left = as_primitive_array::<Float32Type>(left);
3110+
let right = try_to_type!(right, to_f32)?;
3111+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ne())
3112+
}
3113+
DataType::Float64 => {
3114+
let left = as_primitive_array::<Float64Type>(left);
3115+
let right = try_to_type!(right, to_f64)?;
3116+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ne())
3117+
}
3118+
_ => compare_op_scalar(left, |a| a != right),
3119+
}
30613120
}
30623121

30633122
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3074,14 +3133,31 @@ where
30743133

30753134
/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
30763135
/// Null values are less than non-null values.
3136+
///
3137+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3138+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3139+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
30773140
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30783141
where
30793142
T: ArrowNumericType,
3143+
T::Native: num::ToPrimitive + std::fmt::Debug,
30803144
{
30813145
#[cfg(feature = "simd")]
30823146
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
30833147
#[cfg(not(feature = "simd"))]
3084-
return compare_op_scalar(left, |a| a < right);
3148+
match left.data_type() {
3149+
DataType::Float32 => {
3150+
let left = as_primitive_array::<Float32Type>(left);
3151+
let right = try_to_type!(right, to_f32)?;
3152+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_lt())
3153+
}
3154+
DataType::Float64 => {
3155+
let left = as_primitive_array::<Float64Type>(left);
3156+
let right = try_to_type!(right, to_f64)?;
3157+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_lt())
3158+
}
3159+
_ => compare_op_scalar(left, |a| a < right),
3160+
}
30853161
}
30863162

30873163
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3101,14 +3177,31 @@ where
31013177

31023178
/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
31033179
/// Null values are less than non-null values.
3180+
///
3181+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3182+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3183+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
31043184
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31053185
where
31063186
T: ArrowNumericType,
3187+
T::Native: num::ToPrimitive + std::fmt::Debug,
31073188
{
31083189
#[cfg(feature = "simd")]
31093190
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
31103191
#[cfg(not(feature = "simd"))]
3111-
return compare_op_scalar(left, |a| a <= right);
3192+
match left.data_type() {
3193+
DataType::Float32 => {
3194+
let left = as_primitive_array::<Float32Type>(left);
3195+
let right = try_to_type!(right, to_f32)?;
3196+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_le())
3197+
}
3198+
DataType::Float64 => {
3199+
let left = as_primitive_array::<Float64Type>(left);
3200+
let right = try_to_type!(right, to_f64)?;
3201+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_le())
3202+
}
3203+
_ => compare_op_scalar(left, |a| a <= right),
3204+
}
31123205
}
31133206

31143207
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3125,14 +3218,31 @@ where
31253218

31263219
/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
31273220
/// Non-null values are greater than null values.
3221+
///
3222+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3223+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3224+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
31283225
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31293226
where
31303227
T: ArrowNumericType,
3228+
T::Native: num::ToPrimitive + std::fmt::Debug,
31313229
{
31323230
#[cfg(feature = "simd")]
31333231
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
31343232
#[cfg(not(feature = "simd"))]
3135-
return compare_op_scalar(left, |a| a > right);
3233+
match left.data_type() {
3234+
DataType::Float32 => {
3235+
let left = as_primitive_array::<Float32Type>(left);
3236+
let right = try_to_type!(right, to_f32)?;
3237+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_gt())
3238+
}
3239+
DataType::Float64 => {
3240+
let left = as_primitive_array::<Float64Type>(left);
3241+
let right = try_to_type!(right, to_f64)?;
3242+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_gt())
3243+
}
3244+
_ => compare_op_scalar(left, |a| a > right),
3245+
}
31363246
}
31373247

31383248
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3152,14 +3262,31 @@ where
31523262

31533263
/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
31543264
/// Non-null values are greater than null values.
3265+
///
3266+
/// For floating values like f32 and f64, this comparison produces an ordering in accordance to
3267+
/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard.
3268+
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
31553269
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31563270
where
31573271
T: ArrowNumericType,
3272+
T::Native: num::ToPrimitive + std::fmt::Debug,
31583273
{
31593274
#[cfg(feature = "simd")]
31603275
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
31613276
#[cfg(not(feature = "simd"))]
3162-
return compare_op_scalar(left, |a| a >= right);
3277+
match left.data_type() {
3278+
DataType::Float32 => {
3279+
let left = as_primitive_array::<Float32Type>(left);
3280+
let right = try_to_type!(right, to_f32)?;
3281+
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ge())
3282+
}
3283+
DataType::Float64 => {
3284+
let left = as_primitive_array::<Float64Type>(left);
3285+
let right = try_to_type!(right, to_f64)?;
3286+
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ge())
3287+
}
3288+
_ => compare_op_scalar(left, |a| a >= right),
3289+
}
31633290
}
31643291

31653292
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
@@ -5852,12 +5979,12 @@ mod tests {
58525979
.map(Some)
58535980
.collect();
58545981
let expected = BooleanArray::from(
5855-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
5982+
vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
58565983
);
58575984
assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
58585985

58595986
let expected = BooleanArray::from(
5860-
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
5987+
vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
58615988
);
58625989
assert_eq!(neq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
58635990

@@ -5866,12 +5993,12 @@ mod tests {
58665993
.map(Some)
58675994
.collect();
58685995
let expected = BooleanArray::from(
5869-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
5996+
vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
58705997
);
58715998
assert_eq!(eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
58725999

58736000
let expected = BooleanArray::from(
5874-
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
6001+
vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
58756002
);
58766003
assert_eq!(neq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
58776004
}
@@ -5883,12 +6010,12 @@ mod tests {
58836010
.map(Some)
58846011
.collect();
58856012
let expected = BooleanArray::from(
5886-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6013+
vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
58876014
);
58886015
assert_eq!(lt_dyn_scalar(&array, f32::NAN).unwrap(), expected);
58896016

58906017
let expected = BooleanArray::from(
5891-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6018+
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
58926019
);
58936020
assert_eq!(lt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
58946021

@@ -5897,12 +6024,12 @@ mod tests {
58976024
.map(Some)
58986025
.collect();
58996026
let expected = BooleanArray::from(
5900-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6027+
vec![Some(false), Some(true), Some(true), Some(true), Some(true)],
59016028
);
59026029
assert_eq!(lt_dyn_scalar(&array, f64::NAN).unwrap(), expected);
59036030

59046031
let expected = BooleanArray::from(
5905-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6032+
vec![Some(true), Some(true), Some(true), Some(true), Some(true)],
59066033
);
59076034
assert_eq!(lt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
59086035
}
@@ -5919,7 +6046,7 @@ mod tests {
59196046
assert_eq!(gt_dyn_scalar(&array, f32::NAN).unwrap(), expected);
59206047

59216048
let expected = BooleanArray::from(
5922-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6049+
vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
59236050
);
59246051
assert_eq!(gt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
59256052

@@ -5933,7 +6060,7 @@ mod tests {
59336060
assert_eq!(gt_dyn_scalar(&array, f64::NAN).unwrap(), expected);
59346061

59356062
let expected = BooleanArray::from(
5936-
vec![Some(false), Some(false), Some(false), Some(false), Some(false)],
6063+
vec![Some(true), Some(false), Some(false), Some(false), Some(false)],
59376064
);
59386065
assert_eq!(gt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected);
59396066
}

0 commit comments

Comments
 (0)