Skip to content

Commit 0ea6f98

Browse files
committed
Use trait
1 parent bebfa85 commit 0ea6f98

File tree

2 files changed

+78
-95
lines changed

2 files changed

+78
-95
lines changed

arrow/src/compute/kernels/comparison.rs

Lines changed: 20 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ use crate::array::*;
2727
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
2828
use crate::compute::util::combine_option_bitmap;
2929
use crate::datatypes::{
30-
ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
31-
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
32-
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType,
33-
Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit,
34-
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
35-
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
30+
native_op::ArrowNativeTypeOp, ArrowNativeType, ArrowNumericType, DataType,
31+
Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
32+
Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
33+
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
34+
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
35+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
36+
UInt32Type, UInt64Type, UInt8Type,
3637
};
3738
#[allow(unused_imports)]
3839
use crate::downcast_dictionary_array;
3940
use crate::error::{ArrowError, Result};
4041
use crate::util::bit_util;
41-
use num::ToPrimitive;
4242
use regex::Regex;
4343
use std::collections::HashMap;
4444

@@ -1336,7 +1336,7 @@ macro_rules! dyn_compare_utf8_scalar {
13361336
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
13371337
pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
13381338
where
1339-
T: num::ToPrimitive + std::fmt::Debug,
1339+
T: ArrowNativeTypeOp,
13401340
{
13411341
match left.data_type() {
13421342
DataType::Dictionary(key_type, _value_type) => {
@@ -3051,24 +3051,12 @@ where
30513051
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
30523052
where
30533053
T: ArrowNumericType,
3054-
T::Native: num::ToPrimitive + std::fmt::Debug,
3054+
T::Native: ArrowNativeTypeOp,
30553055
{
30563056
#[cfg(feature = "simd")]
30573057
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
30583058
#[cfg(not(feature = "simd"))]
3059-
match left.data_type() {
3060-
DataType::Float32 => {
3061-
let left = as_primitive_array::<Float32Type>(left);
3062-
let right = try_to_type!(right, to_f32)?;
3063-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_eq())
3064-
}
3065-
DataType::Float64 => {
3066-
let left = as_primitive_array::<Float64Type>(left);
3067-
let right = try_to_type!(right, to_f64)?;
3068-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_eq())
3069-
}
3070-
_ => compare_op_scalar(left, |a| a == right),
3071-
}
3059+
return compare_op_scalar(left, |a| a.is_eq(right));
30723060
}
30733061

30743062
/// Applies an unary and infallible comparison function to a primitive array.
@@ -3099,24 +3087,12 @@ where
30993087
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31003088
where
31013089
T: ArrowNumericType,
3102-
T::Native: num::ToPrimitive + std::fmt::Debug,
3090+
T::Native: ArrowNativeTypeOp,
31033091
{
31043092
#[cfg(feature = "simd")]
31053093
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
31063094
#[cfg(not(feature = "simd"))]
3107-
match left.data_type() {
3108-
DataType::Float32 => {
3109-
let left = as_primitive_array::<Float32Type>(left);
3110-
let right = try_to_type!(right, to_f32)?;
3111-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ne())
3112-
}
3113-
DataType::Float64 => {
3114-
let left = as_primitive_array::<Float64Type>(left);
3115-
let right = try_to_type!(right, to_f64)?;
3116-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ne())
3117-
}
3118-
_ => compare_op_scalar(left, |a| a != right),
3119-
}
3095+
return compare_op_scalar(left, |a| a.is_ne(right));
31203096
}
31213097

31223098
/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3140,24 +3116,12 @@ where
31403116
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31413117
where
31423118
T: ArrowNumericType,
3143-
T::Native: num::ToPrimitive + std::fmt::Debug,
3119+
T::Native: ArrowNativeTypeOp,
31443120
{
31453121
#[cfg(feature = "simd")]
31463122
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
31473123
#[cfg(not(feature = "simd"))]
3148-
match left.data_type() {
3149-
DataType::Float32 => {
3150-
let left = as_primitive_array::<Float32Type>(left);
3151-
let right = try_to_type!(right, to_f32)?;
3152-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_lt())
3153-
}
3154-
DataType::Float64 => {
3155-
let left = as_primitive_array::<Float64Type>(left);
3156-
let right = try_to_type!(right, to_f64)?;
3157-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_lt())
3158-
}
3159-
_ => compare_op_scalar(left, |a| a < right),
3160-
}
3124+
return compare_op_scalar(left, |a| a.is_lt(right));
31613125
}
31623126

31633127
/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
@@ -3184,24 +3148,12 @@ where
31843148
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
31853149
where
31863150
T: ArrowNumericType,
3187-
T::Native: num::ToPrimitive + std::fmt::Debug,
3151+
T::Native: ArrowNativeTypeOp,
31883152
{
31893153
#[cfg(feature = "simd")]
31903154
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
31913155
#[cfg(not(feature = "simd"))]
3192-
match left.data_type() {
3193-
DataType::Float32 => {
3194-
let left = as_primitive_array::<Float32Type>(left);
3195-
let right = try_to_type!(right, to_f32)?;
3196-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_le())
3197-
}
3198-
DataType::Float64 => {
3199-
let left = as_primitive_array::<Float64Type>(left);
3200-
let right = try_to_type!(right, to_f64)?;
3201-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_le())
3202-
}
3203-
_ => compare_op_scalar(left, |a| a <= right),
3204-
}
3156+
return compare_op_scalar(left, |a| a.is_le(right));
32053157
}
32063158

