Skip to content

Commit 6a62811

Browse files
committed
fix top-down recursion, fix unit tests to use real a Optimizer to verify behavior on plans
1 parent 7ee369f commit 6a62811

File tree

1 file changed

+20
-36
lines changed

1 file changed

+20
-36
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ impl CommonSubexprEliminate {
368368

369369
None => Ok((new_window_expr_list, input, None)),
370370
})?
371-
// Recurse into the new input. this is similar to top-down optimizer rule's
372-
// logic.
371+
// Recurse into the new input.
372+
// (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
373373
.transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
374374
self.rewrite(new_input, config)?.map_data(|new_input| {
375375
Ok((new_window_expr_list, new_input, window_expr_list))
@@ -467,8 +467,8 @@ impl CommonSubexprEliminate {
467467
None => Ok((new_aggr_expr, new_group_expr, input, None)),
468468
}
469469
})?
470-
// Recurse into the new input. this is similar to top-down optimizer rule's
471-
// logic.
470+
// Recurse into the new input.
471+
// (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
472472
.transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
473473
self.rewrite(new_input, config)?.map_data(|new_input| {
474474
Ok((
@@ -636,8 +636,8 @@ impl CommonSubexprEliminate {
636636
None => Ok((new_exprs, input)),
637637
}
638638
})?
639-
// Recurse into the new input. This is similar to top-down optimizer rule's
640-
// logic.
639+
// Recurse into the new input.
640+
// (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
641641
.transform_data(|(new_exprs, new_input)| {
642642
self.rewrite(new_input, config)?
643643
.map_data(|new_input| Ok((new_exprs, new_input)))
@@ -702,7 +702,10 @@ impl OptimizerRule for CommonSubexprEliminate {
702702
}
703703

704704
fn apply_order(&self) -> Option<ApplyOrder> {
705-
Some(ApplyOrder::TopDown)
705+
// This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
706+
// This is because in some cases adjacent nodes are collected (e.g. `Window`) and
707+
// CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
708+
None
706709
}
707710

708711
fn rewrite(
@@ -740,8 +743,9 @@ impl OptimizerRule for CommonSubexprEliminate {
740743
| LogicalPlan::Unnest(_)
741744
| LogicalPlan::RecursiveQuery(_)
742745
| LogicalPlan::Prepare(_) => {
743-
// ApplyOrder::TopDown handles recursion
744-
Transformed::no(plan)
746+
// This rule handles recursion itself in a `ApplyOrder::TopDown` like
747+
// manner.
748+
plan.map_children(|c| self.rewrite(c, config))?
745749
}
746750
};
747751

@@ -1187,42 +1191,22 @@ mod test {
11871191
};
11881192
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
11891193

1194+
use super::*;
11901195
use crate::optimizer::OptimizerContext;
11911196
use crate::test::*;
1197+
use crate::Optimizer;
11921198
use datafusion_expr::test::function_stub::{avg, sum};
11931199

1194-
use super::*;
1195-
1196-
fn assert_non_optimized_plan_eq(
1197-
expected: &str,
1198-
plan: LogicalPlan,
1199-
config: Option<&dyn OptimizerConfig>,
1200-
) {
1201-
assert_eq!(expected, format!("{plan}"), "Unexpected starting plan");
1202-
let optimizer = CommonSubexprEliminate::new();
1203-
let default_config = OptimizerContext::new();
1204-
let config = config.unwrap_or(&default_config);
1205-
let optimized_plan = optimizer.rewrite(plan, config).unwrap();
1206-
assert!(!optimized_plan.transformed, "unexpectedly optimize plan");
1207-
let optimized_plan = optimized_plan.data;
1208-
assert_eq!(
1209-
expected,
1210-
format!("{optimized_plan}"),
1211-
"Unexpected optimized plan"
1212-
);
1213-
}
1214-
12151200
fn assert_optimized_plan_eq(
12161201
expected: &str,
12171202
plan: LogicalPlan,
12181203
config: Option<&dyn OptimizerConfig>,
12191204
) {
1220-
let optimizer = CommonSubexprEliminate::new();
1205+
let optimizer =
1206+
Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]);
12211207
let default_config = OptimizerContext::new();
12221208
let config = config.unwrap_or(&default_config);
1223-
let optimized_plan = optimizer.rewrite(plan, config).unwrap();
1224-
assert!(optimized_plan.transformed, "failed to optimize plan");
1225-
let optimized_plan = optimized_plan.data;
1209+
let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap();
12261210
let formatted_plan = format!("{optimized_plan}");
12271211
assert_eq!(expected, formatted_plan);
12281212
}
@@ -1612,7 +1596,7 @@ mod test {
16121596
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
16131597
\n TableScan: test";
16141598

1615-
assert_non_optimized_plan_eq(expected, plan, None);
1599+
assert_optimized_plan_eq(expected, plan, None);
16161600

16171601
Ok(())
16181602
}
@@ -1630,7 +1614,7 @@ mod test {
16301614
\n Projection: Int32(1) + test.a, test.a\
16311615
\n TableScan: test";
16321616

1633-
assert_non_optimized_plan_eq(expected, plan, None);
1617+
assert_optimized_plan_eq(expected, plan, None);
16341618
Ok(())
16351619
}
16361620

0 commit comments

Comments
 (0)