Skip to content

Commit 606a6b0

Browse files
committed
POC
1 parent d67c0bb commit 606a6b0

40 files changed

+1101
-981
lines changed

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {
9191

9292
impl MyAnalyzerRule {
9393
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
94-
plan.transform(&|plan| {
94+
plan.transform_up(&|plan| {
9595
Ok(match plan {
9696
LogicalPlan::Filter(filter) => {
9797
let predicate = Self::analyze_expr(filter.predicate.clone())?;
@@ -106,7 +106,7 @@ impl MyAnalyzerRule {
106106
}
107107

108108
fn analyze_expr(expr: Expr) -> Result<Expr> {
109-
expr.transform(&|expr| {
109+
expr.transform_up(&|expr| {
110110
// closure is invoked for all sub expressions
111111
Ok(match expr {
112112
Expr::Literal(ScalarValue::Int64(i)) => {
@@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {
161161

162162
/// use rewrite_expr to modify the expression tree.
163163
fn my_rewrite(expr: Expr) -> Result<Expr> {
164-
expr.transform(&|expr| {
164+
expr.transform_up(&|expr| {
165165
// closure is invoked for all sub expressions
166166
Ok(match expr {
167167
Expr::Between(Between {

datafusion/common/src/tree_node.rs

Lines changed: 296 additions & 107 deletions
Large diffs are not rendered by default.

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue};
3737
use super::PartitionedFile;
3838
use crate::datasource::listing::ListingTableUrl;
3939
use crate::execution::context::SessionState;
40-
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
40+
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
4141
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
4242
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
4343
use datafusion_physical_expr::create_physical_expr;
@@ -52,17 +52,18 @@ use object_store::{ObjectMeta, ObjectStore};
5252
/// was performed
5353
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
5454
let mut is_applicable = true;
55-
expr.apply(&mut |expr| {
55+
expr.visit_down(&mut |expr| {
5656
match expr {
5757
Expr::Column(Column { ref name, .. }) => {
5858
is_applicable &= col_names.contains(name);
5959
if is_applicable {
60-
Ok(VisitRecursion::Skip)
60+
Ok(TreeNodeRecursion::Prune)
6161
} else {
62-
Ok(VisitRecursion::Stop)
62+
Ok(TreeNodeRecursion::Stop)
6363
}
6464
}
65-
Expr::Literal(_)
65+
Expr::Nop
66+
| Expr::Literal(_)
6667
| Expr::Alias(_)
6768
| Expr::OuterReferenceColumn(_, _)
6869
| Expr::ScalarVariable(_, _)
@@ -88,27 +89,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
8889
| Expr::ScalarSubquery(_)
8990
| Expr::GetIndexedField { .. }
9091
| Expr::GroupingSet(_)
91-
| Expr::Case { .. } => Ok(VisitRecursion::Continue),
92+
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),
9293

9394
Expr::ScalarFunction(scalar_function) => {
9495
match &scalar_function.func_def {
9596
ScalarFunctionDefinition::BuiltIn(fun) => {
9697
match fun.volatility() {
97-
Volatility::Immutable => Ok(VisitRecursion::Continue),
98+
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
9899
// TODO: Stable functions could be `applicable`, but that would require access to the context
99100
Volatility::Stable | Volatility::Volatile => {
100101
is_applicable = false;
101-
Ok(VisitRecursion::Stop)
102+
Ok(TreeNodeRecursion::Stop)
102103
}
103104
}
104105
}
105106
ScalarFunctionDefinition::UDF(fun) => {
106107
match fun.signature().volatility {
107-
Volatility::Immutable => Ok(VisitRecursion::Continue),
108+
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
108109
// TODO: Stable functions could be `applicable`, but that would require access to the context
109110
Volatility::Stable | Volatility::Volatile => {
110111
is_applicable = false;
111-
Ok(VisitRecursion::Stop)
112+
Ok(TreeNodeRecursion::Stop)
112113
}
113114
}
114115
}
@@ -128,7 +129,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
128129
| Expr::Wildcard { .. }
129130
| Expr::Placeholder(_) => {
130131
is_applicable = false;
131-
Ok(VisitRecursion::Stop)
132+
Ok(TreeNodeRecursion::Stop)
132133
}
133134
}
134135
})

datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use arrow::{array::ArrayRef, datatypes::Schema};
1919
use arrow_schema::FieldRef;
20-
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
20+
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
2121
use datafusion_common::{Column, DataFusionError, Result, ScalarValue};
2222
use parquet::file::metadata::ColumnChunkMetaData;
2323
use parquet::schema::types::SchemaDescriptor;
@@ -259,7 +259,7 @@ impl BloomFilterPruningPredicate {
259259

260260
fn get_predicate_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<String> {
261261
let mut columns = HashSet::new();
262-
expr.apply(&mut |expr| {
262+
expr.visit_down(&mut |expr| {
263263
if let Some(binary_expr) =
264264
expr.as_any().downcast_ref::<phys_expr::BinaryExpr>()
265265
{
@@ -269,7 +269,7 @@ impl BloomFilterPruningPredicate {
269269
columns.insert(column.name().to_string());
270270
}
271271
}
272-
Ok(VisitRecursion::Continue)
272+
Ok(TreeNodeRecursion::Continue)
273273
})
274274
// no way to fail as only Ok(VisitRecursion::Continue) is returned
275275
.unwrap();

datafusion/core/src/execution/context/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::{
3838
use datafusion_common::{
3939
alias::AliasGenerator,
4040
exec_err, not_impl_err, plan_datafusion_err, plan_err,
41-
tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion},
41+
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
4242
};
4343
use datafusion_execution::registry::SerializerRegistry;
4444
use datafusion_expr::{
@@ -2098,9 +2098,9 @@ impl<'a> BadPlanVisitor<'a> {
20982098
}
20992099

21002100
impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
2101-
type N = LogicalPlan;
2101+
type Node = LogicalPlan;
21022102

2103-
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
2103+
fn pre_visit(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
21042104
match node {
21052105
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
21062106
plan_err!("DDL not supported: {}", ddl.name())
@@ -2114,9 +2114,13 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
21142114
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
21152115
plan_err!("Statement not supported: {}", stmt.name())
21162116
}
2117-
_ => Ok(VisitRecursion::Continue),
2117+
_ => Ok(TreeNodeRecursion::Continue),
21182118
}
21192119
}
2120+
2121+
fn post_visit(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
2122+
Ok(TreeNodeRecursion::Continue)
2123+
}
21202124
}
21212125

21222126
#[cfg(test)]

datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
178178
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
179179
group_expr
180180
.clone()
181-
.transform(&|expr| {
181+
.transform_up(&|expr| {
182182
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
183183
match expr.as_any().downcast_ref::<Column>() {
184184
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),

datafusion/core/src/physical_optimizer/enforce_distribution.rs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ use crate::physical_plan::{
4747
};
4848

4949
use arrow::compute::SortOptions;
50-
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
50+
use datafusion_common::tree_node::{
51+
Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator,
52+
};
5153
use datafusion_expr::logical_plan::JoinType;
5254
use datafusion_physical_expr::expressions::{Column, NoOp};
5355
use datafusion_physical_expr::utils::map_columns_before_projection;
@@ -1476,18 +1478,11 @@ impl DistributionContext {
14761478
}
14771479

14781480
impl TreeNode for DistributionContext {
1479-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
1481+
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
14801482
where
1481-
F: FnMut(&Self) -> Result<VisitRecursion>,
1483+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
14821484
{
1483-
for child in self.children() {
1484-
match op(&child)? {
1485-
VisitRecursion::Continue => {}
1486-
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
1487-
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
1488-
}
1489-
}
1490-
Ok(VisitRecursion::Continue)
1485+
self.children().iter().for_each_till_continue(f)
14911486
}
14921487

14931488
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -1505,6 +1500,23 @@ impl TreeNode for DistributionContext {
15051500
DistributionContext::new_from_children_nodes(children_nodes, self.plan)
15061501
}
15071502
}
1503+
1504+
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
1505+
where
1506+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
1507+
{
1508+
let mut children = self.children();
1509+
if children.is_empty() {
1510+
Ok(TreeNodeRecursion::Continue)
1511+
} else {
1512+
let tnr = children.iter_mut().for_each_till_continue(f)?;
1513+
*self = DistributionContext::new_from_children_nodes(
1514+
children,
1515+
self.plan.clone(),
1516+
)?;
1517+
Ok(tnr)
1518+
}
1519+
}
15081520
}
15091521

15101522
/// implement Display method for `DistributionContext` struct.
@@ -1566,20 +1578,11 @@ impl PlanWithKeyRequirements {
15661578
}
15671579

15681580
impl TreeNode for PlanWithKeyRequirements {
1569-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
1581+
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
15701582
where
1571-
F: FnMut(&Self) -> Result<VisitRecursion>,
1583+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
15721584
{
1573-
let children = self.children();
1574-
for child in children {
1575-
match op(&child)? {
1576-
VisitRecursion::Continue => {}
1577-
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
1578-
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
1579-
}
1580-
}
1581-
1582-
Ok(VisitRecursion::Continue)
1585+
self.children().iter().for_each_till_continue(f)
15831586
}
15841587

15851588
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -1605,6 +1608,22 @@ impl TreeNode for PlanWithKeyRequirements {
16051608
Ok(self)
16061609
}
16071610
}
1611+
1612+
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
1613+
where
1614+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
1615+
{
1616+
let mut children = self.children();
1617+
if !children.is_empty() {
1618+
let tnr = children.iter_mut().for_each_till_continue(f)?;
1619+
let children_plans = children.into_iter().map(|c| c.plan).collect();
1620+
self.plan =
1621+
with_new_children_if_necessary(self.plan.clone(), children_plans)?.into();
1622+
Ok(tnr)
1623+
} else {
1624+
Ok(TreeNodeRecursion::Continue)
1625+
}
1626+
}
16081627
}
16091628

16101629
/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on

datafusion/core/src/physical_optimizer/enforce_sorting.rs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ use crate::physical_plan::{
5757
with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode,
5858
};
5959

60-
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
60+
use datafusion_common::tree_node::{
61+
Transformed, TreeNode, TreeNodeRecursion, VisitRecursionIterator,
62+
};
6163
use datafusion_common::{plan_err, DataFusionError};
6264
use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement};
6365

@@ -157,20 +159,11 @@ impl PlanWithCorrespondingSort {
157159
}
158160

159161
impl TreeNode for PlanWithCorrespondingSort {
160-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
162+
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
161163
where
162-
F: FnMut(&Self) -> Result<VisitRecursion>,
164+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
163165
{
164-
let children = self.children();
165-
for child in children {
166-
match op(&child)? {
167-
VisitRecursion::Continue => {}
168-
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
169-
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
170-
}
171-
}
172-
173-
Ok(VisitRecursion::Continue)
166+
self.children().iter().for_each_till_continue(f)
174167
}
175168

176169
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -188,6 +181,23 @@ impl TreeNode for PlanWithCorrespondingSort {
188181
PlanWithCorrespondingSort::new_from_children_nodes(children_nodes, self.plan)
189182
}
190183
}
184+
185+
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
186+
where
187+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
188+
{
189+
let mut children = self.children();
190+
if children.is_empty() {
191+
Ok(TreeNodeRecursion::Continue)
192+
} else {
193+
let tnr = children.iter_mut().for_each_till_continue(f)?;
194+
*self = PlanWithCorrespondingSort::new_from_children_nodes(
195+
children,
196+
self.plan.clone(),
197+
)?;
198+
Ok(tnr)
199+
}
200+
}
191201
}
192202

193203
/// This object is used within the [EnforceSorting] rule to track the closest
@@ -273,20 +283,11 @@ impl PlanWithCorrespondingCoalescePartitions {
273283
}
274284

275285
impl TreeNode for PlanWithCorrespondingCoalescePartitions {
276-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
286+
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
277287
where
278-
F: FnMut(&Self) -> Result<VisitRecursion>,
288+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
279289
{
280-
let children = self.children();
281-
for child in children {
282-
match op(&child)? {
283-
VisitRecursion::Continue => {}
284-
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
285-
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
286-
}
287-
}
288-
289-
Ok(VisitRecursion::Continue)
290+
self.children().iter().for_each_till_continue(f)
290291
}
291292

292293
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -307,6 +308,23 @@ impl TreeNode for PlanWithCorrespondingCoalescePartitions {
307308
)
308309
}
309310
}
311+
312+
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
313+
where
314+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
315+
{
316+
let mut children = self.children();
317+
if children.is_empty() {
318+
Ok(TreeNodeRecursion::Continue)
319+
} else {
320+
let tnr = children.iter_mut().for_each_till_continue(f)?;
321+
*self = PlanWithCorrespondingCoalescePartitions::new_from_children_nodes(
322+
children,
323+
self.plan.clone(),
324+
)?;
325+
Ok(tnr)
326+
}
327+
}
310328
}
311329

312330
/// The boolean flag `repartition_sorts` defined in the config indicates

0 commit comments

Comments
 (0)