diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index fd65c3352bbc..9cda726db719 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(AggregateUDF::from(GeoMeanUdaf::new())), - aggregate_function.args, - aggregate_function.distinct, - aggregate_function.filter, - aggregate_function.order_by, - aggregate_function.null_treatment, + aggregate_function.params.args, + aggregate_function.params.distinct, + aggregate_function.params.filter, + aggregate_function.params.order_by, + aggregate_function.params.null_treatment, ))) }; Some(Box::new(simplify)) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2303574e88af..bce1aab16e5e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -70,7 +70,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, + WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -1579,11 +1580,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( match e { Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let name = if let Some(name) = name { name diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index a521902389dc..a17bb5eec8a3 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; -use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; use datafusion_expr::{Aggregate, AggregateUDF, Expr}; use datafusion_functions_aggregate::count::Count; @@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> { vec![], vec![Expr::AggregateFunction(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), - args: vec![input_col_ref], - distinct: false, - filter: None, - order_by: None, - null_treatment: None, + params: AggregateFunctionParams { + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, })], )?); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 305519a1f4b4..84ff36a9317d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort { pub struct AggregateFunction { /// Name of the function pub func: Arc, - /// List of expressions to feed to the functions as arguments + pub params: AggregateFunctionParams, +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct AggregateFunctionParams { pub args: Vec, /// Whether this is a DISTINCT aggregation or not pub distinct: bool, @@ -719,11 +723,13 @@ impl AggregateFunction { ) -> Self { Self { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, } } } @@ -1864,19 +1870,25 @@ impl NormalizeEq for Expr { ( Expr::AggregateFunction(AggregateFunction { func: self_func, - args: self_args, - distinct: self_distinct, - filter: self_filter, - order_by: self_order_by, - null_treatment: self_null_treatment, + params: + AggregateFunctionParams { + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, }), Expr::AggregateFunction(AggregateFunction { func: other_func, - args: other_args, - distinct: other_distinct, - filter: other_filter, - order_by: other_order_by, - null_treatment: other_null_treatment, + params: + AggregateFunctionParams { + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, }), ) => { self_func.name() == other_func.name() @@ -2154,11 +2166,14 @@ impl HashNode for Expr { } Expr::AggregateFunction(AggregateFunction { func, - args: _args, - distinct, - filter: _filter, - order_by: _order_by, - null_treatment, + params: + AggregateFunctionParams { + args: _args, + distinct, + filter: _, + order_by: _, + null_treatment, + }, }) => { func.hash(state); distinct.hash(state); @@ -2264,35 +2279,15 @@ impl Display for SchemaDisplay<'_> { | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { - write!( - f, - "{}({}{})", - func.name(), - if *distinct { "DISTINCT " } else { "" }, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {}", e) + } } - - if let Some(filter) = filter { - write!(f, " FILTER (WHERE {filter})")?; - }; - - if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; - - Ok(()) } // Expr is not shown since it is aliased Expr::Alias(Alias { @@ -2653,26 +2648,15 @@ impl Display for Expr { )?; Ok(()) } - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - ref args, - filter, - order_by, - null_treatment, - .. - }) => { - fmt_function(f, func.name(), *distinct, args, true)?; - if let Some(nt) = null_treatment { - write!(f, " {}", nt)?; - } - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.display_name(params) { + Ok(name) => { + write!(f, "{}", name) + } + Err(e) => { + write!(f, "got error from display_name {}", e) + } } - Ok(()) } Expr::Between(Between { expr, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a2de5e7b259f..91d2b379af60 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -830,10 +830,10 @@ impl ExprFuncBuilder { let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; + udaf.params.order_by = order_by; + udaf.params.filter = filter.map(Box::new); + udaf.params.distinct = distinct; + udaf.params.null_treatment = null_treatment; Expr::AggregateFunction(udaf) } ExprFuncKind::Window(mut udwf) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 49791427131f..becb7c14397d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, + InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, @@ -153,7 +153,10 @@ impl ExprSchemable for Expr { Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed046..7801d564135e 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -18,8 +18,9 @@ //! Tree node implementation for Logical Expressions use crate::expr::{ - AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, + WindowFunction, }; use crate::{Expr, ExprFunctionExt}; @@ -87,7 +88,7 @@ impl TreeNode for Expr { }) => (expr, low, high).apply_ref_elements(f), Expr::Case(Case { expr, when_then_expr, else_expr }) => (expr, when_then_expr, else_expr).apply_ref_elements(f), - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { args, @@ -241,12 +242,15 @@ impl TreeNode for Expr { }, ), Expr::AggregateFunction(AggregateFunction { - args, func, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7ffc6623ea92..bf8f34f949e0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::fmt::{self, Debug, Formatter}; +use std::fmt::{self, Debug, Formatter, Write}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; @@ -29,7 +29,10 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expr::AggregateFunction; +use crate::expr::{ + schema_name_from_exprs_comma_separated_without_space, schema_name_from_sorts, + AggregateFunction, AggregateFunctionParams, +}; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; @@ -165,6 +168,16 @@ impl AggregateUDF { self.inner.name() } + /// See [`AggregateUDFImpl::schema_name`] for more details. + pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.schema_name(params) + } + + /// See [`AggregateUDFImpl::display_name`] for more details. + pub fn display_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.display_name(params) + } + pub fn is_nullable(&self) -> bool { self.inner.is_nullable() } @@ -382,6 +395,93 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Returns this function's name fn name(&self) -> &str; + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + /// + /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..] + fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + schema_name_from_exprs_comma_separated_without_space(args)? + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {}", null_treatment))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + }; + + if let Some(order_by) = order_by { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + }; + + Ok(schema_name) + } + + /// Returns the user-defined display name of function, given the arguments + /// + /// This can be used to customize the output column name generated by this + /// function. + /// + /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]` + fn display_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + args.iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", ") + ))?; + + if let Some(nt) = null_treatment { + schema_name.write_fmt(format_args!(" {}", nt))?; + } + if let Some(fe) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; + } + if let Some(ob) = order_by { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + ob.iter() + .map(|o| format!("{o}")) + .collect::>() + .join(", ") + ))?; + } + + Ok(schema_name) + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 5ca51ac20f1e..d55176a42c9a 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,8 +17,11 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] +use std::sync::Arc; + use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, @@ -150,22 +153,26 @@ impl ExprPlanner for FieldAccessPlanner { GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Planned(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new_udf( - nth_value_udaf(), - agg_func - .args - .into_iter() - .chain(std::iter::once(*index)) - .collect(), - agg_func.distinct, - agg_func.filter, - agg_func.order_by, - agg_func.null_treatment, - ), - ))) - } + Expr::AggregateFunction(AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, + }) if is_array_agg(&func) => Ok(PlannerResult::Planned( + Expr::AggregateFunction(AggregateFunction::new_udf( + nth_value_udaf(), + args.into_iter().chain(std::iter::once(*index)).collect(), + distinct, + filter, + order_by, + null_treatment, + )), + )), _ => Ok(PlannerResult::Planned(array_element(expr, *index))), } } @@ -184,6 +191,6 @@ impl ExprPlanner for FieldAccessPlanner { } } -fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func.name() == "array_agg" +fn is_array_agg(func: &Arc) -> bool { + func.name() == "array_agg" } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 95b6f9dc764f..7e73474cf6f5 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -21,7 +21,7 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, WindowFunction}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; @@ -55,8 +55,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { matches!(aggregate_function, AggregateFunction { func, - args, - .. + params: AggregateFunctionParams { args, .. }, } if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } @@ -81,7 +80,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Expr::AggregateFunction(mut aggregate_function) if is_count_star_aggregate(&aggregate_function) => { - aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + aggregate_function.params.args = vec![lit(COUNT_STAR_EXPANSION)]; Ok(Transformed::yes(Expr::AggregateFunction( aggregate_function, ))) diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 16ebb8cd3972..f8a818563609 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -163,6 +163,7 @@ fn validate_args( group_by_expr: &HashMap<&Expr, usize>, ) -> Result<()> { let expr_not_in_group_by = function + .params .args .iter() .find(|expr| !group_by_expr.contains_key(expr)); @@ -183,7 +184,7 @@ fn grouping_function_on_id( is_grouping_set: bool, ) -> Result { validate_args(function, group_by_expr)?; - let args = &function.args; + let args = &function.params.args; // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 85fc9b31bcdd..fd20c9fa5409 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,8 +33,8 @@ use datafusion_common::{ DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, Sort, WindowFunction, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -506,11 +506,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::AggregateFunction(expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c8f3a4bc7859..191377fc2759 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -26,6 +26,7 @@ use datafusion_common::{ internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, }; use datafusion_expr::builder::project; +use datafusion_expr::expr::AggregateFunctionParams; use datafusion_expr::{ col, expr::AggregateFunction, @@ -68,11 +69,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment: _, + params: + AggregateFunctionParams { + distinct, + args, + filter, + order_by, + null_treatment: _, + }, }) = expr { if filter.is_some() || order_by.is_some() { @@ -179,9 +183,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { func, - mut args, - distinct, - .. + params: AggregateFunctionParams { mut args, distinct, .. } }) => { if distinct { if args.len() != 1 { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6d1d4f30610c..228437271694 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,8 +22,8 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, - ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, + Like, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::WriteOp; use datafusion_expr::{ @@ -348,11 +348,14 @@ pub fn serialize_expr( } Expr::AggregateFunction(expr::AggregateFunction { ref func, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, + params: + AggregateFunctionParams { + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 72618c2b6ab4..6491671d84a5 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::Unnest; +use datafusion_expr::expr::{AggregateFunctionParams, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, @@ -284,9 +284,15 @@ impl Unparser<'_> { }), Expr::AggregateFunction(agg) => { let func_name = agg.func.name(); + let AggregateFunctionParams { + distinct, + args, + filter, + .. + } = &agg.params; - let args = self.function_args_to_sql(&agg.args)?; - let filter = match &agg.filter { + let args = self.function_args_to_sql(args)?; + let filter = match filter { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; @@ -297,8 +303,7 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: agg - .distinct + duplicate_treatment: distinct .then_some(ast::DuplicateTreatment::Distinct), args, clauses: vec![], diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 42c226174932..d795a869568b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -52,7 +52,8 @@ use datafusion::common::{ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, + AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -1208,11 +1209,14 @@ pub fn from_aggregate_function( ) -> Result { let expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, } = agg_fn; let sorts = if let Some(order_by) = order_by { order_by