22
22
import functools
23
23
from abc import ABCMeta , abstractmethod
24
24
from enum import Enum
25
- from typing import TYPE_CHECKING , Any , Callable , Optional , TypeVar , overload
25
+ from typing import TYPE_CHECKING , Any , Callable , Optional , Protocol , TypeVar , overload
26
26
27
27
import pyarrow as pa
28
28
@@ -77,6 +77,15 @@ def __str__(self) -> str:
77
77
return self .name .lower ()
78
78
79
79
80
+ class ScalarUDFExportable (Protocol ):
81
+ """Type hint for object that has __datafusion_table_provider__ PyCapsule.
82
+
83
+ https://datafusion.apache.org/python/user-guide/io/table_provider.html
84
+ """
85
+
86
+ def __datafusion_scalar_udf__ (self ) -> object : ... # noqa: D105
87
+
88
+
80
89
class ScalarUDF :
81
90
"""Class for performing scalar user-defined functions (UDF).
82
91
@@ -133,6 +142,10 @@ def udf(
133
142
name : Optional [str ] = None ,
134
143
) -> ScalarUDF : ...
135
144
145
+ @overload
146
+ @staticmethod
147
+ def udf (func : ScalarUDFExportable ) -> ScalarUDF : ...
148
+
136
149
@staticmethod
137
150
def udf (* args : Any , ** kwargs : Any ): # noqa: D417
138
151
"""Create a new User-Defined Function (UDF).
@@ -147,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
147
160
148
161
Args:
149
162
func (Callable, optional): **Only needed when calling as a function.**
150
- Skip this argument when using `udf` as a decorator.
163
+ Skip this argument when using `udf` as a decorator. If you have a Rust
164
+ backed ScalarUDF within a PyCapsule, you can pass this parameter
165
+ and ignore the rest. They will be determined directly from the
166
+ underlying function. See the online documentation for more information.
151
167
input_types (list[pa.DataType]): The data types of the arguments
152
168
to `func`. This list must be of the same length as the number of
153
169
arguments.
@@ -219,21 +235,30 @@ def wrapper(*args: Any, **kwargs: Any):
219
235
return decorator
220
236
221
237
if hasattr (args [0 ], "__datafusion_scalar_udf__" ):
222
- name = str (args [0 ].__class__ )
223
- return ScalarUDF (
224
- name = name ,
225
- func = args [0 ],
226
- input_types = None ,
227
- return_type = None ,
228
- volatility = None ,
229
- )
238
+ return ScalarUDF .from_pycapsule (args [0 ])
230
239
231
240
if args and callable (args [0 ]):
232
241
# Case 1: Used as a function, require the first parameter to be callable
233
242
return _function (* args , ** kwargs )
234
243
# Case 2: Used as a decorator with parameters
235
244
return _decorator (* args , ** kwargs )
236
245
246
+ @staticmethod
247
+ def from_pycapsule (func : ScalarUDFExportable ) -> ScalarUDF :
248
+ """Create a Scalar UDF from ScalarUDF PyCapsule object.
249
+
250
+ This function will instantiate a Scalar UDF that uses a DataFusion
251
+ ScalarUDF that is exported via the FFI bindings.
252
+ """
253
+ name = str (udf .__class__ )
254
+ return ScalarUDF (
255
+ name = name ,
256
+ func = func ,
257
+ input_types = None ,
258
+ return_type = None ,
259
+ volatility = None ,
260
+ )
261
+
237
262
238
263
class Accumulator (metaclass = ABCMeta ):
239
264
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
0 commit comments