Skip to content

Commit 7c99379

Browse files
committed
Intermediate work adding ffi scalar udf
1 parent 10600fb commit 7c99379

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,9 @@ crate-type = ["cdylib", "rlib"]
6161
[profile.release]
6262
lto = true
6363
codegen-units = 1
64+
65+
[patch.crates-io]
66+
datafusion = { path = "../datafusion/datafusion/core" }
67+
datafusion-substrait = { path = "../datafusion/datafusion/substrait" }
68+
datafusion-proto = { path = "../datafusion/datafusion/proto" }
69+
datafusion-ffi = { path = "../datafusion/datafusion/ffi" }

src/udf.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
use std::sync::Arc;
1919

20+
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
21+
use pyo3::types::PyCapsule;
2022
use pyo3::{prelude::*, types::PyTuple};
2123

2224
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
@@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF;
2931
use datafusion::logical_expr::{create_udf, ColumnarValue};
3032

3133
use crate::errors::to_datafusion_err;
34+
use crate::errors::{py_datafusion_err, PyDataFusionResult};
3235
use crate::expr::PyExpr;
33-
use crate::utils::parse_volatility;
36+
use crate::utils::{parse_volatility, validate_pycapsule};
3437

3538
/// Create a Rust callable function from a python function that expects pyarrow arrays
3639
fn pyarrow_function_to_rust(
@@ -105,6 +108,28 @@ impl PyScalarUDF {
105108
Ok(Self { function })
106109
}
107110

111+
#[staticmethod]
112+
pub fn from_pycapsule(
113+
func: Bound<'_, PyAny>,
114+
) -> PyDataFusionResult<Self> {
115+
if func.hasattr("__datafusion_scalar_udf__")? {
116+
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
117+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
118+
validate_pycapsule(capsule, "datafusion_scalar_udf")?;
119+
120+
let udf = unsafe { capsule.reference::<FFI_ScalarUDF>() };
121+
let udf: ForeignScalarUDF = udf.try_into()?;
122+
123+
Ok(Self { function: udf.into() })
124+
} else {
125+
Err(crate::errors::PyDataFusionError::Common(
126+
"__datafusion_scalar_udf__ does not exist on ScalarUDF object."
127+
.to_string(),
128+
))
129+
}
130+
}
131+
132+
108133
/// creates a new PyExpr with the call of the udf
109134
#[pyo3(signature = (*args))]
110135
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {

0 commit comments

Comments
 (0)