Skip to content

Commit 8fa80e7

Browse files
committed
- refactor EnforceDistribution using transform_down_with_payload()
1 parent 5c61470 commit 8fa80e7

File tree

1 file changed

+88
-163
lines changed

1 file changed

+88
-163
lines changed

datafusion/core/src/physical_optimizer/enforce_distribution.rs

Lines changed: 88 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,15 @@ impl EnforceDistribution {
191191
impl PhysicalOptimizerRule for EnforceDistribution {
192192
fn optimize(
193193
&self,
194-
plan: Arc<dyn ExecutionPlan>,
194+
mut plan: Arc<dyn ExecutionPlan>,
195195
config: &ConfigOptions,
196196
) -> Result<Arc<dyn ExecutionPlan>> {
197197
let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering;
198198

199199
let adjusted = if top_down_join_key_reordering {
200200
// Run a top-down process to adjust input key ordering recursively
201-
let plan_requirements = PlanWithKeyRequirements::new(plan);
202-
let adjusted =
203-
plan_requirements.transform_down_old(&adjust_input_keys_ordering)?;
204-
adjusted.plan
201+
plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?;
202+
plan
205203
} else {
206204
// Run a bottom-up process
207205
plan.transform_up_old(&|plan| {
@@ -269,12 +267,15 @@ impl PhysicalOptimizerRule for EnforceDistribution {
269267
/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements
270268
/// 5) For other types of operators, by default, pushdown the parent requirements to children.
271269
///
270+
type RequiredKeyOrdering = Option<Vec<Arc<dyn PhysicalExpr>>>;
271+
272272
fn adjust_input_keys_ordering(
273-
requirements: PlanWithKeyRequirements,
274-
) -> Result<Transformed<PlanWithKeyRequirements>> {
275-
let parent_required = requirements.required_key_ordering.clone();
276-
let plan_any = requirements.plan.as_any();
277-
let transformed = if let Some(HashJoinExec {
273+
plan: &mut Arc<dyn ExecutionPlan>,
274+
required_key_ordering: RequiredKeyOrdering,
275+
) -> Result<(TreeNodeRecursion, Vec<RequiredKeyOrdering>)> {
276+
let parent_required = required_key_ordering.unwrap_or_default().clone();
277+
let plan_any = plan.as_any();
278+
if let Some(HashJoinExec {
278279
left,
279280
right,
280281
on,
@@ -299,13 +300,15 @@ fn adjust_input_keys_ordering(
299300
*null_equals_null,
300301
)?) as Arc<dyn ExecutionPlan>)
301302
};
302-
Some(reorder_partitioned_join_keys(
303-
requirements.plan.clone(),
303+
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
304+
plan.clone(),
304305
&parent_required,
305306
on,
306307
vec![],
307308
&join_constructor,
308-
)?)
309+
)?;
310+
*plan = new_plan;
311+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
309312
}
310313
PartitionMode::CollectLeft => {
311314
let new_right_request = match join_type {
@@ -323,30 +326,28 @@ fn adjust_input_keys_ordering(
323326
};
324327

325328
// Push down requirements to the right side
326-
Some(PlanWithKeyRequirements {
327-
plan: requirements.plan.clone(),
328-
required_key_ordering: vec![],
329-
request_key_ordering: vec![None, new_right_request],
330-
})
329+
Ok((TreeNodeRecursion::Continue, vec![None, new_right_request]))
331330
}
332331
PartitionMode::Auto => {
333332
// Can not satisfy, clear the current requirements and generate new empty requirements
334-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
333+
Ok((
334+
TreeNodeRecursion::Continue,
335+
vec![None; plan.children().len()],
336+
))
335337
}
336338
}
337339
} else if let Some(CrossJoinExec { left, .. }) =
338340
plan_any.downcast_ref::<CrossJoinExec>()
339341
{
340342
let left_columns_len = left.schema().fields().len();
341343
// Push down requirements to the right side
342-
Some(PlanWithKeyRequirements {
343-
plan: requirements.plan.clone(),
344-
required_key_ordering: vec![],
345-
request_key_ordering: vec![
344+
Ok((
345+
TreeNodeRecursion::Continue,
346+
vec![
346347
None,
347348
shift_right_required(&parent_required, left_columns_len),
348349
],
349-
})
350+
))
350351
} else if let Some(SortMergeJoinExec {
351352
left,
352353
right,
@@ -368,26 +369,38 @@ fn adjust_input_keys_ordering(
368369
*null_equals_null,
369370
)?) as Arc<dyn ExecutionPlan>)
370371
};
371-
Some(reorder_partitioned_join_keys(
372-
requirements.plan.clone(),
372+
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
373+
plan.clone(),
373374
&parent_required,
374375
on,
375376
sort_options.clone(),
376377
&join_constructor,
377-
)?)
378+
)?;
379+
*plan = new_plan;
380+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
378381
} else if let Some(aggregate_exec) = plan_any.downcast_ref::<AggregateExec>() {
379382
if !parent_required.is_empty() {
380383
match aggregate_exec.mode() {
381-
AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys(
382-
requirements.plan.clone(),
383-
&parent_required,
384-
aggregate_exec,
385-
)?),
386-
_ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())),
384+
AggregateMode::FinalPartitioned => {
385+
let (new_plan, request_key_ordering) = reorder_aggregate_keys(
386+
plan.clone(),
387+
&parent_required,
388+
aggregate_exec,
389+
)?;
390+
*plan = new_plan;
391+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
392+
}
393+
_ => Ok((
394+
TreeNodeRecursion::Continue,
395+
vec![None; plan.children().len()],
396+
)),
387397
}
388398
} else {
389399
// Keep everything unchanged
390-
None
400+
Ok((
401+
TreeNodeRecursion::Continue,
402+
vec![None; plan.children().len()],
403+
))
391404
}
392405
} else if let Some(proj) = plan_any.downcast_ref::<ProjectionExec>() {
393406
let expr = proj.expr();
@@ -396,34 +409,33 @@ fn adjust_input_keys_ordering(
396409
// Construct a mapping from new name to the the orginal Column
397410
let new_required = map_columns_before_projection(&parent_required, expr);
398411
if new_required.len() == parent_required.len() {
399-
Some(PlanWithKeyRequirements {
400-
plan: requirements.plan.clone(),
401-
required_key_ordering: vec![],
402-
request_key_ordering: vec![Some(new_required.clone())],
403-
})
412+
Ok((
413+
TreeNodeRecursion::Continue,
414+
vec![Some(new_required.clone())],
415+
))
404416
} else {
405417
// Can not satisfy, clear the current requirements and generate new empty requirements
406-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
418+
Ok((
419+
TreeNodeRecursion::Continue,
420+
vec![None; plan.children().len()],
421+
))
407422
}
408423
} else if plan_any.downcast_ref::<RepartitionExec>().is_some()
409424
|| plan_any.downcast_ref::<CoalescePartitionsExec>().is_some()
410425
|| plan_any.downcast_ref::<WindowAggExec>().is_some()
411426
{
412-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
427+
Ok((
428+
TreeNodeRecursion::Continue,
429+
vec![None; plan.children().len()],
430+
))
413431
} else {
414432
// By default, push down the parent requirements to children
415-
let children_len = requirements.plan.children().len();
416-
Some(PlanWithKeyRequirements {
417-
plan: requirements.plan.clone(),
418-
required_key_ordering: vec![],
419-
request_key_ordering: vec![Some(parent_required.clone()); children_len],
420-
})
421-
};
422-
Ok(if let Some(transformed) = transformed {
423-
Transformed::Yes(transformed)
424-
} else {
425-
Transformed::No(requirements)
426-
})
433+
let children_len = plan.children().len();
434+
Ok((
435+
TreeNodeRecursion::Continue,
436+
vec![Some(parent_required.clone()); children_len],
437+
))
438+
}
427439
}
428440

