Skip to content

Commit ff116c3

Browse files
authored
Support filter for List (#11091)
* support basic list cmp Signed-off-by: jayzhan211 <[email protected]> * add more ops Signed-off-by: jayzhan211 <[email protected]> * add distinct Signed-off-by: jayzhan211 <[email protected]> * nested Signed-off-by: jayzhan211 <[email protected]> * add comment Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 2d1e850 commit ff116c3

File tree

7 files changed

+312
-67
lines changed

7 files changed

+312
-67
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// UnLt required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::BooleanArray;
19+
use arrow::array::{make_comparator, ArrayRef, Datum};
20+
use arrow::buffer::NullBuffer;
21+
use arrow::compute::SortOptions;
22+
use arrow::error::ArrowError;
23+
use datafusion_common::internal_err;
24+
use datafusion_common::{Result, ScalarValue};
25+
use datafusion_expr::{ColumnarValue, Operator};
26+
use std::sync::Arc;
27+
28+
/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs`
29+
///
30+
/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction
31+
pub fn apply(
32+
lhs: &ColumnarValue,
33+
rhs: &ColumnarValue,
34+
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
35+
) -> Result<ColumnarValue> {
36+
match (&lhs, &rhs) {
37+
(ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
38+
Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
39+
}
40+
(ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
41+
ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
42+
),
43+
(ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
44+
ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
45+
),
46+
(ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
47+
let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
48+
let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
49+
Ok(ColumnarValue::Scalar(scalar))
50+
}
51+
}
52+
}
53+
54+
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
55+
pub fn apply_cmp(
56+
lhs: &ColumnarValue,
57+
rhs: &ColumnarValue,
58+
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
59+
) -> Result<ColumnarValue> {
60+
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
61+
}
62+
63+
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
64+
/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
65+
pub fn apply_cmp_for_nested(
66+
op: Operator,
67+
lhs: &ColumnarValue,
68+
rhs: &ColumnarValue,
69+
) -> Result<ColumnarValue> {
70+
if matches!(
71+
op,
72+
Operator::Eq
73+
| Operator::NotEq
74+
| Operator::Lt
75+
| Operator::Gt
76+
| Operator::LtEq
77+
| Operator::GtEq
78+
| Operator::IsDistinctFrom
79+
| Operator::IsNotDistinctFrom
80+
) {
81+
apply(lhs, rhs, |l, r| {
82+
Ok(Arc::new(compare_op_for_nested(op, l, r)?))
83+
})
84+
} else {
85+
internal_err!("invalid operator for nested")
86+
}
87+
}
88+
89+
/// Compare on nested type List, Struct, and so on
90+
fn compare_op_for_nested(
91+
op: Operator,
92+
lhs: &dyn Datum,
93+
rhs: &dyn Datum,
94+
) -> Result<BooleanArray> {
95+
let (l, is_l_scalar) = lhs.get();
96+
let (r, is_r_scalar) = rhs.get();
97+
let l_len = l.len();
98+
let r_len = r.len();
99+
100+
if l_len != r_len && !is_l_scalar && !is_r_scalar {
101+
return internal_err!("len mismatch");
102+
}
103+
104+
let len = match is_l_scalar {
105+
true => r_len,
106+
false => l_len,
107+
};
108+
109+
// fast path, if compare with one null and operator is not 'distinct', then we can return null array directly
110+
if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
111+
&& (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
112+
{
113+
return Ok(BooleanArray::new_null(len));
114+
}
115+
116+
// TODO: make SortOptions configurable
117+
// we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour
118+
let cmp = make_comparator(l, r, SortOptions::default())?;
119+
120+
let cmp_with_op = |i, j| match op {
121+
Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
122+
Operator::Lt => cmp(i, j).is_lt(),
123+
Operator::Gt => cmp(i, j).is_gt(),
124+
Operator::LtEq => !cmp(i, j).is_gt(),
125+
Operator::GtEq => !cmp(i, j).is_lt(),
126+
Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
127+
_ => unreachable!("unexpected operator found"),
128+
};
129+
130+
let values = match (is_l_scalar, is_r_scalar) {
131+
(false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
132+
(true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
133+
(false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
134+
(true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
135+
};
136+
137+
// Distinct understand how to compare with NULL
138+
// i.e NULL is distinct from NULL -> false
139+
if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
140+
Ok(BooleanArray::new(values, None))
141+
} else {
142+
// If one of the side is NULL, we returns NULL
143+
// i.e. NULL eq NULL -> NULL
144+
let nulls = NullBuffer::union(l.nulls(), r.nulls());
145+
Ok(BooleanArray::new(values, nulls))
146+
}
147+
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use arrow::{
152+
array::{make_comparator, Array, BooleanArray, ListArray},
153+
buffer::NullBuffer,
154+
compute::SortOptions,
155+
datatypes::Int32Type,
156+
};
157+
158+
#[test]
159+
fn test123() {
160+
let data = vec![
161+
Some(vec![Some(0), Some(1), Some(2)]),
162+
None,
163+
Some(vec![Some(3), None, Some(5)]),
164+
Some(vec![Some(6), Some(7)]),
165+
];
166+
let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
167+
let data = vec![
168+
Some(vec![Some(0), Some(1), Some(2)]),
169+
None,
170+
Some(vec![Some(3), None, Some(5)]),
171+
Some(vec![Some(6), Some(7)]),
172+
];
173+
let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
174+
let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
175+
let len = a.len().min(b.len());
176+
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
177+
let nulls = NullBuffer::union(a.nulls(), b.nulls());
178+
println!("res: {:?}", BooleanArray::new(values, nulls));
179+
}
180+
}

datafusion/physical-expr-common/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod aggregate;
1919
pub mod binary_map;
20+
pub mod datum;
2021
pub mod expressions;
2122
pub mod physical_expr;
2223
pub mod sort_expr;

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ mod kernels;
2020
use std::hash::{Hash, Hasher};
2121
use std::{any::Any, sync::Arc};
2222

23-
use crate::expressions::datum::{apply, apply_cmp};
2423
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
2524
use crate::physical_expr::down_cast_any_ref;
2625
use crate::PhysicalExpr;
@@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, Interval};
4039
use datafusion_expr::sort_properties::ExprProperties;
4140
use datafusion_expr::type_coercion::binary::get_result_type;
4241
use datafusion_expr::{ColumnarValue, Operator};
42+
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
4343

4444
use kernels::{
4545
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
@@ -265,6 +265,13 @@ impl PhysicalExpr for BinaryExpr {
265265
let schema = batch.schema();
266266
let input_schema = schema.as_ref();
267267

268+
if left_data_type.is_nested() {
269+
if right_data_type != left_data_type {
270+
return internal_err!("type mismatch");
271+
}
272+
return apply_cmp_for_nested(self.op, &lhs, &rhs);
273+
}
274+
268275
match self.op {
269276
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
270277
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),

datafusion/physical-expr/src/expressions/datum.rs

Lines changed: 0 additions & 58 deletions
This file was deleted.

datafusion/physical-expr/src/expressions/like.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc};
2020

2121
use crate::{physical_expr::down_cast_any_ref, PhysicalExpr};
2222

23-
use crate::expressions::datum::apply_cmp;
2423
use arrow::record_batch::RecordBatch;
2524
use arrow_schema::{DataType, Schema};
2625
use datafusion_common::{internal_err, Result};
2726
use datafusion_expr::ColumnarValue;
27+
use datafusion_physical_expr_common::datum::apply_cmp;
2828

2929
// Like expression
3030
#[derive(Debug, Hash)]

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
mod binary;
2222
mod case;
2323
mod column;
24-
mod datum;
2524
mod in_list;
2625
mod is_not_null;
2726
mod is_null;

0 commit comments

Comments
 (0)