diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b5e9a555c2da..003a3ed36a60 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -195,6 +195,7 @@ impl ScalarUDF { /// See [`ScalarUDFImpl::invoke`] for more details. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + #[allow(deprecated)] self.inner.invoke(args) } @@ -218,6 +219,7 @@ impl ScalarUDF { /// See [`ScalarUDFImpl::invoke_no_args`] for more details. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke_no_args(&self, number_rows: usize) -> Result { + #[allow(deprecated)] self.inner.invoke_no_args(number_rows) } @@ -226,6 +228,7 @@ impl ScalarUDF { #[deprecated(since = "42.0.0", note = "Use `invoke_batch` instead")] pub fn fun(&self) -> ScalarFunctionImplementation { let captured = Arc::clone(&self.inner); + #[allow(deprecated)] Arc::new(move |args| captured.invoke(args)) } @@ -480,6 +483,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. /// /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] fn invoke(&self, _args: &[ColumnarValue]) -> Result { not_impl_err!( "Function {} does not implement invoke but called", @@ -489,19 +493,40 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Invoke the function with `args` and the number of rows, /// returning the appropriate result. + /// + /// The function will be invoked with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. fn invoke_batch( &self, args: &[ColumnarValue], number_rows: usize, ) -> Result { match args.is_empty() { - true => self.invoke_no_args(number_rows), - false => self.invoke(args), + true => + { + #[allow(deprecated)] + self.invoke_no_args(number_rows) + } + false => + { + #[allow(deprecated)] + self.invoke(args) + } } } /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] fn invoke_no_args(&self, _number_rows: usize) -> Result { not_impl_err!( "Function {} does not implement invoke_no_args but called", @@ -725,10 +750,12 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { } fn invoke(&self, args: &[ColumnarValue]) -> Result { + #[allow(deprecated)] self.inner.invoke(args) } fn invoke_no_args(&self, number_rows: usize) -> Result { + #[allow(deprecated)] self.inner.invoke_no_args(number_rows) } diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index a721836bb68c..5df5d9c7dee2 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -29,7 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_8192", |b| { b.iter(|| { for _ in 0..iterations { - black_box(random_func.invoke_no_args(8192).unwrap()); + black_box(random_func.invoke_batch(&[], 8192).unwrap()); } }) }); @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_128", |b| { b.iter(|| { for _ in 0..iterations_128 { - black_box(random_func.invoke_no_args(128).unwrap()); + black_box(random_func.invoke_batch(&[], 128).unwrap()); } }) }); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index e8d065df8633..065201e1caa3 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -491,6 +491,7 @@ mod tests { use chrono::TimeDelta; #[test] + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch fn test_date_bin() { let res = DateBinFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { @@ -781,6 +782,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateBinFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 4808f020e0ca..f8abef601f70 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -724,6 +724,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::from("day")), @@ -882,6 +883,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::from("hour")), diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed9858106c52..29b2f29b14c2 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -162,6 +162,7 @@ mod test { fn test_without_timezone() { let args = [ColumnarValue::Scalar(Int64(Some(1729900800)))]; + #[allow(deprecated)] // TODO use invoke_batch let result = FromUnixtimeFunc::new().invoke(&args).unwrap(); match result { @@ -181,6 +182,7 @@ mod test { ))), ]; + #[allow(deprecated)] // TODO use invoke_batch let result = FromUnixtimeFunc::new().invoke(&args).unwrap(); match result { diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index c8ef349dfbeb..6b246cb088a2 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -234,6 +234,7 @@ mod tests { #[test] fn test_make_date() { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), @@ -248,6 +249,7 @@ mod tests { panic!("Expected a scalar value") } + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), @@ -262,6 +264,7 @@ mod tests { panic!("Expected a scalar value") } + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), @@ -279,6 +282,7 @@ mod tests { let years = Arc::new((2021..2025).map(Some).collect::()); let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Array(years), @@ -304,6 +308,7 @@ mod tests { // // invalid number of arguments + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); assert_eq!( @@ -312,6 +317,7 @@ mod tests { ); // invalid type + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), @@ -323,6 +329,7 @@ mod tests { ); // overflow of month + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), @@ -334,6 +341,7 @@ mod tests { ); // overflow of day + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f0c4a02c1523..ef5d6a4f6990 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -384,6 +384,7 @@ mod tests { ]; for (value, format, expected) in scalar_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)]) .expect("that to_char parsed values without error"); @@ -458,6 +459,7 @@ mod tests { ]; for (value, format, expected) in scalar_array_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Scalar(value), @@ -583,6 +585,7 @@ mod tests { ]; for (value, format, expected) in array_scalar_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value as ArrayRef), @@ -599,6 +602,7 @@ mod tests { } for (value, format, expected) in array_array_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value), @@ -619,6 +623,7 @@ mod tests { // // invalid number of arguments + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); assert_eq!( @@ -627,6 +632,7 @@ mod tests { ); // invalid type + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 82e189698c5e..8f72100416e8 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -213,6 +213,7 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(sv)]); match to_date_result { @@ -233,6 +234,7 @@ mod tests { A: From> + Array + 'static, { let date_array = A::from(vec![tc.date_str]); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Array(Arc::new(date_array))]); @@ -323,6 +325,7 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), @@ -347,6 +350,7 @@ mod tests { let date_array = A::from(vec![tc.formatted_date]); let format_array = A::from(vec![tc.format_str]); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), @@ -382,6 +386,7 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), @@ -410,6 +415,7 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(formatted_date_scalar)]); @@ -428,6 +434,7 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); @@ -449,6 +456,7 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 376cb6f5f2f8..fef1eb9a60c8 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -558,7 +558,7 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::Scalar(input)]) + .invoke_batch(&[ColumnarValue::Scalar(input)], 1) .unwrap(); match res { ColumnarValue::Scalar(res) => { @@ -616,8 +616,9 @@ mod tests { .iter() .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); + let batch_size = input.len(); let result = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::Array(Arc::new(input))]) + .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size) .unwrap(); if let ColumnarValue::Array(result) = result { assert_eq!( diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 60482ee3c74a..f15fad701c55 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -636,7 +636,6 @@ mod tests { use arrow::array::{ArrayRef, Int64Array, StringBuilder}; use arrow::datatypes::TimeUnit; use chrono::Utc; - use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -1011,7 +1010,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, Some(_)))); let res = udf - .invoke(&[array.clone()]) + .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); let array = match res { ColumnarValue::Array(res) => res, @@ -1054,7 +1053,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, None))); let res = udf - .invoke(&[array.clone()]) + .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); let array = match res { ColumnarValue::Array(res) => res, diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 10f0f87a4ab1..dd90ce6a6c96 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, TimeUnit}; -use std::any::Any; -use std::sync::OnceLock; - use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ToUnixtimeFunc { @@ -63,7 +62,11 @@ impl ScalarUDFImpl for ToUnixtimeFunc { Ok(DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + batch_size: usize, + ) -> Result { if args.is_empty() { return exec_err!("to_unixtime function requires 1 or more arguments, got 0"); } @@ -81,7 +84,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc { .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), DataType::Utf8 => ToTimestampSecondsFunc::new() - .invoke(args)? + .invoke_batch(args, batch_size)? .cast_to(&DataType::Int64, None), other => { exec_err!("Unsupported data type {:?} for function to_unixtime", other) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 9d2e1be3df9d..9110f9f532d8 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -278,7 +278,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ]; - let _ = LogFunc::new().invoke(&args); + let _ = LogFunc::new().invoke_batch(&args, 4); } #[test] @@ -287,7 +287,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ]; - let result = LogFunc::new().invoke(&args); + let result = LogFunc::new().invoke_batch(&args, 1); result.expect_err("expected error"); } @@ -298,7 +298,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 1) .expect("failed to initialize function log"); match result { @@ -322,7 +322,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 1) .expect("failed to initialize function log"); match result { @@ -347,7 +347,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 1) .expect("failed to initialize function log"); match result { @@ -372,7 +372,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 1) .expect("failed to initialize function log"); match result { @@ -398,7 +398,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -427,7 +427,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -457,7 +457,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -487,7 +487,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 9bb6006d55b9..a24c613f5259 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -207,7 +207,7 @@ mod tests { ]; let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { @@ -234,7 +234,7 @@ mod tests { ]; let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ac881eb42f26..7f21297712c7 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -155,7 +155,7 @@ mod test { #[test] fn test_signum_f32() { - let args = [ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + let array = Arc::new(Float32Array::from(vec![ -1.0, -0.0, 0.0, @@ -165,10 +165,10 @@ mod test { f32::NAN, f32::INFINITY, f32::NEG_INFINITY, - ])))]; - + ])); + let batch_size = array.len(); let result = SignumFunc::new() - .invoke(&args) + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); match result { @@ -195,7 +195,7 @@ mod test { #[test] fn test_signum_f64() { - let args = [ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + let array = Arc::new(Float64Array::from(vec![ -1.0, -0.0, 0.0, @@ -205,10 +205,10 @@ mod test { f64::NAN, f64::INFINITY, f64::NEG_INFINITY, - ])))]; - + ])); + let batch_size = array.len(); let result = SignumFunc::new() - .invoke(&args) + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); match result { diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 7f7896ecd923..7c4313effffb 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -651,8 +651,10 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new() - .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -664,8 +666,10 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new() - .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -677,8 +681,10 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new() - .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -701,11 +707,14 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -717,11 +726,14 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -733,11 +745,14 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -762,12 +777,15 @@ mod tests { let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -780,12 +798,15 @@ mod tests { let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -798,12 +819,15 @@ mod tests { let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -877,12 +901,15 @@ mod tests { let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -895,12 +922,15 @@ mod tests { let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -913,12 +943,15 @@ mod tests { let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke(&[ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv), - ColumnarValue::Scalar(flags_sv.clone()), - ]); + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a4218c39e7b2..e3834b291896 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -408,6 +408,7 @@ mod tests { ]))); let args = &[c0, c1, c2]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatFunc::new().invoke(args)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8d966f495663..811939c1699b 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -467,6 +467,7 @@ mod tests { ]))); let args = &[c0, c1, c2]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatWsFunc::new().invoke(args)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; @@ -492,6 +493,7 @@ mod tests { ]))); let args = &[c0, c1, c2]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatWsFunc::new().invoke(args)?; let expected = Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index d0e63bb0f353..0c665a139152 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -145,6 +145,7 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let actual = udf.invoke(&[array, scalar]).unwrap(); let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ Some(true), diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index b07189a832dc..ef56120c582a 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -105,6 +105,7 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); let args = vec![ColumnarValue::Array(input)]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, _ => unreachable!(), diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 042c26b2e3da..68a9d60a1663 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -105,6 +105,7 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); let args = vec![ColumnarValue::Array(input)]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, _ => unreachable!(), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 4d6574d2bd6c..87180cb77de7 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -134,6 +134,13 @@ pub mod test { let func = $FUNC; let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); let return_type = func.return_type(&type_array); match expected { @@ -141,17 +148,10 @@ pub mod test { assert_eq!(return_type.is_ok(), true); assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); - let result = func.invoke($ARGS); + let result = func.invoke_batch($ARGS, cardinality); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.unwrap().clone().into_array(cardinality).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); // value is correct @@ -169,7 +169,7 @@ pub mod test { } else { // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { + match func.invoke_batch($ARGS, cardinality) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));