@@ -798,28 +798,48 @@ impl SessionContext {
798
798
}
799
799
800
800
async fn create_function ( & self , stmt : CreateFunction ) -> Result < DataFrame > {
801
- let function_factory = self . state . read ( ) . function_factory . clone ( ) ;
801
+ let function = {
802
+ let state = self . state . read ( ) . clone ( ) ;
803
+ let function_factory = & state. function_factory ;
804
+
805
+ match function_factory {
806
+ Some ( f) => f. create ( state. config ( ) , stmt) . await ?,
807
+ _ => Err ( DataFusionError :: Configuration (
808
+ "Function factory has not been configured" . into ( ) ,
809
+ ) ) ?,
810
+ }
811
+ } ;
802
812
803
- match function_factory {
804
- Some ( f) => f. create ( self . state . clone ( ) , stmt) . await ?,
805
- None => Err ( DataFusionError :: Configuration (
806
- "Function factory has not been configured" . into ( ) ,
807
- ) ) ?,
813
+ match function {
814
+ RegisterFunction :: Scalar ( f) => {
815
+ self . state . write ( ) . register_udf ( f) ?;
816
+ }
817
+ RegisterFunction :: Aggregate ( f) => {
818
+ self . state . write ( ) . register_udaf ( f) ?;
819
+ }
820
+ RegisterFunction :: Window ( f) => {
821
+ self . state . write ( ) . register_udwf ( f) ?;
822
+ }
823
+ RegisterFunction :: Table ( name, f) => self . register_udtf ( & name, f) ,
808
824
} ;
809
825
810
826
self . return_empty_dataframe ( )
811
827
}
812
828
813
829
async fn drop_function ( & self , stmt : DropFunction ) -> Result < DataFrame > {
814
- let function_factory = self . state . read ( ) . function_factory . clone ( ) ;
815
-
816
- match function_factory {
817
- Some ( f) => f. remove ( self . state . clone ( ) , stmt) . await ?,
818
- None => Err ( DataFusionError :: Configuration (
819
- "Function factory has not been configured" . into ( ) ,
820
- ) ) ?,
830
+ let _function = {
831
+ let state = self . state . read ( ) . clone ( ) ;
832
+ let function_factory = & state. function_factory ;
833
+
834
+ match function_factory {
835
+ Some ( f) => f. remove ( state. config ( ) , stmt) . await ?,
836
+ None => Err ( DataFusionError :: Configuration (
837
+ "Function factory has not been configured" . into ( ) ,
838
+ ) ) ?,
839
+ }
821
840
} ;
822
841
842
+ // TODO: Once we have unregister UDF we need to implement it here
823
843
self . return_empty_dataframe ( )
824
844
}
825
845
@@ -1289,27 +1309,32 @@ impl QueryPlanner for DefaultQueryPlanner {
1289
1309
/// ```
1290
1310
#[ async_trait]
1291
1311
pub trait FunctionFactory : Sync + Send {
1292
- // TODO: I don't like having RwLock Leaking here, who ever implements it
1293
- // has to depend ot `parking_lot`. I'f we expose &mut SessionState it
1294
- // may keep lock of too long.
1295
- //
1296
- // Not sure if there is better approach.
1297
- //
1298
-
1299
1312
/// Handles creation of user defined function specified in [CreateFunction] statement
1300
1313
async fn create (
1301
1314
& self ,
1302
- state : Arc < RwLock < SessionState > > ,
1315
+ state : & SessionConfig ,
1303
1316
statement : CreateFunction ,
1304
- ) -> Result < ( ) > ;
1317
+ ) -> Result < RegisterFunction > ;
1305
1318
1306
1319
/// Drops user defined function from [SessionState]
1307
- // Naming it `drop`` would make more sense but its already occupied in rust
1320
+ // Naming it `drop` would make more sense but its already occupied in rust
1308
1321
async fn remove (
1309
1322
& self ,
1310
- state : Arc < RwLock < SessionState > > ,
1323
+ state : & SessionConfig ,
1311
1324
statement : DropFunction ,
1312
- ) -> Result < ( ) > ;
1325
+ ) -> Result < RegisterFunction > ;
1326
+ }
1327
+
1328
+ /// Type of function to create
1329
+ pub enum RegisterFunction {
1330
+ /// Scalar user defined function
1331
+ Scalar ( Arc < ScalarUDF > ) ,
1332
+ /// Aggregate user defined function
1333
+ Aggregate ( Arc < AggregateUDF > ) ,
1334
+ /// Window user defined function
1335
+ Window ( Arc < WindowUDF > ) ,
1336
+ /// Table user defined function
1337
+ Table ( String , Arc < dyn TableFunctionImpl > ) ,
1313
1338
}
1314
1339
/// Execution context for registering data sources and executing queries.
1315
1340
/// See [`SessionContext`] for a higher level API.
@@ -1628,9 +1653,9 @@ impl SessionState {
1628
1653
/// [`FunctionFactory`] trait.
1629
1654
pub fn with_function_factory (
1630
1655
mut self ,
1631
- create_function_hook : Arc < dyn FunctionFactory > ,
1656
+ function_factory : Arc < dyn FunctionFactory > ,
1632
1657
) -> Self {
1633
- self . function_factory = Some ( create_function_hook ) ;
1658
+ self . function_factory = Some ( function_factory ) ;
1634
1659
self
1635
1660
}
1636
1661
0 commit comments