Skip to content

Commit 6ab208c

Browse files
authored
Gate dyn comparison of dictionary arrays behind dyn_cmp_dict (#2597)
* Add dyn_cmp_dict feature flag * Fix tests * Clippy
1 parent 86446ea commit 6ab208c

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

.github/workflows/arrow.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ jobs:
5151
- name: Test
5252
run: |
5353
cargo test -p arrow
54-
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi
54+
- name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
5555
run: |
56-
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi
56+
cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict
5757
- name: Test --features=nan_ordering
5858
run: |
5959
cargo test -p arrow --features "nan_ordering"
@@ -175,4 +175,4 @@ jobs:
175175
rustup component add clippy
176176
- name: Run clippy
177177
run: |
178-
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression --all-targets -- -D warnings
178+
cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings

arrow/Cargo.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ path = "src/lib.rs"
3838
bench = false
3939

4040
[target.'cfg(target_arch = "wasm32")'.dependencies]
41-
ahash = { version = "0.8", default-features = false, features=["compile-time-rng"] }
41+
ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] }
4242

4343
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
44-
ahash = { version = "0.8", default-features = false, features=["runtime-rng"] }
44+
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
4545

4646
[dependencies]
4747
serde = { version = "1.0", default-features = false }
@@ -90,6 +90,9 @@ force_validate = []
9090
ffi = []
9191
# Enable NaN-ordering behavior on comparison kernels
9292
nan_ordering = []
93+
# Enable dyn-comparison of dictionary arrays with other arrays
94+
# Note: this does not impact comparison against scalars
95+
dyn_cmp_dict = []
9396

9497
[dev-dependencies]
9598
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] }
@@ -102,7 +105,7 @@ tempfile = { version = "3", default-features = false }
102105
[[example]]
103106
name = "dynamic_types"
104107
required-features = ["prettyprint"]
105-
path="./examples/dynamic_types.rs"
108+
path = "./examples/dynamic_types.rs"
106109

107110
[[bench]]
108111
name = "aggregate_kernels"
@@ -144,7 +147,7 @@ required-features = ["test_utils"]
144147
[[bench]]
145148
name = "comparison_kernels"
146149
harness = false
147-
required-features = ["test_utils"]
150+
required-features = ["test_utils", "dyn_cmp_dict"]
148151

149152
[[bench]]
150153
name = "filter_kernels"

arrow/src/compute/kernels/comparison.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,6 +2089,7 @@ where
20892089
compare_op(left_array, right_array, op)
20902090
}
20912091

