Skip to content

Commit 7e7dae7

Browse files
committed
Introduce new trait based ScalarUDF API
1 parent a1e959d commit 7e7dae7

File tree

9 files changed

+565
-89
lines changed

9 files changed

+565
-89
lines changed

datafusion-examples/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ cargo run --example csv_sql
5858
- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3
5959
- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP
6060
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
61+
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF)
62+
- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF)
6163
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
62-
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
6364
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)
6465

6566
## Distributed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::{
19+
arrow::{
20+
array::{ArrayRef, Float32Array, Float64Array},
21+
datatypes::DataType,
22+
record_batch::RecordBatch,
23+
},
24+
logical_expr::Volatility,
25+
};
26+
use std::any::Any;
27+
28+
use arrow::array::{new_null_array, Array, AsArray};
29+
use arrow::compute;
30+
use arrow::datatypes::Float64Type;
31+
use datafusion::error::Result;
32+
use datafusion::prelude::*;
33+
use datafusion_common::{internal_err, ScalarValue};
34+
use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};
35+
use std::sync::Arc;
36+
37+
/// This example shows how to use the full ScalarUDFImpl API to implement a user
38+
/// defined function. As in the `simple_udf.rs` example, this struct implements
39+
/// a function that takes two arguments and returns the first argument raised to
40+
/// the power of the second argument `a^b`.
41+
///
42+
/// To do so, we must implement the `ScalarUDFImpl` trait.
43+
struct PowUdf {
44+
signature: Signature,
45+
aliases: Vec<String>,
46+
}
47+
48+
impl PowUdf {
49+
/// Create a new instance of the `PowUdf` struct
50+
fn new() -> Self {
51+
Self {
52+
signature: Signature::exact(
53+
// this function will always take two arguments of type f64
54+
vec![DataType::Float64, DataType::Float64],
55+
// this function is deterministic and will always return the same
56+
// result for the same input
57+
Volatility::Immutable,
58+
),
59+
// we will also add an alias of "my_pow"
60+
aliases: vec!["my_pow".to_string()],
61+
}
62+
}
63+
}
64+
65+
impl ScalarUDFImpl for PowUdf {
66+
/// We implement as_any so that we can downcast the ScalarUDFImpl trait object
67+
fn as_any(&self) -> &dyn Any {
68+
self
69+
}
70+
71+
/// Return the name of this function
72+
fn name(&self) -> &str {
73+
"pow"
74+
}
75+
76+
/// Return the "signature" of this function -- namely what types of arguments it will take
77+
fn signature(&self) -> &Signature {
78+
&self.signature
79+
}
80+
81+
/// What is the type of value that will be returned by this function? In
82+
/// this case it will always be a constant value, but it could also be a
83+
/// function of the input types.
84+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85+
Ok(DataType::Float64)
86+
}
87+
88+
/// This is the function that actually calculates the results.
89+
///
90+
/// This is the same way that functions built into DataFusion are invoked,
91+
/// which permits important special cases when one or both of the arguments
92+
/// are single values (constants). For example `pow(a, 2)`
93+
///
94+
/// However, it also means the implementation is more complex than when
95+
/// using `create_udf`.
96+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
97+
// DataFusion has arranged for the correct inputs to be passed to this
98+
// function, but we check again to make sure
99+
assert_eq!(args.len(), 2);
100+
let (base, exp) = (&args[0], &args[1]);
101+
assert_eq!(base.data_type(), DataType::Float64);
102+
assert_eq!(exp.data_type(), DataType::Float64);
103+
104+
match (base, exp) {
105+
// For demonstration purposes we also implement the scalar / scalar
106+
// case here, but it is not typically required for high performance.
107+
//
108+
// For performance it is most important to optimize cases where at
109+
// least one argument is an array. If all arguments are constants,
110+
// the DataFusion expression simplification logic will often invoke
111+
// this path once during planning, and simply use the result during
112+
// execution.
113+
(
114+
ColumnarValue::Scalar(ScalarValue::Float64(base)),
115+
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
116+
) => {
117+
// compute the output. Note DataFusion treats `None` as NULL.
118+
let res = match (base, exp) {
119+
(Some(base), Some(exp)) => Some(base.powf(*exp)),
120+
// one or both arguments were NULL
121+
_ => None,
122+
};
123+
Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
124+
}
125+
// special case if the exponent is a constant
126+
(
127+
ColumnarValue::Array(base_array),
128+
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
129+
) => {
130+
let result_array = match exp {
131+
// a ^ null = null
132+
None => new_null_array(base_array.data_type(), base_array.len()),
133+
// a ^ exp
134+
Some(exp) => {
135+
// DataFusion has ensured both arguments are Float64:
136+
let base_array = base_array.as_primitive::<Float64Type>();
137+
// calculate the result for every row. The `unary` very
138+
// fast, "vectorized" code and handles things like null
139+
// values for us.
140+
let res: Float64Array =
141+
compute::unary(base_array, |base| base.powf(*exp));
142+
Arc::new(res)
143+
}
144+
};
145+
Ok(ColumnarValue::Array(result_array))
146+
}
147+
148+
// special case if the base is a constant (note this code is quite
149+
// similar to the previous case, so we omit comments)
150+
(
151+
ColumnarValue::Scalar(ScalarValue::Float64(base)),
152+
ColumnarValue::Array(exp_array),
153+
) => {
154+
let res = match base {
155+
None => new_null_array(exp_array.data_type(), exp_array.len()),
156+
Some(base) => {
157+
let exp_array = exp_array.as_primitive::<Float64Type>();
158+
let res: Float64Array =
159+
compute::unary(exp_array, |exp| base.powf(exp));
160+
Arc::new(res)
161+
}
162+
};
163+
Ok(ColumnarValue::Array(res))
164+
}
165+
// Both arguments are arrays s we have to perform the calculation for every row
166+
(ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
167+
let res: Float64Array = compute::binary(
168+
base_array.as_primitive::<Float64Type>(),
169+
exp_array.as_primitive::<Float64Type>(),
170+
|base, exp| base.powf(exp),
171+
)?;
172+
Ok(ColumnarValue::Array(Arc::new(res)))
173+
}
174+
// if the types were not float, it is a bug in DataFusion
175+
_ => {
176+
use datafusion_common::DataFusionError;
177+
internal_err!("Invalid argument types to pow function")
178+
}
179+
}
180+
}
181+
182+
/// We will also add an alias of "my_pow"
183+
fn aliases(&self) -> &[String] {
184+
&self.aliases
185+
}
186+
}
187+
188+
/// In this example we register `PowUdf` as a user defined function
189+
/// and invoke it via the DataFrame API and SQL
190+
#[tokio::main]
191+
async fn main() -> Result<()> {
192+
let ctx = create_context()?;
193+
194+
// create the UDF
195+
let pow = ScalarUDF::from(PowUdf::new());
196+
197+
// register the UDF with the context so it can be invoked by name and from SQL
198+
ctx.register_udf(pow.clone());
199+
200+
// get a DataFrame from the context for scanning the "t" table
201+
let df = ctx.table("t").await?;
202+
203+
// Call pow(a, 10) using the DataFrame API
204+
let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?;
205+
206+
// note that the second argument is passed as an i32, not f64. DataFusion
207+
// automatically coerces the types to match the UDF's defined signature.
208+
209+
// print the results
210+
df.show().await?;
211+
212+
// You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL
213+
let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?;
214+
sql_df.show().await?;
215+
216+
Ok(())
217+
}
218+
219+
/// create local execution context with an in-memory table:
220+
///
221+
/// ```text
222+
/// +-----+-----+
223+
/// | a | b |
224+
/// +-----+-----+
225+
/// | 2.1 | 1.0 |
226+
/// | 3.1 | 2.0 |
227+
/// | 4.1 | 3.0 |
228+
/// | 5.1 | 4.0 |
229+
/// +-----+-----+
230+
/// ```
231+
fn create_context() -> Result<SessionContext> {
232+
// define data.
233+
let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
234+
let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
235+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
236+
237+
// declare a new context. In spark API, this corresponds to a new spark SQLsession
238+
let ctx = SessionContext::new();
239+
240+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
241+
ctx.register_batch("t", batch)?;
242+
Ok(ctx)
243+
}

