Skip to content

Commit ec0d407

Browse files
committed
Implement PhysicalExpr CSE
1 parent 8065fb2 commit ec0d407

File tree

28 files changed

+739
-61
lines changed

28 files changed

+739
-61
lines changed

datafusion-cli/Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
5656
chrono = { workspace = true }
5757
half = { workspace = true }
5858
hashbrown = { workspace = true }
59-
indexmap = { workspace = true }
6059
libc = "0.2.140"
6160
num_cpus = { workspace = true }
6261
object_store = { workspace = true, optional = true }

datafusion/common/src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap};
2222
use std::fmt::{self, Display};
2323
use std::str::FromStr;
2424

25+
use crate::alias::AliasGenerator;
2526
use crate::error::_config_err;
2627
use crate::parsers::CompressionTypeVariant;
2728
use crate::{DataFusionError, Result};
29+
use std::sync::Arc;
2830

2931
/// A macro that wraps a configuration struct and automatically derives
3032
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
@@ -674,6 +676,8 @@ pub struct ConfigOptions {
674676
pub explain: ExplainOptions,
675677
/// Optional extensions registered using [`Extensions::insert`]
676678
pub extensions: Extensions,
679+
/// Return alias generator used to generate unique aliases
680+
pub alias_generator: Arc<AliasGenerator>,
677681
}
678682

679683
impl ConfigField for ConfigOptions {

datafusion/common/src/cse.rs

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use crate::tree_node::{
2525
TreeNodeVisitor,
2626
};
2727
use crate::Result;
28-
use indexmap::IndexMap;
2928
use std::collections::HashMap;
3029
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
3130
use std::marker::PhantomData;
@@ -123,11 +122,11 @@ type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
123122

124123
/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s
125124
/// by their identifiers.
126-
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize)>;
125+
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize, Option<usize>)>;
127126

128127
/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
129128
/// extracted during the second, rewriting traversal.
130-
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
129+
type CommonNodes<'n, N> = Vec<(N, String)>;
131130

132131
type ChildrenList<N> = (Vec<N>, Vec<N>);
133132

