|
20 | 20 | //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
|
21 | 21 | use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result};
|
22 | 22 | use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
|
23 |
| -use std::collections::HashMap; |
| 23 | +use std::{borrow::Cow, collections::HashMap}; |
24 | 24 |
|
25 | 25 | use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval};
|
26 | 26 |
|
@@ -103,37 +103,44 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
|
103 | 103 | }
|
104 | 104 |
|
105 | 105 | Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
|
106 |
| - // We only support comparisons for now |
107 |
| - if !op.is_comparison_operator() { |
108 |
| - return Ok(expr); |
109 |
| - }; |
110 |
| - |
111 |
| - // Check if this is a comparison between a column and literal |
112 |
| - let (col, op, value) = match (left.as_ref(), right.as_ref()) { |
113 |
| - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), |
114 |
| - (Expr::Literal(value), Expr::Column(_)) => { |
115 |
| - // If we can swap the op, we can simplify the expression |
116 |
| - if let Some(op) = op.swap() { |
117 |
| - (right, op, value) |
| 106 | + // The left or right side of expression might either have a guarantee |
| 107 | + // or be a literal. Either way, we can resolve them to a NullableInterval. |
| 108 | + let left_interval = self |
| 109 | + .guarantees |
| 110 | + .get(left.as_ref()) |
| 111 | + .map(|interval| Cow::Borrowed(*interval)) |
| 112 | + .or_else(|| { |
| 113 | + if let Expr::Literal(value) = left.as_ref() { |
| 114 | + Some(Cow::Owned(value.clone().into())) |
118 | 115 | } else {
|
119 |
| - return Ok(expr); |
| 116 | + None |
| 117 | + } |
| 118 | + }); |
| 119 | + let right_interval = self |
| 120 | + .guarantees |
| 121 | + .get(right.as_ref()) |
| 122 | + .map(|interval| Cow::Borrowed(*interval)) |
| 123 | + .or_else(|| { |
| 124 | + if let Expr::Literal(value) = right.as_ref() { |
| 125 | + Some(Cow::Owned(value.clone().into())) |
| 126 | + } else { |
| 127 | + None |
| 128 | + } |
| 129 | + }); |
| 130 | + |
| 131 | + match (left_interval, right_interval) { |
| 132 | + (Some(left_interval), Some(right_interval)) => { |
| 133 | + let result = |
| 134 | + left_interval.apply_operator(op, right_interval.as_ref())?; |
| 135 | + if result.is_certainly_true() { |
| 136 | + Ok(lit(true)) |
| 137 | + } else if result.is_certainly_false() { |
| 138 | + Ok(lit(false)) |
| 139 | + } else { |
| 140 | + Ok(expr) |
120 | 141 | }
|
121 | 142 | }
|
122 |
| - _ => return Ok(expr), |
123 |
| - }; |
124 |
| - |
125 |
| - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { |
126 |
| - let result = |
127 |
| - col_interval.apply_operator(&op, &value.clone().into())?; |
128 |
| - if result.is_certainly_true() { |
129 |
| - Ok(lit(true)) |
130 |
| - } else if result.is_certainly_false() { |
131 |
| - Ok(lit(false)) |
132 |
| - } else { |
133 |
| - Ok(expr) |
134 |
| - } |
135 |
| - } else { |
136 |
| - Ok(expr) |
| 143 | + _ => Ok(expr), |
137 | 144 | }
|
138 | 145 | }
|
139 | 146 |
|
@@ -262,13 +269,21 @@ mod tests {
|
262 | 269 | values: Interval::make(Some(1_i32), Some(3_i32), (true, false)),
|
263 | 270 | },
|
264 | 271 | ),
|
| 272 | + // s.y ∈ (1, 3] (not null) |
| 273 | + ( |
| 274 | + col("s").field("y"), |
| 275 | + NullableInterval::NotNull { |
| 276 | + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), |
| 277 | + }, |
| 278 | + ), |
265 | 279 | ];
|
266 | 280 |
|
267 | 281 | let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
|
268 | 282 |
|
269 | 283 | // (original_expr, expected_simplification)
|
270 | 284 | let simplified_cases = &[
|
271 | 285 | (col("x").lt_eq(lit(1)), false),
|
| 286 | + (col("s").field("y").lt_eq(lit(1)), false), |
272 | 287 | (col("x").lt_eq(lit(3)), true),
|
273 | 288 | (col("x").gt(lit(3)), false),
|
274 | 289 | (col("x").gt(lit(1)), true),
|
|
0 commit comments