diff --git a/Cargo.lock b/Cargo.lock index 194483b7ab3a..126eeb2123c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -745,9 +745,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.74.0" +version = "1.75.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19d440e1d368759bd10df0dbdddbfff6473d7cd73e9d9ef2363dc9995ac2d711" +checksum = "e3258fa707f2f585ee3049d9550954b959002abd59176975150a01d5cf38ae3f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -3964,9 +3964,9 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.43" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf88cd67e9de251c1781dbe2f641a1a3ad66eaae831b8a2c38fbdc5ddae16d4d" +checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4" dependencies = [ "cc", "libc", @@ -4097,9 +4097,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.47" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1791cbe101e95af5764f06f20f6760521f7158f69dbf9d6baf941ee1bf6bc40" +checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af" dependencies = [ "libmimalloc-sys", ] @@ -4325,9 +4325,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7781f96d79ed0f961a7021424ab01840efbda64ae7a505aaea195efc91eaaec4" +checksum = "d94ac16b433c0ccf75326388c893d2835ab7457ea35ab8ba5d745c053ef5fa16" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 97f83305dcbe..c50268d99676 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2069,6 +2069,15 @@ impl Expr { _ => None, } } + + /// Check if the Expr is literal and get the literal value if it is. + pub fn as_literal(&self) -> Option<&ScalarValue> { + if let Expr::Literal(lit, _) = self { + Some(lit) + } else { + None + } + } } impl Normalizeable for Expr { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7c4a02678899..bcb867f6e7fa 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -40,6 +40,7 @@ use datafusion_expr::{ }; use crate::optimizer::ApplyOrder; +use crate::simplify_expressions::simplify_predicates; use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; @@ -779,6 +780,18 @@ impl OptimizerRule for PushDownFilter { return Ok(Transformed::no(plan)); }; + let predicate = split_conjunction_owned(filter.predicate.clone()); + let old_predicate_len = predicate.len(); + let new_predicates = simplify_predicates(predicate)?; + if old_predicate_len != new_predicates.len() { + let Some(new_predicate) = conjunction(new_predicates) else { + // new_predicates is empty - remove the filter entirely + // Return the child plan without the filter + return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input))); + }; + filter.predicate = new_predicate; + } + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 5fbee02e3909..7ae38eec9a3a 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -23,6 +23,7 @@ mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; +mod simplify_predicates; mod unwrap_cast; mod utils; @@ -31,6 +32,7 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; +pub use simplify_predicates::simplify_predicates; // Export for test in datafusion/core/tests/optimizer_integration.rs pub use guarantees::GuaranteeRewriter; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs new file mode 100644 index 000000000000..32b2315e15d5 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifies predicates by reducing redundant or overlapping conditions. +//! +//! This module provides functionality to optimize logical predicates used in query planning +//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate. +//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on +//! general expression simplification (e.g., constant folding and algebraic simplifications), +//! this module specifically targets predicate optimization by handling containment relationships. +//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition +//! encompasses the former, resulting in fewer checks during query execution. + +use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_expr::{BinaryExpr, Cast, Expr, Operator}; +use std::collections::BTreeMap; + +/// Simplifies a list of predicates by removing redundancies. +/// +/// This function takes a vector of predicate expressions and groups them by the column they reference. +/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =) +/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to +/// `x > 6`. Other predicates that do not fit this pattern are retained as-is. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing the predicates to simplify. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates. +pub fn simplify_predicates(predicates: Vec) -> Result> { + // Early return for simple cases + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group predicates by their column reference + let mut column_predicates: BTreeMap> = BTreeMap::new(); + let mut other_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { + left, + op: + Operator::Gt + | Operator::GtEq + | Operator::Lt + | Operator::LtEq + | Operator::Eq, + right, + }) => { + let left_col = extract_column_from_expr(left); + let right_col = extract_column_from_expr(right); + if let (Some(col), Some(_)) = (&left_col, right.as_literal()) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else { + other_predicates.push(pred); + } + } + _ => other_predicates.push(pred), + } + } + + // Process each column's predicates to remove redundancies + let mut result = other_predicates; + for (_, preds) in column_predicates { + let simplified = simplify_column_predicates(preds)?; + result.extend(simplified); + } + + Ok(result) +} + +/// Simplifies predicates related to a single column. +/// +/// This function processes a list of predicates that all reference the same column and +/// simplifies them based on their operators. It groups predicates into greater-than (>, >=), +/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition +/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6` +/// is retained as it is more restrictive. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing predicates for a single column. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates for the column. +fn simplify_column_predicates(predicates: Vec) -> Result> { + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group by operator type, but combining similar operators + let mut greater_predicates = Vec::new(); // Combines > and >= + let mut less_predicates = Vec::new(); // Combines < and <= + let mut eq_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => { + match (op, right.as_literal().is_some()) { + (Operator::Gt, true) + | (Operator::Lt, false) + | (Operator::GtEq, true) + | (Operator::LtEq, false) => greater_predicates.push(pred), + (Operator::Lt, true) + | (Operator::Gt, false) + | (Operator::LtEq, true) + | (Operator::GtEq, false) => less_predicates.push(pred), + (Operator::Eq, _) => eq_predicates.push(pred), + _ => unreachable!("Unexpected operator: {}", op), + } + } + _ => unreachable!("Unexpected predicate {}", pred.to_string()), + } + } + + let mut result = Vec::new(); + + if !eq_predicates.is_empty() { + // If there are many equality predicates, we can only keep one if they are all the same + if eq_predicates.len() == 1 + || eq_predicates.iter().all(|e| e == &eq_predicates[0]) + { + result.push(eq_predicates.pop().unwrap()); + } else { + // If they are not the same, add a false predicate + result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None)); + } + } + + // Handle all greater-than-style predicates (keep the most restrictive - highest value) + if !greater_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&greater_predicates, true)? + { + result.push(most_restrictive); + } else { + result.extend(greater_predicates); + } + } + + // Handle all less-than-style predicates (keep the most restrictive - lowest value) + if !less_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&less_predicates, false)? + { + result.push(most_restrictive); + } else { + result.extend(less_predicates); + } + } + + Ok(result) +} + +/// Finds the most restrictive predicate from a list based on literal values. +/// +/// This function iterates through a list of predicates to identify the most restrictive one +/// by comparing their literal values. For greater-than predicates, the highest value is most +/// restrictive, while for less-than predicates, the lowest value is most restrictive. +/// +/// # Arguments +/// * `predicates` - A slice of `Expr` representing predicates to compare. +/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=) +/// or the lowest value (false for <, <=). +/// +/// # Returns +/// A `Result` containing an `Option` with the most restrictive predicate, if any. +fn find_most_restrictive_predicate( + predicates: &[Expr], + find_greater: bool, +) -> Result> { + if predicates.is_empty() { + return Ok(None); + } + + let mut most_restrictive_idx = 0; + let mut best_value: Option<&ScalarValue> = None; + + for (idx, pred) in predicates.iter().enumerate() { + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = pred { + // Extract the literal value based on which side has it + let scalar_value = match (right.as_literal(), left.as_literal()) { + (Some(scalar), _) => Some(scalar), + (_, Some(scalar)) => Some(scalar), + _ => None, + }; + + if let Some(scalar) = scalar_value { + if let Some(current_best) = best_value { + if let Some(comparison) = scalar.partial_cmp(current_best) { + let is_better = if find_greater { + comparison == std::cmp::Ordering::Greater + } else { + comparison == std::cmp::Ordering::Less + }; + + if is_better { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } + } else { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } + } + } + + Ok(Some(predicates[most_restrictive_idx].clone())) +} + +/// Extracts a column reference from an expression, if present. +/// +/// This function checks if the given expression is a column reference or contains one, +/// such as within a cast operation. It returns the `Column` if found. +/// +/// # Arguments +/// * `expr` - A reference to an `Expr` to inspect for a column reference. +/// +/// # Returns +/// An `Option` containing the column reference if found, otherwise `None`. +fn extract_column_from_expr(expr: &Expr) -> Option { + match expr { + Expr::Column(col) => Some(col.clone()), + // Handle cases where the column might be wrapped in a cast or other operation + Expr::Cast(Cast { expr, .. }) => extract_column_from_expr(expr), + _ => None, + } +} diff --git a/datafusion/sqllogictest/test_files/simplify_predicates.slt b/datafusion/sqllogictest/test_files/simplify_predicates.slt new file mode 100644 index 000000000000..0dd551d96d0c --- /dev/null +++ b/datafusion/sqllogictest/test_files/simplify_predicates.slt @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test cases for predicate simplification feature +# Basic redundant comparison simplification + +statement ok +set datafusion.explain.logical_plan_only=true; + +statement ok +CREATE TABLE test_data ( + int_col INT, + float_col FLOAT, + str_col VARCHAR, + date_col DATE, + bool_col BOOLEAN +); + +# x > 5 AND x > 6 should simplify to x > 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6; +---- +logical_plan +01)Filter: test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x >= 6 should simplify to x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col >= 6; +---- +logical_plan +01)Filter: test_data.int_col >= Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x < 10 AND x <= 8 should simplify to x <= 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col < 10 AND int_col <= 8; +---- +logical_plan +01)Filter: test_data.int_col <= Int32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x > 6 AND x > 7 should simplify to x > 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6 AND int_col > 7; +---- +logical_plan +01)Filter: test_data.int_col > Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND y < 10 AND x > 6 AND y < 8 should simplify to x > 6 AND y < 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND float_col < 10 AND int_col > 6 AND float_col < 8; +---- +logical_plan +01)Filter: test_data.float_col < Float32(8) AND test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 7 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 7; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 6 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 6; +---- +logical_plan EmptyRelation + +# TODO: x = 7 AND x < 2 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col < 2; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col < Int32(2) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# TODO: x = 7 AND x > 5 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col > 5; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# str_col > 'apple' AND str_col > 'banana' should simplify to str_col > 'banana' +query TT +EXPLAIN SELECT * FROM test_data WHERE str_col > 'apple' AND str_col > 'banana'; +---- +logical_plan +01)Filter: test_data.str_col > Utf8View("banana") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# date_col > '2023-01-01' AND date_col > '2023-02-01' should simplify to date_col > '2023-02-01' +query TT +EXPLAIN SELECT * FROM test_data WHERE date_col > '2023-01-01' AND date_col > '2023-02-01'; +---- +logical_plan +01)Filter: test_data.date_col > Date32("2023-02-01") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +query TT +EXPLAIN SELECT * FROM test_data WHERE bool_col = true AND bool_col = false; +---- +logical_plan +01)Filter: test_data.bool_col AND NOT test_data.bool_col +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# This shouldn't be simplified since they're different relationships +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > float_col AND int_col > 5; +---- +logical_plan +01)Filter: CAST(test_data.int_col AS Float32) > test_data.float_col AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Should simplify the int_col predicates but preserve the others +query TT +EXPLAIN SELECT * FROM test_data +WHERE int_col > 5 + AND int_col > 10 + AND str_col LIKE 'A%' + AND float_col BETWEEN 1 AND 100; +---- +logical_plan +01)Filter: test_data.str_col LIKE Utf8View("A%") AND test_data.float_col >= Float32(1) AND test_data.float_col <= Float32(100) AND test_data.int_col > Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +CREATE TABLE test_data2 ( + id INT, + value INT +); + +query TT +EXPLAIN SELECT t1.int_col, t2.value +FROM test_data t1 +JOIN test_data2 t2 ON t1.int_col = t2.id +WHERE t1.int_col > 5 + AND t1.int_col > 10 + AND t2.value < 100 + AND t2.value < 50; +---- +logical_plan +01)Projection: t1.int_col, t2.value +02)--Inner Join: t1.int_col = t2.id +03)----SubqueryAlias: t1 +04)------Filter: test_data.int_col > Int32(10) +05)--------TableScan: test_data projection=[int_col] +06)----SubqueryAlias: t2 +07)------Filter: test_data2.value < Int32(50) AND test_data2.id > Int32(10) +08)--------TableScan: test_data2 projection=[id, value] + +# Handling negated predicates +# NOT (x < 10) AND NOT (x < 5) should simplify to NOT (x < 10) +query TT +EXPLAIN SELECT * FROM test_data WHERE NOT (int_col < 10) AND NOT (int_col < 5); +---- +logical_plan +01)Filter: test_data.int_col >= Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x < 10 should be preserved (can't be simplified) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col < 10; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col < Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# 5 < x AND 3 < x should simplify to 5 < x +query TT +EXPLAIN SELECT * FROM test_data WHERE 5 < int_col AND 3 < int_col; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# CAST(x AS FLOAT) > 5.0 AND CAST(x AS FLOAT) > 6.0 should simplify +query TT +EXPLAIN SELECT * FROM test_data WHERE CAST(int_col AS FLOAT) > 5.0 AND CAST(int_col AS FLOAT) > 6.0; +---- +logical_plan +01)Filter: CAST(CAST(test_data.int_col AS Float32) AS Float64) > Float64(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 5 AND x = 6 (logically impossible) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 5 AND int_col = 6; +---- +logical_plan EmptyRelation + +# (x > 5 OR y < 10) AND (x > 6 OR y < 8) +# This is more complex but could still benefit from some simplification +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 OR float_col < 10) + AND (int_col > 6 OR float_col < 8); +---- +logical_plan +01)Filter: (test_data.int_col > Int32(5) OR test_data.float_col < Float32(10)) AND (test_data.int_col > Int32(6) OR test_data.float_col < Float32(8)) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Combination of AND and OR with simplifiable predicates +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 AND int_col > 6) + OR (float_col < 10 AND float_col < 8); +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col > Int32(6) OR test_data.float_col < Float32(10) AND test_data.float_col < Float32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +set datafusion.explain.logical_plan_only=false;