Skip to content

Avoid LogicalPlan::clone() in LogicalPlan::map_children when possible #9999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 363 additions & 20 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

//! Tree node implementation for logical plan

use crate::LogicalPlan;
use crate::{
Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, DdlStatement, Distinct,
DistinctOn, DmlStatement, Explain, Extension, Filter, Join, Limit, LogicalPlan,
Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias,
Union, Unnest, Window,
};
use std::sync::Arc;

use crate::dml::CopyTo;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
};
use datafusion_common::Result;
use datafusion_common::{map_until_stop_and_collect, Result};

impl TreeNode for LogicalPlan {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
Expand All @@ -32,23 +39,359 @@ impl TreeNode for LogicalPlan {
self.inputs().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
let new_children = self
.inputs()
.into_iter()
.cloned()
Copy link
Contributor Author

@alamb alamb Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this call to cloned() (and the self.with_new_exprs below) is the point of this PR.

As a subsequent PRs I plan to rewrite the Optimizer and the various passes to use this method to rewrite the plans without copying them

.map_until_stop_and_collect(f)?;
// Propagate up `new_children.transformed` and `new_children.tnr`
// along with the node containing transformed children.
if new_children.transformed {
new_children.map_data(|new_children| {
self.with_new_exprs(self.expressions(), new_children)
})
} else {
Ok(new_children.update_data(|_| self))
}
/// Applies `f` to each child (input) of this plan node, rewriting them *in place.*
///
/// # Notes
///
/// Inputs include ONLY direct children, not embedded `LogicalPlan`s for
/// subqueries, for example such as are in [`Expr::Exists`].
///
/// [`Expr::Exists`]: crate::Expr::Exists
fn map_children<F>(self, mut f: F) -> Result<Transformed<Self>>
where
F: FnMut(Self) -> Result<Transformed<Self>>,
{
Ok(match self {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the (really very cool) pattern @peter-toth came up with in #9913

LogicalPlan::Projection(Projection {
expr,
input,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Projection(Projection {
expr,
input,
schema,
})
}),
LogicalPlan::Filter(Filter { predicate, input }) => rewrite_arc(input, f)?
.update_data(|input| LogicalPlan::Filter(Filter { predicate, input })),
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
})
}),
LogicalPlan::Window(Window {
input,
window_expr,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Window(Window {
input,
window_expr,
schema,
})
}),
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
})
}),
LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)?
.update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })),
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
join_constraint,
schema,
null_equals_null,
}) => map_until_stop_and_collect!(
rewrite_arc(left, &mut f),
right,
rewrite_arc(right, &mut f)
)?
.update_data(|(left, right)| {
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
join_constraint,
schema,
null_equals_null,
})
}),
LogicalPlan::CrossJoin(CrossJoin {
left,
right,
schema,
}) => map_until_stop_and_collect!(
rewrite_arc(left, &mut f),
right,
rewrite_arc(right, &mut f)
)?
.update_data(|(left, right)| {
LogicalPlan::CrossJoin(CrossJoin {
left,
right,
schema,
})
}),
LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)?
.update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })),
LogicalPlan::Subquery(Subquery {
subquery,
outer_ref_columns,
}) => rewrite_arc(subquery, f)?.update_data(|subquery| {
LogicalPlan::Subquery(Subquery {
subquery,
outer_ref_columns,
})
}),
LogicalPlan::SubqueryAlias(SubqueryAlias {
input,
alias,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::SubqueryAlias(SubqueryAlias {
input,
alias,
schema,
})
}),
LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)?
.update_data(LogicalPlan::Extension),
LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)?
.update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })),
LogicalPlan::Distinct(distinct) => match distinct {
Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All),
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
})
}),
}
.update_data(LogicalPlan::Distinct),
LogicalPlan::Explain(Explain {
verbose,
plan,
stringified_plans,
schema,
logical_optimization_succeeded,
}) => rewrite_arc(plan, f)?.update_data(|plan| {
LogicalPlan::Explain(Explain {
verbose,
plan,
stringified_plans,
schema,
logical_optimization_succeeded,
})
}),
LogicalPlan::Analyze(Analyze {
verbose,
input,
schema,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Analyze(Analyze {
verbose,
input,
schema,
})
}),
LogicalPlan::Dml(DmlStatement {
table_name,
table_schema,
op,
input,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Dml(DmlStatement {
table_name,
table_schema,
op,
input,
})
}),
LogicalPlan::Copy(CopyTo {
input,
output_url,
partition_by,
format_options,
options,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Copy(CopyTo {
input,
output_url,
partition_by,
format_options,
options,
})
}),
LogicalPlan::Ddl(ddl) => {
match ddl {
DdlStatement::CreateMemoryTable(CreateMemoryTable {
name,
constraints,
input,
if_not_exists,
or_replace,
column_defaults,
}) => rewrite_arc(input, f)?.update_data(|input| {
DdlStatement::CreateMemoryTable(CreateMemoryTable {
name,
constraints,
input,
if_not_exists,
or_replace,
column_defaults,
})
}),
DdlStatement::CreateView(CreateView {
name,
input,
or_replace,
definition,
}) => rewrite_arc(input, f)?.update_data(|input| {
DdlStatement::CreateView(CreateView {
name,
input,
or_replace,
definition,
})
}),
// no inputs in these statements
DdlStatement::CreateExternalTable(_)
| DdlStatement::CreateCatalogSchema(_)
| DdlStatement::CreateCatalog(_)
| DdlStatement::DropTable(_)
| DdlStatement::DropView(_)
| DdlStatement::DropCatalogSchema(_)
| DdlStatement::CreateFunction(_)
| DdlStatement::DropFunction(_) => Transformed::no(ddl),
}
.update_data(LogicalPlan::Ddl)
}
LogicalPlan::Unnest(Unnest {
input,
column,
schema,
options,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Unnest(Unnest {
input,
column,
schema,
options,
})
}),
LogicalPlan::Prepare(Prepare {
name,
data_types,
input,
}) => rewrite_arc(input, f)?.update_data(|input| {
LogicalPlan::Prepare(Prepare {
name,
data_types,
input,
})
}),
LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term,
recursive_term,
is_distinct,
}) => map_until_stop_and_collect!(
rewrite_arc(static_term, &mut f),
recursive_term,
rewrite_arc(recursive_term, &mut f)
)?
.update_data(|(static_term, recursive_term)| {
LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term,
recursive_term,
is_distinct,
})
}),
// plans without inputs
LogicalPlan::TableScan { .. }
| LogicalPlan::Statement { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Values { .. }
| LogicalPlan::DescribeTable(_) => Transformed::no(self),
})
}
}

