Skip to content

Commit f0e96c6

Browse files
feat: run expression simplifier in a loop until a fixedpoint or 3 cycles (#10358)
* feat: run expression simplifier in a loop * change max_simplifier_iterations to u32 * use simplify_inner to explicitly test iteration count * refactor simplify_inner loop * const evaluator should return transformed=false on literals * update tests * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb <[email protected]> * run shorten_in_list_simplifier once at the end of the loop * move UDF test case to core integration tests * documentation and naming updates * documentation and naming updates * remove unused import and minor doc formatting change * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 9fd697c commit f0e96c6

File tree

2 files changed

+182
-24
lines changed

2 files changed

+182
-24
lines changed

datafusion/core/tests/simplification.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
508508
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
509509
);
510510
}
511+
fn test_simplify_with_cycle_count(
512+
input_expr: Expr,
513+
expected_expr: Expr,
514+
expected_count: u32,
515+
) {
516+
let info: MyInfo = MyInfo {
517+
schema: expr_test_schema(),
518+
execution_props: ExecutionProps::new(),
519+
};
520+
let simplifier = ExprSimplifier::new(info);
521+
let (simplified_expr, count) = simplifier
522+
.simplify_with_cycle_count(input_expr.clone())
523+
.expect("successfully evaluated");
524+
525+
assert_eq!(
526+
simplified_expr, expected_expr,
527+
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
528+
);
529+
assert_eq!(
530+
count, expected_count,
531+
"Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}"
532+
);
533+
}
511534

512535
#[test]
513536
fn test_simplify_log() {
@@ -658,3 +681,11 @@ fn test_simplify_concat() {
658681
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
659682
test_simplify(expr, expected)
660683
}
684+
#[test]
685+
fn test_simplify_cycles() {
686+
// cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
687+
let expr = cast(now(), DataType::Int64)
688+
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
689+
let expected = lit(true);
690+
test_simplify_with_cycle_count(expr, expected, 3);
691+
}

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 151 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ pub struct ExprSimplifier<S> {
9292
/// Should expressions be canonicalized before simplification? Defaults to
9393
/// true
9494
canonicalize: bool,
95+
/// Maximum number of simplifier cycles
96+
max_simplifier_cycles: u32,
9597
}
9698

9799
pub const THRESHOLD_INLINE_INLIST: usize = 3;
100+
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
98101

99102
impl<S: SimplifyInfo> ExprSimplifier<S> {
100103
/// Create a new `ExprSimplifier` with the given `info` such as an
@@ -107,10 +110,11 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
107110
info,
108111
guarantees: vec![],
109112
canonicalize: true,
113+
max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
110114
}
111115
}
112116

113-
/// Simplifies this [`Expr`]`s as much as possible, evaluating
117+
/// Simplifies this [`Expr`] as much as possible, evaluating
114118
/// constants and applying algebraic simplifications.
115119
///
116120
/// The types of the expression must match what operators expect,
@@ -171,7 +175,18 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
171175
/// let expr = simplifier.simplify(expr).unwrap();
172176
/// assert_eq!(expr, b_lt_2);
173177
/// ```
174-
pub fn simplify(&self, mut expr: Expr) -> Result<Expr> {
178+
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
179+
Ok(self.simplify_with_cycle_count(expr)?.0)
180+
}
181+
182+
/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
183+
/// constants and applying algebraic simplifications. Additionally returns a `u32`
184+
/// representing the number of simplification cycles performed, which can be useful for testing
185+
/// optimizations.
186+
///
187+
/// See [Self::simplify] for details and usage examples.
188+
///
189+
pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
175190
let mut simplifier = Simplifier::new(&self.info);
176191
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
177192
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
@@ -181,24 +196,26 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
181196
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
182197
}
183198

