Skip to content

Commit 1780feb

Browse files
committed
Use trait
1 parent b53d4a1 commit 1780feb

File tree

2 files changed

+78
-95
lines changed

2 files changed

+78
-95
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 20 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ use crate::array::*;
2727
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
2828
use crate::compute::util::combine_option_bitmap;
2929
use crate::datatypes::{
30-
ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
31-
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
32-
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType,
33-
Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit,
34-
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
35-
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30+
native_op::ArrowNativeTypeOp, ArrowNativeType, ArrowNumericType, DataType,
31+
Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
32+
Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
33+
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
34+
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
35+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
36+
UInt32Type, UInt64Type, UInt8Type,
3637
};
3738
#[allow(unused_imports)]
3839
use crate::downcast_dictionary_array;
3940
use crate::error::{ArrowError, Result};
4041
use crate::util::bit_util;
41-
use num::ToPrimitive;
4242
use regex::Regex;
4343
use std::collections::HashMap;
4444

@@ -1336,7 +1336,7 @@ macro_rules! dyn_compare_utf8_scalar {
13361336
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13371337
pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13381338
where
1339-
T: num::ToPrimitive + std::fmt::Debug,
1339+
T: ArrowNativeTypeOp,
13401340
{
13411341
match left.data_type() {
13421342
DataType::Dictionary(key_type, _value_type) => {
@@ -3048,24 +3048,12 @@ where
30483048
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30493049
where
30503050
T: ArrowNumericType,
3051-
T::Native: num::ToPrimitive + std::fmt::Debug,
3051+
T::Native: ArrowNativeTypeOp,
30523052
{
30533053
#[cfg(feature = "simd")]
30543054
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
30553055
#[cfg(not(feature = "simd"))]
3056-
match left.data_type() {
3057-
DataType::Float32 => {
3058-
let left = as_primitive_array::<Float32Type>(left);
3059-
let right = try_to_type!(right, to_f32)?;
3060-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_eq())
3061-
}
3062-
DataType::Float64 => {
3063-
let left = as_primitive_array::<Float64Type>(left);
3064-
let right = try_to_type!(right, to_f64)?;
3065-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_eq())
3066-
}
3067-
_ => compare_op_scalar(left, |a| a == right),
3068-
}
3056+
return compare_op_scalar(left, |a| a.is_eq(right));
30693057
}
30703058

30713059
/// Applies an unary and infallible comparison function to a primitive array.
@@ -3096,24 +3084,12 @@ where
30963084
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30973085
where
30983086
T: ArrowNumericType,
3099-
T::Native: num::ToPrimitive + std::fmt::Debug,
3087+
T::Native: ArrowNativeTypeOp,
31003088
{
31013089
#[cfg(feature = "simd")]
31023090
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
31033091
#[cfg(not(feature = "simd"))]
3104-
match left.data_type() {
3105-
DataType::Float32 => {
3106-
let left = as_primitive_array::<Float32Type>(left);
3107-
let right = try_to_type!(right, to_f32)?;
3108-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ne())
3109-
}
3110-
DataType::Float64 => {
3111-
let left = as_primitive_array::<Float64Type>(left);
3112-
let right = try_to_type!(right, to_f64)?;
3113-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ne())
3114-
}
3115-
_ => compare_op_scalar(left, |a| a != right),
3116-
}
3092+
return compare_op_scalar(left, |a| a.is_ne(right));
31173093
}
31183094

31193095
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3137,24 +3113,12 @@ where
31373113
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31383114
where
31393115
T: ArrowNumericType,
3140-
T::Native: num::ToPrimitive + std::fmt::Debug,
3116+
T::Native: ArrowNativeTypeOp,
31413117
{
31423118
#[cfg(feature = "simd")]
31433119
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
31443120
#[cfg(not(feature = "simd"))]
3145-
match left.data_type() {
3146-
DataType::Float32 => {
3147-
let left = as_primitive_array::<Float32Type>(left);
3148-
let right = try_to_type!(right, to_f32)?;
3149-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_lt())
3150-
}
3151-
DataType::Float64 => {
3152-
let left = as_primitive_array::<Float64Type>(left);
3153-
let right = try_to_type!(right, to_f64)?;
3154-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_lt())
3155-
}
3156-
_ => compare_op_scalar(left, |a| a < right),
3157-
}
3121+
return compare_op_scalar(left, |a| a.is_lt(right));
31583122
}
31593123

31603124
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3181,24 +3145,12 @@ where
31813145
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31823146
where
31833147
T: ArrowNumericType,
3184-
T::Native: num::ToPrimitive + std::fmt::Debug,
3148+
T::Native: ArrowNativeTypeOp,
31853149
{
31863150
#[cfg(feature = "simd")]
31873151
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
31883152
#[cfg(not(feature = "simd"))]
3189-
match left.data_type() {
3190-
DataType::Float32 => {
3191-
let left = as_primitive_array::<Float32Type>(left);
3192-
let right = try_to_type!(right, to_f32)?;
3193-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_le())
3194-
}
3195-
DataType::Float64 => {
3196-
let left = as_primitive_array::<Float64Type>(left);
3197-
let right = try_to_type!(right, to_f64)?;
3198-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_le())
3199-
}
3200-
_ => compare_op_scalar(left, |a| a <= right),
3201-
}
3153+
return compare_op_scalar(left, |a| a.is_le(right));
32023154
}
32033155

