Skip to content

Commit ae3bcda

Browse files
committed
Set up for FFI Scalar UDF handling
1 parent 86bda2a commit ae3bcda

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

python/datafusion/udf.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
25+
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
2626

2727
import pyarrow as pa
2828

@@ -77,6 +77,15 @@ def __str__(self) -> str:
7777
return self.name.lower()
7878

7979

80+
class ScalarUDFExportable(Protocol):
81+
"""Type hint for object that has __datafusion_table_provider__ PyCapsule.
82+
83+
https://datafusion.apache.org/python/user-guide/io/table_provider.html
84+
"""
85+
86+
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
87+
88+
8089
class ScalarUDF:
8190
"""Class for performing scalar user-defined functions (UDF).
8291
@@ -133,6 +142,10 @@ def udf(
133142
name: Optional[str] = None,
134143
) -> ScalarUDF: ...
135144

145+
@overload
146+
@staticmethod
147+
def udf(func: ScalarUDFExportable) -> ScalarUDF: ...
148+
136149
@staticmethod
137150
def udf(*args: Any, **kwargs: Any): # noqa: D417
138151
"""Create a new User-Defined Function (UDF).
@@ -147,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
147160
148161
Args:
149162
func (Callable, optional): **Only needed when calling as a function.**
150-
Skip this argument when using `udf` as a decorator.
163+
Skip this argument when using `udf` as a decorator. If you have a Rust
164+
backed ScalarUDF within a PyCapsule, you can pass this parameter
165+
and ignore the rest. They will be determined directly from the
166+
underlying function. See the online documentation for more information.
151167
input_types (list[pa.DataType]): The data types of the arguments
152168
to `func`. This list must be of the same length as the number of
153169
arguments.
@@ -219,21 +235,30 @@ def wrapper(*args: Any, **kwargs: Any):
219235
return decorator
220236

221237
if hasattr(args[0], "__datafusion_scalar_udf__"):
222-
name = str(args[0].__class__)
223-
return ScalarUDF(
224-
name=name,
225-
func=args[0],
226-
input_types=None,
227-
return_type=None,
228-
volatility=None,
229-
)
238+
return ScalarUDF.from_pycapsule(args[0])
230239

231240
if args and callable(args[0]):
232241
# Case 1: Used as a function, require the first parameter to be callable
233242
return _function(*args, **kwargs)
234243
# Case 2: Used as a decorator with parameters
235244
return _decorator(*args, **kwargs)
236245

246+
@staticmethod
247+
def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
248+
"""Create a Scalar UDF from ScalarUDF PyCapsule object.
249+
250+
This function will instantiate a Scalar UDF that uses a DataFusion
251+
ScalarUDF that is exported via the FFI bindings.
252+
"""
253+
name = str(udf.__class__)
254+
return ScalarUDF(
255+
name=name,
256+
func=func,
257+
input_types=None,
258+
return_type=None,
259+
volatility=None,
260+
)
261+
237262

238263
class Accumulator(metaclass=ABCMeta):
239264
"""Defines how an :py:class:`AggregateUDF` accumulates values."""

0 commit comments

Comments
 (0)