From 808d6ab3ceb0281d055965a330b8ffb1c47fa65b Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Sun, 9 Mar 2025 22:16:46 +0800 Subject: [PATCH 01/11] perf: unwrap cast for comparing ints =/!= strings --- .../simplify_expressions/expr_simplifier.rs | 4 +- .../src/simplify_expressions/unwrap_cast.rs | 66 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d5a1b84e6aff..fabba0eca5cc 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1758,7 +1758,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &left, &right, + info, &left, &right, op, ) && op.supports_propagation() => { unwrap_cast_in_comparison_for_binary(info, left, right, op)? @@ -1768,7 +1768,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op_swap literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &right, &left, + info, &right, &left, op, ) && op.supports_propagation() && op.swap().is_some() => { diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 7670bdf98bb4..29c19cfa850c 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); }; + + if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op) + { + return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lit(value)), + }))); + }; + // if the lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { @@ -106,6 +116,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< info: &S, expr: &Expr, literal: &Expr, + op: Operator, ) -> bool { match (expr, literal) { ( @@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< return false; }; + if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() { + return true; + } + try_cast_literal_to_type(lit_val, &expr_type).is_some() && is_supported_type(&expr_type) && is_supported_type(&lit_type) @@ -177,6 +192,33 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< true } +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + dbg!(lit_value, target_type, op); + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(ref str)) + | ScalarValue::Utf8View(Some(ref str)) + | ScalarValue::LargeUtf8(Some(ref str)), + ) => match target_type { + DataType::Int8 => str.parse::().ok().map(ScalarValue::from), + DataType::Int16 => str.parse::().ok().map(ScalarValue::from), + DataType::Int32 => str.parse::().ok().map(ScalarValue::from), + DataType::Int64 => str.parse::().ok().map(ScalarValue::from), + DataType::UInt8 => str.parse::().ok().map(ScalarValue::from), + DataType::UInt16 => str.parse::().ok().map(ScalarValue::from), + DataType::UInt32 => str.parse::().ok().map(ScalarValue::from), + DataType::UInt64 => str.parse::().ok().map(ScalarValue::from), + _ => None, + }, + _ => None, + } +} + /// Returns true if unwrap_cast_in_comparison supports this data type fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) @@ -468,6 +510,10 @@ mod tests { // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) < 123, only eq/not_eq should be optimized + let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); } #[test] @@ -496,6 +542,16 @@ mod tests { let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); + + // cast(c1, UTF8) = "123" => c1 = 123 + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); + let expected = col("c1").eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) != "123" => c1 != 123 + let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c1").not_eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] @@ -505,6 +561,16 @@ mod tests { let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) = "123" => c6 = 123 + let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); + let expected = col("c6").eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) != "123" => c6 != 123 + let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c6").not_eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] From dd71241c103943736a21fbd7aa75e46c83566461 Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Mon, 10 Mar 2025 21:59:43 +0800 Subject: [PATCH 02/11] fix: update casting logic --- .../src/simplify_expressions/unwrap_cast.rs | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 29c19cfa850c..5ce211f5e0f8 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -197,22 +197,34 @@ fn cast_literal_to_type_with_op( target_type: &DataType, op: Operator, ) -> Option { - dbg!(lit_value, target_type, op); + macro_rules! cast_or_else_return_none { + ($value:ident, $ty:expr) => {{ + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + let array = ScalarValue::to_array($value).ok()?; + let casted = arrow::compute::cast_with_options(&array, &$ty, &opts).ok()?; + let scalar = ScalarValue::try_from_array(&casted, 0).ok()?; + Some(scalar) + }}; + } + match (op, lit_value) { ( Operator::Eq | Operator::NotEq, - ScalarValue::Utf8(Some(ref str)) - | ScalarValue::Utf8View(Some(ref str)) - | ScalarValue::LargeUtf8(Some(ref str)), + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), ) => match target_type { - DataType::Int8 => str.parse::().ok().map(ScalarValue::from), - DataType::Int16 => str.parse::().ok().map(ScalarValue::from), - DataType::Int32 => str.parse::().ok().map(ScalarValue::from), - DataType::Int64 => str.parse::().ok().map(ScalarValue::from), - DataType::UInt8 => str.parse::().ok().map(ScalarValue::from), - DataType::UInt16 => str.parse::().ok().map(ScalarValue::from), - DataType::UInt32 => str.parse::().ok().map(ScalarValue::from), - DataType::UInt64 => str.parse::().ok().map(ScalarValue::from), + DataType::Int8 => cast_or_else_return_none!(lit_value, DataType::Int8), + DataType::Int16 => cast_or_else_return_none!(lit_value, DataType::Int16), + DataType::Int32 => cast_or_else_return_none!(lit_value, DataType::Int32), + DataType::Int64 => cast_or_else_return_none!(lit_value, DataType::Int64), + DataType::UInt8 => cast_or_else_return_none!(lit_value, DataType::UInt8), + DataType::UInt16 => cast_or_else_return_none!(lit_value, DataType::UInt16), + DataType::UInt32 => cast_or_else_return_none!(lit_value, DataType::UInt32), + DataType::UInt64 => cast_or_else_return_none!(lit_value, DataType::UInt64), _ => None, }, _ => None, From dc93afeb708b83a6c55c994f1636311af483675d Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Mon, 10 Mar 2025 22:00:02 +0800 Subject: [PATCH 03/11] test: add more unit test and new sqllogictest --- .../src/simplify_expressions/unwrap_cast.rs | 14 ++++++++++++ .../sqllogictest/test_files/explain.slt | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 5ce211f5e0f8..5545d3741fab 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -526,6 +526,15 @@ mod tests { // cast(c1, UTF8) < 123, only eq/not_eq should be optimized let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) = 'not a number', should not be able to cast to column type + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + + // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will + // not be optimized to integer comparison + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } #[test] @@ -564,6 +573,11 @@ mod tests { let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); let expected = col("c1").not_eq(lit(123i32)); assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) = NULL => c1 = NULL + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); + let expected = col("c1").eq(lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index cab7308f6ff8..adc25b98dbb0 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -430,6 +430,28 @@ physical_plan 03)--ProjectionExec: expr=[] 04)----PlaceholderRowExec +query TT +explain select a from t1 where a = '100'; +---- +logical_plan +01)Filter: t1.a = Int32(100) +02)--TableScan: t1 projection=[a] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: a@0 = 100 +03)----DataSourceExec: partitions=1, partition_sizes=[0] + +query TT +explain select a from t1 where a = '99999999999'; +---- +logical_plan +01)Filter: CAST(t1.a AS Utf8) = Utf8("99999999999") +02)--TableScan: t1 projection=[a] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: CAST(a@0 AS Utf8) = 99999999999 +03)----DataSourceExec: partitions=1, partition_sizes=[0] + statement ok drop table t1; From 8fab722c2d8eb8fccbf857a306e835e7ca9665b5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 10 Mar 2025 14:43:44 -0400 Subject: [PATCH 04/11] Tweak slt tests --- .../sqllogictest/test_files/explain.slt | 22 -------- .../test_files/push_down_filter.slt | 52 +++++++++++++++++++ 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index adc25b98dbb0..cab7308f6ff8 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -430,28 +430,6 @@ physical_plan 03)--ProjectionExec: expr=[] 04)----PlaceholderRowExec -query TT -explain select a from t1 where a = '100'; ----- -logical_plan -01)Filter: t1.a = Int32(100) -02)--TableScan: t1 projection=[a] -physical_plan -01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: a@0 = 100 -03)----DataSourceExec: partitions=1, partition_sizes=[0] - -query TT -explain select a from t1 where a = '99999999999'; ----- -logical_plan -01)Filter: CAST(t1.a AS Utf8) = Utf8("99999999999") -02)--TableScan: t1 projection=[a] -physical_plan -01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: CAST(a@0 AS Utf8) = 99999999999 -03)----DataSourceExec: partitions=1, partition_sizes=[0] - statement ok drop table t1; diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 521aa3340981..180737d0cab3 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -188,6 +188,7 @@ select * from test_filter_with_limit where value = 2 limit 1; ---- 2 2 + # Tear down test_filter_with_limit table: statement ok DROP TABLE test_filter_with_limit; @@ -195,3 +196,54 @@ DROP TABLE test_filter_with_limit; # Tear down src_table table: statement ok DROP TABLE src_table; + + +query I +COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) +TO 'test_files/scratch/push_down_filter/t.parquet' +STORED AS PARQUET; +---- +10 + +statement ok +CREATE EXTERNAL TABLE t +( + a INT +) +STORED AS PARQUET +LOCATION 'test_files/scratch/push_down_filter/t.parquet'; + + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a = '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a != '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99999999999'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99.99'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = ''; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] + + +statement ok +drop table t; From a878b07c6f6974a185c97f222637ff0b933d5d63 Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Wed, 12 Mar 2025 00:32:39 +0800 Subject: [PATCH 05/11] Revert "perf: unwrap cast for comparing ints =/!= strings" This reverts commit 808d6ab3ceb0281d055965a330b8ffb1c47fa65b. --- .../simplify_expressions/expr_simplifier.rs | 4 +- .../src/simplify_expressions/unwrap_cast.rs | 92 ------------------- 2 files changed, 2 insertions(+), 94 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fabba0eca5cc..d5a1b84e6aff 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1758,7 +1758,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &left, &right, op, + info, &left, &right, ) && op.supports_propagation() => { unwrap_cast_in_comparison_for_binary(info, left, right, op)? @@ -1768,7 +1768,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op_swap literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &right, &left, op, + info, &right, &left, ) && op.supports_propagation() && op.swap().is_some() => { diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 5545d3741fab..7670bdf98bb4 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -81,16 +81,6 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); }; - - if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op) - { - return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: expr, - op, - right: Box::new(lit(value)), - }))); - }; - // if the lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { @@ -116,7 +106,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< info: &S, expr: &Expr, literal: &Expr, - op: Operator, ) -> bool { match (expr, literal) { ( @@ -136,10 +125,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< return false; }; - if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() { - return true; - } - try_cast_literal_to_type(lit_val, &expr_type).is_some() && is_supported_type(&expr_type) && is_supported_type(&lit_type) @@ -192,45 +177,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< true } -fn cast_literal_to_type_with_op( - lit_value: &ScalarValue, - target_type: &DataType, - op: Operator, -) -> Option { - macro_rules! cast_or_else_return_none { - ($value:ident, $ty:expr) => {{ - let opts = arrow::compute::CastOptions { - safe: false, - format_options: Default::default(), - }; - let array = ScalarValue::to_array($value).ok()?; - let casted = arrow::compute::cast_with_options(&array, &$ty, &opts).ok()?; - let scalar = ScalarValue::try_from_array(&casted, 0).ok()?; - Some(scalar) - }}; - } - - match (op, lit_value) { - ( - Operator::Eq | Operator::NotEq, - ScalarValue::Utf8(Some(_)) - | ScalarValue::Utf8View(Some(_)) - | ScalarValue::LargeUtf8(Some(_)), - ) => match target_type { - DataType::Int8 => cast_or_else_return_none!(lit_value, DataType::Int8), - DataType::Int16 => cast_or_else_return_none!(lit_value, DataType::Int16), - DataType::Int32 => cast_or_else_return_none!(lit_value, DataType::Int32), - DataType::Int64 => cast_or_else_return_none!(lit_value, DataType::Int64), - DataType::UInt8 => cast_or_else_return_none!(lit_value, DataType::UInt8), - DataType::UInt16 => cast_or_else_return_none!(lit_value, DataType::UInt16), - DataType::UInt32 => cast_or_else_return_none!(lit_value, DataType::UInt32), - DataType::UInt64 => cast_or_else_return_none!(lit_value, DataType::UInt64), - _ => None, - }, - _ => None, - } -} - /// Returns true if unwrap_cast_in_comparison supports this data type fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) @@ -522,19 +468,6 @@ mod tests { // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(c1, UTF8) < 123, only eq/not_eq should be optimized - let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(c1, UTF8) = 'not a number', should not be able to cast to column type - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - - // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will - // not be optimized to integer comparison - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } #[test] @@ -563,21 +496,6 @@ mod tests { let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); - - // cast(c1, UTF8) = "123" => c1 = 123 - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); - let expected = col("c1").eq(lit(123i32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c1, UTF8) != "123" => c1 != 123 - let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); - let expected = col("c1").not_eq(lit(123i32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c1, UTF8) = NULL => c1 = NULL - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); - let expected = col("c1").eq(lit(ScalarValue::Int32(None))); - assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] @@ -587,16 +505,6 @@ mod tests { let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c6, UTF8) = "123" => c6 = 123 - let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); - let expected = col("c6").eq(lit(123u32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c6, UTF8) != "123" => c6 != 123 - let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); - let expected = col("c6").not_eq(lit(123u32)); - assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] From b212292bf6d4c6f431a501ed6d22cdc621a76cfa Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Wed, 12 Mar 2025 00:38:14 +0800 Subject: [PATCH 06/11] fix: eliminate column cast and cast literal before coercion --- .../optimizer/src/analyzer/type_coercion.rs | 53 +++++++++++++++++++ .../test_files/push_down_filter.slt | 5 ++ 2 files changed, 58 insertions(+) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 538ef98ac7be..1f24bd3fdeb2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -287,12 +287,29 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { + if let Expr::Literal(ref lit_value) = left { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &right.get_type(right_schema)?) + { + return Ok((casted, right)); + }; + } + + if let Expr::Literal(ref lit_value) = right { + if let Some(casted) = + try_cast_literal_to_type(lit_value, op, &left.get_type(left_schema)?) + { + return Ok((left, casted)); + }; + } + let (left_type, right_type) = BinaryTypeCoercer::new( &left.get_type(left_schema)?, &op, &right.get_type(right_schema)?, ) .get_input_types()?; + Ok(( left.cast_to(&left_type, left_schema)?, right.cast_to(&right_type, right_schema)?, @@ -300,6 +317,42 @@ impl<'a> TypeCoercionRewriter<'a> { } } +fn try_cast_literal_to_type( + lit_value: &ScalarValue, + op: Operator, + target_type: &DataType, +) -> Option { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; + ScalarValue::try_from_array(&casted, 0) + .ok() + .map(Expr::Literal) + } else { + None + } + } + _ => None, + } +} + impl TreeNodeRewriter for TypeCoercionRewriter<'_> { type Node = Expr; diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 180737d0cab3..3e5a493c2cd7 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -244,6 +244,11 @@ explain select a from t where a = ''; ---- logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] +# The predicate should still have the column cast when user explicitly cast the column +query TT +explain select a from t where cast(a as string) = '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("100")] statement ok drop table t; From 3059d5c215d45ea81939f07254d363e995b1b74f Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Wed, 12 Mar 2025 02:43:04 +0800 Subject: [PATCH 07/11] fix: physical expr coercion test --- datafusion/core/tests/expr_api/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index aef10379da07..f771e929abc5 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -330,12 +330,12 @@ async fn test_create_physical_expr_coercion() { create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0"); // compare int col to string literal `i = '202410'` // Note this casts the column (not the field) - create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)"); + create_expr_test(col("i").eq(lit("202410")), "i@1 = 202410"); + create_expr_test(lit("202410").eq(col("i")), "202410 = i@1"); // however, when simplified the casts on i should removed // https://github.com/apache/datafusion/issues/14944 - create_simplified_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); - create_simplified_expr_test(lit("202410").eq(col("i")), "CAST(i@1 AS Utf8) = 202410"); + create_simplified_expr_test(col("i").eq(lit("202410")), "i@1 = 202410"); + create_simplified_expr_test(lit("202410").eq(col("i")), "i@1 = 202410"); } /// Evaluates the specified expr as an aggregate and compares the result to the From 0a4e95dc201382d8c1b8ae9c60d733317b631135 Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:20:37 +0800 Subject: [PATCH 08/11] feat: unwrap cast after round-trip cast verification --- .../simplify_expressions/expr_simplifier.rs | 4 +- .../src/simplify_expressions/unwrap_cast.rs | 121 ++++++++++++++++++ .../test_files/push_down_filter.slt | 7 + 3 files changed, 130 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d5a1b84e6aff..d27b53a2e09f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1758,7 +1758,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &left, &right, + info, &left, op, &right, ) && op.supports_propagation() => { unwrap_cast_in_comparison_for_binary(info, left, right, op)? @@ -1768,7 +1768,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op_swap literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &right, &left, + info, &right, op, &left, ) && op.supports_propagation() && op.swap().is_some() => { diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 7670bdf98bb4..8dc9d1834571 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); }; + + if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op) + { + return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lit(value)), + }))); + }; + // if the lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { @@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< >( info: &S, expr: &Expr, + op: Operator, literal: &Expr, ) -> bool { match (expr, literal) { @@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< return false; }; + if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() { + return true; + } + try_cast_literal_to_type(lit_val, &expr_type).is_some() && is_supported_type(&expr_type) && is_supported_type(&lit_type) @@ -215,6 +230,69 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +/// Try to move a cast from a column to the other side of a `=` / `!=` operator +/// +/// Specifically, rewrites +/// ```sql +/// cast(col) +/// ``` +/// +/// To +/// +/// ```sql +/// col cast() +/// col +/// ``` +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let opts = arrow::compute::CastOptions { + safe: false, + format_options: Default::default(), + }; + + let array = ScalarValue::to_array(lit_value).ok()?; + let casted = + arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; + + // Perform a round-trip cast: literal -> target_type -> original_type + // Ensures cast expressions involving values like '0123' are not unwrapped for correctness (e.g., `cast(c1, UTF8) = '0123'`) + let round_tripped = arrow::compute::cast_with_options( + &casted, + &lit_value.data_type(), + &opts, + ) + .ok()?; + + if array != round_tripped { + return None; + } + + ScalarValue::try_from_array(&casted, 0).ok() + } else { + None + } + } + _ => None, + } +} + /// Convert a literal value from one data type to another pub(super) fn try_cast_literal_to_type( lit_value: &ScalarValue, @@ -468,6 +546,24 @@ mod tests { // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) < '123', only eq/not_eq should be optimized + let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not + // be casted + let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123")); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) = 'not a number', should not be able to cast to column type + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + + // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will + // not be optimized to integer comparison + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } #[test] @@ -496,6 +592,21 @@ mod tests { let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); + + // cast(c1, UTF8) = '123' => c1 = 123 + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); + let expected = col("c1").eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) != '123' => c1 != 123 + let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c1").not_eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) = NULL => c1 = NULL + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); + let expected = col("c1").eq(lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] @@ -505,6 +616,16 @@ mod tests { let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) = "123" => c6 = 123 + let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); + let expected = col("c6").eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) != "123" => c6 != 123 + let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c6").not_eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 3e5a493c2cd7..24d64e3facbe 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -250,5 +250,12 @@ explain select a from t where cast(a as string) = '100'; ---- logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("100")] +# The predicate should still have the column cast when the value's string representation is not the same after a round-trip cast +query TT +explain select a from t where CAST(a AS string) = '0123'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")] + + statement ok drop table t; From 2afaf8c7fc42aca88e2cca2cf97d117fd098e3ae Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Mon, 17 Mar 2025 13:09:18 +0800 Subject: [PATCH 09/11] fix: unwrap cast on round-trip cast stable strings --- datafusion/sqllogictest/test_files/push_down_filter.slt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 24d64e3facbe..67965146e76b 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -244,13 +244,13 @@ explain select a from t where a = ''; ---- logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] -# The predicate should still have the column cast when user explicitly cast the column +# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. query TT explain select a from t where cast(a as string) = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("100")] +logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] -# The predicate should still have the column cast when the value's string representation is not the same after a round-trip cast +# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). query TT explain select a from t where CAST(a AS string) = '0123'; ---- From 126502fe77e052be403e8c23a059bb0d628a5877 Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Wed, 26 Mar 2025 10:20:17 +0800 Subject: [PATCH 10/11] revert: remove avoid cast changes --- datafusion/core/tests/expr_api/mod.rs | 8 +-- .../optimizer/src/analyzer/type_coercion.rs | 52 ------------------- 2 files changed, 4 insertions(+), 56 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index f771e929abc5..aef10379da07 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -330,12 +330,12 @@ async fn test_create_physical_expr_coercion() { create_expr_test(lit(1i32).eq(col("id")), "CAST(1 AS Utf8) = id@0"); // compare int col to string literal `i = '202410'` // Note this casts the column (not the field) - create_expr_test(col("i").eq(lit("202410")), "i@1 = 202410"); - create_expr_test(lit("202410").eq(col("i")), "202410 = i@1"); + create_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); + create_expr_test(lit("202410").eq(col("i")), "202410 = CAST(i@1 AS Utf8)"); // however, when simplified the casts on i should removed // https://github.com/apache/datafusion/issues/14944 - create_simplified_expr_test(col("i").eq(lit("202410")), "i@1 = 202410"); - create_simplified_expr_test(lit("202410").eq(col("i")), "i@1 = 202410"); + create_simplified_expr_test(col("i").eq(lit("202410")), "CAST(i@1 AS Utf8) = 202410"); + create_simplified_expr_test(lit("202410").eq(col("i")), "CAST(i@1 AS Utf8) = 202410"); } /// Evaluates the specified expr as an aggregate and compares the result to the diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3a63f2a8c7ca..a77249424f71 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -290,22 +290,6 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { - if let Expr::Literal(ref lit_value) = left { - if let Some(casted) = - try_cast_literal_to_type(lit_value, op, &right.get_type(right_schema)?) - { - return Ok((casted, right)); - }; - } - - if let Expr::Literal(ref lit_value) = right { - if let Some(casted) = - try_cast_literal_to_type(lit_value, op, &left.get_type(left_schema)?) - { - return Ok((left, casted)); - }; - } - let (left_type, right_type) = BinaryTypeCoercer::new( &left.get_type(left_schema)?, &op, @@ -320,42 +304,6 @@ impl<'a> TypeCoercionRewriter<'a> { } } -fn try_cast_literal_to_type( - lit_value: &ScalarValue, - op: Operator, - target_type: &DataType, -) -> Option { - match (op, lit_value) { - ( - Operator::Eq | Operator::NotEq, - ScalarValue::Utf8(Some(_)) - | ScalarValue::Utf8View(Some(_)) - | ScalarValue::LargeUtf8(Some(_)), - ) => { - // Only try for integer types (TODO can we do this for other types like timestamps)? - use DataType::*; - if matches!( - target_type, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - ) { - let opts = arrow::compute::CastOptions { - safe: false, - format_options: Default::default(), - }; - let array = ScalarValue::to_array(lit_value).ok()?; - let casted = - arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; - ScalarValue::try_from_array(&casted, 0) - .ok() - .map(Expr::Literal) - } else { - None - } - } - _ => None, - } -} - impl TreeNodeRewriter for TypeCoercionRewriter<'_> { type Node = Expr; From bb8a3414fe0c1be0d297ff7892169248b185a324 Mon Sep 17 00:00:00 2001 From: alan910127 <70696274+alan910127@users.noreply.github.com> Date: Wed, 26 Mar 2025 10:26:41 +0800 Subject: [PATCH 11/11] refactor: apply review suggestions --- .../src/simplify_expressions/unwrap_cast.rs | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 8dc9d1834571..be71a8cd19b0 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -230,7 +230,7 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } -/// Try to move a cast from a column to the other side of a `=` / `!=` operator +///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ /// /// Specifically, rewrites /// ```sql @@ -262,29 +262,12 @@ fn cast_literal_to_type_with_op( target_type, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 ) { - let opts = arrow::compute::CastOptions { - safe: false, - format_options: Default::default(), - }; - - let array = ScalarValue::to_array(lit_value).ok()?; - let casted = - arrow::compute::cast_with_options(&array, target_type, &opts).ok()?; - - // Perform a round-trip cast: literal -> target_type -> original_type - // Ensures cast expressions involving values like '0123' are not unwrapped for correctness (e.g., `cast(c1, UTF8) = '0123'`) - let round_tripped = arrow::compute::cast_with_options( - &casted, - &lit_value.data_type(), - &opts, - ) - .ok()?; - - if array != round_tripped { + let casted = lit_value.cast_to(target_type).ok()?; + let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; + if lit_value != &round_tripped { return None; } - - ScalarValue::try_from_array(&casted, 0).ok() + Some(casted) } else { None }