Skip to content

Commit eee2ee5

Browse files
committed
Implement PhysicalExpr CSE
1 parent 9c12919 commit eee2ee5

25 files changed

+864
-68
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
5656
base64 = "0.22.1"
5757
half = { workspace = true }
5858
hashbrown = { workspace = true }
59-
indexmap = { workspace = true }
6059
libc = "0.2.140"
6160
log = { workspace = true }
6261
object_store = { workspace = true, optional = true }

datafusion/common/src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ use std::error::Error;
2323
use std::fmt::{self, Display};
2424
use std::str::FromStr;
2525

26+
use crate::alias::AliasGenerator;
2627
use crate::error::_config_err;
2728
use crate::parsers::CompressionTypeVariant;
2829
use crate::utils::get_available_parallelism;
2930
use crate::{DataFusionError, Result};
31+
use std::sync::Arc;
3032

3133
/// A macro that wraps a configuration struct and automatically derives
3234
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
@@ -736,6 +738,8 @@ pub struct ConfigOptions {
736738
pub explain: ExplainOptions,
737739
/// Optional extensions registered using [`Extensions::insert`]
738740
pub extensions: Extensions,
741+
/// Return alias generator used to generate unique aliases
742+
pub alias_generator: Arc<AliasGenerator>,
739743
}
740744

