Skip to content

Commit 3eb198b

Browse files
Add PyExpr to_variant conversions (#793)
* make PyExpr::to_variant arms explicit * update PyInList to wrap expr::InList * update PyExists to wrap expr::Exists * update PyInSubquery to wrap expr::InSubquery * update Placeholder to wrap expr::Placeholder * make PyLogicalPlan::to_variant match arms explicit * add PySortExpr wrapper * add PyUnnestExpr wrapper * update PyAlias to wrap upstream Alias * return not implemented error for unimplemnted variants in PyExpr::to_variant * added to_variant python test from the GH issue * remove unused import * return unsupported_variants for unimplemented variants in PyLogicalPlan::to_variant
1 parent 9a6805e commit 3eb198b

File tree

10 files changed

+273
-71
lines changed

10 files changed

+273
-71
lines changed

python/datafusion/tests/test_expr.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,31 @@ def test_relational_expr(test_ctx):
139139
assert df.filter(col("b") != "beta").count() == 2
140140

141141
assert df.filter(col("a") == "beta").count() == 0
142+
143+
144+
def test_expr_to_variant():
145+
# Taken from https://github.com/apache/datafusion-python/issues/781
146+
from datafusion import SessionContext
147+
from datafusion.expr import Filter
148+
149+
150+
def traverse_logical_plan(plan):
151+
cur_node = plan.to_variant()
152+
if isinstance(cur_node, Filter):
153+
return cur_node.predicate().to_variant()
154+
if hasattr(plan, 'inputs'):
155+
for input_plan in plan.inputs():
156+
res = traverse_logical_plan(input_plan)
157+
if res is not None:
158+
return res
159+
160+
ctx = SessionContext()
161+
data = {'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']}
162+
ctx.from_pydict(data, name='table1')
163+
query = "SELECT * FROM table1 t1 WHERE t1.name IN ('dfa', 'ad', 'dfre', 'vsa')"
164+
logical_plan = ctx.sql(query).optimized_logical_plan()
165+
variant = traverse_logical_plan(logical_plan)
166+
assert variant is not None
167+
assert variant.expr().to_variant().qualified_name() == 'table1.name'
168+
assert str(variant.list()) == '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'
169+
assert not variant.negated()

src/expr.rs

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_expr::{
3333
};
3434

3535
use crate::common::data_type::{DataTypeMap, RexType};
36-
use crate::errors::{py_runtime_err, py_type_err, DataFusionError};
36+
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError};
3737
use crate::expr::aggregate_expr::PyAggregateFunction;
3838
use crate::expr::binary_expr::PyBinaryExpr;
3939
use crate::expr::column::PyColumn;
@@ -84,11 +84,13 @@ pub mod scalar_subquery;
8484
pub mod scalar_variable;
8585
pub mod signature;
8686
pub mod sort;
87+
pub mod sort_expr;
8788
pub mod subquery;
8889
pub mod subquery_alias;
8990
pub mod table_scan;
9091
pub mod union;
9192
pub mod unnest;
93+
pub mod unnest_expr;
9294
pub mod window;
9395

9496
/// A PyExpr that can be used on a DataFrame
@@ -119,8 +121,9 @@ pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
119121
impl PyExpr {
120122
/// Return the specific expression
121123
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
122-
Python::with_gil(|_| match &self.expr {
123-
Expr::Alias(alias) => Ok(PyAlias::new(&alias.expr, &alias.name).into_py(py)),
124+
Python::with_gil(|_| {
125+
match &self.expr {
126+
Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_py(py)),
124127
Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
125128
Expr::ScalarVariable(data_type, variables) => {
126129
Ok(PyScalarVariable::new(data_type, variables).into_py(py))
@@ -141,10 +144,44 @@ impl PyExpr {
141144
Expr::AggregateFunction(expr) => {
142145
Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
143146
}
144-
other => Err(py_runtime_err(format!(
145-
"Cannot convert this Expr to a Python object: {:?}",
146-
other
147+
Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_py(py)),
148+
Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_py(py)),
149+
Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_py(py)),
150+
Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_py(py)),
151+
Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_py(py)),
152+
Expr::Sort(value) => Ok(sort_expr::PySortExpr::from(value.clone()).into_py(py)),
153+
Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
154+
"Converting Expr::ScalarFunction to a Python object is not implemented: {:?}",
155+
value
147156
))),
157+
Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
158+
"Converting Expr::WindowFunction to a Python object is not implemented: {:?}",
159+
value
160+
))),
161+
Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_py(py)),
162+
Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_py(py)),
163+
Expr::InSubquery(value) => {
164+
Ok(in_subquery::PyInSubquery::from(value.clone()).into_py(py))
165+
}
166+
Expr::ScalarSubquery(value) => {
167+
Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_py(py))
168+
}
169+
Expr::Wildcard { qualifier } => Err(py_unsupported_variant_err(format!(
170+
"Converting Expr::Wildcard to a Python object is not implemented : {:?}",
171+
qualifier
172+
))),
173+
Expr::GroupingSet(value) => {
174+
Ok(grouping_set::PyGroupingSet::from(value.clone()).into_py(py))
175+
}
176+
Expr::Placeholder(value) => {
177+
Ok(placeholder::PyPlaceholder::from(value.clone()).into_py(py))
178+
}
179+
Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
180+
"Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}",
181+
data_type, column
182+
))),
183+
Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_py(py)),
184+
}
148185
})
149186
}
150187

