Skip to content

Commit 824bb66

Browse files
waynexiaalamb
andauthored
feat: support UDAF in substrait producer/consumer (#8119)
* feat: support UDAF in substrait producer/consumer Signed-off-by: Ruihang Xia <[email protected]> * Update datafusion/substrait/src/logical_plan/consumer.rs Co-authored-by: Andrew Lamb <[email protected]> * remove redundent to_lowercase Signed-off-by: Ruihang Xia <[email protected]> --------- Signed-off-by: Ruihang Xia <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent f67c20f commit 824bb66

File tree

3 files changed

+125
-25
lines changed

3 files changed

+125
-25
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use async_recursion::async_recursion;
1919
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
2020
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};
2121

22+
use datafusion::execution::FunctionRegistry;
2223
use datafusion::logical_expr::{
2324
aggregate_function, window_function::find_df_window_func, BinaryExpr,
2425
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
@@ -365,6 +366,7 @@ pub async fn from_substrait_rel(
365366
_ => false,
366367
};
367368
from_substrait_agg_func(
369+
ctx,
368370
f,
369371
input.schema(),
370372
extensions,
@@ -660,6 +662,7 @@ pub async fn from_substriat_func_args(
660662

661663
/// Convert Substrait AggregateFunction to DataFusion Expr
662664
pub async fn from_substrait_agg_func(
665+
ctx: &SessionContext,
663666
f: &AggregateFunction,
664667
input_schema: &DFSchema,
665668
extensions: &HashMap<u32, &String>,
@@ -680,23 +683,37 @@ pub async fn from_substrait_agg_func(
680683
args.push(arg_expr?.as_ref().clone());
681684
}
682685

683-
let fun = match extensions.get(&f.function_reference) {
684-
Some(function_name) => {
685-
aggregate_function::AggregateFunction::from_str(function_name)
686-
}
687-
None => not_impl_err!(
688-
"Aggregated function not found: function anchor = {:?}",
686+
let Some(function_name) = extensions.get(&f.function_reference) else {
687+
return plan_err!(
688+
"Aggregate function not registered: function anchor = {:?}",
689689
f.function_reference
690-
),
690+
);
691691
};
692692

693-
Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
694-
fun: fun.unwrap(),
695-
args,
696-
distinct,
697-
filter,
698-
order_by,
699-
})))
693+
// try udaf first, then built-in aggr fn.
694+
if let Ok(fun) = ctx.udaf(function_name) {
695+
Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
696+
fun,
697+
args,
698+
filter,
699+
order_by,
700+
})))
701+
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
702+
{
703+
Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
704+
fun,
705+
args,
706+
distinct,
707+
filter,
708+
order_by,
709+
})))
710+
} else {
711+
not_impl_err!(
712+
"Aggregated function {} is not supported: function anchor = {:?}",
713+
function_name,
714+
f.function_reference
715+
)
716+
}
700717
}
701718