741745
impl ConfigField for ConfigOptions {

datafusion/common/src/cse.rs

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use crate::tree_node::{
2525
TreeNodeVisitor,
2626
};
2727
use crate::Result;
28-
use indexmap::IndexMap;
2928
use std::collections::HashMap;
3029
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
3130
use std::marker::PhantomData;
@@ -59,6 +58,12 @@ pub trait Normalizeable {
5958
fn can_normalize(&self) -> bool;
6059
}
6160

61+
impl<T: Normalizeable + ?Sized> Normalizeable for Arc<T> {
62+
fn can_normalize(&self) -> bool {
63+
(**self).can_normalize()
64+
}
65+
}
66+
6267
/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
6368
/// normalized nodes in optimizations like Common Subexpression Elimination (CSE).
6469
///
@@ -71,6 +76,12 @@ pub trait NormalizeEq: Eq + Normalizeable {
7176
fn normalize_eq(&self, other: &Self) -> bool;
7277
}
7378

79+
impl<T: NormalizeEq + ?Sized> NormalizeEq for Arc<T> {
80+
fn normalize_eq(&self, other: &Self) -> bool {
81+
(**self).normalize_eq(other)
82+
}
83+
}
84+
7485
/// Identifier that represents a [`TreeNode`] tree.
7586
///
7687
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
@@ -161,11 +172,13 @@ enum NodeEvaluation {
161172
}
162173

163174
/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
164-
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
175+
/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is
176+
/// found to be common and got extracted.
177+
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (NodeEvaluation, Option<usize>)>;
165178

166-
/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
167-
/// extracted during the second, rewriting traversal.
168-
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
179+
/// A list that contains the common [`TreeNode`]s and their alias, extracted during the
180+
/// second, rewriting traversal.
181+
type CommonNodes<'n, N> = Vec<(N, String)>;
169182

170183
type ChildrenList<N> = (Vec<N>, Vec<N>);
171184

@@ -193,7 +206,7 @@ pub trait CSEController {
193206
fn generate_alias(&self) -> String;
194207

195208
// Replaces a node to the generated alias.
196-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
209+
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;
197210

198211
// A helper method called on each node during top-down traversal during the second,
199212
// rewriting traversal of CSE.
@@ -394,7 +407,7 @@ where
394407
self.id_array[down_index].1 = Some(node_id);
395408
self.node_stats
396409
.entry(node_id)
397-
.and_modify(|evaluation| {
410+
.and_modify(|(evaluation, _)| {
398411
if *evaluation == NodeEvaluation::SurelyOnce
399412
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
400413
&& !self.conditional
@@ -404,11 +417,12 @@ where
404417
}
405418
})
406419
.or_insert_with(|| {
407-
if self.conditional {
420+
let evaluation = if self.conditional {
408421
NodeEvaluation::ConditionallyAtLeastOnce
409422
} else {
410423
NodeEvaluation::SurelyOnce
411-
}
424+
};
425+
(evaluation, None)
412426
});
413427
}
414428
self.visit_stack
@@ -428,7 +442,7 @@ where
428442
C: CSEController<Node = N>,
429443
{
430444
/// statistics of [`TreeNode`]s
431-
node_stats: &'a NodeStats<'n, N>,
445+
node_stats: &'a mut NodeStats<'n, N>,
432446

433447
/// cache to speed up second traversal
434448
id_array: &'a IdArray<'n, N>,
@@ -458,7 +472,7 @@ where
458472

459473
// Handle nodes with identifiers only
460474
if let Some(node_id) = node_id {
461-
let evaluation = self.node_stats.get(&node_id).unwrap();
475+
let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap();
462476
if *evaluation == NodeEvaluation::Common {
463477
// step index to skip all sub-node (which has smaller series number).
464478
while self.down_index < self.id_array.len()
@@ -482,13 +496,15 @@ where
482496
//
483497
// This way, we can efficiently handle semantically equivalent expressions without
484498
// incorrectly treating them as identical.
485-
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
486-
{
487-
self.controller.rewrite(&node, alias)
499+
let rewritten = if let Some(index) = common_index {
500+
let (_, alias) = self.common_nodes.get(*index).unwrap();
501+
self.controller.rewrite(&node, alias, *index)
488502
} else {
489-
let node_alias = self.controller.generate_alias();
490-
let rewritten = self.controller.rewrite(&node, &node_alias);
491-
self.common_nodes.insert(node_id, (node, node_alias));
503+
let index = self.common_nodes.len();
504+
let alias = self.controller.generate_alias();
505+
let rewritten = self.controller.rewrite(&node, &alias, index);
506+
*common_index = Some(index);
507+
self.common_nodes.push((node, alias));
492508
rewritten
493509
};
494510

@@ -587,7 +603,7 @@ where
587603
&mut self,
588604
node: N,
589605
id_array: &IdArray<'n, N>,
590-
node_stats: &NodeStats<'n, N>,
606+
node_stats: &mut NodeStats<'n, N>,
591607
common_nodes: &mut CommonNodes<'n, N>,
592608
) -> Result<N> {
593609
if id_array.is_empty() {
@@ -610,7 +626,7 @@ where
610626
&mut self,
611627
nodes_list: Vec<Vec<N>>,
612628
arrays_list: &[Vec<IdArray<'n, N>>],
613-
node_stats: &NodeStats<'n, N>,
629+
node_stats: &mut NodeStats<'n, N>,
614630
common_nodes: &mut CommonNodes<'n, N>,
615631
) -> Result<Vec<Vec<N>>> {
616632
nodes_list
@@ -656,13 +672,13 @@ where
656672
// nodes so we have to keep them intact.
657673
nodes_list.clone(),
658674
&id_arrays_list,
659-
&node_stats,
675+
&mut node_stats,
660676
&mut common_nodes,
661677
)?;
662678
assert!(!common_nodes.is_empty());
663679

664680
Ok(FoundCommonNodes::Yes {
665-
common_nodes: common_nodes.into_values().collect(),
681+
common_nodes,
666682
new_nodes_list,
667683
original_nodes_list: nodes_list,
668684
})
@@ -735,7 +751,12 @@ mod test {
735751
self.alias_generator.next(CSE_PREFIX)
736752
}
737753

738-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
754+
fn rewrite(
755+
&mut self,
756+
node: &Self::Node,
757+
alias: &str,
758+
_index: usize,
759+
) -> Self::Node {
739760
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
740761
}
741762
}

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ impl CSEController for ExprCSEController<'_> {
699699
self.alias_generator.next(CSE_PREFIX)
700700
}
701701

702-
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
702+
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
703703
// alias the expressions without an `Alias` ancestor node
704704
if self.alias_counter > 0 {
705705
col(alias)
@@ -1030,10 +1030,14 @@ mod test {
10301030
fn subexpr_in_same_order() -> Result<()> {
10311031
let table_scan = test_table_scan()?;
10321032

1033+
let a = col("a");
1034+
let lit_1 = lit(1);
1035+
let _1_plus_a = lit_1 + a;
1036+
10331037
let plan = LogicalPlanBuilder::from(table_scan)
10341038
.project(vec![
1035-
(lit(1) + col("a")).alias("first"),
1036-
(lit(1) + col("a")).alias("second"),
1039+
_1_plus_a.clone().alias("first"),
1040+
_1_plus_a.alias("second"),
10371041
])?
10381042
.build()?;
10391043

@@ -1050,8 +1054,13 @@ mod test {
10501054
fn subexpr_in_different_order() -> Result<()> {
10511055
let table_scan = test_table_scan()?;
10521056

1057+
let a = col("a");
1058+
let lit_1 = lit(1);
1059+
let _1_plus_a = lit_1.clone() + a.clone();
1060+
let a_plus_1 = a + lit_1;
1061+
10531062
let plan = LogicalPlanBuilder::from(table_scan)
1054-
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1063+
.project(vec![_1_plus_a, a_plus_1])?
10551064
.build()?;
10561065

10571066
let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\
@@ -1067,6 +1076,8 @@ mod test {
10671076
fn cross_plans_subexpr() -> Result<()> {
10681077
let table_scan = test_table_scan()?;
10691078

1079+
let _1_plus_col_a = lit(1) + col("a");
1080+
10701081
let plan = LogicalPlanBuilder::from(table_scan)
10711082
.project(vec![lit(1) + col("a"), col("a")])?
10721083
.project(vec![lit(1) + col("a")])?
@@ -1284,10 +1295,13 @@ mod test {
12841295
fn test_short_circuits() -> Result<()> {
12851296
let table_scan = test_table_scan()?;
12861297

1287-
let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1288-
let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1289-
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1290-
let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1298+
let a = col("a");
1299+
let b = col("b");
1300+
1301+
let extracted_short_circuit = a.clone().eq(lit(0)).or(b.clone().eq(lit(0)));
1302+
let extracted_short_circuit_leg_1 = (a.clone() + b.clone()).eq(lit(0));
1303+
let not_extracted_short_circuit_leg_2 = (a.clone() - b.clone()).eq(lit(0));
1304+
let extracted_short_circuit_leg_3 = (a * b).eq(lit(0));
12911305
let plan = LogicalPlanBuilder::from(table_scan)
12921306
.project(vec![
12931307
extracted_short_circuit.clone().alias("c1"),
@@ -1319,9 +1333,12 @@ mod test {
13191333
fn test_volatile() -> Result<()> {
13201334
let table_scan = test_table_scan()?;
13211335

1322-
let extracted_child = col("a") + col("b");
1323-
let rand = rand_func().call(vec![]);
1336+
let a = col("a");
1337+
let b = col("b");
1338+
let extracted_child = a + b;
1339+
let rand = rand_expr();
13241340
let not_extracted_volatile = extracted_child + rand;
1341+
13251342
let plan = LogicalPlanBuilder::from(table_scan)
13261343
.project(vec![
13271344
not_extracted_volatile.clone().alias("c1"),
@@ -1342,13 +1359,19 @@ mod test {
13421359
fn test_volatile_short_circuits() -> Result<()> {
13431360
let table_scan = test_table_scan()?;
13441361

1345-
let rand = rand_func().call(vec![]);
1346-
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1362+
let a = col("a");
1363+
let b = col("b");
1364+
let rand = rand_expr();
1365+
let rand_eq_0 = rand.eq(lit(0));
1366+
1367+
let extracted_short_circuit_leg_1 = a.eq(lit(0));
13471368
let not_extracted_volatile_short_circuit_1 =
1348-
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1349-
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1369+
extracted_short_circuit_leg_1.or(rand_eq_0.clone());
1370+
1371+
let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
13501372
let not_extracted_volatile_short_circuit_2 =
1351-
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1373+
rand_eq_0.or(not_extracted_short_circuit_leg_2);
1374+
13521375
let plan = LogicalPlanBuilder::from(table_scan)
13531376
.project(vec![
13541377
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
@@ -1371,7 +1394,10 @@ mod test {
13711394
fn test_non_top_level_common_expression() -> Result<()> {
13721395
let table_scan = test_table_scan()?;
13731396

1374-
let common_expr = col("a") + col("b");
1397+
let a = col("a");
1398+
let b = col("b");
1399+
let common_expr = a + b;
1400+
13751401
let plan = LogicalPlanBuilder::from(table_scan)
13761402
.project(vec![
13771403
common_expr.clone().alias("c1"),
@@ -1394,8 +1420,11 @@ mod test {
13941420
fn test_nested_common_expression() -> Result<()> {
13951421
let table_scan = test_table_scan()?;
13961422

1397-
let nested_common_expr = col("a") + col("b");
1423+
let a = col("a");
1424+
let b = col("b");
1425+
let nested_common_expr = a + b;
13981426
let common_expr = nested_common_expr.clone() * nested_common_expr;
1427+
13991428
let plan = LogicalPlanBuilder::from(table_scan)
14001429
.project(vec![
14011430
common_expr.clone().alias("c1"),
@@ -1671,8 +1700,8 @@ mod test {
16711700
///
16721701
/// Does not use datafusion_functions::rand to avoid introducing a
16731702
/// dependency on that crate.
1674-
fn rand_func() -> ScalarUDF {
1675-
ScalarUDF::new_from_impl(RandomStub::new())
1703+
fn rand_expr() -> Expr {
1704+
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
16761705
}
16771706

16781707
#[derive(Debug)]

0 commit comments

Comments
 (0)