Skip to content

perf: unwrap cast for comparing ints =/!= strings #15110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ impl<'a> TypeCoercionRewriter<'a> {
&right.get_type(right_schema)?,
)
.get_input_types()?;

Ok((
left.cast_to(&left_type, left_schema)?,
right.cast_to(&right_type, right_schema)?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
// try_cast/cast(expr as data_type) op literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &left, &right,
info, &left, op, &right,
) && op.supports_propagation() =>
{
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
Expand All @@ -1768,7 +1768,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
// try_cast/cast(expr as data_type) op_swap literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &right, &left,
info, &right, op, &left,
) && op.supports_propagation()
&& op.swap().is_some() =>
{
Expand Down
104 changes: 104 additions & 0 deletions datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
let Ok(expr_type) = info.get_data_type(&expr) else {
return internal_err!("Can't get the data type of the expr {:?}", &expr);
};

if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op)
{
return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
left: expr,
op,
right: Box::new(lit(value)),
})));
};

// if the lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else {
Expand All @@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
>(
info: &S,
expr: &Expr,
op: Operator,
literal: &Expr,
) -> bool {
match (expr, literal) {
Expand All @@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
return false;
};

if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
return true;
}

try_cast_literal_to_type(lit_val, &expr_type).is_some()
&& is_supported_type(&expr_type)
&& is_supported_type(&lit_type)
Expand Down Expand Up @@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool {
DataType::Dictionary(_, inner) if is_supported_type(inner))
}

///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./
///
/// Specifically, rewrites
/// ```sql
/// cast(col) <op> <literal>
/// ```
///
/// To
///
/// ```sql
/// col <op> cast(<literal>)
/// col <op> <casted_literal>
/// ```
fn cast_literal_to_type_with_op(
lit_value: &ScalarValue,
target_type: &DataType,
op: Operator,
) -> Option<ScalarValue> {
match (op, lit_value) {
(
Operator::Eq | Operator::NotEq,
ScalarValue::Utf8(Some(_))
| ScalarValue::Utf8View(Some(_))
| ScalarValue::LargeUtf8(Some(_)),
) => {
// Only try for integer types (TODO can we do this for other types
// like timestamps)?
use DataType::*;
if matches!(
target_type,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
) {
let casted = lit_value.cast_to(target_type).ok()?;
let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?;
if lit_value != &round_tripped {
return None;
}
Some(casted)
} else {
None
}
}
_ => None,
}
}

/// Convert a literal value from one data type to another
pub(super) fn try_cast_literal_to_type(
lit_value: &ScalarValue,
Expand Down Expand Up @@ -468,6 +529,24 @@ mod tests {
// the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// cast(c1, UTF8) < '123', only eq/not_eq should be optimized
let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not
// be casted
let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);

// cast(c1, UTF8) = 'not a number', should not be able to cast to column type
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number"));
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);

// cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will
// not be optimized to integer comparison
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999"));
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}

#[test]
Expand Down Expand Up @@ -496,6 +575,21 @@ mod tests {
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
let expected = null_bool();
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);

// cast(c1, UTF8) = '123' => c1 = 123
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
let expected = col("c1").eq(lit(123i32));
assert_eq!(optimize_test(expr_input, &schema), expected);

// cast(c1, UTF8) != '123' => c1 != 123
let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
let expected = col("c1").not_eq(lit(123i32));
assert_eq!(optimize_test(expr_input, &schema), expected);

// cast(c1, UTF8) = NULL => c1 = NULL
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
assert_eq!(optimize_test(expr_input, &schema), expected);
}

#[test]
Expand All @@ -505,6 +599,16 @@ mod tests {
let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
let expected = col("c6").eq(lit(0u32));
assert_eq!(optimize_test(expr_input, &schema), expected);

// cast(c6, UTF8) = "123" => c6 = 123
let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
let expected = col("c6").eq(lit(123u32));
assert_eq!(optimize_test(expr_input, &schema), expected);

// cast(c6, UTF8) != "123" => c6 != 123
let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
let expected = col("c6").not_eq(lit(123u32));
assert_eq!(optimize_test(expr_input, &schema), expected);
}

#[test]
Expand Down
64 changes: 64 additions & 0 deletions datafusion/sqllogictest/test_files/push_down_filter.slt
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,74 @@ select * from test_filter_with_limit where value = 2 limit 1;
----
2 2


# Tear down test_filter_with_limit table:
statement ok
DROP TABLE test_filter_with_limit;

# Tear down src_table table:
statement ok
DROP TABLE src_table;


query I
COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10))
TO 'test_files/scratch/push_down_filter/t.parquet'
STORED AS PARQUET;
----
10

statement ok
CREATE EXTERNAL TABLE t
(
a INT
)
STORED AS PARQUET
LOCATION 'test_files/scratch/push_down_filter/t.parquet';


# The predicate should not have a column cast when the value is a valid i32
query TT
explain select a from t where a = '100';
----
logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]

# The predicate should not have a column cast when the value is a valid i32
query TT
explain select a from t where a != '100';
----
logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)]

# The predicate should still have the column cast when the value is a NOT valid i32
query TT
explain select a from t where a = '99999999999';
----
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")]

# The predicate should still have the column cast when the value is a NOT valid i32
query TT
explain select a from t where a = '99.99';
----
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")]

# The predicate should still have the column cast when the value is a NOT valid i32
query TT
explain select a from t where a = '';
----
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")]

# The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information.
query TT
explain select a from t where cast(a as string) = '100';
----
logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]

# The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost).
query TT
explain select a from t where CAST(a AS string) = '0123';
----
logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")]


statement ok
drop table t;