17
17
18
18
use std:: sync:: Arc ;
19
19
20
+ use datafusion_ffi:: udf:: { FFI_ScalarUDF , ForeignScalarUDF } ;
21
+ use pyo3:: types:: PyCapsule ;
20
22
use pyo3:: { prelude:: * , types:: PyTuple } ;
21
23
22
24
use datafusion:: arrow:: array:: { make_array, Array , ArrayData , ArrayRef } ;
@@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF;
29
31
use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
30
32
31
33
use crate :: errors:: to_datafusion_err;
34
+ use crate :: errors:: { py_datafusion_err, PyDataFusionResult } ;
32
35
use crate :: expr:: PyExpr ;
33
- use crate :: utils:: parse_volatility;
36
+ use crate :: utils:: { parse_volatility, validate_pycapsule } ;
34
37
35
38
/// Create a Rust callable function from a python function that expects pyarrow arrays
36
39
fn pyarrow_function_to_rust (
@@ -105,6 +108,28 @@ impl PyScalarUDF {
105
108
Ok ( Self { function } )
106
109
}
107
110
111
+ #[ staticmethod]
112
+ pub fn from_pycapsule (
113
+ func : Bound < ' _ , PyAny > ,
114
+ ) -> PyDataFusionResult < Self > {
115
+ if func. hasattr ( "__datafusion_scalar_udf__" ) ? {
116
+ let capsule = func. getattr ( "__datafusion_scalar_udf__" ) ?. call0 ( ) ?;
117
+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
118
+ validate_pycapsule ( capsule, "datafusion_scalar_udf" ) ?;
119
+
120
+ let udf = unsafe { capsule. reference :: < FFI_ScalarUDF > ( ) } ;
121
+ let udf: ForeignScalarUDF = udf. try_into ( ) ?;
122
+
123
+ Ok ( Self { function : udf. into ( ) } )
124
+ } else {
125
+ Err ( crate :: errors:: PyDataFusionError :: Common (
126
+ "__datafusion_scalar_udf__ does not exist on ScalarUDF object."
127
+ . to_string ( ) ,
128
+ ) )
129
+ }
130
+ }
131
+
132
+
108
133
/// creates a new PyExpr with the call of the udf
109
134
#[ pyo3( signature = ( * args) ) ]
110
135
fn __call__ ( & self , args : Vec < PyExpr > ) -> PyResult < PyExpr > {
0 commit comments