From 9fff25ca8ed3f2fb68d7553fc876c875c569c25a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 8 Aug 2024 11:08:22 -0400 Subject: [PATCH 1/2] Extract the result of find_common_exprs into a struct --- .../optimizer/src/common_subexpr_eliminate.rs | 111 ++++++++++++------ 1 file changed, 74 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 3a38d9f8eb03..b8b3dded44c7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -143,7 +143,22 @@ pub struct CommonSubexprEliminate { random_state: RandomState, } -type FindCommonExprResult = Option<(Vec<(Expr, String)>, Vec>)>; +/// The result of potentially rewriting a list of expressions to eliminate common +/// subexpressions. +#[derive(Debug)] +enum FoundCommonExprs { + /// No common expressions were found + No { original_exprs: Vec> }, + /// Common expressions were found + Yes { + /// extracted common expressions + common_exprs: Vec<(Expr, String)>, + /// new expressions with common subexpressions replaced + new_exprs: Vec>, + /// original expressions + original_exprs: Vec>, + }, +} impl CommonSubexprEliminate { pub fn new() -> Self { @@ -251,7 +266,7 @@ impl CommonSubexprEliminate { exprs_list: Vec>, config: &dyn OptimizerConfig, expr_mask: ExprMask, - ) -> Result>, FindCommonExprResult)>> { + ) -> Result> { let mut found_common = false; let mut expr_stats = ExprStats::new(); let id_arrays_list = exprs_list @@ -279,12 +294,15 @@ impl CommonSubexprEliminate { )?; assert!(!common_exprs.is_empty()); - Ok(Transformed::yes(( - new_exprs_list, - Some((common_exprs.into_values().collect(), exprs_list)), - ))) + Ok(Transformed::yes(FoundCommonExprs::Yes { + common_exprs: common_exprs.into_values().collect(), + new_exprs: new_exprs_list, + original_exprs: exprs_list, + })) } else { - Ok(Transformed::no((exprs_list, None))) + Ok(Transformed::no(FoundCommonExprs::No { + original_exprs: exprs_list, + })) } } @@ -356,17 +374,22 @@ impl CommonSubexprEliminate { // Extract common sub-expressions from the list. self.find_common_exprs(window_expr_list, config, ExprMask::Normal)? - .map_data(|(new_window_expr_list, common)| match common { + .map_data(|common| match common { // If there are common sub-expressions, then the insert a projection node // with the common expressions between the new window nodes and the // original input. - Some((common_exprs, window_expr_list)) => { + FoundCommonExprs::Yes { + common_exprs, + new_exprs: new_window_expr_list, + original_exprs: window_expr_list, + } => { build_common_expr_project_plan(input, common_exprs).map(|new_input| { (new_window_expr_list, new_input, Some(window_expr_list)) }) } - - None => Ok((new_window_expr_list, input, None)), + FoundCommonExprs::No { original_exprs } => { + Ok((original_exprs, input, None)) + } })? // Recurse into the new input. // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) @@ -441,19 +464,22 @@ impl CommonSubexprEliminate { let input = unwrap_arc(input); // Extract common sub-expressions from the aggregate and grouping expressions. self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? - .map_data(|(mut new_expr_list, common)| { - let new_aggr_expr = new_expr_list.pop().unwrap(); - let new_group_expr = new_expr_list.pop().unwrap(); - + .map_data(|common| { match common { // If there are common sub-expressions, then insert a projection node // with the common expressions between the new aggregate node and the // original input. - Some((common_exprs, mut expr_list)) => { + FoundCommonExprs::Yes { + common_exprs, + new_exprs: mut new_expr_list, + original_exprs: mut expr_list, + } => { + let new_aggr_expr = new_expr_list.pop().unwrap(); + let new_group_expr = new_expr_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs).map( |new_input| { let aggr_expr = expr_list.pop().unwrap(); - ( new_aggr_expr, new_group_expr, @@ -464,7 +490,12 @@ impl CommonSubexprEliminate { ) } - None => Ok((new_aggr_expr, new_group_expr, input, None)), + FoundCommonExprs::No { mut original_exprs } => { + let new_aggr_expr = original_exprs.pop().unwrap(); + let new_group_expr = original_exprs.pop().unwrap(); + + Ok((new_aggr_expr, new_group_expr, input, None)) + } } })? // Recurse into the new input. @@ -487,13 +518,14 @@ impl CommonSubexprEliminate { config, ExprMask::NormalAndAggregates, )? - .map_data(|(mut new_aggr_list, common)| { - let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); - + .map_data(|common| { match common { - // If there are common aggregate sub-expressions, then insert a - // projection above the new rebuilt aggregate node. - Some((common_aggr_exprs, mut aggr_list)) => { + FoundCommonExprs::Yes { + common_exprs: common_aggr_exprs, + new_exprs: mut new_aggr_list, + original_exprs: mut aggr_list, + } => { + let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); let new_aggr_expr = aggr_list.pop().unwrap(); let mut agg_exprs = common_aggr_exprs @@ -552,7 +584,11 @@ impl CommonSubexprEliminate { // If there aren't any common aggregate sub-expressions, then just // rebuild the aggregate node. - None => { + FoundCommonExprs::No { + original_exprs: mut new_aggr_list, + } => { + let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); + // If there were common expressions extracted, then we need to // make sure we restore the original column names. // TODO: Although `find_common_exprs()` inserts aliases around @@ -622,18 +658,19 @@ impl CommonSubexprEliminate { ) -> Result, LogicalPlan)>> { // Extract common sub-expressions from the expressions. self.find_common_exprs(vec![exprs], config, ExprMask::Normal)? - .map_data(|(mut new_exprs_list, common)| { - let new_exprs = new_exprs_list.pop().unwrap(); - - match common { - // If there are common sub-expressions, then insert a projection node - // with the common expressions between the original node and the - // original input. - Some((common_exprs, _)) => { - build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs, new_input)) - } - None => Ok((new_exprs, input)), + .map_data(|common| match common { + FoundCommonExprs::Yes { + common_exprs, + new_exprs: mut new_exprs_list, + original_exprs: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| (new_exprs, new_input)) + } + FoundCommonExprs::No { mut original_exprs } => { + let new_exprs = original_exprs.pop().unwrap(); + Ok((new_exprs, input)) } })? // Recurse into the new input. From 03acbc4064c1b96ed588f319a4cdf667f1127811 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 8 Aug 2024 11:16:49 -0400 Subject: [PATCH 2/2] Make naming consistent --- .../optimizer/src/common_subexpr_eliminate.rs | 73 ++++++++++--------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index b8b3dded44c7..b3281c4e0592 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -148,15 +148,15 @@ pub struct CommonSubexprEliminate { #[derive(Debug)] enum FoundCommonExprs { /// No common expressions were found - No { original_exprs: Vec> }, + No { original_exprs_list: Vec> }, /// Common expressions were found Yes { /// extracted common expressions common_exprs: Vec<(Expr, String)>, /// new expressions with common subexpressions replaced - new_exprs: Vec>, + new_exprs_list: Vec>, /// original expressions - original_exprs: Vec>, + original_exprs_list: Vec>, }, } @@ -257,10 +257,7 @@ impl CommonSubexprEliminate { /// Extracts common sub-expressions and rewrites `exprs_list`. /// - /// Returns a tuple of: - /// 1. The rewritten expressions - /// 2. An optional tuple that contains the extracted common sub-expressions and the - /// original `exprs_list`. + /// Returns `FoundCommonExprs` recording the result of the extraction fn find_common_exprs( &self, exprs_list: Vec>, @@ -296,12 +293,12 @@ impl CommonSubexprEliminate { Ok(Transformed::yes(FoundCommonExprs::Yes { common_exprs: common_exprs.into_values().collect(), - new_exprs: new_exprs_list, - original_exprs: exprs_list, + new_exprs_list, + original_exprs_list: exprs_list, })) } else { Ok(Transformed::no(FoundCommonExprs::No { - original_exprs: exprs_list, + original_exprs_list: exprs_list, })) } } @@ -380,16 +377,16 @@ impl CommonSubexprEliminate { // original input. FoundCommonExprs::Yes { common_exprs, - new_exprs: new_window_expr_list, - original_exprs: window_expr_list, + new_exprs_list, + original_exprs_list, } => { build_common_expr_project_plan(input, common_exprs).map(|new_input| { - (new_window_expr_list, new_input, Some(window_expr_list)) + (new_exprs_list, new_input, Some(original_exprs_list)) }) } - FoundCommonExprs::No { original_exprs } => { - Ok((original_exprs, input, None)) - } + FoundCommonExprs::No { + original_exprs_list, + } => Ok((original_exprs_list, input, None)), })? // Recurse into the new input. // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) @@ -471,15 +468,15 @@ impl CommonSubexprEliminate { // original input. FoundCommonExprs::Yes { common_exprs, - new_exprs: mut new_expr_list, - original_exprs: mut expr_list, + mut new_exprs_list, + mut original_exprs_list, } => { - let new_aggr_expr = new_expr_list.pop().unwrap(); - let new_group_expr = new_expr_list.pop().unwrap(); + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); build_common_expr_project_plan(input, common_exprs).map( |new_input| { - let aggr_expr = expr_list.pop().unwrap(); + let aggr_expr = original_exprs_list.pop().unwrap(); ( new_aggr_expr, new_group_expr, @@ -490,9 +487,11 @@ impl CommonSubexprEliminate { ) } - FoundCommonExprs::No { mut original_exprs } => { - let new_aggr_expr = original_exprs.pop().unwrap(); - let new_group_expr = original_exprs.pop().unwrap(); + FoundCommonExprs::No { + mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); Ok((new_aggr_expr, new_group_expr, input, None)) } @@ -521,14 +520,14 @@ impl CommonSubexprEliminate { .map_data(|common| { match common { FoundCommonExprs::Yes { - common_exprs: common_aggr_exprs, - new_exprs: mut new_aggr_list, - original_exprs: mut aggr_list, + common_exprs, + mut new_exprs_list, + mut original_exprs_list, } => { - let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); - let new_aggr_expr = aggr_list.pop().unwrap(); + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); - let mut agg_exprs = common_aggr_exprs + let mut agg_exprs = common_exprs .into_iter() .map(|(expr, expr_alias)| expr.alias(expr_alias)) .collect::>(); @@ -585,9 +584,9 @@ impl CommonSubexprEliminate { // If there aren't any common aggregate sub-expressions, then just // rebuild the aggregate node. FoundCommonExprs::No { - original_exprs: mut new_aggr_list, + mut original_exprs_list, } => { - let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); // If there were common expressions extracted, then we need to // make sure we restore the original column names. @@ -661,15 +660,17 @@ impl CommonSubexprEliminate { .map_data(|common| match common { FoundCommonExprs::Yes { common_exprs, - new_exprs: mut new_exprs_list, - original_exprs: _, + mut new_exprs_list, + original_exprs_list: _, } => { let new_exprs = new_exprs_list.pop().unwrap(); build_common_expr_project_plan(input, common_exprs) .map(|new_input| (new_exprs, new_input)) } - FoundCommonExprs::No { mut original_exprs } => { - let new_exprs = original_exprs.pop().unwrap(); + FoundCommonExprs::No { + mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); Ok((new_exprs, input)) } })?