Skip to content

Commit 755342f

Browse files
committed
Implement TreeNode::map_children in place
1 parent 4bd7c13 commit 755342f

File tree

4 files changed

+233
-18
lines changed

4 files changed

+233
-18
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,15 @@ impl<T> Transformed<T> {
530530
})
531531
}
532532
}
533+
534+
/// Discards the data of this [`Transformed`] object transforming it into Transformed<()>
535+
pub fn discard_data(self) -> Transformed<()> {
536+
Transformed {
537+
data: (),
538+
transformed: self.transformed,
539+
tnr: self.tnr,
540+
}
541+
}
533542
}
534543

535544
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.

datafusion/expr/src/logical_plan/ddl.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ impl DdlStatement {
112112
}
113113
}
114114

115+
/// Return a mutable reference to the input `LogicalPlan`, if any
116+
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
117+
match self {
118+
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
119+
Some(input)
120+
}
121+
DdlStatement::CreateExternalTable(_) => None,
122+
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
123+
DdlStatement::CreateCatalogSchema(_) => None,
124+
DdlStatement::CreateCatalog(_) => None,
125+
DdlStatement::DropTable(_) => None,
126+
DdlStatement::DropView(_) => None,
127+
DdlStatement::DropCatalogSchema(_) => None,
128+
DdlStatement::CreateFunction(_) => None,
129+
DdlStatement::DropFunction(_) => None,
130+
}
131+
}
132+
115133
/// Return a `format`able structure with the a human readable
116134
/// description of this LogicalPlan node per node, not including
117135
/// children.

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 199 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use std::collections::{HashMap, HashSet};
2121
use std::fmt::{self, Debug, Display, Formatter};
2222
use std::hash::{Hash, Hasher};
23-
use std::sync::Arc;
23+
use std::sync::{Arc, OnceLock};
2424