/// Converts a `Arc<LogicalPlan>` without copying, if possible. Copies the plan
Copy link
Contributor Author

@alamb alamb Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the code that avoids copying for Arc<LogicalPlan> when possible (my performance results show it is possible most of the time)

Also, you can see it with a local change like this:

diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs
index 97e2f7f56..a8570c0af 100644
--- a/datafusion/expr/src/logical_plan/tree_node.rs
+++ b/datafusion/expr/src/logical_plan/tree_node.rs
@@ -338,10 +338,16 @@ impl TreeNode for LogicalPlan {
 /// Converts a `Arc<LogicalPlan>` without copying, if possible. Copies the plan
 /// if there is a shared reference
 fn unwrap_arc(plan: Arc<LogicalPlan>) -> LogicalPlan {
-    Arc::try_unwrap(plan)
-        // if None is returned, there is another reference to this
-        // LogicalPlan, so we can not own it, and must clone instead
-        .unwrap_or_else(|node| node.as_ref().clone())
+    match Arc::try_unwrap(plan) {
+        Ok(plan) => {
+            println!("unwrapped!");
+            plan
+        }
+        Err(plan) => {
+            println!("BOO copying");
+            plan.as_ref().clone()
+        }
+    }
 }

 /// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible

And then running

cargo test --test sqllogictests
...
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!
unwrapped!

There still is plenty of copying going on (957 copies), but there are 22,913 less copies!

andrewlamb@Andrews-MacBook-Pro:~/Software/arrow-datafusion$ cargo test --test sqllogictests  | grep BOO | wc -l
    Finished test [unoptimized + debuginfo] target(s) in 0.13s
     Running bin/sqllogictests.rs (target/debug/deps/sqllogictests-518eef2279430877)
     957
andrewlamb@Andrews-MacBook-Pro:~/Software/arrow-datafusion$ cargo test --test sqllogictests  | grep unwrapped | wc -l
    Finished test [unoptimized + debuginfo] target(s) in 0.13s
     Running bin/sqllogictests.rs (target/debug/deps/sqllogictests-518eef2279430877)
   22913

/// if there is a shared reference
fn unwrap_arc(plan: Arc<LogicalPlan>) -> LogicalPlan {
Arc::try_unwrap(plan)
// if None is returned, there is another reference to this
// LogicalPlan, so we can not own it, and must clone instead
.unwrap_or_else(|node| node.as_ref().clone())
}

/// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible
fn rewrite_arc<F>(
plan: Arc<LogicalPlan>,
mut f: F,
) -> Result<Transformed<Arc<LogicalPlan>>>
where
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
{
f(unwrap_arc(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan)))
}

/// rewrite a `Vec` of `Arc<LogicalPlan>` without copying, if possible
fn rewrite_arcs<F>(
input_plans: Vec<Arc<LogicalPlan>>,
mut f: F,
) -> Result<Transformed<Vec<Arc<LogicalPlan>>>>
where
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
{
input_plans
.into_iter()
.map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f))
}

/// Rewrites all inputs for an Extension node "in place"
/// (it currently has to copy values because there are no APIs for in place modification)
///
/// Should be removed when we have an API for in place modifications of the
/// extension to avoid these copies
fn rewrite_extension_inputs<F>(
extension: Extension,
f: F,
) -> Result<Transformed<Extension>>
where
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
{
let Extension { node } = extension;

node.inputs()
.into_iter()
.cloned()
.map_until_stop_and_collect(f)?
.map_data(|new_inputs| {
let exprs = node.expressions();
Ok(Extension {
node: node.from_template(&exprs, &new_inputs),
})
})
}
Loading