Skip to content

Commit cca3135

Browse files
committed
- refactor transform_down() and transform_up() to work on mutable TreeNodes and use them in a few examples
- add `transform_down_with_payload()`, `transform_up_with_payload()`, `transform_with_payload()` and use it in `EnforceSorting` as an example
1 parent c0990de commit cca3135

30 files changed

+370
-391
lines changed

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {
9191

9292
impl MyAnalyzerRule {
9393
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
94-
plan.transform_up(&|plan| {
94+
plan.transform_up_old(&|plan| {
9595
Ok(match plan {
9696
LogicalPlan::Filter(filter) => {
9797
let predicate = Self::analyze_expr(filter.predicate.clone())?;
@@ -106,7 +106,7 @@ impl MyAnalyzerRule {
106106
}
107107

108108
fn analyze_expr(expr: Expr) -> Result<Expr> {
109-
expr.transform_up(&|expr| {
109+
expr.transform_up_old(&|expr| {
110110
// closure is invoked for all sub expressions
111111
Ok(match expr {
112112
Expr::Literal(ScalarValue::Int64(i)) => {
@@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {
161161

162162
/// use rewrite_expr to modify the expression tree.
163163
fn my_rewrite(expr: Expr) -> Result<Expr> {
164-
expr.transform_up(&|expr| {
164+
expr.transform_up_old(&|expr| {
165165
// closure is invoked for all sub expressions
166166
Ok(match expr {
167167
Expr::Between(Between {

datafusion/common/src/tree_node.rs

Lines changed: 146 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,19 @@ pub trait TreeNode: Sized {
5757
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
5858
{
5959
// Apply `f` on self.
60-
f(self)
60+
f(self)?
6161
// If it returns continue (not prune or stop or stop all) then continue
6262
// traversal on inner children and children.
6363
.and_then_on_continue(|| {
6464
// Run the recursive `apply` on each inner children, but as they are
6565
// unrelated root nodes of inner trees if any returns stop then continue
6666
// with the next one.
67-
self.apply_inner_children(&mut |c| c.visit_down(f).continue_on_stop())
67+
self.apply_inner_children(&mut |c| c.visit_down(f)?.continue_on_stop())?
6868
// Run the recursive `apply` on each children.
6969
.and_then_on_continue(|| {
7070
self.apply_children(&mut |c| c.visit_down(f))
7171
})
72-
})
72+
})?
7373
// Applying `f` on self might have returned prune, but we need to propagate
7474
// continue.
7575
.continue_on_prune()
@@ -107,21 +107,21 @@ pub trait TreeNode: Sized {
107107
) -> Result<TreeNodeRecursion> {
108108
// Apply `pre_visit` on self.
109109
visitor
110-
.pre_visit(self)
110+
.pre_visit(self)?
111111
// If it returns continue (not prune or stop or stop all) then continue
112112
// traversal on inner children and children.
113113
.and_then_on_continue(|| {
114114
// Run the recursive `visit` on each inner children, but as they are
115115
// unrelated subquery plans if any returns stop then continue with the
116116
// next one.
117-
self.apply_inner_children(&mut |c| c.visit(visitor).continue_on_stop())
117+
self.apply_inner_children(&mut |c| c.visit(visitor)?.continue_on_stop())?
118118
// Run the recursive `visit` on each children.
119119
.and_then_on_continue(|| {
120120
self.apply_children(&mut |c| c.visit(visitor))
121-
})
121+
})?
122122
// Apply `post_visit` on self.
123123
.and_then_on_continue(|| visitor.post_visit(self))
124-
})
124+
})?
125125
// Applying `pre_visit` or `post_visit` on self might have returned prune,
126126
// but we need to propagate continue.
127127
.continue_on_prune()
@@ -133,31 +133,144 @@ pub trait TreeNode: Sized {
133133
) -> Result<TreeNodeRecursion> {
134134
// Apply `pre_transform` on self.
135135
transformer
136-
.pre_transform(self)
136+
.pre_transform(self)?
137137
// If it returns continue (not prune or stop or stop all) then continue
138138
// traversal on inner children and children.
139139
.and_then_on_continue(||
140140
// Run the recursive `transform` on each children.
141141
self
142-
.transform_children(&mut |c| c.transform(transformer))
142+
.transform_children(&mut |c| c.transform(transformer))?
143143
// Apply `post_transform` on new self.
144-
.and_then_on_continue(|| {
145-
transformer.post_transform(self)
146-
}))
144+
.and_then_on_continue(|| transformer.post_transform(self)))?
147145
// Applying `pre_transform` or `post_transform` on self might have returned
148146
// prune, but we need to propagate continue.
149147
.continue_on_prune()
150148
}
151149

150+
fn transform_with_payload<FD, PD, FU, PU>(
151+
&mut self,
152+
f_down: &mut FD,
153+
payload_down: Option<PD>,
154+
f_up: &mut FU,
155+
) -> Result<(TreeNodeRecursion, Option<PU>)>
156+
where
157+
FD: FnMut(&mut Self, Option<PD>) -> Result<(TreeNodeRecursion, Vec<PD>)>,
158+
FU: FnMut(&mut Self, Vec<PU>) -> Result<(TreeNodeRecursion, PU)>,
159+
{
160+
// Apply `f_down` on self.
161+
let (tnr, new_payload_down) = f_down(self, payload_down)?;
162+
let mut new_payload_down_iter = new_payload_down.into_iter();
163+
// If it returns continue (not prune or stop or stop all) then continue traversal
164+
// on inner children and children.
165+
let mut new_payload_up = None;
166+
tnr.and_then_on_continue(|| {
167+
// Run the recursive `transform` on each children.
168+
let mut payload_up = vec![];
169+
let tnr = self.transform_children(&mut |c| {
170+
let (tnr, p) =
171+
c.transform_with_payload(f_down, new_payload_down_iter.next(), f_up)?;
172+
p.into_iter().for_each(|p| payload_up.push(p));
173+
Ok(tnr)
174+
})?;
175+
// Apply `f_up` on self.
176+
tnr.and_then_on_continue(|| {
177+
let (tnr, np) = f_up(self, payload_up)?;
178+
new_payload_up = Some(np);
179+
Ok(tnr)
180+
})
181+
})?
182+
// Applying `f_down` or `f_up` on self might have returned prune, but we need to propagate
183+
// continue.
184+
.continue_on_prune()
185+
.map(|tnr| (tnr, new_payload_up))
186+
}
187+
188+
fn transform_down<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
189+
where
190+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
191+
{
192+
// Apply `f` on self.
193+
f(self)?
194+
// If it returns continue (not prune or stop or stop all) then continue
195+
// traversal on inner children and children.
196+
.and_then_on_continue(||
197+
// Run the recursive `transform` on each children.
198+
self.transform_children(&mut |c| c.transform_down(f)))?
199+
// Applying `f` on self might have returned prune, but we need to propagate
200+
// continue.
201+
.continue_on_prune()
202+
}
203+
204+
fn transform_down_with_payload<F, P>(
205+
&mut self,
206+
f: &mut F,
207+
payload: P,
208+
) -> Result<TreeNodeRecursion>
209+
where
210+
F: FnMut(&mut Self, P) -> Result<(TreeNodeRecursion, Vec<P>)>,
211+
{
212+
// Apply `f` on self.
213+
let (tnr, new_payload) = f(self, payload)?;
214+
let mut new_payload_iter = new_payload.into_iter();
215+
// If it returns continue (not prune or stop or stop all) then continue
216+
// traversal on inner children and children.
217+
tnr.and_then_on_continue(||
218+
// Run the recursive `transform` on each children.
219+
self.transform_children(&mut |c| c.transform_down_with_payload(f, new_payload_iter.next().unwrap())))?
220+
// Applying `f` on self might have returned prune, but we need to propagate
221+
// continue.
222+
.continue_on_prune()
223+
}
224+
225+
fn transform_up<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
226+
where
227+
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
228+
{
229+
// Run the recursive `transform` on each children.
230+
self.transform_children(&mut |c| c.transform_up(f))?
231+
// Apply `f` on self.
232+
.and_then_on_continue(|| f(self))?
233+
// Applying `f` on self might have returned prune, but we need to propagate
234+
// continue.
235+
.continue_on_prune()
236+
}
237+
238+
fn transform_up_with_payload<F, P>(
239+
&mut self,
240+
f: &mut F,
241+
) -> Result<(TreeNodeRecursion, Option<P>)>
242+
where
243+
F: FnMut(&mut Self, Vec<P>) -> Result<(TreeNodeRecursion, P)>,
244+
{
245+
// Run the recursive `transform` on each children.
246+
let mut payload = vec![];
247+
let tnr = self.transform_children(&mut |c| {
248+
let (tnr, p) = c.transform_up_with_payload(f)?;
249+
p.into_iter().for_each(|p| payload.push(p));
250+
Ok(tnr)
251+
})?;
252+
let mut new_payload = None;
253+
// Apply `f` on self.
254+
tnr.and_then_on_continue(|| {
255+
let (tnr, np) = f(self, payload)?;
256+
new_payload = Some(np);
257+
Ok(tnr)
258+
})?
259+
// Applying `f` on self might have returned prune, but we need to propagate
260+
// continue.
261+
.continue_on_prune()
262+
.map(|tnr| (tnr, new_payload))
263+
}
264+
152265
/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
153266
/// children(Preorder Traversal).
154267
/// When the `op` does not apply to a given node, it is left unchanged.
155-
fn transform_down<F>(self, op: &F) -> Result<Self>
268+
fn transform_down_old<F>(self, op: &F) -> Result<Self>
156269
where
157270
F: Fn(Self) -> Result<Transformed<Self>>,
158271
{
159272
let after_op = op(self)?.into();
160-
after_op.map_children(|node| node.transform_down(op))
273+
after_op.map_children(|node| node.transform_down_old(op))
161274
}
162275

163276
/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
@@ -174,11 +287,11 @@ pub trait TreeNode: Sized {
174287
/// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
175288
/// children and then itself(Postorder Traversal).
176289
/// When the `op` does not apply to a given node, it is left unchanged.
177-
fn transform_up<F>(self, op: &F) -> Result<Self>
290+
fn transform_up_old<F>(self, op: &F) -> Result<Self>
178291
where
179292
F: Fn(Self) -> Result<Transformed<Self>>,
180293
{
181-
let after_op_children = self.map_children(|node| node.transform_up(op))?;
294+
let after_op_children = self.map_children(|node| node.transform_up_old(op))?;
182295

183296
let new_node = op(after_op_children)?.into();
184297
Ok(new_node)
@@ -402,63 +515,35 @@ pub enum TreeNodeRecursion {
402515
}
403516

404517
impl TreeNodeRecursion {
405-
fn continue_on_prune(self) -> TreeNodeRecursion {
406-
match self {
407-
TreeNodeRecursion::Prune => TreeNodeRecursion::Continue,
408-
o => o,
409-
}
410-
}
411-
412-
fn fail_on_prune(self) -> TreeNodeRecursion {
413-
match self {
414-
TreeNodeRecursion::Prune => panic!("Recursion can't prune."),
415-
o => o,
416-
}
417-
}
418-
419-
fn continue_on_stop(self) -> TreeNodeRecursion {
420-
match self {
421-
TreeNodeRecursion::Stop => TreeNodeRecursion::Continue,
422-
o => o,
423-
}
424-
}
425-
}
426-
427-
/// This helper trait provide functions to control recursion on
428-
/// [`Result<TreeNodeRecursion>`].
429-
pub trait TreeNodeRecursionResult: Sized {
430-
fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
431-
where
432-
F: FnOnce() -> Result<TreeNodeRecursion>;
433-
434-
fn continue_on_prune(self) -> Result<TreeNodeRecursion>;
435-
436-
fn fail_on_prune(self) -> Result<TreeNodeRecursion>;
437-
438-
fn continue_on_stop(self) -> Result<TreeNodeRecursion>;
439-
}
440-
441-
impl TreeNodeRecursionResult for Result<TreeNodeRecursion> {
442-
fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
518+
pub fn and_then_on_continue<F>(self, f: F) -> Result<TreeNodeRecursion>
443519
where
444520
F: FnOnce() -> Result<TreeNodeRecursion>,
445521
{
446-
match self? {
522+
match self {
447523
TreeNodeRecursion::Continue => f(),
448524
o => Ok(o),
449525
}
450526
}
451527

452-
fn continue_on_prune(self) -> Result<TreeNodeRecursion> {
453-
self.map(|tnr| tnr.continue_on_prune())
528+
pub fn continue_on_prune(self) -> Result<TreeNodeRecursion> {
529+
Ok(match self {
530+
TreeNodeRecursion::Prune => TreeNodeRecursion::Continue,
531+
o => o,
532+
})
454533
}
455534

456-
fn fail_on_prune(self) -> Result<TreeNodeRecursion> {
457-
self.map(|tnr| tnr.fail_on_prune())
535+
pub fn fail_on_prune(self) -> Result<TreeNodeRecursion> {
536+
Ok(match self {
537+
TreeNodeRecursion::Prune => panic!("Recursion can't prune."),
538+
o => o,
539+
})
458540
}
459541

460-
fn continue_on_stop(self) -> Result<TreeNodeRecursion> {
461-
self.map(|tnr| tnr.continue_on_stop())
542+
pub fn continue_on_stop(self) -> Result<TreeNodeRecursion> {
543+
Ok(match self {
544+
TreeNodeRecursion::Stop => TreeNodeRecursion::Continue,
545+
o => o,
546+
})
462547
}
463548
}
464549

datafusion/core/src/physical_optimizer/coalesce_batches.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
5252
}
5353

5454
let target_batch_size = config.execution.batch_size;
55-
plan.transform_up(&|plan| {
55+
plan.transform_up_old(&|plan| {
5656
let plan_any = plan.as_any();
5757
// The goal here is to detect operators that could produce small batches and only
5858
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here

0 commit comments

Comments
 (0)