2525
use super::dml::CopyTo;
2626
use super::DdlStatement;
@@ -45,7 +45,8 @@ use crate::{
4545

4646
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
4747
use datafusion_common::tree_node::{
48-
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
48+
Transformed, TransformedIterator, TransformedResult, TreeNode, TreeNodeRecursion,
49+
TreeNodeVisitor,
4950
};
5051
use datafusion_common::{
5152
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
@@ -1131,6 +1132,202 @@ impl LogicalPlan {
11311132
})?;
11321133
Ok(())
11331134
}
1135+
}
1136+
1137+
// TODO put this somewhere better than here
1138+
1139+
/// A temporary node that is left in place while rewriting the children of a
1140+
/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
1141+
/// always in a valid state (from the Rust perspective)
1142+
static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new();
1143+
1144+
/// its inputs, so this code would not be needed. However, for now we try and
1145+
/// unwrap the `Arc` which avoids `clone`ing in most cases.
1146+
///
1147+
/// On error, node be left with a placeholder logical plan
1148+
fn rewrite_arc<F>(
1149+
node: &mut Arc<LogicalPlan>,
1150+
mut f: F,
1151+
) -> Result<Transformed<&mut Arc<LogicalPlan>>>
1152+
where
1153+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
1154+
{
1155+
// We need to leave a valid node in the Arc, while we rewrite the existing
1156+
// one, so use a single global static placeholder node
1157+
let mut new_node = PLACEHOLDER
1158+
.get_or_init(|| {
1159+
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1160+
produce_one_row: false,
1161+
schema: DFSchemaRef::new(DFSchema::empty()),
1162+
}))
1163+
})
1164+
.clone();
1165+
1166+
// take the old value out of the Arc
1167+
std::mem::swap(node, &mut new_node);
1168+
1169+
// try to update existing node, if it isn't shared with others
1170+
let new_node = Arc::try_unwrap(new_node)
1171+
// if None is returned, there is another reference to this
1172+
// LogicalPlan, so we must clone instead
1173+
.unwrap_or_else(|node| node.as_ref().clone());
1174+
1175+
// apply the actual transform
1176+
let result = f(new_node)?;
1177+
1178+
// put the new value back into the Arc
1179+
let mut new_node = Arc::new(result.data);
1180+
std::mem::swap(node, &mut new_node);
1181+
1182+
// return the `node` back
1183+
Ok(Transformed::new(node, result.transformed, result.tnr))
1184+
}
1185+
1186+
/// Rewrite the arc and discard the contents of Transformed
1187+
fn rewrite_arc_no_data<F>(node: &mut Arc<LogicalPlan>, f: F) -> Result<Transformed<()>>
1188+
where
1189+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
1190+
{
1191+
rewrite_arc(node, f).map(|res| res.discard_data())
1192+
}
1193+
1194+
/// Rewrites all inputs for an Extension node "in place"
1195+
/// (it currently has to copy values because there are no APIs for in place modification)
1196+
///
1197+
/// Should be removed when we have an API for in place modifications of the
1198+
/// extension to avoid these copies
1199+
fn rewrite_extension_inputs<F>(
1200+
node: &mut Arc<dyn UserDefinedLogicalNode>,
1201+
f: F,
1202+
) -> Result<Transformed<()>>
1203+
where
1204+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
1205+
{
1206+
let Transformed {
1207+
data: new_inputs,
1208+
transformed,
1209+
tnr,
1210+
} = node
1211+
.inputs()
1212+
.into_iter()
1213+
.cloned()
1214+
.map_until_stop_and_collect(f)?;
1215+
1216+
let exprs = node.expressions();
1217+
let mut new_node = node.from_template(&exprs, &new_inputs);
1218+
std::mem::swap(node, &mut new_node);
1219+
Ok(Transformed {
1220+
data: (),
1221+
transformed,
1222+
tnr,
1223+
})
1224+
}
1225+
1226+
impl LogicalPlan {
1227+
/// applies `f` to each input of this plan node, rewriting them *in place.*
1228+
///
1229+
/// # Notes
1230+
/// Inputs include both direct children as well as any embedded subquery
1231+
/// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
1232+
///
1233+
/// If `f` returns an `Err`, that Err is returned, and the inputs are left
1234+
/// in a partially modified state
1235+
pub(crate) fn rewrite_children<F>(&mut self, mut f: F) -> Result<Transformed<()>>
1236+
where
1237+
F: FnMut(Self) -> Result<Transformed<Self>>,
1238+
{
1239+
let children_result = match self {
1240+
LogicalPlan::Projection(Projection { input, .. }) => {
1241+
rewrite_arc_no_data(input, &mut f)
1242+
}
1243+
LogicalPlan::Filter(Filter { input, .. }) => {
1244+
rewrite_arc_no_data(input, &mut f)
1245+
}
1246+
LogicalPlan::Repartition(Repartition { input, .. }) => {
1247+
rewrite_arc_no_data(input, &mut f)
1248+
}
1249+
LogicalPlan::Window(Window { input, .. }) => {
1250+
rewrite_arc_no_data(input, &mut f)
1251+
}
1252+
LogicalPlan::Aggregate(Aggregate { input, .. }) => {
1253+
rewrite_arc_no_data(input, &mut f)
1254+
}
1255+
LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc_no_data(input, &mut f),
1256+
LogicalPlan::Join(Join { left, right, .. }) => {
1257+
let results = [left, right]
1258+
.into_iter()
1259+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
1260+
Ok(results.discard_data())
1261+
}
1262+
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
1263+
let results = [left, right]
1264+
.into_iter()
1265+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
1266+
Ok(results.discard_data())
1267+
}
1268+
LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc_no_data(input, &mut f),
1269+
LogicalPlan::Subquery(Subquery { subquery, .. }) => {
1270+
rewrite_arc_no_data(subquery, &mut f)
1271+
}
1272+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
1273+
rewrite_arc_no_data(input, &mut f)
1274+
}
1275+
LogicalPlan::Extension(extension) => {
1276+
rewrite_extension_inputs(&mut extension.node, &mut f)
1277+
}
1278+
LogicalPlan::Union(Union { inputs, .. }) => {
1279+
let results = inputs
1280+
.iter_mut()
1281+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
1282+
Ok(results.discard_data())
1283+
}
1284+
LogicalPlan::Distinct(
1285+
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
1286+
) => rewrite_arc_no_data(input, &mut f),
1287+
LogicalPlan::Explain(explain) => {
1288+
rewrite_arc_no_data(&mut explain.plan, &mut f)
1289+
}
1290+
LogicalPlan::Analyze(analyze) => {
1291+
rewrite_arc_no_data(&mut analyze.input, &mut f)
1292+
}
1293+
LogicalPlan::Dml(write) => rewrite_arc_no_data(&mut write.input, &mut f),
1294+
LogicalPlan::Copy(copy) => rewrite_arc_no_data(&mut copy.input, &mut f),
1295+
LogicalPlan::Ddl(ddl) => {
1296+
if let Some(input) = ddl.input_mut() {
1297+
rewrite_arc_no_data(input, &mut f)
1298+
} else {
1299+
Ok(Transformed::no(()))
1300+
}
1301+
}
1302+
LogicalPlan::Unnest(Unnest { input, .. }) => {
1303+
rewrite_arc_no_data(input, &mut f)
1304+
}
1305+
LogicalPlan::Prepare(Prepare { input, .. }) => {
1306+
rewrite_arc_no_data(input, &mut f)
1307+
}
1308+
LogicalPlan::RecursiveQuery(RecursiveQuery {
1309+
static_term,
1310+
recursive_term,
1311+
..
1312+
}) => {
1313+
let results = [static_term, recursive_term]
1314+
.into_iter()
1315+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
1316+
Ok(results.discard_data())
1317+
}
1318+
// plans without inputs
1319+
LogicalPlan::TableScan { .. }
1320+
| LogicalPlan::Statement { .. }
1321+
| LogicalPlan::EmptyRelation { .. }
1322+
| LogicalPlan::Values { .. }
1323+
| LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())),
1324+
}?;
1325+
1326+
// after visiting the actual children we we need to visit any subqueries
1327+
// that are inside the expressions
1328+
// children_result.and_then(|| self.rewrite_subqueries(&mut f))
1329+
Ok(children_result)
1330+
}
11341331