32043156
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3222,24 +3174,12 @@ where
32223174
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
32233175
where
32243176
T: ArrowNumericType,
3225-
T::Native: num::ToPrimitive + std::fmt::Debug,
3177+
T::Native: ArrowNativeTypeOp,
32263178
{
32273179
#[cfg(feature = "simd")]
32283180
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
32293181
#[cfg(not(feature = "simd"))]
3230-
match left.data_type() {
3231-
DataType::Float32 => {
3232-
let left = as_primitive_array::<Float32Type>(left);
3233-
let right = try_to_type!(right, to_f32)?;
3234-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_gt())
3235-
}
3236-
DataType::Float64 => {
3237-
let left = as_primitive_array::<Float64Type>(left);
3238-
let right = try_to_type!(right, to_f64)?;
3239-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_gt())
3240-
}
3241-
_ => compare_op_scalar(left, |a| a > right),
3242-
}
3182+
return compare_op_scalar(left, |a| a.is_gt(right));
32433183
}
32443184

32453185
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3266,24 +3206,12 @@ where
32663206
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
32673207
where
32683208
T: ArrowNumericType,
3269-
T::Native: num::ToPrimitive + std::fmt::Debug,
3209+
T::Native: ArrowNativeTypeOp,
32703210
{
32713211
#[cfg(feature = "simd")]
32723212
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
32733213
#[cfg(not(feature = "simd"))]
3274-
match left.data_type() {
3275-
DataType::Float32 => {
3276-
let left = as_primitive_array::<Float32Type>(left);
3277-
let right = try_to_type!(right, to_f32)?;
3278-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ge())
3279-
}
3280-
DataType::Float64 => {
3281-
let left = as_primitive_array::<Float64Type>(left);
3282-
let right = try_to_type!(right, to_f64)?;
3283-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ge())
3284-
}
3285-
_ => compare_op_scalar(left, |a| a >= right),
3286-
}
3214+
return compare_op_scalar(left, |a| a.is_ge(right));
32873215
}
32883216

32893217
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]

arrow/src/datatypes/native.rs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub(crate) mod native_op {
4646
+ Div<Output = Self>
4747
+ Rem<Output = Self>
4848
+ Zero
49+
+ num::ToPrimitive
4950
{
5051
fn add_checked(self, rhs: Self) -> Result<Self> {
5152
Ok(self + rhs)
@@ -94,6 +95,30 @@ pub(crate) mod native_op {
9495
fn mod_wrapping(self, rhs: Self) -> Self {
9596
self % rhs
9697
}
98+
99+
fn is_eq(self, rhs: Self) -> bool {
100+
self == rhs
101+
}
102+
103+
fn is_ne(self, rhs: Self) -> bool {
104+
self != rhs
105+
}
106+
107+
fn is_lt(self, rhs: Self) -> bool {
108+
self < rhs
109+
}
110+
111+
fn is_le(self, rhs: Self) -> bool {
112+
self <= rhs
113+
}
114+
115+
fn is_gt(self, rhs: Self) -> bool {
116+
self > rhs
117+
}
118+
119+
fn is_ge(self, rhs: Self) -> bool {
120+
self >= rhs
121+
}
97122
}
98123
}
99124

@@ -186,6 +211,36 @@ native_type_op!(u16);
186211
native_type_op!(u32);
187212
native_type_op!(u64);
188213

189-
impl native_op::ArrowNativeTypeOp for f16 {}
190-
impl native_op::ArrowNativeTypeOp for f32 {}
191-
impl native_op::ArrowNativeTypeOp for f64 {}
214+
macro_rules! native_type_float_op {
215+
($t:tt) => {
216+
impl native_op::ArrowNativeTypeOp for $t {
217+
fn is_eq(self, rhs: Self) -> bool {
218+
self.total_cmp(&rhs).is_eq()
219+
}
220+
221+
fn is_ne(self, rhs: Self) -> bool {
222+
self.total_cmp(&rhs).is_ne()
223+
}
224+
225+
fn is_lt(self, rhs: Self) -> bool {
226+
self.total_cmp(&rhs).is_lt()
227+
}
228+
229+
fn is_le(self, rhs: Self) -> bool {
230+
self.total_cmp(&rhs).is_le()
231+
}
232+
233+
fn is_gt(self, rhs: Self) -> bool {
234+
self.total_cmp(&rhs).is_gt()
235+
}
236+
237+
fn is_ge(self, rhs: Self) -> bool {
238+
self.total_cmp(&rhs).is_ge()
239+
}
240+
}
241+
};
242+
}
243+
244+
native_type_float_op!(f16);
245+
native_type_float_op!(f32);
246+
native_type_float_op!(f64);

0 commit comments

Comments
 (0)