diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ddf075c2c27b..4872e5acda5e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2407,6 +2407,16 @@ pub enum Distinct { On(DistinctOn), } +impl Distinct { + /// return a reference to the nodes input + pub fn input(&self) -> &Arc { + match self { + Distinct::All(input) => input, + Distinct::On(DistinctOn { input, .. }) => input, + } + } +} + /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] pub struct DistinctOn { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 57b38bd0d0fd..b684b5490342 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -14,6 +14,7 @@ //! [`PushDownFilter`] applies filters as early as possible +use indexmap::IndexSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -23,10 +24,9 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, + internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, }; -use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ @@ -131,7 +131,8 @@ use crate::{OptimizerConfig, OptimizerRule}; #[derive(Default)] pub struct PushDownFilter {} -/// For a given JOIN logical plan, determine whether each side of the join is preserved. +/// For a given JOIN type, determine whether each side of the join is preserved. +/// /// We say a join side is preserved if the join returns all or a subset of the rows from /// the relevant side, such that each row of the output table directly maps to a row of /// the preserved input table. If a table is not preserved, it can provide extra null rows. @@ -150,44 +151,33 @@ pub struct PushDownFilter {} /// non-preserved side it can be more tricky. /// /// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - // No columns from the left side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => internal_err!("lr_is_preserved only valid for JOIN nodes"), +fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((true, false)), + JoinType::Right => Ok((false, true)), + JoinType::Full => Ok((false, false)), + // No columns from the right side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), + // No columns from the left side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), } } /// For a given JOIN logical plan, determine whether each side of the join is preserved /// in terms on join filtering. -/// /// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), - JoinType::LeftAnti => Ok((false, true)), - JoinType::RightAnti => Ok((true, false)), - }, - LogicalPlan::CrossJoin(_) => { - internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") - } - _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"), +fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { + match join_type { + JoinType::Inner => Ok((true, true)), + JoinType::Left => Ok((false, true)), + JoinType::Right => Ok((true, false)), + JoinType::Full => Ok((false, false)), + JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), + JoinType::LeftAnti => Ok((false, true)), + JoinType::RightAnti => Ok((true, false)), } } @@ -400,23 +390,20 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, - infer_predicates: Vec, - join_plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, + inferred_join_predicates: Vec, + mut join: Join, on_filter: Vec, - is_inner_join: bool, ) -> Result> { - let on_filter_empty = on_filter.is_empty(); + let is_inner_join = join.join_type == JoinType::Inner; // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?; + let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?; // The predicates can be divided to three categories: // 1) can push through join to its children(left or right) // 2) can be converted to join conditions if the join type is Inner // 3) should be kept as filter conditions - let left_schema = left.schema(); - let right_schema = right.schema(); + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); let mut left_push = vec![]; let mut right_push = vec![]; let mut keep_predicates = vec![]; @@ -438,7 +425,7 @@ fn push_down_all_join( } // For infer predicates, if they can not push through join, just drop them - for predicate in infer_predicates { + for predicate in inferred_join_predicates { if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { left_push.push(predicate); } else if right_preserved @@ -449,7 +436,7 @@ fn push_down_all_join( } if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; for on in on_filter { if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { left_push.push(on) @@ -474,46 +461,29 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } - let left = match conjunction(left_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) - } - None => left.clone(), - }; - let right = match conjunction(right_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) - } - None => right.clone(), - }; - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let mut exprs = join_plan.expressions(); - if !on_filter_empty { - exprs.pop(); - } - exprs.extend(join_conditions.into_iter().reduce(Expr::and)); - let plan = join_plan.with_new_exprs(exprs, vec![left, right])?; - - // wrap the join on the filter whose predicates must be kept - match conjunction(keep_predicates) { - Some(predicate) => { - let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?; - Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan))) - } - None => Ok(Transformed::no(plan)), + if let Some(predicate) = conjunction(left_push) { + join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); } + if let Some(predicate) = conjunction(right_push) { + join.right = + Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + } + + // Add any new join conditions as the non join predicates + join.filter = conjunction(join_conditions); + + // wrap the join on the filter whose predicates must be kept, if any + let plan = LogicalPlan::Join(join); + let plan = if let Some(predicate) = conjunction(keep_predicates) { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + } else { + plan + }; + Ok(Transformed::yes(plan)) } fn push_down_join( - plan: &LogicalPlan, - join: &Join, + join: Join, parent_predicate: Option<&Expr>, ) -> Result> { // Split the parent predicate into individual conjunctive parts. @@ -526,93 +496,102 @@ fn push_down_join( .as_ref() .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); - let mut is_inner_join = false; - let infer_predicates = if join.join_type == JoinType::Inner { - is_inner_join = true; - - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .filter_map(|(l, r)| { - let left_col = l.try_as_col().cloned()?; - let right_col = r.try_as_col().cloned()?; - Some((left_col, right_col)) - }) - .collect::>(); - - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + // Are there any new join predicates that can be inferred from the filter expressions? + let inferred_join_predicates = + infer_join_predicates(&join, &predicates, &on_filters)?; - for col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } + if on_filters.is_empty() + && predicates.is_empty() + && inferred_join_predicates.is_empty() + { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } - if join_cols_to_replace.is_empty() { - return None; - } + push_down_all_join(predicates, inferred_join_predicates, join, on_filters) +} - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; +/// Extracts any equi-join join predicates from the given filter expressions. +/// +/// Parameters +/// * `join` the join in question +/// +/// * `predicates` the pushed down filter expression +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +fn infer_join_predicates( + join: &Join, + predicates: &[Expr], + on_filters: &[Expr], +) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(vec![]); + } - Some(Ok(join_side_predicate)) - }) - .collect::>>()? - } else { - vec![] - }; + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .filter_map(|(l, r)| { + let left_col = l.try_as_col()?; + let right_col = r.try_as_col()?; + Some((left_col, right_col)) + }) + .collect::>(); - if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() { - return Ok(Transformed::no(plan.clone())); - } + // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down + // For inner joins, duplicate filters for joined columns so filters can be pushed down + // to both sides. Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + // This logic should also been applied to conditions in JOIN ON clause + predicates + .iter() + .chain(on_filters.iter()) + .filter_map(|predicate| { + let mut join_cols_to_replace = HashMap::new(); + + let columns = match predicate.to_columns() { + Ok(columns) => columns, + Err(e) => return Some(Err(e)), + }; + + for col in columns.iter() { + for (l, r) in join_col_keys.iter() { + if col == *l { + join_cols_to_replace.insert(col, *r); + break; + } else if col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } - match push_down_all_join( - predicates, - infer_predicates, - plan, - &join.left, - &join.right, - on_filters, - is_inner_join, - ) { - Ok(plan) => Ok(Transformed::yes(plan.data)), - Err(e) => Err(e), - } + if join_cols_to_replace.is_empty() { + return None; + } + + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + Some(Ok(join_side_predicate)) + }) + .collect::>>() } impl OptimizerRule for PushDownFilter { @@ -641,46 +620,57 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let filter = match plan { - LogicalPlan::Filter(ref filter) => filter, - LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None), - _ => return Ok(Transformed::no(plan)), + if let LogicalPlan::Join(join) = plan { + return push_down_join(join, None); + }; + + let plan_schema = plan.schema().clone(); + + let LogicalPlan::Filter(mut filter) = plan else { + return Ok(Transformed::no(plan)); }; - let child_plan = filter.input.as_ref(); - let new_plan = match child_plan { - LogicalPlan::Filter(ref child_filter) => { - let parents_predicates = split_conjunction(&filter.predicate); - let set: HashSet<&&Expr> = parents_predicates.iter().collect(); + match unwrap_arc(filter.input) { + LogicalPlan::Filter(child_filter) => { + let parents_predicates = split_conjunction_owned(filter.predicate); + // remove duplicated filters + let child_predicates = split_conjunction_owned(child_filter.predicate); let new_predicates = parents_predicates - .iter() - .chain( - split_conjunction(&child_filter.predicate) - .iter() - .filter(|e| !set.contains(e)), - ) - .map(|e| (*e).clone()) + .into_iter() + .chain(child_predicates) + // use IndexSet to remove dupes while preserving predicate order + .collect::>() + .into_iter() .collect::>(); - let new_predicate = conjunction(new_predicates).ok_or_else(|| { - plan_datafusion_err!("at least one expression exists") - })?; + + let Some(new_predicate) = conjunction(new_predicates) else { + return plan_err!("at least one expression exists"); + }; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, - child_filter.input.clone(), + child_filter.input, )?); - self.rewrite(new_filter, _config)?.data + self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) => { - let new_filter = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + LogicalPlan::Repartition(repartition) => { + let new_filter = + Filter::try_new(filter.predicate, repartition.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Repartition(repartition), new_filter) } - LogicalPlan::SubqueryAlias(ref subquery_alias) => { + LogicalPlan::Distinct(distinct) => { + let new_filter = + Filter::try_new(filter.predicate, distinct.input().clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Distinct(distinct), new_filter) + } + LogicalPlan::Sort(sort) => { + let new_filter = Filter::try_new(filter.predicate, sort.input.clone()) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Sort(sort), new_filter) + } + LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -692,15 +682,15 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = - replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; + let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, subquery_alias.input.clone(), )?); - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } - LogicalPlan::Projection(ref projection) => { + LogicalPlan::Projection(projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. @@ -711,10 +701,7 @@ impl OptimizerRule for PushDownFilter { .enumerate() .map(|(i, (qualifier, field))| { // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; + let expr = projection.expr[i].clone().unalias(); (qualified_name(qualifier, field.name()), expr) }) @@ -741,23 +728,24 @@ impl OptimizerRule for PushDownFilter { )?); match conjunction(keep_predicates) { - None => child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?, - Some(keep_predicate) => { - let child_plan = child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?; - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(child_plan), - )?) - } + None => insert_below( + LogicalPlan::Projection(projection), + new_filter, + ), + Some(keep_predicate) => insert_below( + LogicalPlan::Projection(projection), + new_filter, + )? + .map_data(|child_plan| { + Filter::try_new(keep_predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + }), } } - None => return Ok(Transformed::no(plan)), + None => { + filter.input = Arc::new(LogicalPlan::Projection(projection)); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } } } LogicalPlan::Union(ref union) => { @@ -780,12 +768,12 @@ impl OptimizerRule for PushDownFilter { input.clone(), )?))) } - LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs, - schema: plan.schema().clone(), - }) + schema: plan_schema.clone(), + }))) } - LogicalPlan::Aggregate(ref agg) => { + LogicalPlan::Aggregate(agg) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -818,49 +806,33 @@ impl OptimizerRule for PushDownFilter { .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) .collect::>>()?; - let child = match conjunction(replaced_push_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - agg.input.clone(), - )?), - None => (*agg.input).clone(), - }; - let new_agg = filter - .input - .with_new_exprs(filter.input.expressions(), vec![child])?; - match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_agg), - )?), - None => new_agg, - } - } - LogicalPlan::Join(ref join) => { - push_down_join( - &unwrap_arc(filter.clone().input), - join, - Some(&filter.predicate), - )? - .data + let agg_input = agg.input.clone(); + Transformed::yes(LogicalPlan::Aggregate(agg)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the aggregate + if let Some(predicate) = conjunction(replaced_push_predicates) { + let new_filter = make_filter(predicate, agg_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } - LogicalPlan::CrossJoin(ref cross_join) => { + LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::CrossJoin(cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - let join = convert_cross_join_to_inner_join(cross_join.clone())?; - let join_plan = LogicalPlan::Join(join); - let inputs = join_plan.inputs(); - let left = inputs[0]; - let right = inputs[1]; - let plan = push_down_all_join( - predicates, - vec![], - &join_plan, - left, - right, - vec![], - true, - )?; - convert_to_cross_join_if_beneficial(plan.data)? + let join = convert_cross_join_to_inner_join(cross_join)?; + let plan = push_down_all_join(predicates, vec![], join, vec![])?; + convert_to_cross_join_if_beneficial(plan.data) } LogicalPlan::TableScan(ref scan) => { let filter_predicates = split_conjunction(&filter.predicate); @@ -901,25 +873,47 @@ impl OptimizerRule for PushDownFilter { fetch: scan.fetch, }); - match conjunction(new_predicate) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_scan), - )?), - None => new_scan, - } + Transformed::yes(new_scan).transform_data(|new_scan| { + if let Some(predicate) = conjunction(new_predicate) { + make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) + } else { + Ok(Transformed::no(new_scan)) + } + }) } - LogicalPlan::Extension(ref extension_plan) => { + LogicalPlan::Extension(extension_plan) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = split_conjunction_owned(filter.predicate.clone()); + // determine if we can push any predicates down past the extension node + + // each element is true for push, false to keep + let predicate_push_or_keep = split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + let cols = expr.to_columns()?; + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + Ok(false) // No push (keep) + } else { + Ok(true) // push + } + }) + .collect::>>()?; + // all predicates are kept, no changes needed + if predicate_push_or_keep.iter().all(|&x| !x) { + filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + // going to push some predicates down, so split the predicates let mut keep_predicates = vec![]; let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.to_columns()?; - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + for (push, expr) in predicate_push_or_keep + .into_iter() + .zip(split_conjunction_owned(filter.predicate).into_iter()) + { + if !push { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -941,22 +935,65 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. + let child_plan = LogicalPlan::Extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - match conjunction(keep_predicates) { + let new_plan = match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_extension), )?), None => new_extension, - } + }; + Ok(Transformed::yes(new_plan)) } - _ => return Ok(Transformed::no(plan)), - }; + child => { + filter.input = Arc::new(child); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } + } +} + +/// Creates a new LogicalPlan::Filter node. +pub fn make_filter(predicate: Expr, input: Arc) -> Result { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) +} - Ok(Transformed::yes(new_plan)) +/// Replace the existing child of the single input node with `new_child`. +/// +/// Starting: +/// ```text +/// plan +/// child +/// ``` +/// +/// Ending: +/// ```text +/// plan +/// new_child +/// ``` +fn insert_below( + plan: LogicalPlan, + new_child: LogicalPlan, +) -> Result> { + let mut new_child = Some(new_child); + let transformed_plan = plan.map_children(|_child| { + if let Some(new_child) = new_child.take() { + Ok(Transformed::yes(new_child)) + } else { + // already took the new child + internal_err!("node had more than one input") + } + })?; + + // make sure we did the actual replacement + if new_child.is_some() { + return internal_err!("node had no inputs"); } + + Ok(transformed_plan) } impl PushDownFilter { @@ -985,21 +1022,27 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { /// Converts the given inner join with an empty equality predicate and an /// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { - if let LogicalPlan::Join(join) = &plan { +fn convert_to_cross_join_if_beneficial( + plan: LogicalPlan, +) -> Result> { + match plan { // Can be converted back to cross join - if join.on.is_empty() && join.filter.is_none() { - return LogicalPlanBuilder::from(join.left.as_ref().clone()) - .cross_join(join.right.as_ref().clone())? - .build(); + LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { + LogicalPlanBuilder::from(unwrap_arc(join.left)) + .cross_join(unwrap_arc(join.right))? + .build() + .map(Transformed::yes) } - } else if let LogicalPlan::Filter(filter) = &plan { - let new_input = - convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; - return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) - .map(LogicalPlan::Filter); + LogicalPlan::Filter(filter) => convert_to_cross_join_if_beneficial(unwrap_arc( + filter.input, + ))? + .transform_data(|child_plan| { + Filter::try_new(filter.predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + .map(Transformed::yes) + }), + plan => Ok(Transformed::no(plan)), } - Ok(plan) } /// replaces columns by its name on the projection.