Skip to content

Commit a325825

Browse files
committed
add reverse enum
Signed-off-by: jayzhan211 <[email protected]>
1 parent deebda7 commit a325825

File tree

6 files changed

+149
-6
lines changed

6 files changed

+149
-6
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818
//! This module contains end to end demonstrations of creating
1919
//! user defined aggregate functions
2020
21+
use std::fmt::Debug;
22+
2123
use arrow::{array::AsArray, datatypes::Fields};
2224
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
2325
use arrow_schema::Schema;
26+
use datafusion_physical_plan::udaf::create_aggregate_expr;
27+
use sqlparser::ast::NullTreatment;
2428
use std::sync::{
2529
atomic::{AtomicBool, Ordering},
2630
Arc,
2731
};
2832

29-
use datafusion::datasource::MemTable;
30-
use datafusion::test_util::plan_and_collect;
3133
use datafusion::{
3234
arrow::{
3335
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
@@ -43,10 +45,11 @@ use datafusion::{
4345
prelude::SessionContext,
4446
scalar::ScalarValue,
4547
};
48+
use datafusion::{datasource::MemTable, test_util::plan_and_collect};
4649
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
4750
use datafusion_expr::{
48-
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
49-
SimpleAggregateUDF,
51+
create_udaf, expr::AggregateFunction, function::AccumulatorArgs, AggregateUDFImpl,
52+
GroupsAccumulator, ReversedExpr, SimpleAggregateUDF,
5053
};
5154
use datafusion_physical_expr::expressions::AvgAccumulator;
5255

@@ -795,3 +798,89 @@ impl GroupsAccumulator for TestGroupsAccumulator {
795798
std::mem::size_of::<u64>()
796799
}
797800
}
801+
802+
#[derive(Clone)]
803+
struct TestReverseUDAF {
804+
signature: Signature,
805+
// accumulator: AccumulatorFactoryFunction,
806+
// state_fields: Vec<Field>,
807+
}
808+
809+
impl Debug for TestReverseUDAF {
810+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
811+
f.debug_struct("TestReverseUDAF")
812+
.field("name", &self.name())
813+
.field("signature", self.signature())
814+
.finish()
815+
}
816+
}
817+
818+
impl AggregateUDFImpl for TestReverseUDAF {
819+
fn as_any(&self) -> &dyn std::any::Any {
820+
self
821+
}
822+
823+
fn name(&self) -> &str {
824+
"test_reverse"
825+
}
826+
827+
fn signature(&self) -> &Signature {
828+
&self.signature
829+
}
830+
831+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
832+
Ok(DataType::Float64)
833+
}
834+
835+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
836+
todo!("no need")
837+
}
838+
839+
fn state_fields(
840+
&self,
841+
_name: &str,
842+
_value_type: DataType,
843+
_ordering_fields: Vec<Field>,
844+
) -> Result<Vec<Field>> {
845+
Ok(vec![])
846+
}
847+
848+
fn reverse_expr(&self) -> ReversedExpr {
849+
ReversedExpr::Reversed(AggregateFunction::new_udf(
850+
Arc::new(self.clone().into()),
851+
vec![],
852+
false,
853+
None,
854+
None,
855+
Some(NullTreatment::RespectNulls),
856+
))
857+
}
858+
}
859+
860+
/// tests the creation, registration and usage of a UDAF
861+
#[tokio::test]
862+
async fn test_reverse_udaf() -> Result<()> {
863+
let my_reverse = AggregateUDF::from(TestReverseUDAF {
864+
signature: Signature::exact(vec![], Volatility::Immutable),
865+
});
866+
867+
let empty_schema = Schema::empty();
868+
let e = create_aggregate_expr(
869+
&my_reverse,
870+
&[],
871+
&[],
872+
&[],
873+
&empty_schema,
874+
"test_reverse_udaf",
875+
true,
876+
)?;
877+
878+
// TODO: We don't have a nice way to test the change without introducing many other things
879+
// We check with the output string. `ignore nulls` is expeceted to be false.
880+
let res = e.reverse_expr();
881+
let res_str = format!("{:?}", res.unwrap());
882+
883+
assert_eq!(&res_str, "AggregateFunctionExpr { fun: AggregateUDF { inner: TestReverseUDAF { name: \"test_reverse\", signature: Signature { type_signature: Exact([]), volatility: Immutable } } }, args: [], data_type: Float64, name: \"test_reverse_udaf\", schema: Schema { fields: [], metadata: {} }, sort_exprs: [], ordering_req: [], ignore_nulls: false, ordering_fields: [] }");
884+
885+
Ok(())
886+
}

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ pub use signature::{
8181
TIMEZONE_WILDCARD,
8282
};
8383
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
84-
pub use udaf::{AggregateUDF, AggregateUDFImpl};
84+
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedExpr};
8585
pub use udf::{ScalarUDF, ScalarUDFImpl};
8686
pub use udwf::{WindowUDF, WindowUDFImpl};
8787
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