702719
/// Convert Substrait Rex to DataFusion Expr

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure(
588588
for arg in args {
589589
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
590590
}
591-
let function_name = fun.to_string().to_lowercase();
592-
let function_anchor = _register_function(function_name, extension_info);
591+
let function_anchor = _register_function(fun.to_string(), extension_info);
593592
Ok(Measure {
594593
measure: Some(AggregateFunction {
595594
function_reference: function_anchor,
@@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure(
610609
}
611610
})
612611
}
612+
Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
613+
let sorts = if let Some(order_by) = order_by {
614+
order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
615+
} else {
616+
vec![]
617+
};
618+
let mut arguments: Vec<FunctionArgument> = vec![];
619+
for arg in args {
620+
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
621+
}
622+
let function_anchor = _register_function(fun.name.clone(), extension_info);
623+
Ok(Measure {
624+
measure: Some(AggregateFunction {
625+
function_reference: function_anchor,
626+
arguments,
627+
sorts,
628+
output_type: None,
629+
invocation: AggregationInvocation::All as i32,
630+
phase: AggregationPhase::Unspecified as i32,
631+
args: vec![],
632+
options: vec![],
633+
}),
634+
filter: match filter {
635+
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
636+
None => None
637+
}
638+
})
639+
},
613640
Expr::Alias(Alias{expr,..})=> {
614641
to_substrait_agg_measure(expr, schema, extension_info)
615642
}
@@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func(
703730
HashMap<String, u32>,
704731
),
705732
) -> Expression {
706-
let function_name = operator_to_name(op).to_string().to_lowercase();
707-
let function_anchor = _register_function(function_name, extension_info);
733+
let function_anchor =
734+
_register_function(operator_to_name(op).to_string(), extension_info);
708735
Expression {
709736
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
710737
function_reference: function_anchor,
@@ -807,8 +834,7 @@ pub fn to_substrait_rex(
807834
)?)),
808835
});
809836
}
810-
let function_name = fun.to_string().to_lowercase();
811-
let function_anchor = _register_function(function_name, extension_info);
837+
let function_anchor = _register_function(fun.to_string(), extension_info);
812838
Ok(Expression {
813839
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
814840
function_reference: function_anchor,
@@ -973,8 +999,7 @@ pub fn to_substrait_rex(
973999
window_frame,
9741000
}) => {
9751001
// function reference
976-
let function_name = fun.to_string().to_lowercase();
977-
let function_anchor = _register_function(function_name, extension_info);
1002+
let function_anchor = _register_function(fun.to_string(), extension_info);
9781003
// arguments
9791004
let mut arguments: Vec<FunctionArgument> = vec![];
9801005
for arg in args {

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use datafusion::arrow::array::ArrayRef;
19+
use datafusion::physical_plan::Accumulator;
20+
use datafusion::scalar::ScalarValue;
1821
use datafusion_substrait::logical_plan::{
1922
consumer::from_substrait_plan, producer::to_substrait_plan,
2023
};
@@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result};
2831
use datafusion::execution::context::SessionState;
2932
use datafusion::execution::registry::SerializerRegistry;
3033
use datafusion::execution::runtime_env::RuntimeEnv;
31-
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
34+
use datafusion::logical_expr::{
35+
Extension, LogicalPlan, UserDefinedLogicalNode, Volatility,
36+
};
3237
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
3338
use datafusion::prelude::*;
3439

@@ -636,6 +641,56 @@ async fn extension_logical_plan() -> Result<()> {
636641
Ok(())
637642
}
638643

644+
#[tokio::test]
645+
async fn roundtrip_aggregate_udf() -> Result<()> {
646+
#[derive(Debug)]
647+
struct Dummy {}
648+
649+
impl Accumulator for Dummy {
650+
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
651+
Ok(vec![])
652+
}
653+
654+
fn update_batch(
655+
&mut self,
656+
_values: &[ArrayRef],
657+
) -> datafusion::error::Result<()> {
658+
Ok(())
659+
}
660+
661+
fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> {
662+
Ok(())
663+
}
664+
665+
fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
666+
Ok(ScalarValue::Float64(None))
667+
}
668+
669+
fn size(&self) -> usize {
670+
std::mem::size_of_val(self)
671+
}
672+
}
673+
674+
let dummy_agg = create_udaf(
675+
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
676+
"dummy_agg",
677+
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
678+
vec![DataType::Int64],
679+
// the return type; DataFusion expects this to match the type returned by `evaluate`.
680+
Arc::new(DataType::Int64),
681+
Volatility::Immutable,
682+
// This is the accumulator factory; DataFusion uses it to create new accumulators.
683+
Arc::new(|_| Ok(Box::new(Dummy {}))),
684+
// This is the description of the state. `state()` must match the types here.
685+
Arc::new(vec![DataType::Float64, DataType::UInt32]),
686+
);
687+
688+
let ctx = create_context().await?;
689+
ctx.register_udaf(dummy_agg);
690+
691+
roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
692+
}
693+
639694
fn check_post_join_filters(rel: &Rel) -> Result<()> {
640695
// search for target_rel and field value in proto
641696
match &rel.rel_type {
@@ -772,8 +827,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
772827
Ok(())
773828
}
774829

775-
async fn roundtrip(sql: &str) -> Result<()> {
776-
let ctx = create_context().await?;
830+
async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> {
777831
let df = ctx.sql(sql).await?;
778832
let plan = df.into_optimized_plan()?;
779833
let proto = to_substrait_plan(&plan, &ctx)?;
@@ -789,6 +843,10 @@ async fn roundtrip(sql: &str) -> Result<()> {
789843
Ok(())
790844
}
791845

846+
async fn roundtrip(sql: &str) -> Result<()> {
847+
roundtrip_with_ctx(sql, create_context().await?).await
848+
}
849+
792850
async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
793851
let ctx = create_context().await?;
794852
let df = ctx.sql(sql).await?;

0 commit comments

Comments
 (0)