@@ -599,13 +636,15 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
599636
m.add_class::<cross_join::PyCrossJoin>()?;
600637
m.add_class::<union::PyUnion>()?;
601638
m.add_class::<unnest::PyUnnest>()?;
639+
m.add_class::<unnest_expr::PyUnnestExpr>()?;
602640
m.add_class::<extension::PyExtension>()?;
603641
m.add_class::<filter::PyFilter>()?;
604642
m.add_class::<projection::PyProjection>()?;
605643
m.add_class::<table_scan::PyTableScan>()?;
606644
m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
607645
m.add_class::<create_view::PyCreateView>()?;
608646
m.add_class::<distinct::PyDistinct>()?;
647+
m.add_class::<sort_expr::PySortExpr>()?;
609648
m.add_class::<subquery_alias::PySubqueryAlias>()?;
610649
m.add_class::<drop_table::PyDropTable>()?;
611650
m.add_class::<repartition::PyPartitioning>()?;

src/expr/alias.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,24 @@ use crate::expr::PyExpr;
1919
use pyo3::prelude::*;
2020
use std::fmt::{self, Display, Formatter};
2121

22-
use datafusion_expr::Expr;
22+
use datafusion_expr::expr::Alias;
2323

2424
#[pyclass(name = "Alias", module = "datafusion.expr", subclass)]
2525
#[derive(Clone)]
2626
pub struct PyAlias {
27-
expr: PyExpr,
28-
alias_name: String,
27+
alias: Alias,
28+
}
29+
30+
impl From<Alias> for PyAlias {
31+
fn from(alias: Alias) -> Self {
32+
Self { alias }
33+
}
34+
}
35+
36+
impl From<PyAlias> for Alias {
37+
fn from(py_alias: PyAlias) -> Self {
38+
py_alias.alias
39+
}
2940
}
3041

3142
impl Display for PyAlias {
@@ -35,29 +46,20 @@ impl Display for PyAlias {
3546
"Alias
3647
\nExpr: `{:?}`
3748
\nAlias Name: `{}`",
38-
&self.expr, &self.alias_name
49+
&self.alias.expr, &self.alias.name
3950
)
4051
}
4152
}
4253