32073159
/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3225,24 +3177,12 @@ where
32253177
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
32263178
where
32273179
T: ArrowNumericType,
3228-
T::Native: num::ToPrimitive + std::fmt::Debug,
3180+
T::Native: ArrowNativeTypeOp,
32293181
{
32303182
#[cfg(feature = "simd")]
32313183
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
32323184
#[cfg(not(feature = "simd"))]
3233-
match left.data_type() {
3234-
DataType::Float32 => {
3235-
let left = as_primitive_array::<Float32Type>(left);
3236-
let right = try_to_type!(right, to_f32)?;
3237-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_gt())
3238-
}
3239-
DataType::Float64 => {
3240-
let left = as_primitive_array::<Float64Type>(left);
3241-
let right = try_to_type!(right, to_f64)?;
3242-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_gt())
3243-
}
3244-
_ => compare_op_scalar(left, |a| a > right),
3245-
}
3185+
return compare_op_scalar(left, |a| a.is_gt(right));
32463186
}
32473187

32483188
/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
@@ -3269,24 +3209,12 @@ where
32693209
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
32703210
where
32713211
T: ArrowNumericType,
3272-
T::Native: num::ToPrimitive + std::fmt::Debug,
3212+
T::Native: ArrowNativeTypeOp,
32733213
{
32743214
#[cfg(feature = "simd")]
32753215
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
32763216
#[cfg(not(feature = "simd"))]
3277-
match left.data_type() {
3278-
DataType::Float32 => {
3279-
let left = as_primitive_array::<Float32Type>(left);
3280-
let right = try_to_type!(right, to_f32)?;
3281-
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ge())
3282-
}
3283-
DataType::Float64 => {
3284-
let left = as_primitive_array::<Float64Type>(left);
3285-
let right = try_to_type!(right, to_f64)?;
3286-
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ge())
3287-
}
3288-
_ => compare_op_scalar(left, |a| a >= right),
3289-
}
3217+
return compare_op_scalar(left, |a| a.is_ge(right));
32903218
}
32913219

32923220
/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]

arrow/src/datatypes/native.rs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub(crate) mod native_op {
4545
+ Mul<Output = Self>
4646
+ Div<Output = Self>
4747
+ Zero
48+
+ num::ToPrimitive
4849
{
4950
fn add_checked(self, rhs: Self) -> Result<Self> {
5051
Ok(self + rhs)
@@ -81,6 +82,30 @@ pub(crate) mod native_op {
8182
fn div_wrapping(self, rhs: Self) -> Self {
8283
self / rhs
8384
}
85+
86+
fn is_eq(self, rhs: Self) -> bool {
87+
self == rhs
88+
}
89+
90+
fn is_ne(self, rhs: Self) -> bool {
91+
self != rhs
92+
}
93+
94+
fn is_lt(self, rhs: Self) -> bool {
95+
self < rhs
96+
}
97+
98+
fn is_le(self, rhs: Self) -> bool {
99+
self <= rhs
100+
}
101+
102+
fn is_gt(self, rhs: Self) -> bool {
103+
self > rhs
104+
}
105+
106+
fn is_ge(self, rhs: Self) -> bool {
107+
self >= rhs
108+
}
84109
}
85110
}
86111

@@ -156,6 +181,36 @@ native_type_op!(u16);
156181
native_type_op!(u32);
157182
native_type_op!(u64);
158183

159-
impl native_op::ArrowNativeTypeOp for f16 {}
160-
impl native_op::ArrowNativeTypeOp for f32 {}
161-
impl native_op::ArrowNativeTypeOp for f64 {}
184+
macro_rules! native_type_float_op {
185+
($t:tt) => {
186+
impl native_op::ArrowNativeTypeOp for $t {
187+
fn is_eq(self, rhs: Self) -> bool {
188+
self.total_cmp(&rhs).is_eq()
189+
}
190+
191+
fn is_ne(self, rhs: Self) -> bool {
192+
self.total_cmp(&rhs).is_ne()
193+
}
194+
195+
fn is_lt(self, rhs: Self) -> bool {
196+
self.total_cmp(&rhs).is_lt()
197+
}
198+
199+
fn is_le(self, rhs: Self) -> bool {
200+
self.total_cmp(&rhs).is_le()
201+
}
202+
203+
fn is_gt(self, rhs: Self) -> bool {
204+
self.total_cmp(&rhs).is_gt()
205+
}
206+
207+
fn is_ge(self, rhs: Self) -> bool {
208+
self.total_cmp(&rhs).is_ge()
209+
}
210+
}
211+
};
212+
}
213+
214+
native_type_float_op!(f16);
215+
native_type_float_op!(f32);
216+
native_type_float_op!(f64);

0 commit comments

Comments
 (0)