Skip to content

Commit b9bf00e

Browse files
committed
Address PR comments (factory interface)
1 parent a650e16 commit b9bf00e

File tree

2 files changed

+72
-40
lines changed

2 files changed

+72
-40
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -798,28 +798,48 @@ impl SessionContext {
798798
}
799799

800800
async fn create_function(&self, stmt: CreateFunction) -> Result<DataFrame> {
801-
let function_factory = self.state.read().function_factory.clone();
801+
let function = {
802+
let state = self.state.read().clone();
803+
let function_factory = &state.function_factory;
804+
805+
match function_factory {
806+
Some(f) => f.create(state.config(), stmt).await?,
807+
_ => Err(DataFusionError::Configuration(
808+
"Function factory has not been configured".into(),
809+
))?,
810+
}
811+
};
802812

803-
match function_factory {
804-
Some(f) => f.create(self.state.clone(), stmt).await?,
805-
None => Err(DataFusionError::Configuration(
806-
"Function factory has not been configured".into(),
807-
))?,
813+
match function {
814+
RegisterFunction::Scalar(f) => {
815+
self.state.write().register_udf(f)?;
816+
}
817+
RegisterFunction::Aggregate(f) => {
818+
self.state.write().register_udaf(f)?;
819+
}
820+
RegisterFunction::Window(f) => {
821+
self.state.write().register_udwf(f)?;
822+
}
823+
RegisterFunction::Table(name, f) => self.register_udtf(&name, f),
808824
};
809825

810826
self.return_empty_dataframe()
811827
}
812828

813829
async fn drop_function(&self, stmt: DropFunction) -> Result<DataFrame> {
814-
let function_factory = self.state.read().function_factory.clone();
815-
816-
match function_factory {
817-
Some(f) => f.remove(self.state.clone(), stmt).await?,
818-
None => Err(DataFusionError::Configuration(
819-
"Function factory has not been configured".into(),
820-
))?,
830+
let _function = {
831+
let state = self.state.read().clone();
832+
let function_factory = &state.function_factory;
833+
834+
match function_factory {
835+
Some(f) => f.remove(state.config(), stmt).await?,
836+
None => Err(DataFusionError::Configuration(
837+
"Function factory has not been configured".into(),
838+
))?,
839+
}
821840
};
822841

842+
// TODO: Once we have unregister UDF we need to implement it here
823843
self.return_empty_dataframe()
824844
}
825845

@@ -1289,27 +1309,36 @@ impl QueryPlanner for DefaultQueryPlanner {
12891309
/// ```
12901310
#[async_trait]
12911311
pub trait FunctionFactory: Sync + Send {
1292-
// TODO: I don't like having RwLock Leaking here, who ever implements it
1293-
// has to depend ot `parking_lot`. I'f we expose &mut SessionState it
1294-
// may keep lock of too long.
12951312
//
1296-
// Not sure if there is better approach.
1313+
// This api holds a read lock for state
12971314
//
12981315

12991316
/// Handles creation of user defined function specified in [CreateFunction] statement
13001317
async fn create(
13011318
&self,
1302-
state: Arc<RwLock<SessionState>>,
1319+
state: &SessionConfig,
13031320
statement: CreateFunction,
1304-
) -> Result<()>;
1321+
) -> Result<RegisterFunction>;
13051322

13061323
/// Drops user defined function from [SessionState]
13071324
// Naming it `drop`` would make more sense but its already occupied in rust
13081325
async fn remove(
13091326
&self,
1310-
state: Arc<RwLock<SessionState>>,
1327+
state: &SessionConfig,
13111328
statement: DropFunction,
1312-
) -> Result<()>;
1329+
) -> Result<RegisterFunction>;
1330+
}
1331+
1332+
/// Type of function to create
1333+
pub enum RegisterFunction {
1334+
/// Scalar user defined function
1335+
Scalar(Arc<ScalarUDF>),
1336+
/// Aggregate user defined function
1337+
Aggregate(Arc<AggregateUDF>),
1338+
/// Window user defined function
1339+
Window(Arc<WindowUDF>),
1340+
/// Table user defined function
1341+
Table(String, Arc<dyn TableFunctionImpl>),
13131342
}
13141343
/// Execution context for registering data sources and executing queries.
13151344
/// See [`SessionContext`] for a higher level API.

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow_array::{
2121
};
2222
use arrow_schema::DataType::Float64;
2323
use arrow_schema::{DataType, Field, Schema};
24-
use datafusion::execution::context::{FunctionFactory, SessionState};
24+
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
2525
use datafusion::prelude::*;
2626
use datafusion::{execution::registry::FunctionRegistry, test_util};
2727
use datafusion_common::cast::as_float64_array;
@@ -34,7 +34,7 @@ use datafusion_expr::{
3434
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, DropFunction,
3535
ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
3636
};
37-
use parking_lot::{Mutex, RwLock};
37+
use parking_lot::Mutex;
3838
use rand::{thread_rng, Rng};
3939
use std::any::Any;
4040
use std::iter;
@@ -636,9 +636,9 @@ impl FunctionFactory for MockFunctionFactory {
636636
#[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)]
637637
async fn create(
638638
&self,
639-
state: Arc<RwLock<SessionState>>,
639+
_config: &SessionConfig,
640640
statement: CreateFunction,
641-
) -> datafusion::error::Result<()> {
641+
) -> datafusion::error::Result<RegisterFunction> {
642642
// this function is a mock for testing
643643
// `CreateFunction` should be used to derive this function
644644

@@ -675,22 +675,26 @@ impl FunctionFactory for MockFunctionFactory {
675675
// it has been parsed
676676
*self.captured_expr.lock() = statement.params.return_;
677677

678-
// we may need other infrastructure provided by state, for example:
679-
// state.config().get_extension()
680-
681-
// register mock udf for testing
682-
state.write().register_udf(mock_udf.into())?;
683-
Ok(())
678+
Ok(RegisterFunction::Scalar(Arc::new(mock_udf)))
684679
}
685680

686681
async fn remove(
687682
&self,
688-
_state: Arc<RwLock<SessionState>>,
683+
_config: &SessionConfig,
689684
_statement: DropFunction,
690-
) -> datafusion::error::Result<()> {
691-
// at the moment state does not support unregister
692-
// ignoring for now
693-
Ok(())
685+
) -> datafusion::error::Result<RegisterFunction> {
686+
687+
// TODO: I don't like that remove returns RegisterFunction
688+
// we have to keep two states in FunctionFactory iml and
689+
// SessionState
690+
//
691+
// It would be better to return (function_name, function type) tuple
692+
693+
// at the moment state does not support unregister user defined functions
694+
695+
Err(DataFusionError::NotImplemented(
696+
"remove function has not been implemented".into(),
697+
))
694698
}
695699
}
696700

@@ -722,15 +726,14 @@ async fn create_scalar_function_from_sql_statement() {
722726
.await
723727
.unwrap();
724728

725-
// sql expression should be convert to datafusion expression
726-
// in this case
729+
// check if we sql expr has been converted to datafusion expr
727730
let captured_expression = function_factory.captured_expr.lock().clone().unwrap();
728731

729732
// is there some better way to test this
730733
assert_eq!("$1 + $2", captured_expression.to_string());
731-
println!("{:?}", captured_expression);
732734

733-
ctx.sql("drop function better_add").await.unwrap();
735+
// no support at the moment
736+
// ctx.sql("drop function better_add").await.unwrap();
734737
}
735738

736739
fn create_udf_context() -> SessionContext {

0 commit comments

Comments
 (0)