datafusion/expr/src/udaf.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20+
use crate::expr::AggregateFunction;
2021
use crate::function::AccumulatorArgs;
2122
use crate::groups_accumulator::GroupsAccumulator;
2223
use crate::utils::format_state_name;
@@ -195,6 +196,11 @@ impl AggregateUDF {
195196
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
196197
self.inner.create_groups_accumulator()
197198
}
199+
200+
/// See [`AggregateUDFImpl::reverse_expr`] for more details.
201+
pub fn reverse_expr(&self) -> ReversedExpr {
202+
self.inner.reverse_expr()
203+
}
198204
}
199205

200206
impl<F> From<F> for AggregateUDF
@@ -354,6 +360,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
354360
fn aliases(&self) -> &[String] {
355361
&[]
356362
}
363+
364+
/// Construct an expression that calculates the aggregate in reverse.
365+
/// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
366+
/// For aggregates that do not support calculation in reverse,
367+
/// returns None (which is the default value).
368+
fn reverse_expr(&self) -> ReversedExpr {
369+
ReversedExpr::NotSupported
370+
}
371+
}
372+
373+
#[derive(Debug)]
374+
pub enum ReversedExpr {
375+
/// The expression is the same as the original expression, like SUM, COUNT
376+
Identical,
377+
/// The expression does not support reverse calculation, like ArrayAgg
378+
NotSupported,
379+
/// The expression is different from the original expression
380+
Reversed(AggregateFunction),
357381
}
358382

359383
/// AggregateUDF that adds an alias to the underlying function. It is better to

datafusion/physical-expr-common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ path = "src/lib.rs"
3939
arrow = { workspace = true }
4040
datafusion-common = { workspace = true, default-features = true }
4141
datafusion-expr = { workspace = true }
42+
sqlparser = { workspace = true }

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ pub mod utils;
1919

2020
use arrow::datatypes::{DataType, Field, Schema};
2121
use datafusion_common::{not_impl_err, Result};
22+
use datafusion_expr::expr::AggregateFunction;
2223
use datafusion_expr::type_coercion::aggregates::check_arg_count;
24+
use datafusion_expr::ReversedExpr;
2325
use datafusion_expr::{
2426
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
2527
};
28+
use sqlparser::ast::NullTreatment;
2629
use std::fmt::Debug;
2730
use std::{any::Any, sync::Arc};
2831

@@ -147,7 +150,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
147150
}
148151

149152
/// Physical aggregate expression of a UDAF.
150-
#[derive(Debug)]
153+
#[derive(Debug, Clone)]
151154
pub struct AggregateFunctionExpr {
152155
fun: AggregateUDF,
153156
args: Vec<Arc<dyn PhysicalExpr>>,
@@ -273,6 +276,31 @@ impl AggregateExpr for AggregateFunctionExpr {
273276
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
274277
(!self.ordering_req.is_empty()).then_some(&self.ordering_req)
275278
}
279+
280+
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
281+
match self.fun.reverse_expr() {
282+
ReversedExpr::NotSupported => None,
283+
ReversedExpr::Identical => Some(Arc::new(self.clone())),
284+
ReversedExpr::Reversed(AggregateFunction {
285+
func_def: _,
286+
args: _,
287+
distinct: _,
288+
filter: _,
289+
order_by: _,
290+
null_treatment,
291+
}) => {
292+
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
293+
== NullTreatment::IgnoreNulls;
294+
295+
// TODO: Do the actual conversion from logical expr
296+
// for other fields
297+
let mut expr = self.clone();
298+
expr.ignore_nulls = ignore_nulls;
299+
300+
Some(Arc::new(expr))
301+
}
302+
}
303+
}
276304
}
277305

278306
impl PartialEq<dyn Any> for AggregateFunctionExpr {

0 commit comments

Comments
 (0)