@@ -1592,7 +1592,10 @@ pub fn from_cast(
1592
1592
let Cast { expr, data_type } = cast;
1593
1593
// since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null
1594
1594
if let Expr :: Literal ( lit) = expr. as_ref ( ) {
1595
- if lit. is_null ( ) {
1595
+ // only the untyped(a null scalar value) null literal need this special handling
1596
+ // since all other kind of nulls are already typed and can be handled by substrait
1597
+ // e.g. null::<Int32Type> or null::<Utf8Type>
1598
+ if matches ! ( lit, ScalarValue :: Null ) {
1596
1599
let lit = Literal {
1597
1600
nullable : true ,
1598
1601
type_variation_reference : DEFAULT_TYPE_VARIATION_REF ,
@@ -2960,5 +2963,38 @@ mod test {
2960
2963
} else {
2961
2964
panic ! ( "Expected expression type" ) ;
2962
2965
}
2966
+
2967
+ // a typed null should not be folded
2968
+ let expr = Expr :: Literal ( ScalarValue :: Int64 ( None ) )
2969
+ . cast_to ( & DataType :: Int32 , & empty_schema)
2970
+ . unwrap ( ) ;
2971
+
2972
+ let typed_null =
2973
+ to_substrait_extended_expr ( & [ ( & expr, & field) ] , & empty_schema, & state)
2974
+ . unwrap ( ) ;
2975
+
2976
+ if let ExprType :: Expression ( expr) =
2977
+ typed_null. referred_expr [ 0 ] . expr_type . as_ref ( ) . unwrap ( )
2978
+ {
2979
+ let cast_expr = substrait:: proto:: expression:: Cast {
2980
+ r#type : Some ( to_substrait_type ( & DataType :: Int32 , true ) . unwrap ( ) ) ,
2981
+ input : Some ( Box :: new ( Expression {
2982
+ rex_type : Some ( RexType :: Literal ( Literal {
2983
+ nullable : true ,
2984
+ type_variation_reference : DEFAULT_TYPE_VARIATION_REF ,
2985
+ literal_type : Some ( LiteralType :: Null (
2986
+ to_substrait_type ( & DataType :: Int64 , true ) . unwrap ( ) ,
2987
+ ) ) ,
2988
+ } ) ) ,
2989
+ } ) ) ,
2990
+ failure_behavior : FailureBehavior :: ThrowException as i32 ,
2991
+ } ;
2992
+ let expected = Expression {
2993
+ rex_type : Some ( RexType :: Cast ( Box :: new ( cast_expr) ) ) ,
2994
+ } ;
2995
+ assert_eq ! ( * expr, expected) ;
2996
+ } else {
2997
+ panic ! ( "Expected expression type" ) ;
2998
+ }
2963
2999
}
2964
3000
}
0 commit comments