Skip to content

Commit d974ee1

Browse files
committed
Factor out Substrait consumers into separate files
1 parent 280997d commit d974ee1

36 files changed

+4321
-3452
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

-3,452
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use super::{from_substrait_func_args, substrait_fun_name, SubstraitConsumer};
19+
use datafusion::common::{not_impl_datafusion_err, plan_err, DFSchema, ScalarValue};
20+
use datafusion::execution::FunctionRegistry;
21+
use datafusion::logical_expr::{expr, Expr, SortExpr};
22+
use std::sync::Arc;
23+
use substrait::proto::AggregateFunction;
24+
25+
/// Convert Substrait AggregateFunction to DataFusion Expr
26+
pub async fn from_substrait_agg_func(
27+
consumer: &impl SubstraitConsumer,
28+
f: &AggregateFunction,
29+
input_schema: &DFSchema,
30+
filter: Option<Box<Expr>>,
31+
order_by: Option<Vec<SortExpr>>,
32+
distinct: bool,
33+
) -> datafusion::common::Result<Arc<Expr>> {
34+
let Some(fn_signature) = consumer
35+
.get_extensions()
36+
.functions
37+
.get(&f.function_reference)
38+
else {
39+
return plan_err!(
40+
"Aggregate function not registered: function anchor = {:?}",
41+
f.function_reference
42+
);
43+
};
44+
45+
let fn_name = substrait_fun_name(fn_signature);
46+
let udaf = consumer.get_function_registry().udaf(fn_name);
47+
let udaf = udaf.map_err(|_| {
48+
not_impl_datafusion_err!(
49+
"Aggregate function {} is not supported: function anchor = {:?}",
50+
fn_signature,
51+
f.function_reference
52+
)
53+
})?;
54+
55+
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
56+
57+
// Datafusion does not support aggregate functions with no arguments, so
58+
// we inject a dummy argument that does not affect the query, but allows
59+
// us to bypass this limitation.
60+
let args = if udaf.name() == "count" && args.is_empty() {
61+
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
62+
} else {
63+
args
64+
};
65+
66+
Ok(Arc::new(Expr::AggregateFunction(
67+
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
68+
)))
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use super::grouping::from_substrait_grouping;
19+
use super::SubstraitConsumer;
20+
use super::{from_substrait_agg_func, from_substrait_sorts};
21+
use datafusion::common::not_impl_err;
22+
use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder};
23+
use substrait::proto::aggregate_function::AggregationInvocation;
24+
use substrait::proto::AggregateRel;
25+
26+
pub async fn from_aggregate_rel(
27+
consumer: &impl SubstraitConsumer,
28+
agg: &AggregateRel,
29+
) -> datafusion::common::Result<LogicalPlan> {
30+
if let Some(input) = agg.input.as_ref() {
31+
let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
32+
let mut ref_group_exprs = vec![];
33+
34+
for e in &agg.grouping_expressions {
35+
let x = consumer.consume_expression(e, input.schema()).await?;
36+
ref_group_exprs.push(x);
37+
}
38+
39+
let mut group_exprs = vec![];
40+
let mut aggr_exprs = vec![];
41+
42+
match agg.groupings.len() {
43+
1 => {
44+
group_exprs.extend_from_slice(
45+
&from_substrait_grouping(
46+
consumer,
47+
&agg.groupings[0],
48+
&ref_group_exprs,
49+
input.schema(),
50+
)
51+
.await?,
52+
);
53+
}
54+
_ => {
55+
let mut grouping_sets = vec![];
56+
for grouping in &agg.groupings {
57+
let grouping_set = from_substrait_grouping(
58+
consumer,
59+
grouping,
60+
&ref_group_exprs,
61+
input.schema(),
62+
)
63+
.await?;
64+
grouping_sets.push(grouping_set);
65+
}
66+
// Single-element grouping expression of type Expr::GroupingSet.
67+
// Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when
68+
// parsed by the producer and consumer, since Substrait does not have a type dedicated
69+
// to ROLLUP. Only vector of Groupings (grouping sets) is available.
70+
group_exprs
71+
.push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)));
72+
}
73+
};
74+
75+
for m in &agg.measures {
76+
let filter = match &m.filter {
77+
Some(fil) => Some(Box::new(
78+
consumer.consume_expression(fil, input.schema()).await?,
79+
)),
80+
None => None,
81+
};
82+
let agg_func = match &m.measure {
83+
Some(f) => {
84+
let distinct = match f.invocation {
85+
_ if f.invocation == AggregationInvocation::Distinct as i32 => {
86+
true
87+
}
88+
_ if f.invocation == AggregationInvocation::All as i32 => false,
89+
_ => false,
90+
};
91+
let order_by = if !f.sorts.is_empty() {
92+
Some(
93+
from_substrait_sorts(consumer, &f.sorts, input.schema())
94+
.await?,
95+
)
96+
} else {
97+
None
98+
};
99+
100+
from_substrait_agg_func(
101+
consumer,
102+
f,
103+
input.schema(),
104+
filter,
105+
order_by,
106+
distinct,
107+
)
108+
.await
109+
}
110+
None => {
111+
not_impl_err!("Aggregate without aggregate function is not supported")
112+
}
113+
};
114+
aggr_exprs.push(agg_func?.as_ref().clone());
115+
}
116+
input.aggregate(group_exprs, aggr_exprs)?.build()
117+
} else {
118+
not_impl_err!("Aggregate without an input is not valid")
119+
}
120+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::{plan_err, substrait_err, ScalarValue};
19+
use datafusion::logical_expr::WindowFrameBound;
20+
use substrait::proto::expression::{
21+
window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind,
22+
window_function::Bound,
23+
};
24+
25+
pub(super) fn from_substrait_bound(
26+
bound: &Option<Bound>,
27+
is_lower: bool,
28+
) -> datafusion::common::Result<WindowFrameBound> {
29+
match bound {
30+
Some(b) => match &b.kind {
31+
Some(k) => match k {
32+
BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => {
33+
Ok(WindowFrameBound::CurrentRow)
34+
}
35+
BoundKind::Preceding(SubstraitBound::Preceding { offset }) => {
36+
if *offset <= 0 {
37+
return plan_err!("Preceding bound must be positive");
38+
}
39+
Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
40+
*offset as u64,
41+
))))
42+
}
43+
BoundKind::Following(SubstraitBound::Following { offset }) => {
44+
if *offset <= 0 {
45+
return plan_err!("Following bound must be positive");
46+
}
47+
Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some(
48+
*offset as u64,
49+
))))
50+
}
51+
BoundKind::Unbounded(SubstraitBound::Unbounded {}) => {
52+
if is_lower {
53+
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
54+
} else {
55+
Ok(WindowFrameBound::Following(ScalarValue::Null))
56+
}
57+
}
58+
},
59+
None => substrait_err!("WindowFunction missing Substrait Bound kind"),
60+
},
61+
None => {
62+
if is_lower {
63+
Ok(WindowFrameBound::Preceding(ScalarValue::Null))
64+
} else {
65+
Ok(WindowFrameBound::Following(ScalarValue::Null))
66+
}
67+
}
68+
}
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use super::r#type::from_substrait_type_without_names;
19+
use super::SubstraitConsumer;
20+
use datafusion::common::{substrait_err, DFSchema};
21+
use datafusion::logical_expr::{Cast, Expr, TryCast};
22+
use substrait::proto::expression as substrait_expression;
23+
use substrait::proto::expression::cast::FailureBehavior::ReturnNull;
24+
25+
pub async fn from_cast(
26+
consumer: &impl SubstraitConsumer,
27+
cast: &substrait_expression::Cast,
28+
input_schema: &DFSchema,
29+
) -> datafusion::common::Result<Expr> {
30+
match cast.r#type.as_ref() {
31+
Some(output_type) => {
32+
let input_expr = Box::new(
33+
consumer
34+
.consume_expression(
35+
cast.input.as_ref().unwrap().as_ref(),
36+
input_schema,
37+
)
38+
.await?,
39+
);
40+
let data_type = from_substrait_type_without_names(consumer, output_type)?;
41+
if cast.failure_behavior() == ReturnNull {
42+
Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
43+
} else {
44+
Ok(Expr::Cast(Cast::new(input_expr, data_type)))
45+
}
46+
}
47+
None => substrait_err!("Cast expression without output type is not allowed"),
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use super::utils::requalify_sides_if_needed;
19+
use super::SubstraitConsumer;
20+
use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder};
21+
use substrait::proto::CrossRel;
22+
23+
pub async fn from_cross_rel(
24+
consumer: &impl SubstraitConsumer,
25+
cross: &CrossRel,
26+
) -> datafusion::common::Result<LogicalPlan> {
27+
let left = LogicalPlanBuilder::from(
28+
consumer.consume_rel(cross.left.as_ref().unwrap()).await?,
29+
);
30+
let right = LogicalPlanBuilder::from(
31+
consumer.consume_rel(cross.right.as_ref().unwrap()).await?,
32+
);
33+
let (left, right) = requalify_sides_if_needed(left, right)?;
34+
left.cross_join(right.build()?)?.build()
35+
}

0 commit comments

Comments
 (0)