diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 3bdc71a8eb99..e8ef34c2afe7 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -355,7 +355,8 @@ async fn csv_explain_verbose() { async fn csv_explain_inlist_verbose() { let ctx = SessionContext::new(); register_aggregate_csv_by_sql(&ctx).await; - let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4)"; + // Inlist len <=3 case will be transformed to OR List so we test with len=4 + let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4,5)"; let actual = execute(&ctx, sql).await; // Optimized by PreCastLitInComparisonExpressions rule @@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() { // before optimization (Int64 literals) assert_contains!( &actual, - "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])" + "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])" ); // after optimization (casted to Int8) assert_contains!( &actual, - "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])" + "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])" ); } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 61ca9b31cd29..1280bf2f466e 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; -pub mod unwrap_cast_in_comparison; pub mod utils; #[cfg(test)] diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 49bce3c1ce82..018ad8ace0e3 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -54,7 +54,6 @@ use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; -use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; /// `OptimizerRule`s transforms one [`LogicalPlan`] into another which @@ -243,7 +242,6 @@ impl Optimizer { let rules: Vec> = vec![ Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), @@ -266,7 +264,6 @@ impl Optimizer { // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateGroupByConstant::new()), Arc::new(OptimizeProjections::new()), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 840c108905a9..d5a1b84e6aff 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,6 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, WindowFunctionDefinition, @@ -42,14 +41,23 @@ use datafusion_expr::{ expr::{InList, InSubquery, WindowFunction}, utils::{iter_conjunction, iter_conjunction_owned}, }; +use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::unwrap_cast::{ + is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, + is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, + unwrap_cast_in_comparison_for_binary, +}; use crate::simplify_expressions::SimplifyInfo; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::unwrap_cast::try_cast_literal_to_type, +}; use indexmap::IndexSet; use regex::Regex; @@ -1742,6 +1750,86 @@ impl TreeNodeRewriter for Simplifier<'_, S> { } } + // ======================================= + // unwrap_cast_in_comparison + // ======================================= + // + // For case: + // try_cast/cast(expr as data_type) op literal + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( + info, &left, &right, + ) && op.supports_propagation() => + { + unwrap_cast_in_comparison_for_binary(info, left, right, op)? + } + // literal op try_cast/cast(expr as data_type) + // --> + // try_cast/cast(expr as data_type) op_swap literal + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( + info, &right, &left, + ) && op.supports_propagation() + && op.swap().is_some() => + { + unwrap_cast_in_comparison_for_binary( + info, + right, + left, + op.swap().unwrap(), + )? + } + // For case: + // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) + Expr::InList(InList { + expr: mut left, + list, + negated, + }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist( + info, &left, &list, + ) => + { + let (Expr::TryCast(TryCast { + expr: left_expr, .. + }) + | Expr::Cast(Cast { + expr: left_expr, .. + })) = left.as_mut() + else { + return internal_err!("Expect cast expr, but got {:?}", left)?; + }; + + let expr_type = info.get_data_type(left_expr)?; + let right_exprs = list + .into_iter() + .map(|right| { + match right { + Expr::Literal(right_lit_value) => { + // if the right_lit_value can be casted to the type of internal_left_expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { + internal_err!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &expr_type + )? + }; + Ok(lit(value)) + } + other_expr => internal_err!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ), + } + }) + .collect::>>()?; + + Transformed::yes(Expr::InList(InList { + expr: std::mem::take(left_expr), + list: right_exprs, + negated, + })) + } + // no additional rewrites possible expr => Transformed::no(expr), }) diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 46c066c11c0f..5fbee02e3909 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -23,6 +23,7 @@ mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; +mod unwrap_cast; mod utils; // backwards compatibility diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs similarity index 79% rename from datafusion/optimizer/src/unwrap_cast_in_comparison.rs rename to datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index e2b8a966cb92..7670bdf98bb4 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -15,274 +15,176 @@ // specific language governing permissions and limitations // under the License. -//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` +//! Unwrap casts in binary comparisons +//! +//! The functions in this module attempt to remove casts from +//! comparisons to literals ([`ScalarValue`]s) by applying the casts +//! to the literals if possible. It is inspired by the optimizer rule +//! `UnwrapCastInBinaryComparison` of Spark. +//! +//! Removing casts often improves performance because: +//! 1. The cast is done once (to the literal) rather than to every value +//! 2. Can enable other optimizations such as predicate pushdown that +//! don't support casting +//! +//! The rule is applied to expressions of the following forms: +//! +//! 1. `cast(left_expr as data_type) comparison_op literal_expr` +//! 2. `literal_expr comparison_op cast(left_expr as data_type)` +//! 3. `cast(literal_expr) IN (expr1, expr2, ...)` +//! 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)` +//! +//! If the expression matches one of the forms above, the rule will +//! ensure the value of `literal` is in range(min, max) of the +//! expr's data_type, and if the scalar is within range, the literal +//! will be casted to the data type of expr on the other side, and the +//! cast will be removed from the other side. +//! +//! # Example +//! +//! If the DataType of c1 is INT32. Given the filter +//! +//! ```text +//! cast(c1 as INT64) > INT64(10)` +//! ``` +//! +//! This rule will remove the cast and rewrite the expression to: +//! +//! ```text +//! c1 > INT32(10) +//! ``` +//! use std::cmp::Ordering; -use std::mem; -use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; - -use crate::utils::NamePreserver; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; -use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; -use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; -use datafusion_expr::utils::merge_schema; -use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; - -/// [`UnwrapCastInComparison`] attempts to remove casts from -/// comparisons to literals ([`ScalarValue`]s) by applying the casts -/// to the literals if possible. It is inspired by the optimizer rule -/// `UnwrapCastInBinaryComparison` of Spark. -/// -/// Removing casts often improves performance because: -/// 1. The cast is done once (to the literal) rather than to every value -/// 2. Can enable other optimizations such as predicate pushdown that -/// don't support casting -/// -/// The rule is applied to expressions of the following forms: -/// -/// 1. `cast(left_expr as data_type) comparison_op literal_expr` -/// 2. `literal_expr comparison_op cast(left_expr as data_type)` -/// 3. `cast(literal_expr) IN (expr1, expr2, ...)` -/// 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)` -/// -/// If the expression matches one of the forms above, the rule will -/// ensure the value of `literal` is in range(min, max) of the -/// expr's data_type, and if the scalar is within range, the literal -/// will be casted to the data type of expr on the other side, and the -/// cast will be removed from the other side. -/// -/// # Example -/// -/// If the DataType of c1 is INT32. Given the filter -/// -/// ```text -/// Filter: cast(c1 as INT64) > INT64(10)` -/// ``` -/// -/// This rule will remove the cast and rewrite the expression to: -/// -/// ```text -/// Filter: c1 > INT32(10) -/// ``` -/// -#[derive(Default, Debug)] -pub struct UnwrapCastInComparison {} - -impl UnwrapCastInComparison { - pub fn new() -> Self { - Self::default() +use datafusion_common::{internal_err, tree_node::Transformed}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{lit, BinaryExpr}; +use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; + +pub(super) fn unwrap_cast_in_comparison_for_binary( + info: &S, + cast_expr: Box, + literal: Box, + op: Operator, +) -> Result> { + match (*cast_expr, *literal) { + ( + Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), + Expr::Literal(lit_value), + ) => { + let Ok(expr_type) = info.get_data_type(&expr) else { + return internal_err!("Can't get the data type of the expr {:?}", &expr); + }; + // if the lit_value can be casted to the type of internal_left_expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { + return internal_err!( + "Can't cast the literal expr {:?} to type {:?}", + &lit_value, + &expr_type + ); + }; + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lit(value)), + }))) + } + _ => internal_err!("Expect cast expr and literal"), } } -impl OptimizerRule for UnwrapCastInComparison { - fn name(&self) -> &str { - "unwrap_cast_in_comparison" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< + S: SimplifyInfo, +>( + info: &S, + expr: &Expr, + literal: &Expr, +) -> bool { + match (expr, literal) { + ( + Expr::TryCast(TryCast { + expr: left_expr, .. + }) + | Expr::Cast(Cast { + expr: left_expr, .. + }), + Expr::Literal(lit_val), + ) => { + let Ok(expr_type) = info.get_data_type(left_expr) else { + return false; + }; - fn supports_rewrite(&self) -> bool { - true - } + let Ok(lit_type) = info.get_data_type(literal) else { + return false; + }; - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - let mut schema = merge_schema(&plan.inputs()); - - if let LogicalPlan::TableScan(ts) = &plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); + try_cast_literal_to_type(lit_val, &expr_type).is_some() + && is_supported_type(&expr_type) + && is_supported_type(&lit_type) } + _ => false, + } +} - schema.merge(plan.schema()); +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< + S: SimplifyInfo, +>( + info: &S, + expr: &Expr, + list: &[Expr], +) -> bool { + let (Expr::TryCast(TryCast { + expr: left_expr, .. + }) + | Expr::Cast(Cast { + expr: left_expr, .. + })) = expr + else { + return false; + }; - let mut expr_rewriter = UnwrapCastExprRewriter { - schema: Arc::new(schema), - }; + let Ok(expr_type) = info.get_data_type(left_expr) else { + return false; + }; - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewriter) - .map(|transformed| transformed.update_data(|e| original_name.restore(e))) - }) + if !is_supported_type(&expr_type) { + return false; } -} -struct UnwrapCastExprRewriter { - schema: DFSchemaRef, -} + for right in list { + let Ok(right_type) = info.get_data_type(right) else { + return false; + }; -impl TreeNodeRewriter for UnwrapCastExprRewriter { - type Node = Expr; - - fn f_up(&mut self, mut expr: Expr) -> Result> { - match &mut expr { - // For case: - // try_cast/cast(expr as data_type) op literal - // literal op try_cast/cast(expr as data_type) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if { - let Ok(left_type) = left.get_type(&self.schema) else { - return Ok(Transformed::no(expr)); - }; - let Ok(right_type) = right.get_type(&self.schema) else { - return Ok(Transformed::no(expr)); - }; - is_supported_type(&left_type) - && is_supported_type(&right_type) - && op.supports_propagation() - } => - { - match (left.as_mut(), right.as_mut()) { - ( - Expr::Literal(left_lit_value), - Expr::TryCast(TryCast { - expr: right_expr, .. - }) - | Expr::Cast(Cast { - expr: right_expr, .. - }), - ) => { - // if the left_lit_value can be cast to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Ok(expr_type) = right_expr.get_type(&self.schema) else { - return Ok(Transformed::no(expr)); - }; - match expr_type { - // https://github.com/apache/datafusion/issues/12180 - DataType::Utf8View => Ok(Transformed::no(expr)), - _ => { - let Some(value) = - try_cast_literal_to_type(left_lit_value, &expr_type) - else { - return Ok(Transformed::no(expr)); - }; - **left = lit(value); - // unwrap the cast/try_cast for the right expr - **right = mem::take(right_expr); - Ok(Transformed::yes(expr)) - } - } - } - ( - Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - }), - Expr::Literal(right_lit_value), - ) => { - // if the right_lit_value can be cast to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Ok(expr_type) = left_expr.get_type(&self.schema) else { - return Ok(Transformed::no(expr)); - }; - match expr_type { - // https://github.com/apache/datafusion/issues/12180 - DataType::Utf8View => Ok(Transformed::no(expr)), - _ => { - let Some(value) = - try_cast_literal_to_type(right_lit_value, &expr_type) - else { - return Ok(Transformed::no(expr)); - }; - // unwrap the cast/try_cast for the left expr - **left = mem::take(left_expr); - **right = lit(value); - Ok(Transformed::yes(expr)) - } - } - } - _ => Ok(Transformed::no(expr)), - } - } - // For case: - // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) - Expr::InList(InList { - expr: left, list, .. - }) => { - let (Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - })) = left.as_mut() - else { - return Ok(Transformed::no(expr)); - }; - let Ok(expr_type) = left_expr.get_type(&self.schema) else { - return Ok(Transformed::no(expr)); - }; - if !is_supported_type(&expr_type) { - return Ok(Transformed::no(expr)); - } - let Ok(right_exprs) = list - .iter() - .map(|right| { - let right_type = right.get_type(&self.schema)?; - if !is_supported_type(&right_type) { - internal_err!( - "The type of list expr {} is not supported", - &right_type - )?; - } - match right { - Expr::Literal(right_lit_value) => { - // if the right_lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Some(value) = try_cast_literal_to_type(right_lit_value, &expr_type) else { - internal_err!( - "Can't cast the list expr {:?} to type {:?}", - right_lit_value, &expr_type - )? - }; - Ok(lit(value)) - } - other_expr => internal_err!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ), - } - }) - .collect::>>() else { - return Ok(Transformed::no(expr)) - }; - **left = mem::take(left_expr); - *list = right_exprs; - Ok(Transformed::yes(expr)) - } - // TODO: handle other expr type and dfs visit them - _ => Ok(Transformed::no(expr)), + if !is_supported_type(&right_type) { + return false; + } + + match right { + Expr::Literal(lit_val) + if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} + _ => return false, } } + + true } -/// Returns true if [UnwrapCastExprRewriter] supports this data type +/// Returns true if unwrap_cast_in_comparison supports this data type fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) || is_supported_string_type(data_type) || is_supported_dictionary_type(data_type) } -/// Returns true if [[UnwrapCastExprRewriter]] support this numeric type +/// Returns true if unwrap_cast_in_comparison support this numeric type fn is_supported_numeric_type(data_type: &DataType) -> bool { matches!( data_type, @@ -299,7 +201,7 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool { ) } -/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a string +/// Returns true if unwrap_cast_in_comparison supports casting this value as a string fn is_supported_string_type(data_type: &DataType) -> bool { matches!( data_type, @@ -307,14 +209,14 @@ fn is_supported_string_type(data_type: &DataType) -> bool { ) } -/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a dictionary +/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary fn is_supported_dictionary_type(data_type: &DataType) -> bool { matches!(data_type, DataType::Dictionary(_, inner) if is_supported_type(inner)) } /// Convert a literal value from one data type to another -fn try_cast_literal_to_type( +pub(super) fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, ) -> Option { @@ -540,13 +442,16 @@ fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; + use std::collections::HashMap; + use std::sync::Arc; + use crate::simplify_expressions::ExprSimplifier; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::Field; - use datafusion_common::tree_node::TransformedResult; + use datafusion_common::{DFSchema, DFSchemaRef}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{cast, col, in_list, try_cast}; #[test] @@ -587,9 +492,9 @@ mod tests { let expected = col("c1").lt(null_i32()); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) + // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); - let expected = null_i8().lt(lit(12i8)); + let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } @@ -623,7 +528,7 @@ mod tests { // Verify reversed argument order // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = lit("value").eq(col("str1")); + let expected = col("str1").eq(lit("value")); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -740,15 +645,27 @@ mod tests { #[test] fn test_unwrap_list_cast_comparison() { let schema = expr_test_schema(); - // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); - let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false); + // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) -> + // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c1"), DataType::Int64).in_list( + vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c1").in_list( + vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)], + false, + ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false); - let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false); + // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) -> + // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78)) + let expr_lt = cast(col("c2"), DataType::Int32).in_list( + vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], + false, + ); + let expected = col("c2").in_list( + vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); assert_eq!(optimize_test(expr_lt, &schema), expected); @@ -774,10 +691,14 @@ mod tests { ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(12i32), DataType::Int64) - .in_list(vec![lit(13i64), lit(12i64)], false); - let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false); + // cast(INT32(12), INT64) IN (.....) => + // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16)) + // => true + let expr_lt = cast(lit(12i32), DataType::Int64).in_list( + vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)], + false, + ); + let expected = lit(true); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -815,8 +736,12 @@ mod tests { assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // inlist for unsupported data type - let expr_input = - in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false); + let expr_input = in_list( + cast(col("c6"), DataType::Float64), + // need more literals to avoid rewriting to binary expr + vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)], + false, + ); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } @@ -833,10 +758,12 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let mut expr_rewriter = UnwrapCastExprRewriter { - schema: Arc::clone(schema), - }; - expr.rewrite(&mut expr_rewriter).data().unwrap() + let props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&props).with_schema(Arc::clone(schema)), + ); + + simplifier.simplify(expr).unwrap() } fn expr_test_schema() -> DFSchemaRef { @@ -862,6 +789,10 @@ mod tests { ) } + fn null_bool() -> Expr { + lit(ScalarValue::Boolean(None)) + } + fn null_i8() -> Expr { lit(ScalarValue::Int8(None)) } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 16c61a1db6ee..06a5a41675d1 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -181,7 +181,6 @@ logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE @@ -200,13 +199,11 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE @@ -225,7 +222,6 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE