Skip to content

Commit 0e50d42

Browse files
committed
add simplify method for aggregate function
1 parent 97148bd commit 0e50d42

File tree

3 files changed

+364
-1
lines changed

3 files changed

+364
-1
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_schema::{Field, Schema};
19+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20+
21+
use std::{any::Any, sync::Arc};
22+
23+
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
24+
use datafusion::error::Result;
25+
use datafusion::{assert_batches_eq, prelude::*};
26+
use datafusion_common::cast::as_float64_array;
27+
use datafusion_expr::{
28+
expr::{AggregateFunction, AggregateFunctionDefinition},
29+
function::AccumulatorArgs,
30+
simplify::ExprSimplifyResult,
31+
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
32+
};
33+
34+
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
35+
/// defined aggregate function with a different expression which is defined in the `simplify` method.
36+
37+
#[derive(Debug, Clone)]
38+
struct BetterAvgUdaf {
39+
signature: Signature,
40+
}
41+
42+
impl BetterAvgUdaf {
43+
/// Create a new instance of the GeoMeanUdaf struct
44+
fn new() -> Self {
45+
Self {
46+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
47+
}
48+
}
49+
}
50+
51+
impl AggregateUDFImpl for BetterAvgUdaf {
52+
fn as_any(&self) -> &dyn Any {
53+
self
54+
}
55+
56+
fn name(&self) -> &str {
57+
"better_avg"
58+
}
59+
60+
fn signature(&self) -> &Signature {
61+
&self.signature
62+
}
63+
64+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65+
Ok(DataType::Float64)
66+
}
67+
68+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
69+
unimplemented!("should not be invoked")
70+
}
71+
72+
fn state_fields(
73+
&self,
74+
_name: &str,
75+
_value_type: DataType,
76+
_ordering_fields: Vec<arrow_schema::Field>,
77+
) -> Result<Vec<arrow_schema::Field>> {
78+
unimplemented!("should not be invoked")
79+
}
80+
81+
fn groups_accumulator_supported(&self) -> bool {
82+
true
83+
}
84+
85+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
86+
unimplemented!("should not get here");
87+
}
88+
// we override method, to return new expression which would substitute
89+
// user defined function call
90+
fn simplify(
91+
&self,
92+
args: Vec<Expr>,
93+
_distinct: &bool,
94+
_filter: &Option<Box<Expr>>,
95+
_order_by: &Option<Vec<Expr>>,
96+
_null_treatment: &Option<datafusion_sql::sqlparser::ast::NullTreatment>,
97+
_info: &dyn SimplifyInfo,
98+
) -> Result<ExprSimplifyResult> {
99+
// as an example for this functionality we replace UDF function
100+
// with build-in aggregate function to illustrate the use
101+
let expr = Expr::AggregateFunction(AggregateFunction {
102+
func_def: AggregateFunctionDefinition::BuiltIn(
103+
// yes it is the same Avg, `BetterAvgUdaf` was just a
104+
// marketing pitch :)
105+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
106+
),
107+
args,
108+
distinct: false,
109+
filter: None,
110+
order_by: None,
111+
null_treatment: None,
112+
});
113+
114+
Ok(ExprSimplifyResult::Simplified(expr))
115+
}
116+
}
117+
118+
// create local session context with an in-memory table
119+
fn create_context() -> Result<SessionContext> {
120+
use datafusion::datasource::MemTable;
121+
// define a schema.
122+
let schema = Arc::new(Schema::new(vec![
123+
Field::new("a", DataType::Float32, false),
124+
Field::new("b", DataType::Float32, false),
125+
]));
126+
127+
// define data in two partitions
128+
let batch1 = RecordBatch::try_new(
129+
schema.clone(),
130+
vec![
131+
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
132+
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
133+
],
134+
)?;
135+
let batch2 = RecordBatch::try_new(
136+
schema.clone(),
137+
vec![
138+
Arc::new(Float32Array::from(vec![16.0])),
139+
Arc::new(Float32Array::from(vec![2.0])),
140+
],
141+
)?;
142+
143+
let ctx = SessionContext::new();
144+
145+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
146+
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
147+
ctx.register_table("t", Arc::new(provider))?;
148+
Ok(ctx)
149+
}
150+
151+
#[tokio::main]
152+
async fn main() -> Result<()> {
153+
let ctx = create_context()?;
154+
155+
let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
156+
ctx.register_udaf(better_avg.clone());
157+
158+
let result = ctx
159+
.sql("SELECT better_avg(a) FROM t group by b")
160+
.await?
161+
.collect()
162+
.await?;
163+
let expected = vec![
164+
"+-----------------+",
165+
"| better_avg(t.a) |",
166+
"+-----------------+",
167+
"| 7.5 |",
168+
"+-----------------+",
169+
];
170+
171+
assert_batches_eq!(expected, &result);
172+
173+
let df = ctx.table("t").await?;
174+
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
175+
176+
let results = df.collect().await?;
177+
let result = as_float64_array(results[0].column(0))?;
178+
179+
assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
180+
println!("The average of [2,4,8,16] is {}", result.value(0));
181+
182+
Ok(())
183+
}

