diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 07eb795462c1..a77249424f71 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -296,6 +296,7 @@ impl<'a> TypeCoercionRewriter<'a> { &right.get_type(right_schema)?, ) .get_input_types()?; + Ok(( left.cast_to(&left_type, left_schema)?, right.cast_to(&right_type, right_schema)?, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d5a1b84e6aff..d27b53a2e09f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1758,7 +1758,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &left, &right, + info, &left, op, &right, ) && op.supports_propagation() => { unwrap_cast_in_comparison_for_binary(info, left, right, op)? @@ -1768,7 +1768,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // try_cast/cast(expr as data_type) op_swap literal Expr::BinaryExpr(BinaryExpr { left, op, right }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &right, &left, + info, &right, op, &left, ) && op.supports_propagation() && op.swap().is_some() => { diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 7670bdf98bb4..be71a8cd19b0 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); }; + + if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op) + { + return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lit(value)), + }))); + }; + // if the lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { @@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< >( info: &S, expr: &Expr, + op: Operator, literal: &Expr, ) -> bool { match (expr, literal) { @@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< return false; }; + if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() { + return true; + } + try_cast_literal_to_type(lit_val, &expr_type).is_some() && is_supported_type(&expr_type) && is_supported_type(&lit_type) @@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ +/// +/// Specifically, rewrites +/// ```sql +/// cast(col) +/// ``` +/// +/// To +/// +/// ```sql +/// col cast() +/// col +/// ``` +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let casted = lit_value.cast_to(target_type).ok()?; + let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; + if lit_value != &round_tripped { + return None; + } + Some(casted) + } else { + None + } + } + _ => None, + } +} + /// Convert a literal value from one data type to another pub(super) fn try_cast_literal_to_type( lit_value: &ScalarValue, @@ -468,6 +529,24 @@ mod tests { // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) < '123', only eq/not_eq should be optimized + let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not + // be casted + let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123")); + assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); + + // cast(c1, UTF8) = 'not a number', should not be able to cast to column type + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + + // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will + // not be optimized to integer comparison + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } #[test] @@ -496,6 +575,21 @@ mod tests { let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); + + // cast(c1, UTF8) = '123' => c1 = 123 + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); + let expected = col("c1").eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) != '123' => c1 != 123 + let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c1").not_eq(lit(123i32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c1, UTF8) = NULL => c1 = NULL + let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); + let expected = col("c1").eq(lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] @@ -505,6 +599,16 @@ mod tests { let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) = "123" => c6 = 123 + let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); + let expected = col("c6").eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(c6, UTF8) != "123" => c6 != 123 + let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); + let expected = col("c6").not_eq(lit(123u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); } #[test] diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 521aa3340981..67965146e76b 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -188,6 +188,7 @@ select * from test_filter_with_limit where value = 2 limit 1; ---- 2 2 + # Tear down test_filter_with_limit table: statement ok DROP TABLE test_filter_with_limit; @@ -195,3 +196,66 @@ DROP TABLE test_filter_with_limit; # Tear down src_table table: statement ok DROP TABLE src_table; + + +query I +COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)) +TO 'test_files/scratch/push_down_filter/t.parquet' +STORED AS PARQUET; +---- +10 + +statement ok +CREATE EXTERNAL TABLE t +( + a INT +) +STORED AS PARQUET +LOCATION 'test_files/scratch/push_down_filter/t.parquet'; + + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a = '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] + +# The predicate should not have a column cast when the value is a valid i32 +query TT +explain select a from t where a != '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99999999999'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = '99.99'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")] + +# The predicate should still have the column cast when the value is a NOT valid i32 +query TT +explain select a from t where a = ''; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] + +# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. +query TT +explain select a from t where cast(a as string) = '100'; +---- +logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] + +# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). +query TT +explain select a from t where CAST(a AS string) = '0123'; +---- +logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")] + + +statement ok +drop table t;