Skip to content

Commit f99f8da

Browse files
authored
Fix CI in main (#15917)
1 parent eeea69d commit f99f8da

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

datafusion/spark/src/function/utils.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,20 @@ pub mod test {
2828
let expected: datafusion_common::Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
2929
let func = $FUNC;
3030

31-
let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
31+
let arg_fields_owned = $ARGS
32+
.iter()
33+
.enumerate()
34+
.map(|(idx, arg)| {
35+
36+
let nullable = match arg {
37+
datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
38+
datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
39+
};
40+
41+
arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable)
42+
})
43+
.collect::<Vec<_>>();
44+
3245
let cardinality = $ARGS
3346
.iter()
3447
.fold(Option::<usize>::None, |acc, arg| match arg {
@@ -43,24 +56,18 @@ pub mod test {
4356
}).collect::<Vec<_>>();
4457
let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
4558

46-
let nullables = $ARGS.iter().map(|arg| match arg {
47-
datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
48-
datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
49-
}).collect::<Vec<_>>();
5059

51-
let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs {
52-
arg_types: &type_array,
53-
scalar_arguments: &scalar_arguments_refs,
54-
nullables: &nullables
60+
let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
61+
arg_fields: &arg_fields_owned,
62+
scalar_arguments: &scalar_arguments_refs
5563
});
5664

5765
match expected {
5866
Ok(expected) => {
59-
assert_eq!(return_info.is_ok(), true);
60-
let (return_type, _nullable) = return_info.unwrap().into_parts();
61-
assert_eq!(return_type, $EXPECTED_DATA_TYPE);
67+
let return_field = return_field.unwrap();
68+
assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE);
6269

63-
let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type});
70+
let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_field: &return_field, arg_fields: arg_fields_owned.iter().collect()});
6471
assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
6572

6673
let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
@@ -74,17 +81,17 @@ pub mod test {
7481
};
7582
}
7683
Err(expected_error) => {
77-
if return_info.is_err() {
78-
match return_info {
84+
if return_field.is_err() {
85+
match return_field {
7986
Ok(_) => assert!(false, "expected error"),
8087
Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
8188
}
8289
}
8390
else {
84-
let (return_type, _nullable) = return_info.unwrap().into_parts();
91+
let return_field = return_field.unwrap();
8592

8693
// invoke is expected error - cannot use .expect_err() due to Debug not being implemented
87-
match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) {
94+
match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_field: &return_field, arg_fields: arg_fields_owned.iter().collect()}) {
8895
Ok(_) => assert!(false, "expected error"),
8996
Err(error) => {
9097
assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));

0 commit comments

Comments
 (0)