Skip to content

Commit f6adc6a

Browse files
committed
Add example for [FunctionFactory]
1 parent ea01e56 commit f6adc6a

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, Int64Array, RecordBatch};
19+
use datafusion::error::Result;
20+
use datafusion::execution::config::SessionConfig;
21+
use datafusion::execution::context::{
22+
FunctionFactory, RegisterFunction, SessionContext, SessionState,
23+
};
24+
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
25+
use datafusion_common::tree_node::{Transformed, TreeNode};
26+
use datafusion_common::{exec_err, internal_err, DataFusionError};
27+
use datafusion_expr::simplify::ExprSimplifyResult;
28+
use datafusion_expr::simplify::SimplifyInfo;
29+
use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature};
30+
use std::result::Result as RResult;
31+
use std::sync::Arc;
32+
33+
/// This example shows how to utilize [FunctionFactory] to register
34+
/// `CREATE FUNCTION` handler. Apart from [FunctionFactory] this
35+
/// example covers [ScalarUDFImpl::simplify()] usage and synergy
36+
/// between those two functionality.
37+
///
38+
/// This example is rather simple, there are many edge cases to be covered
39+
///
40+
41+
#[tokio::main]
42+
async fn main() -> Result<()> {
43+
let runtime_config = RuntimeConfig::new();
44+
let runtime_environment = RuntimeEnv::new(runtime_config)?;
45+
46+
let session_config = SessionConfig::new();
47+
let state =
48+
SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment))
49+
// register custom function factory
50+
.with_function_factory(Arc::new(CustomFunctionFactory::default()));
51+
52+
let ctx = SessionContext::new_with_state(state);
53+
54+
let sql = r#"
55+
CREATE FUNCTION better_add(BIGINT, BIGINT)
56+
RETURNS BIGINT
57+
RETURN $1 + $2
58+
"#;
59+
60+
ctx.sql(sql).await?.show().await?;
61+
62+
let sql = r#"
63+
SELECT better_add(1, 2)
64+
"#;
65+
66+
ctx.sql(sql).await?.show().await?;
67+
68+
let a: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3, 4]));
69+
let b: ArrayRef = Arc::new(Int64Array::from(vec![10, 20, 30, 40]));
70+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
71+
72+
ctx.register_batch("t", batch)?;
73+
74+
ctx.sql(sql).await?.show().await?;
75+
76+
let sql = r#"
77+
SELECT better_add(a, b) from t
78+
"#;
79+
80+
ctx.sql(sql).await?.show().await?;
81+
82+
Ok(())
83+
}
84+
85+
#[derive(Debug, Default)]
86+
struct CustomFunctionFactory {}
87+
88+
#[async_trait::async_trait]
89+
impl FunctionFactory for CustomFunctionFactory {
90+
async fn create(
91+
&self,
92+
_state: &SessionConfig,
93+
statement: CreateFunction,
94+
) -> Result<RegisterFunction> {
95+
let f: ScalarFunctionWrapper = statement.try_into()?;
96+
97+
Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f))))
98+
}
99+
}
100+
// a wrapper type to be used to register
101+
// custom function to datafusion context
102+
//
103+
// it also defines custom [ScalarUDFImpl::simplify()]
104+
// to replace ScalarUDF expression with one instance contains.
105+
#[derive(Debug)]
106+
struct ScalarFunctionWrapper {
107+
name: String,
108+
expr: Expr,
109+
signature: Signature,
110+
return_type: arrow_schema::DataType,
111+
}
112+
113+
impl ScalarUDFImpl for ScalarFunctionWrapper {
114+
fn as_any(&self) -> &dyn std::any::Any {
115+
self
116+
}
117+
118+
fn name(&self) -> &str {
119+
&self.name
120+
}
121+
122+
fn signature(&self) -> &datafusion_expr::Signature {
123+
&self.signature
124+
}
125+
126+
fn return_type(
127+
&self,
128+
_arg_types: &[arrow_schema::DataType],
129+
) -> Result<arrow_schema::DataType> {
130+
Ok(self.return_type.clone())
131+
}
132+
133+
fn invoke(
134+
&self,
135+
_args: &[datafusion_expr::ColumnarValue],
136+
) -> Result<datafusion_expr::ColumnarValue> {
137+
internal_err!("This function should not get invoked!")
138+
}
139+
140+
fn simplify(
141+
&self,
142+
args: Vec<Expr>,
143+
_info: &dyn SimplifyInfo,
144+
) -> Result<ExprSimplifyResult> {
145+
let replacement = Self::replacement(&self.expr, &args)?;
146+
147+
Ok(ExprSimplifyResult::Simplified(replacement))
148+
}
149+
150+
fn aliases(&self) -> &[String] {
151+
&[]
152+
}
153+
154+
fn monotonicity(&self) -> Result<Option<datafusion_expr::FuncMonotonicity>> {
155+
Ok(None)
156+
}
157+
}
158+
159+
impl ScalarFunctionWrapper {
160+
// replaces placeholders with actual arguments
161+
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
162+
let result = expr.clone().transform(&|e| {
163+
let r = match e {
164+
Expr::Placeholder(placeholder) => {
165+
let placeholder_position =
166+
Self::parse_placeholder_identifier(&placeholder.id)?;
167+
if placeholder_position < args.len() {
168+
Transformed::yes(args[placeholder_position].clone())
169+
} else {
170+
exec_err!(
171+
"Function argument {} not provided, argument missing!",
172+
placeholder.id
173+
)?
174+
}
175+
}
176+
_ => Transformed::no(e),
177+
};
178+
179+
Ok(r)
180+
})?;
181+
182+
Ok(result.data)
183+
}
184+
// Finds placeholder identifier.
185+
// placeholders are in `$X` format where X >= 1
186+
fn parse_placeholder_identifier(placeholder: &str) -> Result<usize> {
187+
if let Some(value) = placeholder.strip_prefix('$') {
188+
Ok(value.parse().map(|v: usize| v - 1).map_err(|e| {
189+
DataFusionError::Execution(format!(
190+
"Placeholder `{}` parsing error: {}!",
191+
placeholder, e
192+
))
193+
})?)
194+
} else {
195+
exec_err!("Placeholder should start with `$`!")
196+
}
197+
}
198+
}
199+
200+
impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
201+
type Error = DataFusionError;
202+
203+
fn try_from(definition: CreateFunction) -> RResult<Self, Self::Error> {
204+
Ok(Self {
205+
name: definition.name,
206+
expr: definition
207+
.params
208+
.return_
209+
.expect("Expression has to be defined!"),
210+
return_type: definition
211+
.return_type
212+
.expect("Return type has to be defined!"),
213+
signature: Signature::exact(
214+
definition
215+
.args
216+
.unwrap_or_default()
217+
.into_iter()
218+
.map(|a| a.data_type)
219+
.collect(),
220+
definition
221+
.params
222+
.behavior
223+
.unwrap_or(datafusion_expr::Volatility::Volatile),
224+
),
225+
})
226+
}
227+
}

0 commit comments

Comments
 (0)