@@ -155,7 +154,7 @@ pub trait CSEController {
155154
fn generate_alias(&self) -> String;
156155

157156
// Replaces a node to the generated alias.
158-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
157+
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;
159158

160159
// A helper method called on each node during top-down traversal during the second,
161160
// rewriting traversal of CSE.
@@ -331,8 +330,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
331330
self.id_array[down_index].0 = self.up_index;
332331
if is_valid && !self.controller.is_ignored(node) {
333332
self.id_array[down_index].1 = Some(node_id);
334-
let (count, conditional_count) =
335-
self.node_stats.entry(node_id).or_insert((0, 0));
333+
let (count, conditional_count, _) =
334+
self.node_stats.entry(node_id).or_insert((0, 0, None));
336335
if self.conditional {
337336
*conditional_count += 1;
338337
} else {
@@ -355,7 +354,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
355354
/// replaced [`TreeNode`] tree.
356355
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
357356
/// statistics of [`TreeNode`]s
358-
node_stats: &'a NodeStats<'n, N>,
357+
node_stats: &'a mut NodeStats<'n, N>,
359358

360359
/// cache to speed up second traversal
361360
id_array: &'a IdArray<'n, N>,
@@ -383,7 +382,8 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
383382

384383
// Handle nodes with identifiers only
385384
if let Some(node_id) = node_id {
386-
let (count, conditional_count) = self.node_stats.get(&node_id).unwrap();
385+
let (count, conditional_count, common_index) =
386+
self.node_stats.get_mut(&node_id).unwrap();
387387
if *count > 1 || *count == 1 && *conditional_count > 0 {
388388
// step index to skip all sub-node (which has smaller series number).
389389
while self.down_index < self.id_array.len()
@@ -392,13 +392,15 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
392392
self.down_index += 1;
393393
}
394394

395-
let (node, alias) =
396-
self.common_nodes.entry(node_id).or_insert_with(|| {
397-
let node_alias = self.controller.generate_alias();
398-
(node, node_alias)
399-
});
395+
let index = *common_index.get_or_insert_with(|| {
396+
let index = self.common_nodes.len();
397+
let node_alias = self.controller.generate_alias();
398+
self.common_nodes.push((node, node_alias));
399+
index
400+
});
401+
let (node, alias) = self.common_nodes.get(index).unwrap();
400402

401-
let rewritten = self.controller.rewrite(node, alias);
403+
let rewritten = self.controller.rewrite(node, alias, index);
402404

403405
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
404406
}
@@ -491,7 +493,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
491493
&mut self,
492494
node: N,
493495
id_array: &IdArray<'n, N>,
494-
node_stats: &NodeStats<'n, N>,
496+
node_stats: &mut NodeStats<'n, N>,
495497
common_nodes: &mut CommonNodes<'n, N>,
496498
) -> Result<N> {
497499
if id_array.is_empty() {
@@ -514,7 +516,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
514516
&mut self,
515517
nodes_list: Vec<Vec<N>>,
516518
arrays_list: &[Vec<IdArray<'n, N>>],
517-
node_stats: &NodeStats<'n, N>,
519+
node_stats: &mut NodeStats<'n, N>,
518520
common_nodes: &mut CommonNodes<'n, N>,
519521
) -> Result<Vec<Vec<N>>> {
520522
nodes_list
@@ -559,13 +561,13 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
559561
// nodes so we have to keep them intact.
560562
nodes_list.clone(),
561563
&id_arrays_list,
562-
&node_stats,
564+
&mut node_stats,
563565
&mut common_nodes,
564566
)?;
565567
assert!(!common_nodes.is_empty());
566568

567569
Ok(FoundCommonNodes::Yes {
568-
common_nodes: common_nodes.into_values().collect(),
570+
common_nodes,
569571
new_nodes_list,
570572
original_nodes_list: nodes_list,
571573
})
@@ -635,7 +637,12 @@ mod test {
635637
self.alias_generator.next(CSE_PREFIX)
636638
}
637639

638-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
640+
fn rewrite(
641+
&mut self,
642+
node: &Self::Node,
643+
alias: &str,
644+
_index: usize,
645+
) -> Self::Node {
639646
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
640647
}
641648
}

datafusion/core/src/physical_optimizer/optimizer.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717

1818
//! Physical optimizer traits
1919
20-
use datafusion_physical_optimizer::PhysicalOptimizerRule;
21-
use std::sync::Arc;
22-
2320
use super::projection_pushdown::ProjectionPushdown;
2421
use super::update_aggr_exprs::OptimizeAggregateOrder;
2522
use crate::physical_optimizer::aggregate_statistics::AggregateStatistics;
@@ -33,6 +30,9 @@ use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggr
3330
use crate::physical_optimizer::output_requirements::OutputRequirements;
3431
use crate::physical_optimizer::sanity_checker::SanityCheckPlan;
3532
use crate::physical_optimizer::topk_aggregation::TopKAggregation;
33+
use datafusion_physical_optimizer::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs;
34+
use datafusion_physical_optimizer::PhysicalOptimizerRule;
35+
use std::sync::Arc;
3636

3737
/// A rule-based physical optimizer.
3838
#[derive(Clone, Debug)]
@@ -103,6 +103,10 @@ impl PhysicalOptimizer {
103103
// replacing operators with fetching variants, or adding limits
104104
// past operators that support limit pushdown.
105105
Arc::new(LimitPushdown::new()),
106+
// The EliminateCommonPhysicalSubExprs rule extracts common physical
107+
// subexpression trees into a `ProjectionExec` node under the actual node to
108+
// calculate the common values only once.
109+
Arc::new(EliminateCommonPhysicalSubexprs::new()),
106110
// The SanityCheckPlan rule checks whether the order and
107111
// distribution requirements of each node in the plan
108112
// is satisfied. It will also reject non-runnable query

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,15 @@ impl CommonSubexprEliminate {
387387
// Since `group_expr` may have changed, schema may also.
388388
// Use `try_new()` method.
389389
Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
390-
.map(LogicalPlan::Aggregate)
391-
.map(Transformed::no)
390+
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
392391
} else {
393392
Aggregate::try_new_with_schema(
394393
new_input,
395394
new_group_expr,
396395
rewritten_aggr_expr,
397396
schema,
398397
)
399-
.map(LogicalPlan::Aggregate)
400-
.map(Transformed::no)
398+
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
401399
}
402400
}
403401
}
@@ -617,9 +615,7 @@ impl CSEController for ExprCSEController<'_> {
617615

618616
fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
619617
match node {
620-
// In case of `ScalarFunction`s we don't know which children are surely
621-
// executed so start visiting all children conditionally and stop the
622-
// recursion with `TreeNodeRecursion::Jump`.
618+
// In case of `ScalarFunction`s all children can be conditionally executed.
623619
Expr::ScalarFunction(ScalarFunction { func, args })
624620
if func.short_circuits() =>
625621
{
@@ -689,7 +685,7 @@ impl CSEController for ExprCSEController<'_> {
689685
self.alias_generator.next(CSE_PREFIX)
690686
}
691687

692-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
688+
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
693689
// alias the expressions without an `Alias` ancestor node
694690
if self.alias_counter > 0 {
695691
col(alias)
@@ -1019,10 +1015,14 @@ mod test {
10191015
fn subexpr_in_same_order() -> Result<()> {
10201016
let table_scan = test_table_scan()?;
10211017

1018+
let a = col("a");
1019+
let lit_1 = lit(1);
1020+
let _1_plus_a = lit_1 + a;
1021+
10221022
let plan = LogicalPlanBuilder::from(table_scan)
10231023
.project(vec![
1024-
(lit(1) + col("a")).alias("first"),
1025-
(lit(1) + col("a")).alias("second"),
1024+
_1_plus_a.clone().alias("first"),
1025+
_1_plus_a.alias("second"),
10261026
])?
10271027
.build()?;
10281028

@@ -1039,8 +1039,13 @@ mod test {
10391039
fn subexpr_in_different_order() -> Result<()> {
10401040
let table_scan = test_table_scan()?;
10411041

1042+
let a = col("a");
1043+
let lit_1 = lit(1);
1044+
let _1_plus_a = lit_1.clone() + a.clone();
1045+
let a_plus_1 = a + lit_1;
1046+
10421047
let plan = LogicalPlanBuilder::from(table_scan)
1043-
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1048+
.project(vec![_1_plus_a, a_plus_1])?
10441049
.build()?;
10451050

10461051
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
@@ -1055,6 +1060,8 @@ mod test {
10551060
fn cross_plans_subexpr() -> Result<()> {
10561061
let table_scan = test_table_scan()?;
10571062

1063+
let _1_plus_col_a = lit(1) + col("a");
1064+
10581065
let plan = LogicalPlanBuilder::from(table_scan)
10591066
.project(vec![lit(1) + col("a"), col("a")])?
10601067
.project(vec![lit(1) + col("a")])?
@@ -1307,9 +1314,12 @@ mod test {
13071314
fn test_volatile() -> Result<()> {
13081315
let table_scan = test_table_scan()?;
13091316

1310-
let extracted_child = col("a") + col("b");
1311-
let rand = rand_func().call(vec![]);
1317+
let a = col("a");
1318+
let b = col("b");
1319+
let extracted_child = a + b;
1320+
let rand = rand_expr();
13121321
let not_extracted_volatile = extracted_child + rand;
1322+
13131323
let plan = LogicalPlanBuilder::from(table_scan)
13141324
.project(vec![
13151325
not_extracted_volatile.clone().alias("c1"),
@@ -1330,13 +1340,19 @@ mod test {
13301340
fn test_volatile_short_circuits() -> Result<()> {
13311341
let table_scan = test_table_scan()?;
13321342

1333-
let rand = rand_func().call(vec![]);
1334-
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1343+
let a = col("a");
1344+
let b = col("b");
1345+
let rand = rand_expr();
1346+
let rand_eq_0 = rand.eq(lit(0));
1347+
1348+
let extracted_short_circuit_leg_1 = a.eq(lit(0));
13351349
let not_extracted_volatile_short_circuit_1 =
1336-
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1337-
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1350+
extracted_short_circuit_leg_1.or(rand_eq_0.clone());
1351+
1352+
let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
13381353
let not_extracted_volatile_short_circuit_2 =
1339-
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1354+
rand_eq_0.or(not_extracted_short_circuit_leg_2);
1355+
13401356
let plan = LogicalPlanBuilder::from(table_scan)
13411357
.project(vec![
13421358
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
@@ -1359,7 +1375,10 @@ mod test {
13591375
fn test_non_top_level_common_expression() -> Result<()> {
13601376
let table_scan = test_table_scan()?;
13611377

1362-
let common_expr = col("a") + col("b");
1378+
let a = col("a");
1379+
let b = col("b");
1380+
let common_expr = a + b;
1381+
13631382
let plan = LogicalPlanBuilder::from(table_scan)
13641383
.project(vec![
13651384
common_expr.clone().alias("c1"),
@@ -1382,8 +1401,11 @@ mod test {
13821401
fn test_nested_common_expression() -> Result<()> {
13831402
let table_scan = test_table_scan()?;
13841403

1385-
let nested_common_expr = col("a") + col("b");
1404+
let a = col("a");
1405+
let b = col("b");
1406+
let nested_common_expr = a + b;
13861407
let common_expr = nested_common_expr.clone() * nested_common_expr;
1408+
13871409
let plan = LogicalPlanBuilder::from(table_scan)
13881410
.project(vec![
13891411
common_expr.clone().alias("c1"),
@@ -1406,8 +1428,8 @@ mod test {
14061428
///
14071429
/// Does not use datafusion_functions::rand to avoid introducing a
14081430
/// dependency on that crate.
1409-
fn rand_func() -> ScalarUDF {
1410-
ScalarUDF::new_from_impl(RandomStub::new())
1431+
fn rand_expr() -> Expr {
1432+
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
14111433
}
14121434

14131435
#[derive(Debug)]

0 commit comments

Comments
 (0)