11351332
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
11361333
/// ...) replaced with corresponding values provided in

datafusion/expr/src/tree_node/plan.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use crate::LogicalPlan;
2121

2222
use datafusion_common::tree_node::{
23-
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
23+
Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
2424
};
2525
use datafusion_common::{handle_visit_recursion, Result};
2626

@@ -91,23 +91,14 @@ impl TreeNode for LogicalPlan {
9191
Ok(tnr)
9292
}
9393

94-
fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
94+
fn map_children<F>(mut self, f: F) -> Result<Transformed<Self>>
9595
where
9696
F: FnMut(Self) -> Result<Transformed<Self>>,
9797
{
98-
let new_children = self
99-
.inputs()
100-
.iter()
101-
.map(|&c| c.clone())
102-
.map_until_stop_and_collect(f)?;
103-
// Propagate up `new_children.transformed` and `new_children.tnr`
104-
// along with the node containing transformed children.
105-
if new_children.transformed {
106-
new_children.map_data(|new_children| {
107-
self.with_new_exprs(self.expressions(), new_children)
108-
})
109-
} else {
110-
Ok(new_children.update_data(|_| self))
111-
}
98+
// Apply the rewrites in place for each child
99+
let result = self.rewrite_children(f)?;
100+
101+
// return a reference to ourself
102+
Ok(result.update_data(|_| self))
112103
}
113104
}

0 commit comments

Comments
 (0)