429441
fn reorder_partitioned_join_keys<F>(
@@ -432,7 +444,7 @@ fn reorder_partitioned_join_keys<F>(
432444
on: &[(Column, Column)],
433445
sort_options: Vec<SortOptions>,
434446
join_constructor: &F,
435-
) -> Result<PlanWithKeyRequirements>
447+
) -> Result<(Arc<dyn ExecutionPlan>, Vec<RequiredKeyOrdering>)>
436448
where
437449
F: Fn((Vec<(Column, Column)>, Vec<SortOptions>)) -> Result<Arc<dyn ExecutionPlan>>,
438450
{
@@ -455,35 +467,29 @@ where
455467
new_sort_options.push(sort_options[new_positions[idx]])
456468
}
457469

458-
Ok(PlanWithKeyRequirements {
459-
plan: join_constructor((new_join_on, new_sort_options))?,
460-
required_key_ordering: vec![],
461-
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
462-
})
470+
Ok((
471+
join_constructor((new_join_on, new_sort_options))?,
472+
vec![Some(left_keys), Some(right_keys)],
473+
))
463474
} else {
464-
Ok(PlanWithKeyRequirements {
465-
plan: join_plan,
466-
required_key_ordering: vec![],
467-
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
468-
})
475+
Ok((join_plan, vec![Some(left_keys), Some(right_keys)]))
469476
}
470477
} else {
471-
Ok(PlanWithKeyRequirements {
472-
plan: join_plan,
473-
required_key_ordering: vec![],
474-
request_key_ordering: vec![
478+
Ok((
479+
join_plan,
480+
vec![
475481
Some(join_key_pairs.left_keys),
476482
Some(join_key_pairs.right_keys),
477483
],
478-
})
484+
))
479485
}
480486
}
481487

