Skip to content

Commit b4a9ffd

Browse files
committed
Implement TreeNode::map_children in place
1 parent 63888e8 commit b4a9ffd

File tree

5 files changed

+261
-16
lines changed

5 files changed

+261
-16
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,15 @@ impl<T> Transformed<T> {
530530
})
531531
}
532532
}
533+
534+
/// Discards the data of this [`Transformed`] object transforming it into Transformed<()>
535+
pub fn discard_data(self) -> Transformed<()> {
536+
Transformed {
537+
data: (),
538+
transformed: self.transformed,
539+
tnr: self.tnr,
540+
}
541+
}
533542
}
534543

535544
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.

datafusion/expr/src/logical_plan/ddl.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ impl DdlStatement {
112112
}
113113
}
114114

115+
/// Return a mutable reference to the input `LogicalPlan`, if any
116+
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
117+
match self {
118+
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
119+
Some(input)
120+
}
121+
DdlStatement::CreateExternalTable(_) => None,
122+
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
123+
DdlStatement::CreateCatalogSchema(_) => None,
124+
DdlStatement::CreateCatalog(_) => None,
125+
DdlStatement::DropTable(_) => None,
126+
DdlStatement::DropView(_) => None,
127+
DdlStatement::DropCatalogSchema(_) => None,
128+
DdlStatement::CreateFunction(_) => None,
129+
DdlStatement::DropFunction(_) => None,
130+
}
131+
}
132+
115133
/// Return a `format`able structure with the a human readable
116134
/// description of this LogicalPlan node per node, not including
117135
/// children.

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod display;
2121
pub mod dml;
2222
mod extension;
2323
mod plan;
24+
mod rewrite;
2425
mod statement;
2526