datafusion-examples/examples/simple_udf.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,23 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function};
2929
use datafusion_common::cast::as_float64_array;
3030
use std::sync::Arc;
3131

32-
// create local execution context with an in-memory table
32+
/// create local execution context with an in-memory table:
33+
///
34+
/// ```text
35+
/// +-----+-----+
36+
/// | a | b |
37+
/// +-----+-----+
38+
/// | 2.1 | 1.0 |
39+
/// | 3.1 | 2.0 |
40+
/// | 4.1 | 3.0 |
41+
/// | 5.1 | 4.0 |
42+
/// +-----+-----+
43+
/// ```
3344
fn create_context() -> Result<SessionContext> {
34-
use datafusion::arrow::datatypes::{Field, Schema};
35-
// define a schema.
36-
let schema = Arc::new(Schema::new(vec![
37-
Field::new("a", DataType::Float32, false),
38-
Field::new("b", DataType::Float64, false),
39-
]));
40-
4145
// define data.
42-
let batch = RecordBatch::try_new(
43-
schema,
44-
vec![
45-
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])),
46-
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
47-
],
48-
)?;
46+
let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
47+
let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
48+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
4949

5050
// declare a new context. In spark API, this corresponds to a new spark SQLsession
5151
let ctx = SessionContext::new();
@@ -140,5 +140,11 @@ async fn main() -> Result<()> {
140140
// print the results
141141
df.show().await?;
142142

143+
// Given that `pow` is registered in the context, we can also use it in SQL:
144+
let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?;
145+
146+
// print the results
147+
sql_df.show().await?;
148+
143149
Ok(())
144150
}