datafusion/expr/src/udaf.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
2020
use crate::function::AccumulatorArgs;
2121
use crate::groups_accumulator::GroupsAccumulator;
22+
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
2223
use crate::utils::format_state_name;
2324
use crate::{Accumulator, Expr};
2425
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2526
use arrow::datatypes::{DataType, Field};
2627
use datafusion_common::{not_impl_err, Result};
28+
use sqlparser::ast::NullTreatment;
2729
use std::any::Any;
2830
use std::fmt::{self, Debug, Formatter};
2931
use std::sync::Arc;
@@ -195,6 +197,21 @@ impl AggregateUDF {
195197
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
196198
self.inner.create_groups_accumulator()
197199
}
200+
/// Do the function rewrite
201+
///
202+
/// See [`AggregateUDFImpl::simplify`] for more details.
203+
pub fn simplify(
204+
&self,
205+
args: Vec<Expr>,
206+
distinct: &bool,
207+
filter: &Option<Box<Expr>>,
208+
order_by: &Option<Vec<Expr>>,
209+
null_treatment: &Option<NullTreatment>,
210+
info: &dyn SimplifyInfo,
211+
) -> Result<ExprSimplifyResult> {
212+
self.inner
213+
.simplify(args, distinct, filter, order_by, null_treatment, info)
214+
}
198215
}
199216

200217
impl<F> From<F> for AggregateUDF
@@ -354,6 +371,37 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
354371
fn aliases(&self) -> &[String] {
355372
&[]
356373
}
374+
375+
/// Optionally apply per-UDF simplification / rewrite rules.
376+
///
377+
/// This can be used to apply function specific simplification rules during
378+
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
379+
/// implementation does nothing.
380+
///
381+
/// Note that DataFusion handles simplifying arguments and "constant
382+
/// folding" (replacing a function call with constant arguments such as
383+
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
384+
/// optimizations manually for specific UDFs.
385+
///
386+
/// # Arguments
387+
/// * 'args': The arguments of the function
388+
/// * 'schema': The schema of the function
389+
///
390+
/// # Returns
391+
/// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
392+
/// if the function cannot be simplified, the arguments *MUST* be returned
393+
/// unmodified
394+
fn simplify(
395+
&self,
396+
args: Vec<Expr>,
397+
_distinct: &bool,
398+
_filter: &Option<Box<Expr>>,
399+
_order_by: &Option<Vec<Expr>>,
400+
_null_treatment: &Option<NullTreatment>,
401+
_info: &dyn SimplifyInfo,
402+
) -> Result<ExprSimplifyResult> {
403+
Ok(ExprSimplifyResult::Original(args))
404+
}
357405
}
358406

359407
/// AggregateUDF that adds an alias to the underlying function. It is better to

0 commit comments

Comments
 (0)