Skip to content

Commit 25206aa

Browse files
committed
Address PR comments (factory interface)
1 parent a650e16 commit 25206aa

File tree

2 files changed

+73
-46
lines changed

2 files changed

+73
-46
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -798,28 +798,48 @@ impl SessionContext {
798798
}
799799

800800
async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
801-
let function_factory = self.state.read().function_factory.clone();
801+
let function = {
802+
let state = self.state.read().clone();
803+
let function_factory = &state.function_factory;
804+
805+
match function_factory {
806+
Some(f) => f.create(state.config(), stmt).await?,
807+
_ => Err(DataFusionError::Configuration(
808+
"Function factory has not been configured".into(),
809+
))?,
810+
}
811+
};
802812

803-
match function_factory {
804-
Some(f) => f.create(self.state.clone(), stmt).await?,
805-
None => Err(DataFusionError::Configuration(
806-
"Function factory has not been configured".into(),
807-
))?,
813+
match function {
814+
RegisterFunction::Scalar(f) => {
815+
self.state.write().register_udf(f)?;
816+
}
817+
RegisterFunction::Aggregate(f) => {
818+
self.state.write().register_udaf(f)?;
819+
}
820+
RegisterFunction::Window(f) => {
821+
self.state.write().register_udwf(f)?;
822+
}
823+
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
808824
};
809825

810826
self.return_empty_dataframe()
811827
}
812828

813829
async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
814-
let function_factory = self.state.read().function_factory.clone();
815-
816-
match function_factory {
817-
Some(f) => f.remove(self.state.clone(), stmt).await?,
818-
None => Err(DataFusionError::Configuration(
819-
"Function factory has not been configured".into(),
820-
))?,
830+
let _function = {
831+
let state = self.state.read().clone();
832+
let function_factory = &state.function_factory;
833+
834+
match function_factory {
835+
Some(f) => f.remove(state.config(), stmt).await?,
836+
None => Err(DataFusionError::Configuration(
837+
"Function factory has not been configured".into(),
838+
))?,
839+
}
821840
};
822841

842+
// TODO: Once we have unregister UDF we need to implement it here
823843
self.return_empty_dataframe()
824844
}
825845

