Skip to content

Commit d6634e3

Browse files
committed
fix: only handle ScalarValue::Null instead of all null-ed value
1 parent 2dc0298 commit d6634e3

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,10 @@ pub fn from_cast(
15921592
let Cast { expr, data_type } = cast;
15931593
// since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null
15941594
if let Expr::Literal(lit) = expr.as_ref() {
1595-
if lit.is_null() {
1595+
// only the untyped(a null scalar value) null literal need this special handling
1596+
// since all other kind of nulls are already typed and can be handled by substrait
1597+
// e.g. null::<Int32Type> or null::<Utf8Type>
1598+
if matches!(lit, ScalarValue::Null) {
15961599
let lit = Literal {
15971600
nullable: true,
15981601
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
@@ -2960,5 +2963,38 @@ mod test {
29602963
} else {
29612964
panic!("Expected expression type");
29622965
}
2966+
2967+
// a typed null should not be folded
2968+
let expr = Expr::Literal(ScalarValue::Int64(None))
2969+
.cast_to(&DataType::Int32, &empty_schema)
2970+
.unwrap();
2971+
2972+
let typed_null =
2973+
to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)
2974+
.unwrap();
2975+
2976+
if let ExprType::Expression(expr) =
2977+
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
2978+
{
2979+
let cast_expr = substrait::proto::expression::Cast {
2980+
r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()),
2981+
input: Some(Box::new(Expression {
2982+
rex_type: Some(RexType::Literal(Literal {
2983+
nullable: true,
2984+
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
2985+
literal_type: Some(LiteralType::Null(
2986+
to_substrait_type(&DataType::Int64, true).unwrap(),
2987+
)),
2988+
})),
2989+
})),
2990+
failure_behavior: FailureBehavior::ThrowException as i32,
2991+
};
2992+
let expected = Expression {
2993+
rex_type: Some(RexType::Cast(Box::new(cast_expr))),
2994+
};
2995+
assert_eq!(*expr, expected);
2996+
} else {
2997+
panic!("Expected expression type");
2998+
}
29632999
}
29643000
}

0 commit comments

Comments
 (0)