Skip to content

Extract result of find_common_exprs into a struct #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 82 additions & 44 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,22 @@ pub struct CommonSubexprEliminate {
random_state: RandomState,
}

type FindCommonExprResult = Option<(Vec<(Expr, String)>, Vec<Vec<Expr>>)>;
/// The result of potentially rewriting a list of expressions to eliminate common
/// subexpressions.
#[derive(Debug)]
enum FoundCommonExprs {
/// No common expressions were found
No { original_exprs_list: Vec<Vec<Expr>> },
/// Common expressions were found
Yes {
/// extracted common expressions
common_exprs: Vec<(Expr, String)>,
/// new expressions with common subexpressions replaced
new_exprs_list: Vec<Vec<Expr>>,
/// original expressions
original_exprs_list: Vec<Vec<Expr>>,
},
}

impl CommonSubexprEliminate {
pub fn new() -> Self {
Expand Down Expand Up @@ -242,16 +257,13 @@ impl CommonSubexprEliminate {

/// Extracts common sub-expressions and rewrites `exprs_list`.
///
/// Returns a tuple of:
/// 1. The rewritten expressions
/// 2. An optional tuple that contains the extracted common sub-expressions and the
/// original `exprs_list`.
/// Returns `FoundCommonExprs` recording the result of the extraction
fn find_common_exprs(
&self,
exprs_list: Vec<Vec<Expr>>,
config: &dyn OptimizerConfig,
expr_mask: ExprMask,
) -> Result<Transformed<(Vec<Vec<Expr>>, FindCommonExprResult)>> {
) -> Result<Transformed<FoundCommonExprs>> {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of the PR is to make this change and the rest of the changes are a consequence of doign so

let mut found_common = false;
let mut expr_stats = ExprStats::new();
let id_arrays_list = exprs_list
Expand Down Expand Up @@ -279,12 +291,15 @@ impl CommonSubexprEliminate {
)?;
assert!(!common_exprs.is_empty());

Ok(Transformed::yes((
Ok(Transformed::yes(FoundCommonExprs::Yes {
common_exprs: common_exprs.into_values().collect(),
new_exprs_list,
Some((common_exprs.into_values().collect(), exprs_list)),
)))
original_exprs_list: exprs_list,
}))
} else {
Ok(Transformed::no((exprs_list, None)))
Ok(Transformed::no(FoundCommonExprs::No {
original_exprs_list: exprs_list,
}))
}
}

Expand Down Expand Up @@ -356,17 +371,22 @@ impl CommonSubexprEliminate {

// Extract common sub-expressions from the list.
self.find_common_exprs(window_expr_list, config, ExprMask::Normal)?
.map_data(|(new_window_expr_list, common)| match common {
.map_data(|common| match common {
// If there are common sub-expressions, then the insert a projection node
// with the common expressions between the new window nodes and the
// original input.
Some((common_exprs, window_expr_list)) => {
FoundCommonExprs::Yes {
common_exprs,
new_exprs_list,
original_exprs_list,
} => {
build_common_expr_project_plan(input, common_exprs).map(|new_input| {
(new_window_expr_list, new_input, Some(window_expr_list))
(new_exprs_list, new_input, Some(original_exprs_list))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also renamed the local variables to follow the naming of the fields

})
}

None => Ok((new_window_expr_list, input, None)),
FoundCommonExprs::No {
original_exprs_list,
} => Ok((original_exprs_list, input, None)),
})?
// Recurse into the new input.
// (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
Expand Down Expand Up @@ -441,19 +461,22 @@ impl CommonSubexprEliminate {
let input = unwrap_arc(input);
// Extract common sub-expressions from the aggregate and grouping expressions.
self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)?
.map_data(|(mut new_expr_list, common)| {
let new_aggr_expr = new_expr_list.pop().unwrap();
let new_group_expr = new_expr_list.pop().unwrap();

.map_data(|common| {
match common {
// If there are common sub-expressions, then insert a projection node
// with the common expressions between the new aggregate node and the
// original input.
Some((common_exprs, mut expr_list)) => {
FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
mut original_exprs_list,
} => {
let new_aggr_expr = new_exprs_list.pop().unwrap();
let new_group_expr = new_exprs_list.pop().unwrap();

build_common_expr_project_plan(input, common_exprs).map(
|new_input| {
let aggr_expr = expr_list.pop().unwrap();

let aggr_expr = original_exprs_list.pop().unwrap();
(
new_aggr_expr,
new_group_expr,
Expand All @@ -464,7 +487,14 @@ impl CommonSubexprEliminate {
)
}

None => Ok((new_aggr_expr, new_group_expr, input, None)),
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let new_aggr_expr = original_exprs_list.pop().unwrap();
let new_group_expr = original_exprs_list.pop().unwrap();

Ok((new_aggr_expr, new_group_expr, input, None))
}
}
})?
// Recurse into the new input.
Expand All @@ -487,16 +517,17 @@ impl CommonSubexprEliminate {
config,
ExprMask::NormalAndAggregates,
)?
.map_data(|(mut new_aggr_list, common)| {
let rewritten_aggr_expr = new_aggr_list.pop().unwrap();

.map_data(|common| {
match common {
// If there are common aggregate sub-expressions, then insert a
// projection above the new rebuilt aggregate node.
Some((common_aggr_exprs, mut aggr_list)) => {
let new_aggr_expr = aggr_list.pop().unwrap();
FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
mut original_exprs_list,
} => {
let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
let new_aggr_expr = original_exprs_list.pop().unwrap();

let mut agg_exprs = common_aggr_exprs
let mut agg_exprs = common_exprs
.into_iter()
.map(|(expr, expr_alias)| expr.alias(expr_alias))
.collect::<Vec<_>>();
Expand Down Expand Up @@ -552,7 +583,11 @@ impl CommonSubexprEliminate {

// If there aren't any common aggregate sub-expressions, then just
// rebuild the aggregate node.
None => {
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let rewritten_aggr_expr = original_exprs_list.pop().unwrap();

// If there were common expressions extracted, then we need to
// make sure we restore the original column names.
// TODO: Although `find_common_exprs()` inserts aliases around
Expand Down Expand Up @@ -622,18 +657,21 @@ impl CommonSubexprEliminate {
) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
// Extract common sub-expressions from the expressions.
self.find_common_exprs(vec![exprs], config, ExprMask::Normal)?
.map_data(|(mut new_exprs_list, common)| {
let new_exprs = new_exprs_list.pop().unwrap();

match common {
// If there are common sub-expressions, then insert a projection node
// with the common expressions between the original node and the
// original input.
Some((common_exprs, _)) => {
build_common_expr_project_plan(input, common_exprs)
.map(|new_input| (new_exprs, new_input))
}
None => Ok((new_exprs, input)),
.map_data(|common| match common {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite as concise as the previous version but I do think it is more explicit and easier to follow

FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
original_exprs_list: _,
} => {
let new_exprs = new_exprs_list.pop().unwrap();
build_common_expr_project_plan(input, common_exprs)
.map(|new_input| (new_exprs, new_input))
}
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let new_exprs = original_exprs_list.pop().unwrap();
Ok((new_exprs, input))
}
})?
// Recurse into the new input.
Expand Down