@@ -1289,27 +1309,32 @@ impl QueryPlanner for DefaultQueryPlanner {
12891309
/// ```
12901310
#[async_trait]
12911311
pub trait FunctionFactory: Sync + Send {
1292-
// TODO: I don't like having RwLock Leaking here, who ever implements it
1293-
// has to depend ot `parking_lot`. I'f we expose &mut SessionState it
1294-
// may keep lock of too long.
1295-
//
1296-
// Not sure if there is better approach.
1297-
//
1298-
12991312
/// Handles creation of user defined function specified in [CreateFunction] statement
13001313
async fn create(
13011314
&self,
1302-
state: Arc<RwLock<SessionState>>,
1315+
state: &SessionConfig,
13031316
statement: CreateFunction,
1304-
) -> Result<()>;
1317+
) -> Result<RegisterFunction>;
13051318

13061319
/// Drops user defined function from [SessionState]
1307-
// Naming it `drop`` would make more sense but its already occupied in rust
1320+
// Naming it `drop` would make more sense but its already occupied in rust
13081321
async fn remove(
13091322
&self,
1310-
state: Arc<RwLock<SessionState>>,
1323+
state: &SessionConfig,
13111324
statement: DropFunction,
1312-
) -> Result<()>;
1325+
) -> Result<RegisterFunction>;
1326+
}
1327+
1328+
/// Type of function to create
1329+
pub enum RegisterFunction {
1330+
/// Scalar user defined function
1331+
Scalar(Arc<ScalarUDF>),
1332+
/// Aggregate user defined function
1333+
Aggregate(Arc<AggregateUDF>),
1334+
/// Window user defined function
1335+
Window(Arc<WindowUDF>),
1336+
/// Table user defined function
1337+
Table(String, Arc<dyn TableFunctionImpl>),
13131338
}
13141339
/// Execution context for registering data sources and executing queries.
13151340
/// See [`SessionContext`] for a higher level API.
@@ -1628,9 +1653,9 @@ impl SessionState {
16281653
/// [`FunctionFactory`] trait.
16291654
pub fn with_function_factory(
16301655
mut self,
1631-
create_function_hook: Arc<dyn FunctionFactory>,
1656+
function_factory: Arc<dyn FunctionFactory>,
16321657
) -> Self {
1633-
self.function_factory = Some(create_function_hook);
1658+
self.function_factory = Some(function_factory);
16341659
self
16351660
}
16361661

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow_array::{
2121
};
2222
use arrow_schema::DataType::Float64;
2323
use arrow_schema::{DataType, Field, Schema};
24-
use datafusion::execution::context::{FunctionFactory, SessionState};
24+
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2525
use datafusion::prelude::*;
2626
use datafusion::{execution::registry::FunctionRegistry, test_util};
2727
use datafusion_common::cast::as_float64_array;
@@ -34,7 +34,7 @@ use datafusion_expr::{
3434
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, DropFunction,
3535
ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
3636
};
37-
use parking_lot::{Mutex, RwLock};
37+
use parking_lot::Mutex;
3838
use rand::{thread_rng, Rng};
3939
use std::any::Any;
4040
use std::iter;
@@ -636,9 +636,9 @@ impl FunctionFactory for MockFunctionFactory {
636636
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
637637
async fn create(
638638
&self,
639-
state: Arc<RwLock<SessionState>>,
639+
_config: &SessionConfig,
640640
statement: CreateFunction,
641-
) -> datafusion::error::Result<()> {
641+
) -> datafusion::error::Result<RegisterFunction> {
642642
// this function is a mock for testing
643643
// `CreateFunction` should be used to derive this function
644644

@@ -675,22 +675,25 @@ impl FunctionFactory for MockFunctionFactory {
675675
// it has been parsed
676676
*self.captured_expr.lock() = statement.params.return_;
677677

678-
// we may need other infrastructure provided by state, for example:
679-
// state.config().get_extension()
680-
681-
// register mock udf for testing
682-
state.write().register_udf(mock_udf.into())?;
683-
Ok(())
678+
Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
684679
}
685680

686681
async fn remove(
687682
&self,
688-
_state: Arc<RwLock<SessionState>>,
683+
_config: &SessionConfig,
689684
_statement: DropFunction,
690-
) -> datafusion::error::Result<()> {
691-
// at the moment state does not support unregister
692-
// ignoring for now
693-
Ok(())
685+
) -> datafusion::error::Result<RegisterFunction> {
686+
// TODO: I don't like that remove returns RegisterFunction
687+
// we have to keep two states in FunctionFactory iml and
688+
// SessionState
689+
//
690+
// It would be better to return (function_name, function type) tuple
691+
692+
// at the moment state does not support unregister user defined functions
693+
694+
Err(DataFusionError::NotImplemented(
695+
"remove function has not been implemented".into(),
696+
))
694697
}
695698
}
696699

@@ -722,15 +725,14 @@ async fn create_scalar_function_from_sql_statement() {
722725
.await
723726
.unwrap();
724727

725-
// sql expression should be convert to datafusion expression
726-
// in this case
728+
// check if we sql expr has been converted to datafusion expr
727729
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();
728730

729731
// is there some better way to test this
730732
assert_eq!("$1 + $2", captured_expression.to_string());
731-
println!("{:?}", captured_expression);
732733

733-
ctx.sql("drop function better_add").await.unwrap();
734+
// no support at the moment
735+
// ctx.sql("drop function better_add").await.unwrap();
734736
}
735737

736738
fn create_udf_context() -> SessionContext {

0 commit comments

Comments
 (0)