From 693639444bbd97b901c6a7f7325f88450737eb1b Mon Sep 17 00:00:00 2001 From: landa Date: Fri, 16 May 2025 17:38:02 +0300 Subject: [PATCH] Fix: common_sub_expression_eliminate optimizer rule failed Common_sub_expression_eliminate rule failed with error: `SchemaError(FieldNotFound {field: }, valid_fields: []})` due to the schema being changed by the second application of `find_common_exprs` As I understood the source of the problem was in sequential call of `find_common_exprs`. First call returned original names as `aggr_expr` and changed names as `new_aggr_expr`. Second call takes into account only `new_aggr_expr` and if names was already changed by first call will return changed names as `aggr_expr`(original ones) and put them into Projection logic. I used NamePreserver mechanism to restore original schema names and generate Projection with original name at the end of aggregate optimization. --- .../optimizer/src/common_subexpr_eliminate.rs | 26 +++++++++++++++++-- .../sqllogictest/test_files/aggregate.slt | 25 ++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d526b63ae5d2..63fb46fa6383 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -316,6 +316,19 @@ impl CommonSubexprEliminate { } => { let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); let new_aggr_expr = original_exprs_list.pop().unwrap(); + let saved_names = if let Some(aggr_expr) = aggr_expr { + let name_preserver = NamePreserver::new_for_projection(); + aggr_expr + .iter() + .map(|expr| Some(name_preserver.save(expr))) + .collect::>() + } else { + new_aggr_expr + .clone() + .into_iter() + .map(|_| None) + .collect::>() + }; let mut agg_exprs = common_exprs .into_iter() @@ -326,10 +339,19 @@ impl CommonSubexprEliminate { for expr in &new_group_expr { extract_expressions(expr, &mut proj_exprs) } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + for ((expr_rewritten, expr_orig), saved_name) in + rewritten_aggr_expr + .into_iter() + .zip(new_aggr_expr) + .zip(saved_names) { if expr_rewritten == expr_orig { + let expr_rewritten = if let Some(saved_name) = saved_name + { + saved_name.restore(expr_rewritten) + } else { + expr_rewritten + }; if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 19f92ed72e0b..3760f251dabd 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -690,6 +690,31 @@ SELECT c2, var_samp(c12) FILTER (WHERE c12 > 0.95) FROM aggregate_test_100 GROUP 4 NULL 5 NULL +statement ok +CREATE TABLE t ( + a DOUBLE, + b BIGINT, + c INT +) AS VALUES +(1.0, 10, -5), +(2.0, 20, -5), +(3.0, 20, 4); + +# https://github.com/apache/datafusion/issues/15291 +query III +WITH s AS ( + SELECT + COUNT(a) FILTER (WHERE (b * b) - 3600 <= b), + COUNT(a) FILTER (WHERE (b * b) - 3000 <= b AND (c >= 0)), + COUNT(a) FILTER (WHERE (b * b) - 3000 <= b AND (c >= 0) AND (c >= 0)) + FROM t +) SELECT * FROM s +---- +3 1 1 + +statement ok +DROP TABLE t + # Restore the default dialect statement ok set datafusion.sql_parser.dialect = 'Generic';