184-
// TODO iterate until no changes are made during rewrite
185-
// (evaluating constants can enable new simplifications and
186-
// simplifications can enable new constant evaluation)
187-
// https://github.com/apache/datafusion/issues/1160
188-
expr.rewrite(&mut const_evaluator)
189-
.data()?
190-
.rewrite(&mut simplifier)
191-
.data()?
192-
.rewrite(&mut guarantee_rewriter)
193-
.data()?
194-
// run both passes twice to try an minimize simplifications that we missed
195-
.rewrite(&mut const_evaluator)
196-
.data()?
197-
.rewrite(&mut simplifier)
198-
.data()?
199-
// shorten inlist should be started after other inlist rules are applied
200-
.rewrite(&mut shorten_in_list_simplifier)
201-
.data()
199+
// Evaluating constants can enable new simplifications and
200+
// simplifications can enable new constant evaluation
201+
// see `Self::with_max_cycles`
202+
let mut num_cycles = 0;
203+
loop {
204+
let Transformed {
205+
data, transformed, ..
206+
} = expr
207+
.rewrite(&mut const_evaluator)?
208+
.transform_data(|expr| expr.rewrite(&mut simplifier))?
209+
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
210+
expr = data;
211+
num_cycles += 1;
212+
if !transformed || num_cycles >= self.max_simplifier_cycles {
213+
break;
214+
}
215+
}
216+
// shorten inlist should be started after other inlist rules are applied
217+
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
218+
Ok((expr, num_cycles))
202219
}
203220

204221
/// Apply type coercion to an [`Expr`] so that it can be
@@ -323,6 +340,63 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
323340
self.canonicalize = canonicalize;
324341
self
325342
}
343+
344+
/// Specifies the maximum number of simplification cycles to run.
345+
///
346+
/// The simplifier can perform multiple passes of simplification. This is
347+
/// because the output of one simplification step can allow more optimizations
348+
/// in another simplification step. For example, constant evaluation can allow more
349+
/// expression simplifications, and expression simplifications can allow more constant
350+
/// evaluations.
351+
///
352+
/// This method specifies the maximum number of allowed iteration cycles before the simplifier
353+
/// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
354+
/// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
355+
/// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
356+
///
357+
/// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
358+
/// instead.
359+
///
360+
/// ```rust
361+
/// use arrow::datatypes::{DataType, Field, Schema};
362+
/// use datafusion_expr::{col, lit, Expr};
363+
/// use datafusion_common::{Result, ScalarValue, ToDFSchema};
364+
/// use datafusion_expr::execution_props::ExecutionProps;
365+
/// use datafusion_expr::simplify::SimplifyContext;
366+
/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
367+
///
368+
/// let schema = Schema::new(vec![
369+
/// Field::new("a", DataType::Int64, false),
370+
/// ])
371+
/// .to_dfschema_ref().unwrap();
372+
///
373+
/// // Create the simplifier
374+
/// let props = ExecutionProps::new();
375+
/// let context = SimplifyContext::new(&props)
376+
/// .with_schema(schema);
377+
/// let simplifier = ExprSimplifier::new(context);
378+
///
379+
/// // Expression: a IS NOT NULL
380+
/// let expr = col("a").is_not_null();
381+
///
382+
/// // When using default maximum cycles, 2 cycles will be performed.
383+
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
384+
/// assert_eq!(simplified_expr, lit(true));
385+
/// // 2 cycles were executed, but only 1 was needed
386+
/// assert_eq!(count, 2);
387+
///
388+
/// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
389+
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
390+
/// // Expression has been rewritten to: (c = a AND b = 1)
391+
/// assert_eq!(simplified_expr, lit(true));
392+
/// // Only 1 cycle was executed
393+
/// assert_eq!(count, 1);
394+
///
395+
/// ```
396+
pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
397+
self.max_simplifier_cycles = max_simplifier_cycles;
398+
self
399+
}
326400
}
327401