43-
impl PyAlias {
44-
pub fn new(expr: &Expr, alias_name: &String) -> Self {
45-
Self {
46-
expr: expr.clone().into(),
47-
alias_name: alias_name.to_owned(),
48-
}
49-
}
50-
}
51-
5254
#[pymethods]
5355
impl PyAlias {
5456
/// Retrieve the "name" of the alias
5557
fn alias(&self) -> PyResult<String> {
56-
Ok(self.alias_name.clone())
58+
Ok(self.alias.name.clone())
5759
}
5860

5961
fn expr(&self) -> PyResult<PyExpr> {
60-
Ok(self.expr.clone())
62+
Ok((*self.alias.expr.clone()).into())
6163
}
6264

6365
/// Get a String representation of this column

src/expr/exists.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,30 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion_expr::Subquery;
18+
use datafusion_expr::expr::Exists;
1919
use pyo3::prelude::*;
2020

2121
use super::subquery::PySubquery;
2222

2323
#[pyclass(name = "Exists", module = "datafusion.expr", subclass)]
2424
#[derive(Clone)]
2525
pub struct PyExists {
26-
subquery: Subquery,
27-
negated: bool,
26+
exists: Exists,
2827
}
2928

30-
impl PyExists {
31-
pub fn new(subquery: Subquery, negated: bool) -> Self {
32-
Self { subquery, negated }
29+
impl From<Exists> for PyExists {
30+
fn from(exists: Exists) -> Self {
31+
PyExists { exists }
3332
}
3433
}
3534

3635
#[pymethods]
3736
impl PyExists {
3837
fn subquery(&self) -> PySubquery {
39-
self.subquery.clone().into()
38+
self.exists.subquery.clone().into()
4039
}
4140

4241
fn negated(&self) -> bool {
43-
self.negated
42+
self.exists.negated
4443
}
4544
}

src/expr/in_list.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,32 @@
1616
// under the License.
1717

1818
use crate::expr::PyExpr;
19-
use datafusion_expr::Expr;
19+
use datafusion_expr::expr::InList;
2020
use pyo3::prelude::*;
2121

2222
#[pyclass(name = "InList", module = "datafusion.expr", subclass)]
2323
#[derive(Clone)]
2424
pub struct PyInList {
25-
expr: Box<Expr>,
26-
list: Vec<Expr>,
27-
negated: bool,
25+
in_list: InList,
2826
}
2927

30-
impl PyInList {
31-
pub fn new(expr: Box<Expr>, list: Vec<Expr>, negated: bool) -> Self {
32-
Self {
33-
expr,
34-
list,
35-
negated,
36-
}
28+
impl From<InList> for PyInList {
29+
fn from(in_list: InList) -> Self {
30+
PyInList { in_list }
3731
}
3832
}
3933

4034
#[pymethods]
4135
impl PyInList {
4236
fn expr(&self) -> PyExpr {
43-
(*self.expr).clone().into()
37+
(*self.in_list.expr).clone().into()
4438
}
4539

4640
fn list(&self) -> Vec<PyExpr> {
47-
self.list.iter().map(|e| e.clone().into()).collect()
41+
self.in_list.list.iter().map(|e| e.clone().into()).collect()
4842
}
4943

5044
fn negated(&self) -> bool {
51-
self.negated
45+
self.in_list.negated
5246
}
5347
}

src/expr/in_subquery.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,40 +15,34 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion_expr::{Expr, Subquery};
18+
use datafusion_expr::expr::InSubquery;
1919
use pyo3::prelude::*;
2020

2121
use super::{subquery::PySubquery, PyExpr};
2222

2323
#[pyclass(name = "InSubquery", module = "datafusion.expr", subclass)]
2424
#[derive(Clone)]
2525
pub struct PyInSubquery {
26-
expr: Box<Expr>,
27-
subquery: Subquery,
28-
negated: bool,
26+
in_subquery: InSubquery,
2927
}
3028

31-
impl PyInSubquery {
32-
pub fn new(expr: Box<Expr>, subquery: Subquery, negated: bool) -> Self {
33-
Self {
34-
expr,
35-
subquery,
36-
negated,
37-
}
29+
impl From<InSubquery> for PyInSubquery {
30+
fn from(in_subquery: InSubquery) -> Self {
31+
PyInSubquery { in_subquery }
3832
}
3933
}
4034

4135
#[pymethods]
4236
impl PyInSubquery {
4337
fn expr(&self) -> PyExpr {
44-
(*self.expr).clone().into()
38+
(*self.in_subquery.expr).clone().into()
4539
}
4640

4741
fn subquery(&self) -> PySubquery {
48-
self.subquery.clone().into()
42+
self.in_subquery.subquery.clone().into()
4943
}
5044

5145
fn negated(&self) -> bool {
52-
self.negated
46+
self.in_subquery.negated
5347
}
5448
}

src/expr/placeholder.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,33 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion::arrow::datatypes::DataType;
18+
use datafusion_expr::expr::Placeholder;
1919
use pyo3::prelude::*;
2020

2121
use crate::common::data_type::PyDataType;
2222

2323
#[pyclass(name = "Placeholder", module = "datafusion.expr", subclass)]
2424
#[derive(Clone)]
2525
pub struct PyPlaceholder {
26-
id: String,
27-
data_type: Option<DataType>,
26+
placeholder: Placeholder,
2827
}
2928

30-
impl PyPlaceholder {
31-
pub fn new(id: String, data_type: DataType) -> Self {
32-
Self {
33-
id,
34-
data_type: Some(data_type),
35-
}
29+
impl From<Placeholder> for PyPlaceholder {
30+
fn from(placeholder: Placeholder) -> Self {
31+
PyPlaceholder { placeholder }
3632
}
3733
}
3834

3935
#[pymethods]
4036
impl PyPlaceholder {
4137
fn id(&self) -> String {
42-
self.id.clone()
38+
self.placeholder.id.clone()
4339
}
4440

4541
fn data_type(&self) -> Option<PyDataType> {
46-
self.data_type.as_ref().map(|e| e.clone().into())
42+
self.placeholder
43+
.data_type
44+
.as_ref()
45+
.map(|e| e.clone().into())
4746
}
4847
}

0 commit comments

Comments
 (0)