Skip to content

Commit 2d27ff7

Browse files
committed
Implement PhysicalExpr CSE
1 parent f6c92fe commit 2d27ff7

26 files changed

+745
-70
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
5656
chrono = { workspace = true }
5757
half = { workspace = true }
5858
hashbrown = { workspace = true }
59-
indexmap = { workspace = true }
6059
libc = "0.2.140"
6160
num_cpus = { workspace = true }
6261
object_store = { workspace = true, optional = true }

datafusion/common/src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap};
2222
use std::fmt::{self, Display};
2323
use std::str::FromStr;
2424

25+
use crate::alias::AliasGenerator;
2526
use crate::error::_config_err;
2627
use crate::parsers::CompressionTypeVariant;
2728
use crate::{DataFusionError, Result};
29+
use std::sync::Arc;
2830

2931
/// A macro that wraps a configuration struct and automatically derives
3032
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
@@ -693,6 +695,8 @@ pub struct ConfigOptions {
693695
pub explain: ExplainOptions,
694696
/// Optional extensions registered using [`Extensions::insert`]
695697
pub extensions: Extensions,
698+
/// Return alias generator used to generate unique aliases
699+
pub alias_generator: Arc<AliasGenerator>,
696700
}
697701

698702
impl ConfigField for ConfigOptions {

datafusion/common/src/cse.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use crate::tree_node::{
2525
TreeNodeVisitor,
2626
};
2727
use crate::Result;
28-
use indexmap::IndexMap;
2928
use std::collections::HashMap;
3029
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
3130
use std::marker::PhantomData;
@@ -131,11 +130,13 @@ enum NodeEvaluation {
131130
}
132131

133132
/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
134-
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
133+
/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is
134+
/// found to be common and got extracted.
135+
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (NodeEvaluation, Option<usize>)>;
135136

136-
/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
137-
/// extracted during the second, rewriting traversal.
138-
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
137+
/// A list that contains the common [`TreeNode`]s and their alias, extracted during the
138+
/// second, rewriting traversal.
139+
type CommonNodes<'n, N> = Vec<(N, String)>;
139140

140141
type ChildrenList<N> = (Vec<N>, Vec<N>);
141142

@@ -163,7 +164,7 @@ pub trait CSEController {
163164
fn generate_alias(&self) -> String;
164165

165166
// Replaces a node to the generated alias.
166-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
167+
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;
167168

168169
// A helper method called on each node during top-down traversal during the second,
169170
// rewriting traversal of CSE.
@@ -341,7 +342,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
341342
self.id_array[down_index].1 = Some(node_id);
342343
self.node_stats
343344
.entry(node_id)
344-
.and_modify(|evaluation| {
345+
.and_modify(|(evaluation, _)| {
345346
if *evaluation == NodeEvaluation::SurelyOnce
346347
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
347348
&& !self.conditional
@@ -351,11 +352,12 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
351352
}
352353
})
353354
.or_insert_with(|| {
354-
if self.conditional {
355+
let evaluation = if self.conditional {
355356
NodeEvaluation::ConditionallyAtLeastOnce
356357
} else {
357358
NodeEvaluation::SurelyOnce
358-
}
359+
};
360+
(evaluation, None)
359361
});
360362
}
361363
self.visit_stack
@@ -371,7 +373,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
371373
/// replaced [`TreeNode`] tree.
372374
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
373375
/// statistics of [`TreeNode`]s
374-
node_stats: &'a NodeStats<'n, N>,
376+
node_stats: &'a mut NodeStats<'n, N>,
375377

376378
/// cache to speed up second traversal
377379
id_array: &'a IdArray<'n, N>,
@@ -399,7 +401,7 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
399401

400402
// Handle nodes with identifiers only
401403
if let Some(node_id) = node_id {
402-
let evaluation = self.node_stats.get(&node_id).unwrap();
404+
let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap();
403405
if *evaluation == NodeEvaluation::Common {
404406
// step index to skip all sub-node (which has smaller series number).
405407
while self.down_index < self.id_array.len()
@@ -408,13 +410,15 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
408410
self.down_index += 1;
409411
}
410412

411-
let (node, alias) =
412-
self.common_nodes.entry(node_id).or_insert_with(|| {
413-
let node_alias = self.controller.generate_alias();
414-
(node, node_alias)
415-
});
413+
let index = *common_index.get_or_insert_with(|| {
414+
let index = self.common_nodes.len();
415+
let node_alias = self.controller.generate_alias();
416+
self.common_nodes.push((node, node_alias));
417+
index
418+
});
419+
let (node, alias) = self.common_nodes.get(index).unwrap();
416420

417-
let rewritten = self.controller.rewrite(node, alias);
421+
let rewritten = self.controller.rewrite(node, alias, index);
418422

419423
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
420424
}
@@ -507,7 +511,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
507511
&mut self,
508512
node: N,
509513
id_array: &IdArray<'n, N>,
510-
node_stats: &NodeStats<'n, N>,
514+
node_stats: &mut NodeStats<'n, N>,
511515
common_nodes: &mut CommonNodes<'n, N>,
512516
) -> Result<N> {
513517
if id_array.is_empty() {
@@ -530,7 +534,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
530534
&mut self,
531535
nodes_list: Vec<Vec<N>>,
532536
arrays_list: &[Vec<IdArray<'n, N>>],
533-
node_stats: &NodeStats<'n, N>,
537+
node_stats: &mut NodeStats<'n, N>,
534538
common_nodes: &mut CommonNodes<'n, N>,
535539
) -> Result<Vec<Vec<N>>> {
536540
nodes_list
@@ -575,13 +579,13 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
575579
// nodes so we have to keep them intact.
576580
nodes_list.clone(),
577581
&id_arrays_list,
578-
&node_stats,
582+
&mut node_stats,
579583
&mut common_nodes,
580584
)?;
581585
assert!(!common_nodes.is_empty());
582586

583587
Ok(FoundCommonNodes::Yes {
584-
common_nodes: common_nodes.into_values().collect(),
588+
common_nodes,
585589
new_nodes_list,
586590
original_nodes_list: nodes_list,
587591
})
@@ -651,7 +655,12 @@ mod test {
651655
self.alias_generator.next(CSE_PREFIX)
652656
}
653657

654-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
658+
fn rewrite(
659+
&mut self,
660+
node: &Self::Node,
661+
alias: &str,
662+
_index: usize,
663+
) -> Self::Node {
655664
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
656665
}
657666
}

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -399,17 +399,15 @@ impl CommonSubexprEliminate {
399399
// Since `group_expr` may have changed, schema may also.
400400
// Use `try_new()` method.
401401
Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
402-
.map(LogicalPlan::Aggregate)
403-
.map(Transformed::no)
402+
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
404403
} else {
405404
Aggregate::try_new_with_schema(
406405
new_input,
407406
new_group_expr,
408407
rewritten_aggr_expr,
409408
schema,
410409
)
411-
.map(LogicalPlan::Aggregate)
412-
.map(Transformed::no)
410+
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
413411
}
414412
}
415413
}
@@ -628,9 +626,7 @@ impl CSEController for ExprCSEController<'_> {
628626

