Skip to content

Commit d6ab343

Browse files
viiryaalamb
andauthored
Add helper function for processing scalar function input (#8962)
* Add helper function for scalar function * Update datafusion/physical-expr/src/functions.rs Co-authored-by: Andrew Lamb <[email protected]> * Fix * Fix --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent d81c82d commit d6ab343

File tree

4 files changed

+59
-43
lines changed

4 files changed

+59
-43
lines changed

datafusion-examples/examples/simple_udf.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::error::Result;
2828
use datafusion::prelude::*;
2929
use datafusion_common::cast::as_float64_array;
3030
use datafusion_expr::ColumnarValue;
31+
use datafusion_physical_expr::functions::columnar_values_to_array;
3132
use std::sync::Arc;
3233

3334
/// create local execution context with an in-memory table:
@@ -70,22 +71,11 @@ async fn main() -> Result<()> {
7071
// this is guaranteed by DataFusion based on the function's signature.
7172
assert_eq!(args.len(), 2);
7273

73-
// Try to obtain row number
74-
let len = args
75-
.iter()
76-
.fold(Option::<usize>::None, |acc, arg| match arg {
77-
ColumnarValue::Scalar(_) => acc,
78-
ColumnarValue::Array(a) => Some(a.len()),
79-
});
80-
81-
let inferred_length = len.unwrap_or(1);
82-
83-
let arg0 = args[0].clone().into_array(inferred_length)?;
84-
let arg1 = args[1].clone().into_array(inferred_length)?;
74+
let args = columnar_values_to_array(args)?;
8575

8676
// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
87-
let base = as_float64_array(&arg0).expect("cast failed");
88-
let exponent = as_float64_array(&arg1).expect("cast failed");
77+
let base = as_float64_array(&args[0]).expect("cast failed");
78+
let exponent = as_float64_array(&args[1]).expect("cast failed");
8979

9080
// this is guaranteed by DataFusion. We place it just to make it obvious.
9181
assert_eq!(exponent.len(), base.len());

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,7 @@ mod tests {
13761376
use datafusion_physical_expr::execution_props::ExecutionProps;
13771377

13781378
use chrono::{DateTime, TimeZone, Utc};
1379+
use datafusion_physical_expr::functions::columnar_values_to_array;
13791380

13801381
// ------------------------------
13811382
// --- ExprSimplifier tests -----
@@ -1489,30 +1490,10 @@ mod tests {
14891490
let return_type = Arc::new(DataType::Int32);
14901491

14911492
let fun = Arc::new(|args: &[ColumnarValue]| {
1492-
let len = args
1493-
.iter()
1494-
.fold(Option::<usize>::None, |acc, arg| match arg {
1495-
ColumnarValue::Scalar(_) => acc,
1496-
ColumnarValue::Array(a) => Some(a.len()),
1497-
});
1498-
1499-
let inferred_length = len.unwrap_or(1);
1500-
1501-
let arg0 = match &args[0] {
1502-
ColumnarValue::Array(array) => array.clone(),
1503-
ColumnarValue::Scalar(scalar) => {
1504-
scalar.to_array_of_size(inferred_length).unwrap()
1505-
}
1506-
};
1507-
let arg1 = match &args[1] {
1508-
ColumnarValue::Array(array) => array.clone(),
1509-
ColumnarValue::Scalar(scalar) => {
1510-
scalar.to_array_of_size(inferred_length).unwrap()
1511-
}
1512-
};
1493+
let args = columnar_values_to_array(args)?;
15131494

1514-
let arg0 = as_int32_array(&arg0)?;
1515-
let arg1 = as_int32_array(&arg1)?;
1495+
let arg0 = as_int32_array(&args[0])?;
1496+
let arg1 = as_int32_array(&args[1])?;
15161497

15171498
// 2. perform the computation
15181499
let array = arg0

datafusion/physical-expr/src/functions.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use arrow::{
4242
compute::kernels::length::{bit_length, length},
4343
datatypes::{DataType, Int32Type, Int64Type, Schema},
4444
};
45+
use arrow_array::Array;
4546
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
4647
pub use datafusion_expr::FuncMonotonicity;
4748
use datafusion_expr::{
@@ -191,6 +192,51 @@ pub(crate) enum Hint {
191192
AcceptsSingular,
192193
}
193194

195+
/// A helper function used to infer the length of arguments of Scalar functions and convert
196+
/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that this function
197+
/// only works for functions that accept either that all arguments are scalars or all arguments
198+
/// are arrays with same length. Otherwise, it will return an error.
199+
pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> {
200+
if args.is_empty() {
201+
return Ok(vec![]);
202+
}
203+
204+
let len = args
205+
.iter()
206+
.fold(Option::<usize>::None, |acc, arg| match arg {
207+
ColumnarValue::Scalar(_) if acc.is_none() => Some(1),
208+
ColumnarValue::Scalar(_) => {
209+
if let Some(1) = acc {
210+
acc
211+
} else {
212+
None
213+
}
214+
}
215+
ColumnarValue::Array(a) => {
216+
if let Some(l) = acc {
217+
if l == a.len() {
218+
acc
219+
} else {
220+
None
221+
}
222+
} else {
223+
Some(a.len())
224+
}
225+
}
226+
});
227+
228+
let inferred_length = len.ok_or(DataFusionError::Internal(
229+
"Arguments has mixed length".to_string(),
230+
))?;
231+
232+
let args = args
233+
.iter()
234+
.map(|arg| arg.clone().into_array(inferred_length))
235+
.collect::<Result<Vec<_>>>()?;
236+
237+
Ok(args)
238+
}
239+
194240
/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
195241
/// and vice-versa after evaluation.
196242
/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar.

docs/source/library-user-guide/adding-udfs.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ use std::sync::Arc;
4141

4242
use datafusion::arrow::array::{ArrayRef, Int64Array};
4343
use datafusion::common::Result;
44-
4544
use datafusion::common::cast::as_int64_array;
45+
use datafusion::physical_plan::functions::columnar_values_to_array;
4646

47-
pub fn add_one(args: &[ArrayRef]) -> Result<ArrayRef> {
47+
pub fn add_one(args: &[ColumnarValue]) -> Result<ArrayRef> {
4848
// Error handling omitted for brevity
49-
49+
let args = columnar_values_to_array(args)?;
5050
let i64s = as_int64_array(&args[0])?;
5151

5252
let new_array = i64s
@@ -82,7 +82,6 @@ There is a lower level API with more functionality but is more complex, that is
8282

8383
```rust
8484
use datafusion::logical_expr::{Volatility, create_udf};
85-
use datafusion::physical_plan::functions::make_scalar_function;
8685
use datafusion::arrow::datatypes::DataType;
8786
use std::sync::Arc;
8887

@@ -91,13 +90,13 @@ let udf = create_udf(
9190
vec![DataType::Int64],
9291
Arc::new(DataType::Int64),
9392
Volatility::Immutable,
94-
make_scalar_function(add_one),
93+
Arc::new(add_one),
9594
);
9695
```
9796

9897
[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
9998
[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
100-
[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
99+
[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html
101100
[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
102101

103102
A few things to note:

0 commit comments

Comments
 (0)