328402
/// Canonicalize any BinaryExprs that are not in canonical form
@@ -404,6 +478,8 @@ struct ConstEvaluator<'a> {
404478
enum ConstSimplifyResult {
405479
// Expr was simplifed and contains the new expression
406480
Simplified(ScalarValue),
481+
// Expr was not simplified and original value is returned
482+
NotSimplified(ScalarValue),
407483
// Evaluation encountered an error, contains the original expression
408484
SimplifyRuntimeError(DataFusionError, Expr),
409485
}
@@ -450,6 +526,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
450526
ConstSimplifyResult::Simplified(s) => {
451527
Ok(Transformed::yes(Expr::Literal(s)))
452528
}
529+
ConstSimplifyResult::NotSimplified(s) => {
530+
Ok(Transformed::no(Expr::Literal(s)))
531+
}
453532
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
454533
Ok(Transformed::yes(expr))
455534
}
@@ -548,7 +627,7 @@ impl<'a> ConstEvaluator<'a> {
548627
/// Internal helper to evaluates an Expr
549628
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
550629
if let Expr::Literal(s) = expr {
551-
return ConstSimplifyResult::Simplified(s);
630+
return ConstSimplifyResult::NotSimplified(s);
552631
}
553632

554633
let phys_expr =
@@ -1672,15 +1751,14 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
16721751

16731752
#[cfg(test)]
16741753
mod tests {
1754+
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1755+
use datafusion_expr::{interval_arithmetic::Interval, *};
16751756
use std::{
16761757
collections::HashMap,
16771758
ops::{BitAnd, BitOr, BitXor},
16781759
sync::Arc,
16791760
};
16801761

1681-
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1682-
use datafusion_expr::{interval_arithmetic::Interval, *};
1683-
16841762
use crate::simplify_expressions::SimplifyContext;
16851763
use crate::test::test_table_scan_with_name;
16861764

@@ -2868,6 +2946,19 @@ mod tests {
28682946
try_simplify(expr).unwrap()
28692947
}
28702948

2949+
fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
2950+
let schema = expr_test_schema();
2951+
let execution_props = ExecutionProps::new();
2952+
let simplifier = ExprSimplifier::new(
2953+
SimplifyContext::new(&execution_props).with_schema(schema),
2954+
);
2955+
simplifier.simplify_with_cycle_count(expr)
2956+
}
2957+
2958+
fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
2959+
try_simplify_with_cycle_count(expr).unwrap()
2960+
}
2961+
28712962
fn simplify_with_guarantee(
28722963
expr: Expr,
28732964
guarantees: Vec<(Expr, NullableInterval)>,
@@ -3575,4 +3666,40 @@ mod tests {
35753666

35763667
assert_eq!(simplify(expr), expected);
35773668
}
3669+
3670+
#[test]
3671+
fn test_simplify_cycles() {
3672+
// TRUE
3673+
let expr = lit(true);
3674+
let expected = lit(true);
3675+
let (expr, num_iter) = simplify_with_cycle_count(expr);
3676+
assert_eq!(expr, expected);
3677+
assert_eq!(num_iter, 1);
3678+
3679+
// (true != NULL) OR (5 > 10)
3680+
let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
3681+
let expected = lit_bool_null();
3682+
let (expr, num_iter) = simplify_with_cycle_count(expr);
3683+
assert_eq!(expr, expected);
3684+
assert_eq!(num_iter, 2);
3685+
3686+
// NOTE: this currently does not simplify
3687+
// (((c4 - 10) + 10) *100) / 100
3688+
let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
3689+
let expected = expr.clone();
3690+
let (expr, num_iter) = simplify_with_cycle_count(expr);
3691+
assert_eq!(expr, expected);
3692+
assert_eq!(num_iter, 1);
3693+
3694+
// ((c4<1 or c3<2) and c3_non_null<3) and false
3695+
let expr = col("c4")
3696+
.lt(lit(1))
3697+
.or(col("c3").lt(lit(2)))
3698+
.and(col("c3_non_null").lt(lit(3)))
3699+
.and(lit(false));
3700+
let expected = lit(false);
3701+
let (expr, num_iter) = simplify_with_cycle_count(expr);
3702+
assert_eq!(expr, expected);
3703+
assert_eq!(num_iter, 2);
3704+
}
35783705
}

0 commit comments

Comments
 (0)