datafusion/expr/src/expr.rs

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,13 +1724,13 @@ mod test {
17241724
use crate::expr::Cast;
17251725
use crate::expr_fn::col;
17261726
use crate::{
1727-
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
1728-
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
1729-
Volatility,
1727+
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ScalarFunctionDefinition,
1728+
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
17301729
};
17311730
use arrow::datatypes::DataType;
17321731
use datafusion_common::Column;
17331732
use datafusion_common::{Result, ScalarValue};
1733+
use std::any::Any;
17341734
use std::sync::Arc;
17351735

17361736
#[test]
@@ -1848,24 +1848,41 @@ mod test {
18481848
);
18491849

18501850
// UDF
1851-
let return_type: ReturnTypeFunction =
1852-
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
1853-
let fun: ScalarFunctionImplementation =
1854-
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
1855-
let udf = Arc::new(ScalarUDF::new(
1856-
"TestScalarUDF",
1857-
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1858-
&return_type,
1859-
&fun,
1860-
));
1851+
struct TestScalarUDF {
1852+
signature: Signature,
1853+
}
1854+
impl ScalarUDFImpl for TestScalarUDF {
1855+
fn as_any(&self) -> &dyn Any {
1856+
self
1857+
}
1858+
fn name(&self) -> &str {
1859+
"TestScalarUDF"
1860+
}
1861+
1862+
fn signature(&self) -> &Signature {
1863+
&self.signature
1864+
}
1865+
1866+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1867+
Ok(DataType::Utf8)
1868+
}
1869+
1870+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
1871+
Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1872+
}
1873+
}
1874+
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
1875+
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1876+
}));
18611877
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
18621878

1863-
let udf = Arc::new(ScalarUDF::new(
1864-
"TestScalarUDF",
1865-
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
1866-
&return_type,
1867-
&fun,
1868-
));
1879+
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
1880+
signature: Signature::uniform(
1881+
1,
1882+
vec![DataType::Float32],
1883+
Volatility::Volatile,
1884+
),
1885+
}));
18691886
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
18701887

18711888
// Unresolved function

0 commit comments

Comments
 (0)