2627
pub use builder::{
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Methods for rewriting logical plans
19+
20+
use crate::{
21+
Aggregate, CrossJoin, Distinct, DistinctOn, EmptyRelation, Filter, Join, Limit,
22+
LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery,
23+
SubqueryAlias, Union, Unnest, UserDefinedLogicalNode, Window,
24+
};
25+
use datafusion_common::tree_node::{Transformed, TransformedIterator};
26+
use datafusion_common::{DFSchema, DFSchemaRef, Result};
27+
use std::sync::{Arc, OnceLock};
28+
29+
/// A temporary node that is left in place while rewriting the children of a
30+
/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
31+
/// always in a valid state (from the Rust perspective)
32+
static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new();
33+
34+
/// its inputs, so this code would not be needed. However, for now we try and
35+
/// unwrap the `Arc` which avoids `clone`ing in most cases.
36+
///
37+
/// On error, node be left with a placeholder logical plan
38+
fn rewrite_arc<F>(
39+
node: &mut Arc<LogicalPlan>,
40+
mut f: F,
41+
) -> datafusion_common::Result<Transformed<&mut Arc<LogicalPlan>>>
42+
where
43+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
44+
{
45+
// We need to leave a valid node in the Arc, while we rewrite the existing
46+
// one, so use a single global static placeholder node
47+
let mut new_node = PLACEHOLDER
48+
.get_or_init(|| {
49+
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
50+
produce_one_row: false,
51+
schema: DFSchemaRef::new(DFSchema::empty()),
52+
}))
53+
})
54+
.clone();
55+
56+
// take the old value out of the Arc
57+
std::mem::swap(node, &mut new_node);
58+
59+
// try to update existing node, if it isn't shared with others
60+
let new_node = Arc::try_unwrap(new_node)
61+
// if None is returned, there is another reference to this
62+
// LogicalPlan, so we must clone instead
63+
.unwrap_or_else(|node| node.as_ref().clone());
64+
65+
// apply the actual transform
66+
let result = f(new_node)?;
67+
68+
// put the new value back into the Arc
69+
let mut new_node = Arc::new(result.data);
70+
std::mem::swap(node, &mut new_node);
71+
72+
// return the `node` back
73+
Ok(Transformed::new(node, result.transformed, result.tnr))
74+
}
75+
76+
/// Rewrite the arc and discard the contents of Transformed
77+
fn rewrite_arc_no_data<F>(
78+
node: &mut Arc<LogicalPlan>,
79+
f: F,
80+
) -> datafusion_common::Result<Transformed<()>>
81+
where
82+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
83+
{
84+
rewrite_arc(node, f).map(|res| res.discard_data())
85+
}
86+
87+
/// Rewrites all inputs for an Extension node "in place"
88+
/// (it currently has to copy values because there are no APIs for in place modification)
89+
///
90+
/// Should be removed when we have an API for in place modifications of the
91+
/// extension to avoid these copies
92+
fn rewrite_extension_inputs<F>(
93+
node: &mut Arc<dyn UserDefinedLogicalNode>,
94+
f: F,
95+
) -> datafusion_common::Result<Transformed<()>>
96+
where
97+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
98+
{
99+
let Transformed {
100+
data: new_inputs,
101+
transformed,
102+
tnr,
103+
} = node
104+
.inputs()
105+
.into_iter()
106+
.cloned()
107+
.map_until_stop_and_collect(f)?;
108+
109+
let exprs = node.expressions();
110+
let mut new_node = node.from_template(&exprs, &new_inputs);
111+
std::mem::swap(node, &mut new_node);
112+
Ok(Transformed {
113+
data: (),
114+
transformed,
115+
tnr,
116+
})
117+
}
118+
119+
impl LogicalPlan {
120+
/// Applies `f` to each child (input) of this plan node, rewriting them *in place.*
121+
///
122+
/// Note that this function returns `Transformed<()>` because it it does not
123+
/// consume `self`, but instead modifies it in place. However, `F` transforms
124+
/// the children by ownership
125+
///
126+
/// # Notes
127+
///
128+
/// Inputs include both direct children as well as any embedded subquery
129+
/// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
130+
pub(crate) fn rewrite_children<F>(&mut self, mut f: F) -> Result<Transformed<()>>
131+
where
132+
F: FnMut(Self) -> Result<Transformed<Self>>,
133+
{
134+
let children_result = match self {
135+
LogicalPlan::Projection(Projection { input, .. }) => {
136+
rewrite_arc_no_data(input, &mut f)
137+
}
138+
LogicalPlan::Filter(Filter { input, .. }) => {
139+
rewrite_arc_no_data(input, &mut f)
140+
}
141+
LogicalPlan::Repartition(Repartition { input, .. }) => {
142+
rewrite_arc_no_data(input, &mut f)
143+
}
144+
LogicalPlan::Window(Window { input, .. }) => {
145+
rewrite_arc_no_data(input, &mut f)
146+
}
147+
LogicalPlan::Aggregate(Aggregate { input, .. }) => {
148+
rewrite_arc_no_data(input, &mut f)
149+
}
150+
LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc_no_data(input, &mut f),
151+
LogicalPlan::Join(Join { left, right, .. }) => {
152+
let results = [left, right]
153+
.into_iter()
154+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
155+
Ok(results.discard_data())
156+
}
157+
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
158+
let results = [left, right]
159+
.into_iter()
160+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
161+
Ok(results.discard_data())
162+
}
163+
LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc_no_data(input, &mut f),
164+
LogicalPlan::Subquery(Subquery { subquery, .. }) => {
165+
rewrite_arc_no_data(subquery, &mut f)
166+
}
167+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
168+
rewrite_arc_no_data(input, &mut f)
169+
}
170+
LogicalPlan::Extension(extension) => {
171+
rewrite_extension_inputs(&mut extension.node, &mut f)
172+
}
173+
LogicalPlan::Union(Union { inputs, .. }) => {
174+
let results = inputs
175+
.iter_mut()
176+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
177+
Ok(results.discard_data())
178+
}
179+
LogicalPlan::Distinct(
180+
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
181+
) => rewrite_arc_no_data(input, &mut f),
182+
LogicalPlan::Explain(explain) => {
183+
rewrite_arc_no_data(&mut explain.plan, &mut f)
184+
}
185+
LogicalPlan::Analyze(analyze) => {
186+
rewrite_arc_no_data(&mut analyze.input, &mut f)
187+
}
188+
LogicalPlan::Dml(write) => rewrite_arc_no_data(&mut write.input, &mut f),
189+
LogicalPlan::Copy(copy) => rewrite_arc_no_data(&mut copy.input, &mut f),
190+
LogicalPlan::Ddl(ddl) => {
191+
if let Some(input) = ddl.input_mut() {
192+
rewrite_arc_no_data(input, &mut f)
193+
} else {
194+
Ok(Transformed::no(()))
195+
}
196+
}
197+
LogicalPlan::Unnest(Unnest { input, .. }) => {
198+
rewrite_arc_no_data(input, &mut f)
199+
}
200+
LogicalPlan::Prepare(Prepare { input, .. }) => {
201+
rewrite_arc_no_data(input, &mut f)
202+
}
203+
LogicalPlan::RecursiveQuery(RecursiveQuery {
204+
static_term,
205+
recursive_term,
206+
..
207+
}) => {
208+
let results = [static_term, recursive_term]
209+
.into_iter()
210+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
211+
Ok(results.discard_data())
212+
}
213+
// plans without inputs
214+
LogicalPlan::TableScan { .. }
215+
| LogicalPlan::Statement { .. }
216+
| LogicalPlan::EmptyRelation { .. }
217+
| LogicalPlan::Values { .. }
218+
| LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())),
219+
}?;
220+
221+
// after visiting the actual children we we need to visit any subqueries
222+
// that are inside the expressions
223+
// TODO use pattern introduced in https://github.com/apache/arrow-datafusion/pull/9913
224+
Ok(children_result)
225+
}
226+
}

datafusion/expr/src/tree_node/plan.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use crate::LogicalPlan;
2121

2222
use datafusion_common::tree_node::{
23-
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
23+
Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
2424
};
2525
use datafusion_common::{handle_visit_recursion, Result};
2626

@@ -91,23 +91,14 @@ impl TreeNode for LogicalPlan {
9191
Ok(tnr)
9292
}
9393

94-
fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
94+
fn map_children<F>(mut self, f: F) -> Result<Transformed<Self>>
9595
where
9696
F: FnMut(Self) -> Result<Transformed<Self>>,
9797
{
98-
let new_children = self
99-
.inputs()
100-
.iter()
101-
.map(|&c| c.clone())
102-
.map_until_stop_and_collect(f)?;
103-
// Propagate up `new_children.transformed` and `new_children.tnr`
104-
// along with the node containing transformed children.
105-
if new_children.transformed {
106-
new_children.map_data(|new_children| {
107-
self.with_new_exprs(self.expressions(), new_children)
108-
})
109-
} else {
110-
Ok(new_children.update_data(|_| self))
111-
}
98+
// Apply the rewrite *in place* for each child to avoid cloning
99+
let result = self.rewrite_children(f)?;
100+
101+
// return ourself
102+
Ok(result.update_data(|_| self))
112103
}
113104
}

0 commit comments

Comments
 (0)