2092+
#[cfg(feature = "dyn_cmp_dict")]
20922093
macro_rules! typed_dict_non_dict_cmp {
20932094
($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{
20942095
match $LEFT_KEY_TYPE {
@@ -2132,6 +2133,7 @@ macro_rules! typed_dict_non_dict_cmp {
21322133
}};
21332134
}
21342135

2136+
#[cfg(feature = "dyn_cmp_dict")]
21352137
macro_rules! typed_cmp_dict_non_dict {
21362138
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
21372139
match ($LEFT.data_type(), $RIGHT.data_type()) {
@@ -2182,6 +2184,16 @@ macro_rules! typed_cmp_dict_non_dict {
21822184
}};
21832185
}
21842186

2187+
#[cfg(not(feature = "dyn_cmp_dict"))]
2188+
macro_rules! typed_cmp_dict_non_dict {
2189+
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr) => {{
2190+
Err(ArrowError::CastError(format!(
2191+
"Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
2192+
$LEFT.data_type(), $RIGHT.data_type()
2193+
)))
2194+
}}
2195+
}
2196+
21852197
macro_rules! typed_compares {
21862198
($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{
21872199
match ($LEFT.data_type(), $RIGHT.data_type()) {
@@ -2298,6 +2310,7 @@ macro_rules! typed_compares {
22982310
}
22992311

23002312
/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
2313+
#[cfg(feature = "dyn_cmp_dict")]
23012314
macro_rules! typed_dict_cmp {
23022315
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{
23032316
match ($LEFT.value_type(), $RIGHT.value_type()) {
@@ -2430,6 +2443,7 @@ macro_rules! typed_dict_cmp {
24302443
}};
24312444
}
24322445

2446+
#[cfg(feature = "dyn_cmp_dict")]
24332447
macro_rules! typed_dict_compares {
24342448
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
24352449
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
@@ -2494,8 +2508,19 @@ macro_rules! typed_dict_compares {
24942508
}};
24952509
}
24962510

2511+
#[cfg(not(feature = "dyn_cmp_dict"))]
2512+
macro_rules! typed_dict_compares {
2513+
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{
2514+
Err(ArrowError::CastError(format!(
2515+
"Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature",
2516+
$LEFT.data_type(), $RIGHT.data_type()
2517+
)))
2518+
}}
2519+
}
2520+
24972521
/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value
24982522
/// type of `DictionaryArray` is same as `PrimitiveArray`'s type.
2523+
#[cfg(feature = "dyn_cmp_dict")]
24992524
fn cmp_dict_primitive<K, T, F>(
25002525
left: &DictionaryArray<K>,
25012526
right: &dyn Array,
@@ -2516,6 +2541,7 @@ where
25162541
/// Perform given operation on two `DictionaryArray`s which value type is
25172542
/// primitive type. Returns an error if the two arrays have different value
25182543
/// type
2544+
#[cfg(feature = "dyn_cmp_dict")]
25192545
pub fn cmp_dict<K, T, F>(
25202546
left: &DictionaryArray<K>,
25212547
right: &DictionaryArray<K>,
@@ -2535,6 +2561,7 @@ where
25352561

25362562
/// Perform the given operation on two `DictionaryArray`s which value type is
25372563
/// `DataType::Boolean`.
2564+
#[cfg(feature = "dyn_cmp_dict")]
25382565
pub fn cmp_dict_bool<K, F>(
25392566
left: &DictionaryArray<K>,
25402567
right: &DictionaryArray<K>,
@@ -2553,6 +2580,7 @@ where
25532580

25542581
/// Perform the given operation on two `DictionaryArray`s which value type is
25552582
/// `DataType::Utf8` or `DataType::LargeUtf8`.
2583+
#[cfg(feature = "dyn_cmp_dict")]
25562584
pub fn cmp_dict_utf8<K, OffsetSize: OffsetSizeTrait, F>(
25572585
left: &DictionaryArray<K>,
25582586
right: &DictionaryArray<K>,
@@ -2574,6 +2602,7 @@ where
25742602

25752603
/// Perform the given operation on two `DictionaryArray`s which value type is
25762604
/// `DataType::Binary` or `DataType::LargeBinary`.
2605+
#[cfg(feature = "dyn_cmp_dict")]
25772606
pub fn cmp_dict_binary<K, OffsetSize: OffsetSizeTrait, F>(
25782607
left: &DictionaryArray<K>,
25792608
right: &DictionaryArray<K>,
@@ -5476,6 +5505,7 @@ mod tests {
54765505
}
54775506

54785507
#[test]
5508+
#[cfg(feature = "dyn_cmp_dict")]
54795509
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
54805510
// Construct a value array
54815511
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
@@ -5496,6 +5526,7 @@ mod tests {
54965526
}
54975527

54985528
#[test]
5529+
#[cfg(feature = "dyn_cmp_dict")]
54995530
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
55005531
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);
55015532

@@ -5517,6 +5548,7 @@ mod tests {
55175548
}
55185549

55195550
#[test]
5551+
#[cfg(feature = "dyn_cmp_dict")]
55205552
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
55215553
let test1 = vec!["a", "a", "b", "c"];
55225554
let test2 = vec!["a", "b", "b", "c"];
@@ -5544,6 +5576,7 @@ mod tests {
55445576
}
55455577

55465578
#[test]
5579+
#[cfg(feature = "dyn_cmp_dict")]
55475580
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
55485581
let values: BinaryArray = ["hello", "", "parquet"]
55495582
.into_iter()
@@ -5568,6 +5601,7 @@ mod tests {
55685601
}
55695602

55705603
#[test]
5604+
#[cfg(feature = "dyn_cmp_dict")]
55715605
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
55725606
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);
55735607

@@ -5589,6 +5623,7 @@ mod tests {
55895623
}
55905624

55915625
#[test]
5626+
#[cfg(feature = "dyn_cmp_dict")]
55925627
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
55935628
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);
55945629

@@ -5610,6 +5645,7 @@ mod tests {
56105645
}
56115646

56125647
#[test]
5648+
#[cfg(feature = "dyn_cmp_dict")]
56135649
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
56145650
let values = BooleanArray::from(vec![true, false]);
56155651

@@ -5631,6 +5667,7 @@ mod tests {
56315667
}
56325668

56335669
#[test]
5670+
#[cfg(feature = "dyn_cmp_dict")]
56345671
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
56355672
// Construct a value array
56365673
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
@@ -5660,6 +5697,7 @@ mod tests {
56605697
}
56615698

56625699
#[test]
5700+
#[cfg(feature = "dyn_cmp_dict")]
56635701
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
56645702
let values = BooleanArray::from(vec![true, false]);
56655703

@@ -5702,6 +5740,7 @@ mod tests {
57025740
}
57035741

57045742
#[test]
5743+
#[cfg(feature = "dyn_cmp_dict")]
57055744
fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() {
57065745
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
57075746
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
@@ -5736,6 +5775,7 @@ mod tests {
57365775
}
57375776

57385777
#[test]
5778+
#[cfg(feature = "dyn_cmp_dict")]
57395779
fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() {
57405780
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
57415781
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);

0 commit comments

Comments
 (0)