629627
fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
630628
match node {
631-
// In case of `ScalarFunction`s we don't know which children are surely
632-
// executed so start visiting all children conditionally and stop the
633-
// recursion with `TreeNodeRecursion::Jump`.
629+
// In case of `ScalarFunction`s all children can be conditionally executed.
634630
Expr::ScalarFunction(ScalarFunction { func, args })
635631
if func.short_circuits() =>
636632
{
@@ -700,7 +696,7 @@ impl CSEController for ExprCSEController<'_> {
700696
self.alias_generator.next(CSE_PREFIX)
701697
}
702698

703-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
699+
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
704700
// alias the expressions without an `Alias` ancestor node
705701
if self.alias_counter > 0 {
706702
col(alias)
@@ -1030,10 +1026,14 @@ mod test {
10301026
fn subexpr_in_same_order() -> Result<()> {
10311027
let table_scan = test_table_scan()?;
10321028

1029+
let a = col("a");
1030+
let lit_1 = lit(1);
1031+
let _1_plus_a = lit_1 + a;
1032+
10331033
let plan = LogicalPlanBuilder::from(table_scan)
10341034
.project(vec![
1035-
(lit(1) + col("a")).alias("first"),
1036-
(lit(1) + col("a")).alias("second"),
1035+
_1_plus_a.clone().alias("first"),
1036+
_1_plus_a.alias("second"),
10371037
])?
10381038
.build()?;
10391039

@@ -1050,8 +1050,13 @@ mod test {
10501050
fn subexpr_in_different_order() -> Result<()> {
10511051
let table_scan = test_table_scan()?;
10521052

1053+
let a = col("a");
1054+
let lit_1 = lit(1);
1055+
let _1_plus_a = lit_1.clone() + a.clone();
1056+
let a_plus_1 = a + lit_1;
1057+
10531058
let plan = LogicalPlanBuilder::from(table_scan)
1054-
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1059+
.project(vec![_1_plus_a, a_plus_1])?
10551060
.build()?;
10561061

10571062
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
@@ -1066,6 +1071,8 @@ mod test {
10661071
fn cross_plans_subexpr() -> Result<()> {
10671072
let table_scan = test_table_scan()?;
10681073

1074+
let _1_plus_col_a = lit(1) + col("a");
1075+
10691076
let plan = LogicalPlanBuilder::from(table_scan)
10701077
.project(vec![lit(1) + col("a"), col("a")])?
10711078
.project(vec![lit(1) + col("a")])?
@@ -1318,9 +1325,12 @@ mod test {
13181325
fn test_volatile() -> Result<()> {
13191326
let table_scan = test_table_scan()?;
13201327

1321-
let extracted_child = col("a") + col("b");
1322-
let rand = rand_func().call(vec![]);
1328+
let a = col("a");
1329+
let b = col("b");
1330+
let extracted_child = a + b;
1331+
let rand = rand_expr();
13231332
let not_extracted_volatile = extracted_child + rand;
1333+
13241334
let plan = LogicalPlanBuilder::from(table_scan)
13251335
.project(vec![
13261336
not_extracted_volatile.clone().alias("c1"),
@@ -1341,13 +1351,19 @@ mod test {
13411351
fn test_volatile_short_circuits() -> Result<()> {
13421352
let table_scan = test_table_scan()?;
13431353

1344-
let rand = rand_func().call(vec![]);
1345-
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1354+
let a = col("a");
1355+
let b = col("b");
1356+
let rand = rand_expr();
1357+
let rand_eq_0 = rand.eq(lit(0));
1358+
1359+
let extracted_short_circuit_leg_1 = a.eq(lit(0));
13461360
let not_extracted_volatile_short_circuit_1 =
1347-
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1348-
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1361+
extracted_short_circuit_leg_1.or(rand_eq_0.clone());
1362+
1363+
let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
13491364
let not_extracted_volatile_short_circuit_2 =
1350-
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1365+
rand_eq_0.or(not_extracted_short_circuit_leg_2);
1366+
13511367
let plan = LogicalPlanBuilder::from(table_scan)
13521368
.project(vec![
13531369
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
@@ -1370,7 +1386,10 @@ mod test {
13701386
fn test_non_top_level_common_expression() -> Result<()> {
13711387
let table_scan = test_table_scan()?;
13721388

1373-
let common_expr = col("a") + col("b");
1389+
let a = col("a");
1390+
let b = col("b");
1391+
let common_expr = a + b;
1392+
13741393
let plan = LogicalPlanBuilder::from(table_scan)
13751394
.project(vec![
13761395
common_expr.clone().alias("c1"),
@@ -1393,8 +1412,11 @@ mod test {
13931412
fn test_nested_common_expression() -> Result<()> {
13941413
let table_scan = test_table_scan()?;
13951414

1396-
let nested_common_expr = col("a") + col("b");
1415+
let a = col("a");
1416+
let b = col("b");
1417+
let nested_common_expr = a + b;
13971418
let common_expr = nested_common_expr.clone() * nested_common_expr;
1419+
13981420
let plan = LogicalPlanBuilder::from(table_scan)
13991421
.project(vec![
14001422
common_expr.clone().alias("c1"),
@@ -1417,8 +1439,8 @@ mod test {
14171439
///
14181440
/// Does not use datafusion_functions::rand to avoid introducing a
14191441
/// dependency on that crate.
1420-
fn rand_func() -> ScalarUDF {
1421-
ScalarUDF::new_from_impl(RandomStub::new())
1442+
fn rand_expr() -> Expr {
1443+
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
14221444
}
14231445

14241446
#[derive(Debug)]

0 commit comments

Comments
 (0)