diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index b05e4acd20ae..57505170c2fd 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -28,7 +28,20 @@ pub mod test { let expected: datafusion_common::Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let arg_fields_owned = $ARGS + .iter() + .enumerate() + .map(|(idx, arg)| { + + let nullable = match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(), + datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, + }; + + arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable) + }) + .collect::>(); + let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -43,24 +56,18 @@ pub mod test { }).collect::>(); let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::>(); - let nullables = $ARGS.iter().map(|arg| match arg { - datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(), - datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, - }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, - scalar_arguments: &scalar_arguments_refs, - nullables: &nullables + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &arg_fields_owned, + scalar_arguments: &scalar_arguments_refs }); match expected { Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + let return_field = return_field.unwrap(); + assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_field: &return_field, arg_fields: arg_fields_owned.iter().collect()}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -74,17 +81,17 @@ pub mod test { }; } Err(expected_error) => { - if return_info.is_err() { - match return_info { + if return_field.is_err() { + match return_field { Ok(_) => assert!(false, "expected error"), Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } } } else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); + let return_field = return_field.unwrap(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_field: &return_field, arg_fields: arg_fields_owned.iter().collect()}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));