Skip to content

Commit 2f55003

Browse files
peter-tothalamb
andauthored
Simplify Expr::map_children (#9876)
* add map_until_stop_and_collect macro * fix clippy * simplify * Update datafusion/common/src/tree_node.rs Co-authored-by: Andrew Lamb <[email protected]> * add documentation * fix macro --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent daf182d commit 2f55003

File tree

2 files changed

+171
-137
lines changed

2 files changed

+171
-137
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,20 @@ impl<T> Transformed<T> {
532532
}
533533
}
534534

535-
/// Transformation helper to process tree nodes that are siblings.
535+
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
536536
pub trait TransformedIterator: Iterator {
537+
/// Apples `f` to each item in this iterator
538+
///
539+
/// Visits all items in the iterator unless
540+
/// `f` returns an error or `f` returns TreeNodeRecursion::stop.
541+
///
542+
/// # Returns
543+
/// Error if `f` returns an error
544+
///
545+
/// Ok(Transformed) such that:
546+
/// 1. `transformed` is true if any return from `f` had transformed true
547+
/// 2. `data` from the last invocation of `f`
548+
/// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
537549
fn map_until_stop_and_collect<
538550
F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
539551
>(
@@ -551,22 +563,64 @@ impl<I: Iterator> TransformedIterator for I {
551563
) -> Result<Transformed<Vec<Self::Item>>> {
552564
let mut tnr = TreeNodeRecursion::Continue;
553565
let mut transformed = false;
554-
let data = self
555-
.map(|item| match tnr {
556-
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
557-
f(item).map(|result| {
558-
tnr = result.tnr;
559-
transformed |= result.transformed;
560-
result.data
561-
})
562-
}
563-
TreeNodeRecursion::Stop => Ok(item),
564-
})
565-
.collect::<Result<Vec<_>>>()?;
566-
Ok(Transformed::new(data, transformed, tnr))
566+
self.map(|item| match tnr {
567+
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
568+
f(item).map(|result| {
569+
tnr = result.tnr;
570+
transformed |= result.transformed;
571+
result.data
572+
})
573+
}
574+
TreeNodeRecursion::Stop => Ok(item),
575+
})
576+
.collect::<Result<Vec<_>>>()
577+
.map(|data| Transformed::new(data, transformed, tnr))
567578
}
568579
}
569580

