diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c123ebb22ecb..e358ec1c142b 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -243,6 +243,7 @@ impl SessionState { feature = "unicode_expressions" ))] Arc::new(functions::planner::UserDefinedFunctionPlanner), + Arc::new(functions_aggregate::planner::AggregateUDFPlanner), ]; let mut new_self = SessionState { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index aeb8ed8372b7..ac5c5a172c91 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -24,6 +24,7 @@ use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; +use sqlparser::ast::NullTreatment; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; @@ -161,6 +162,28 @@ pub trait ExprPlanner: Send + Sync { ) -> Result>> { Ok(PlannerResult::Original(args)) } + + /// Plans a `RawAggregateUDF` based on the given input expressions. + /// + /// Returns a `PlannerResult` containing either the planned aggregate function or the original + /// input expressions if planning is not possible. + fn plan_aggregate_udf( + &self, + aggregate_function: RawAggregateUDF, + ) -> Result> { + Ok(PlannerResult::Original(aggregate_function)) + } +} + +// An `AggregateUDF` to be planned. +#[derive(Debug, Clone)] +pub struct RawAggregateUDF { + pub udf: Arc, + pub args: Vec, + pub distinct: bool, + pub filter: Option>, + pub order_by: Option>, + pub null_treatment: Option, } /// An operator with two arguments to plan diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 6ae2dfb3697c..3a1bd6bc70bc 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -62,6 +62,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod planner; pub mod regr; pub mod stddev; pub mod sum; diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs new file mode 100644 index 000000000000..56f12d08ad20 --- /dev/null +++ b/datafusion/functions-aggregate/src/planner.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawAggregateUDF}; + +pub struct AggregateUDFPlanner; + +impl ExprPlanner for AggregateUDFPlanner { + fn plan_aggregate_udf( + &self, + aggregate_function: RawAggregateUDF, + ) -> datafusion_common::Result> { + Ok(PlannerResult::Original(aggregate_function)) + } +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index d9ddf57eb192..67652e459906 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -21,6 +21,7 @@ use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; +use datafusion_expr::planner::{PlannerResult, RawAggregateUDF}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -335,7 +336,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function - if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + if let Some(udf) = self.context_provider.get_aggregate_meta(&name) { let order_by = self.order_by_to_sort_expr( &order_by, schema, @@ -349,13 +350,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? .map(Box::new); - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, + + let raw_aggregate_function = RawAggregateUDF { + udf, args, distinct, filter, order_by, null_treatment, + }; + + for planner in self.planners.iter() { + if let PlannerResult::Planned(aggregate_function) = + planner.plan_aggregate_udf(raw_aggregate_function.clone())? + { + return Ok(aggregate_function); + } + } + + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + raw_aggregate_function.udf, + raw_aggregate_function.args, + raw_aggregate_function.distinct, + raw_aggregate_function.filter, + raw_aggregate_function.order_by, + raw_aggregate_function.null_treatment, ))); }