482488
fn reorder_aggregate_keys(
483489
agg_plan: Arc<dyn ExecutionPlan>,
484490
parent_required: &[Arc<dyn PhysicalExpr>],
485491
agg_exec: &AggregateExec,
486-
) -> Result<PlanWithKeyRequirements> {
492+
) -> Result<(Arc<dyn ExecutionPlan>, Vec<RequiredKeyOrdering>)> {
487493
let output_columns = agg_exec
488494
.group_by()
489495
.expr()
@@ -501,11 +507,15 @@ fn reorder_aggregate_keys(
501507
|| !agg_exec.group_by().null_expr().is_empty()
502508
|| physical_exprs_equal(&output_exprs, parent_required)
503509
{
504-
Ok(PlanWithKeyRequirements::new(agg_plan))
510+
let request_key_ordering = vec![None; agg_plan.children().len()];
511+
Ok((agg_plan, request_key_ordering))
505512
} else {
506513
let new_positions = expected_expr_positions(&output_exprs, parent_required);
507514
match new_positions {
508-
None => Ok(PlanWithKeyRequirements::new(agg_plan)),
515+
None => {
516+
let request_key_ordering = vec![None; agg_plan.children().len()];
517+
Ok((agg_plan, request_key_ordering))
518+
}
509519
Some(positions) => {
510520
let new_partial_agg = if let Some(agg_exec) =
511521
agg_exec.input().as_any().downcast_ref::<AggregateExec>()
@@ -577,11 +587,13 @@ fn reorder_aggregate_keys(
577587
.push((Arc::new(Column::new(name, idx)) as _, name.clone()))
578588
}
579589
// TODO merge adjacent Projections if there are
580-
Ok(PlanWithKeyRequirements::new(Arc::new(
581-
ProjectionExec::try_new(proj_exprs, new_final_agg)?,
582-
)))
590+
let new_plan =
591+
Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?);
592+
let request_key_ordering = vec![None; new_plan.children().len()];
593+
Ok((new_plan, request_key_ordering))
583594
} else {
584-
Ok(PlanWithKeyRequirements::new(agg_plan))
595+
let request_key_ordering = vec![None; agg_plan.children().len()];
596+
Ok((agg_plan, request_key_ordering))
585597
}
586598
}
587599
}
@@ -1539,93 +1551,6 @@ struct JoinKeyPairs {
15391551
right_keys: Vec<Arc<dyn PhysicalExpr>>,
15401552
}
15411553

1542-
#[derive(Debug, Clone)]
1543-
struct PlanWithKeyRequirements {
1544-
plan: Arc<dyn ExecutionPlan>,
1545-
/// Parent required key ordering
1546-
required_key_ordering: Vec<Arc<dyn PhysicalExpr>>,
1547-
/// The request key ordering to children
1548-
request_key_ordering: Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>,
1549-
}
1550-
1551-
impl PlanWithKeyRequirements {
1552-
fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
1553-
let children_len = plan.children().len();
1554-
PlanWithKeyRequirements {
1555-
plan,
1556-
required_key_ordering: vec![],
1557-
request_key_ordering: vec![None; children_len],
1558-
}
1559-
}
1560-
1561-
fn children(&self) -> Vec<PlanWithKeyRequirements> {
1562-
let plan_children = self.plan.children();
1563-
assert_eq!(plan_children.len(), self.request_key_ordering.len());
1564-
plan_children
1565-
.into_iter()
1566-
.zip(self.request_key_ordering.clone())
1567-
.map(|(child, required)| {
1568-
let from_parent = required.unwrap_or_default();
1569-
let length = child.children().len();
1570-
PlanWithKeyRequirements {
1571-
plan: child,
1572-
required_key_ordering: from_parent,
1573-
request_key_ordering: vec![None; length],
1574-
}
1575-
})
1576-
.collect()
1577-
}
1578-
}
1579-
1580-
impl TreeNode for PlanWithKeyRequirements {
1581-
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
1582-
where
1583-
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
1584-
{
1585-
self.children().iter().for_each_till_continue(f)
1586-
}
1587-
1588-
fn map_children<F>(self, transform: F) -> Result<Self>
1589-
where
1590-
F: FnMut(Self) -> Result<Self>,
1591-
{
1592-
let children = self.children();
1593-
if !children.is_empty() {
1594-
let new_children: Result<Vec<_>> =
1595-
children.into_iter().map(transform).collect();
1596-
1597-
let children_plans = new_children?
1598-
.into_iter()
1599-
.map(|child| child.plan)
1600-
.collect::<Vec<_>>();
1601-
let new_plan = with_new_children_if_necessary(self.plan, children_plans)?;
1602-
Ok(PlanWithKeyRequirements {
1603-
plan: new_plan.into(),
1604-
required_key_ordering: self.required_key_ordering,
1605-
request_key_ordering: self.request_key_ordering,
1606-
})
1607-
} else {
1608-
Ok(self)
1609-
}
1610-
}
1611-
1612-
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
1613-
where
1614-
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
1615-
{
1616-
let mut children = self.children();
1617-
if !children.is_empty() {
1618-
let tnr = children.iter_mut().for_each_till_continue(f)?;
1619-
let children_plans = children.into_iter().map(|c| c.plan).collect();
1620-
self.plan =
1621-
with_new_children_if_necessary(self.plan.clone(), children_plans)?.into();
1622-
Ok(tnr)
1623-
} else {
1624-
Ok(TreeNodeRecursion::Continue)
1625-
}
1626-
}
1627-
}
1628-
16291554
/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
16301555
#[cfg(feature = "parquet")]
16311556
#[cfg(test)]

0 commit comments

Comments
 (0)