581+
/// Transformation helper to process a heterogeneous sequence of tree node containing
582+
/// expressions.
583+
/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to
584+
/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and
585+
/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
586+
/// transformation (`F`).
587+
///
588+
/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the
589+
/// first element and further elements from the sequence of pairs. An element from a pair
590+
/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on
591+
/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
592+
///
593+
/// # Returns
594+
/// Error if any of the transformations returns an error
595+
///
596+
/// Ok(Transformed<(data0, ..., dataN)>) such that:
597+
/// 1. `transformed` is true if any of the transformations had transformed true
598+
/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and
599+
/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F`
600+
/// 3. `tnr` from `F0` or the last invocation of `F`
601+
#[macro_export]
602+
macro_rules! map_until_stop_and_collect {
603+
($F0:expr, $($EXPR:expr, $F:expr),*) => {{
604+
$F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
605+
let all_datas = (
606+
data0,
607+
$(
608+
if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump {
609+
$F.map(|result| {
610+
tnr = result.tnr;
611+
transformed |= result.transformed;
612+
result.data
613+
})?
614+
} else {
615+
$EXPR
616+
},
617+
)*
618+
);
619+
Ok(Transformed::new(all_datas, transformed, tnr))
620+
})
621+
}}
622+
}
623+
570624
/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
571625
pub trait TransformedResult<T> {
572626
fn data(self) -> Result<T>;

datafusion/expr/src/tree_node/expr.rs

Lines changed: 103 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess};
2727
use datafusion_common::tree_node::{
2828
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
2929
};
30-
use datafusion_common::{handle_visit_recursion, internal_err, Result};
30+
use datafusion_common::{
31+
handle_visit_recursion, internal_err, map_until_stop_and_collect, Result,
32+
};
3133

3234
impl TreeNode for Expr {
3335
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
@@ -167,58 +169,55 @@ impl TreeNode for Expr {
167169
Expr::InSubquery(InSubquery::new(be, subquery, negated))
168170
}),
169171
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
170-
transform_box(left, &mut f)?
171-
.update_data(|new_left| (new_left, right))
172-
.try_transform_node(|(new_left, right)| {
173-
Ok(transform_box(right, &mut f)?
174-
.update_data(|new_right| (new_left, new_right)))
175-
})?
176-
.update_data(|(new_left, new_right)| {
177-
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
178-
})
172+
map_until_stop_and_collect!(
173+
transform_box(left, &mut f),
174+
right,
175+
transform_box(right, &mut f)
176+
)?
177+
.update_data(|(new_left, new_right)| {
178+
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
179+
})
179180
}
180181
Expr::Like(Like {
181182
negated,
182183
expr,
183184
pattern,
184185
escape_char,
185186
case_insensitive,
186-
}) => transform_box(expr, &mut f)?
187-
.update_data(|new_expr| (new_expr, pattern))
188-
.try_transform_node(|(new_expr, pattern)| {
189-
Ok(transform_box(pattern, &mut f)?
190-
.update_data(|new_pattern| (new_expr, new_pattern)))
191-
})?
192-
.update_data(|(new_expr, new_pattern)| {
193-
Expr::Like(Like::new(
194-
negated,
195-
new_expr,
196-
new_pattern,
197-
escape_char,
198-
case_insensitive,
199-
))
200-
}),
187+
}) => map_until_stop_and_collect!(
188+
transform_box(expr, &mut f),
189+
pattern,
190+
transform_box(pattern, &mut f)
191+
)?
192+
.update_data(|(new_expr, new_pattern)| {
193+
Expr::Like(Like::new(
194+
negated,
195+
new_expr,
196+
new_pattern,
197+
escape_char,
198+
case_insensitive,
199+
))
200+
}),
201201
Expr::SimilarTo(Like {
202202
negated,
203203
expr,
204204
pattern,
205205
escape_char,
206206
case_insensitive,
207-
}) => transform_box(expr, &mut f)?
208-
.update_data(|new_expr| (new_expr, pattern))
209-
.try_transform_node(|(new_expr, pattern)| {
210-
Ok(transform_box(pattern, &mut f)?
211-
.update_data(|new_pattern| (new_expr, new_pattern)))
212-
})?
213-
.update_data(|(new_expr, new_pattern)| {
214-
Expr::SimilarTo(Like::new(
215-
negated,
216-
new_expr,
217-
new_pattern,
218-
escape_char,
219-
case_insensitive,
220-
))
221-
}),
207+
}) => map_until_stop_and_collect!(
208+
transform_box(expr, &mut f),
209+
pattern,
210+
transform_box(pattern, &mut f)
211+
)?
212+
.update_data(|(new_expr, new_pattern)| {
213+
Expr::SimilarTo(Like::new(
214+
negated,
215+
new_expr,
216+
new_pattern,
217+
escape_char,
218+
case_insensitive,
219+
))
220+
}),
222221
Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not),
223222
Expr::IsNotNull(expr) => {
224223
transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
@@ -248,48 +247,38 @@ impl TreeNode for Expr {
248247
negated,
249248
low,
250249
high,
251-
}) => transform_box(expr, &mut f)?
252-
.update_data(|new_expr| (new_expr, low, high))
253-
.try_transform_node(|(new_expr, low, high)| {
254-
Ok(transform_box(low, &mut f)?
255-
.update_data(|new_low| (new_expr, new_low, high)))
256-
})?
257-
.try_transform_node(|(new_expr, new_low, high)| {
258-
Ok(transform_box(high, &mut f)?
259-
.update_data(|new_high| (new_expr, new_low, new_high)))
260-
})?
261-
.update_data(|(new_expr, new_low, new_high)| {
262-
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
263-
}),
250+
}) => map_until_stop_and_collect!(
251+
transform_box(expr, &mut f),
252+
low,
253+
transform_box(low, &mut f),
254+
high,
255+
transform_box(high, &mut f)
256+
)?
257+
.update_data(|(new_expr, new_low, new_high)| {
258+
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
259+
}),
264260
Expr::Case(Case {
265261
expr,
266262
when_then_expr,
267263
else_expr,
268-
}) => transform_option_box(expr, &mut f)?
269-
.update_data(|new_expr| (new_expr, when_then_expr, else_expr))
270-
.try_transform_node(|(new_expr, when_then_expr, else_expr)| {
271-
Ok(when_then_expr
272-
.into_iter()
273-
.map_until_stop_and_collect(|(when, then)| {
274-
transform_box(when, &mut f)?
275-
.update_data(|new_when| (new_when, then))
276-
.try_transform_node(|(new_when, then)| {
277-
Ok(transform_box(then, &mut f)?
278-
.update_data(|new_then| (new_when, new_then)))
279-
})
280-
})?
281-
.update_data(|new_when_then_expr| {
282-
(new_expr, new_when_then_expr, else_expr)
283-
}))
284-
})?
285-
.try_transform_node(|(new_expr, new_when_then_expr, else_expr)| {
286-
Ok(transform_option_box(else_expr, &mut f)?.update_data(
287-
|new_else_expr| (new_expr, new_when_then_expr, new_else_expr),
288-
))
289-
})?
290-
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
291-
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
292-
}),
264+
}) => map_until_stop_and_collect!(
265+
transform_option_box(expr, &mut f),
266+
when_then_expr,
267+
when_then_expr
268+
.into_iter()
269+
.map_until_stop_and_collect(|(when, then)| {
270+
map_until_stop_and_collect!(
271+
transform_box(when, &mut f),
272+
then,
273+
transform_box(then, &mut f)
274+
)
275+
}),
276+
else_expr,
277+
transform_option_box(else_expr, &mut f)
278+
)?
279+
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
280+
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
281+
}),
293282
Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)?
294283
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
295284
Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)?
@@ -320,48 +309,39 @@ impl TreeNode for Expr {
320309
order_by,
321310
window_frame,
322311
null_treatment,
323-
}) => transform_vec(args, &mut f)?
324-
.update_data(|new_args| (new_args, partition_by, order_by))
325-
.try_transform_node(|(new_args, partition_by, order_by)| {
326-
Ok(transform_vec(partition_by, &mut f)?.update_data(
327-
|new_partition_by| (new_args, new_partition_by, order_by),
328-
))
329-
})?
330-
.try_transform_node(|(new_args, new_partition_by, order_by)| {
331-
Ok(
332-
transform_vec(order_by, &mut f)?.update_data(|new_order_by| {
333-
(new_args, new_partition_by, new_order_by)
334-
}),
335-
)
336-
})?
337-
.update_data(|(new_args, new_partition_by, new_order_by)| {
338-
Expr::WindowFunction(WindowFunction::new(
339-
fun,
340-
new_args,
341-
new_partition_by,
342-
new_order_by,
343-
window_frame,
344-
null_treatment,
345-
))
346-
}),
312+
}) => map_until_stop_and_collect!(
313+
transform_vec(args, &mut f),
314+
partition_by,
315+
transform_vec(partition_by, &mut f),
316+
order_by,
317+
transform_vec(order_by, &mut f)
318+
)?
319+
.update_data(|(new_args, new_partition_by, new_order_by)| {
320+
Expr::WindowFunction(WindowFunction::new(
321+
fun,
322+
new_args,
323+
new_partition_by,
324+
new_order_by,
325+
window_frame,
326+
null_treatment,
327+
))
328+
}),
347329
Expr::AggregateFunction(AggregateFunction {
348330
args,
349331
func_def,
350332
distinct,
351333
filter,
352334
order_by,
353335
null_treatment,
354-
}) => transform_vec(args, &mut f)?
355-
.update_data(|new_args| (new_args, filter, order_by))
356-
.try_transform_node(|(new_args, filter, order_by)| {
357-
Ok(transform_option_box(filter, &mut f)?
358-
.update_data(|new_filter| (new_args, new_filter, order_by)))
359-
})?
360-
.try_transform_node(|(new_args, new_filter, order_by)| {
361-
Ok(transform_option_vec(order_by, &mut f)?
362-
.update_data(|new_order_by| (new_args, new_filter, new_order_by)))
363-
})?
364-
.map_data(|(new_args, new_filter, new_order_by)| match func_def {
336+
}) => map_until_stop_and_collect!(
337+
transform_vec(args, &mut f),
338+
filter,
339+
transform_option_box(filter, &mut f),
340+
order_by,
341+
transform_option_vec(order_by, &mut f)
342+
)?
343+
.map_data(
344+
|(new_args, new_filter, new_order_by)| match func_def {
365345
AggregateFunctionDefinition::BuiltIn(fun) => {
366346
Ok(Expr::AggregateFunction(AggregateFunction::new(
367347
fun,
@@ -385,7 +365,8 @@ impl TreeNode for Expr {
385365
AggregateFunctionDefinition::Name(_) => {
386366
internal_err!("Function `Expr` with name should be resolved.")
387367
}
388-
})?,
368+
},
369+
)?,
389370
Expr::GroupingSet(grouping_set) => match grouping_set {
390371
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
391372
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
@@ -402,15 +383,14 @@ impl TreeNode for Expr {
402383
expr,
403384
list,
404385
negated,
405-
}) => transform_box(expr, &mut f)?
406-
.update_data(|new_expr| (new_expr, list))
407-
.try_transform_node(|(new_expr, list)| {
408-
Ok(transform_vec(list, &mut f)?
409-
.update_data(|new_list| (new_expr, new_list)))
410-
})?
411-
.update_data(|(new_expr, new_list)| {
412-
Expr::InList(InList::new(new_expr, new_list, negated))
413-
}),
386+
}) => map_until_stop_and_collect!(
387+
transform_box(expr, &mut f),
388+
list,
389+
transform_vec(list, &mut f)
390+
)?
391+
.update_data(|(new_expr, new_list)| {
392+
Expr::InList(InList::new(new_expr, new_list, negated))
393+
}),
414394
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
415395
transform_box(expr, &mut f)?.update_data(|be| {
416396
Expr::GetIndexedField(GetIndexedField::new(be, field))

0 commit comments

Comments
 (0)