diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 059b1452ff3d..b97974c85999 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -19,6 +19,8 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::internal_err; +use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::logical_plan::Filter; @@ -56,29 +58,34 @@ use datafusion_expr::{Expr, LogicalPlan, Operator}; /// /// ```sql /// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( /// ( +/// p_partkey = l_partkey /// and p_brand = ‘[BRAND1]’ /// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) /// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 /// and p_size between 1 and 5 +/// and l_shipmode in (‘AIR’, ‘AIR REG’) +/// and l_shipinstruct = ‘DELIVER IN PERSON’ /// ) /// or /// ( +/// p_partkey = l_partkey /// and p_brand = ‘[BRAND2]’ /// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) /// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 /// and p_size between 1 and 10 +/// and l_shipmode in (‘AIR’, ‘AIR REG’) +/// and l_shipinstruct = ‘DELIVER IN PERSON’ /// ) /// or /// ( +/// p_partkey = l_partkey /// and p_brand = ‘[BRAND3]’ /// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) /// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 /// and p_size between 1 and 15 +/// and l_shipmode in (‘AIR’, ‘AIR REG’) +/// and l_shipinstruct = ‘DELIVER IN PERSON’ /// ) /// ) /// ``` @@ -128,21 +135,10 @@ impl RewriteDisjunctivePredicate { impl OptimizerRule for RewriteDisjunctivePredicate { fn try_optimize( &self, - plan: &LogicalPlan, + _plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = predicate(&filter.predicate)?; - let rewritten_predicate = rewrite_predicate(predicate); - let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(Some(LogicalPlan::Filter(Filter::try_new( - rewritten_expr, - filter.input.clone(), - )?))) - } - _ => Ok(None), - } + internal_err!("Should have called RewriteDisjunctivePredicate::rewrite") } fn name(&self) -> &str { @@ -152,6 +148,29 @@ impl OptimizerRule for RewriteDisjunctivePredicate { fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Filter(filter) => { + let predicate = predicate(filter.predicate)?; + let rewritten_predicate = rewrite_predicate(predicate); + let rewritten_expr = normalize_predicate(rewritten_predicate); + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( + rewritten_expr, + filter.input, + )?))) + } + _ => Ok(Transformed::no(plan)), + } + } } #[derive(Clone, PartialEq, Debug)] @@ -161,27 +180,23 @@ enum Predicate { Other { expr: Box }, } -fn predicate(expr: &Expr) -> Result { +fn predicate(expr: Expr) -> Result { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { Operator::And => { - let args = vec![predicate(left)?, predicate(right)?]; + let args = vec![predicate(*left)?, predicate(*right)?]; Ok(Predicate::And { args }) } Operator::Or => { - let args = vec![predicate(left)?, predicate(right)?]; + let args = vec![predicate(*left)?, predicate(*right)?]; Ok(Predicate::Or { args }) } _ => Ok(Predicate::Other { - expr: Box::new(Expr::BinaryExpr(BinaryExpr::new( - left.clone(), - *op, - right.clone(), - ))), + expr: Box::new(Expr::BinaryExpr(BinaryExpr::new(left, op, right))), }), }, _ => Ok(Predicate::Other { - expr: Box::new(expr.clone()), + expr: Box::new(expr), }), } } @@ -210,8 +225,8 @@ fn rewrite_predicate(predicate: Predicate) -> Predicate { match predicate { Predicate::And { args } => { let mut rewritten_args = Vec::with_capacity(args.len()); - for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg.clone())); + for arg in args.into_iter() { + rewritten_args.push(rewrite_predicate(arg)); } rewritten_args = flatten_and_predicates(rewritten_args); Predicate::And { @@ -220,15 +235,13 @@ fn rewrite_predicate(predicate: Predicate) -> Predicate { } Predicate::Or { args } => { let mut rewritten_args = vec![]; - for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg.clone())); + for arg in args.into_iter() { + rewritten_args.push(rewrite_predicate(arg)); } rewritten_args = flatten_or_predicates(rewritten_args); - delete_duplicate_predicates(&rewritten_args) + delete_duplicate_predicates(rewritten_args) } - Predicate::Other { expr } => Predicate::Other { - expr: Box::new(*expr), - }, + Predicate::Other { expr } => Predicate::Other { expr }, } } @@ -239,8 +252,7 @@ fn flatten_and_predicates( for predicate in and_predicates { match predicate { Predicate::And { args } => { - flattened_predicates - .extend_from_slice(flatten_and_predicates(args).as_slice()); + flattened_predicates.append(&mut flatten_and_predicates(args)); } _ => { flattened_predicates.push(predicate); @@ -257,8 +269,7 @@ fn flatten_or_predicates( for predicate in or_predicates { match predicate { Predicate::Or { args } => { - flattened_predicates - .extend_from_slice(flatten_or_predicates(args).as_slice()); + flattened_predicates.append(&mut flatten_or_predicates(args)); } _ => { flattened_predicates.push(predicate); @@ -268,7 +279,7 @@ fn flatten_or_predicates( flattened_predicates } -fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { +fn delete_duplicate_predicates(or_predicates: Vec) -> Predicate { let mut shortest_exprs: Vec = vec![]; let mut shortest_exprs_len = 0; // choose the shortest AND predicate @@ -305,23 +316,22 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { } if exist_exprs.is_empty() { return Predicate::Or { - args: or_predicates.to_vec(), + args: or_predicates, }; } // Rebuild the OR predicate. // (A AND B) OR A will be optimized to A. let mut new_or_predicates = vec![]; - for or_predicate in or_predicates.iter() { + for or_predicate in or_predicates.into_iter() { match or_predicate { - Predicate::And { args } => { - let mut new_args = (*args).clone(); - new_args.retain(|expr| !exist_exprs.contains(expr)); - if !new_args.is_empty() { - if new_args.len() == 1 { - new_or_predicates.push(new_args[0].clone()); + Predicate::And { mut args } => { + args.retain(|expr| !exist_exprs.contains(expr)); + if !args.is_empty() { + if args.len() == 1 { + new_or_predicates.push(args.remove(0)); } else { - new_or_predicates.push(Predicate::And { args: new_args }); + new_or_predicates.push(Predicate::And { args }); } } else { new_or_predicates.clear(); @@ -329,7 +339,7 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { } } _ => { - if exist_exprs.contains(or_predicate) { + if exist_exprs.contains(&or_predicate) { new_or_predicates.clear(); break; } @@ -338,7 +348,7 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { } if !new_or_predicates.is_empty() { if new_or_predicates.len() == 1 { - exist_exprs.push(new_or_predicates[0].clone()); + exist_exprs.push(new_or_predicates.remove(0)); } else { exist_exprs.push(Predicate::Or { args: flatten_or_predicates(new_or_predicates), @@ -347,7 +357,7 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { } if exist_exprs.len() == 1 { - exist_exprs[0].clone() + exist_exprs.remove(0) } else { Predicate::And { args: flatten_and_predicates(exist_exprs), @@ -373,7 +383,7 @@ mod tests { and(equi_expr.clone(), gt_expr.clone()), and(equi_expr.clone(), lt_expr.clone()), ); - let predicate = predicate(&expr)?; + let predicate = predicate(expr)?; assert_eq!( predicate, Predicate::Or {