Skip to content

Commit 1fd74c8

Browse files
committed
Specialize SUM and AVG (apache#6842)
1 parent 870857a commit 1fd74c8

File tree

7 files changed

+244
-318
lines changed

7 files changed

+244
-318
lines changed

datafusion/core/src/execution/context.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,12 +2452,7 @@ mod tests {
24522452
vec![DataType::Float64],
24532453
Arc::new(DataType::Float64),
24542454
Volatility::Immutable,
2455-
Arc::new(|_| {
2456-
Ok(Box::new(AvgAccumulator::try_new(
2457-
&DataType::Float64,
2458-
&DataType::Float64,
2459-
)?))
2460-
}),
2455+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
24612456
Arc::new(vec![DataType::UInt64, DataType::Float64]),
24622457
);
24632458

datafusion/core/tests/sql/udf.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,7 @@ async fn simple_udaf() -> Result<()> {
237237
vec![DataType::Float64],
238238
Arc::new(DataType::Float64),
239239
Volatility::Immutable,
240-
Arc::new(|_| {
241-
Ok(Box::new(AvgAccumulator::try_new(
242-
&DataType::Float64,
243-
&DataType::Float64,
244-
)?))
245-
}),
240+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
246241
Arc::new(vec![DataType::UInt64, DataType::Float64]),
247242
);
248243

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -906,12 +906,7 @@ mod test {
906906
vec![DataType::Float64],
907907
Arc::new(DataType::Float64),
908908
Volatility::Immutable,
909-
Arc::new(|_| {
910-
Ok(Box::new(AvgAccumulator::try_new(
911-
&DataType::Float64,
912-
&DataType::Float64,
913-
)?))
914-
}),
909+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
915910
Arc::new(vec![DataType::UInt64, DataType::Float64]),
916911
);
917912
let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
@@ -932,12 +927,8 @@ mod test {
932927
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
933928
let state_type: StateTypeFunction =
934929
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
935-
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
936-
Ok(Box::new(AvgAccumulator::try_new(
937-
&DataType::Float64,
938-
&DataType::Float64,
939-
)?))
940-
});
930+
let accumulator: AccumulatorFactoryFunction =
931+
Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
941932
let my_avg = AggregateUDF::new(
942933
"MY_AVG",
943934
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 124 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder};
2121
use log::debug;
2222

2323
use std::any::Any;
24-
use std::convert::TryFrom;
2524
use std::sync::Arc;
2625

2726
use crate::aggregate::groups_accumulator::accumulate::NullState;
28-
use crate::aggregate::sum;
29-
use crate::aggregate::sum::sum_batch;
30-
use crate::aggregate::utils::calculate_result_decimal_for_avg;
3127
use crate::aggregate::utils::down_cast_any_ref;
3228
use crate::expressions::format_state_name;
3329
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
34-
use arrow::compute;
30+
use arrow::compute::sum;
3531
use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type};
3632
use arrow::{
3733
array::{ArrayRef, UInt64Array},
@@ -40,9 +36,7 @@ use arrow::{
4036
use arrow_array::{
4137
Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray,
4238
};
43-
use datafusion_common::{
44-
downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue,
45-
};
39+
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
4640
use datafusion_expr::type_coercion::aggregates::avg_return_type;
4741
use datafusion_expr::Accumulator;
4842

@@ -93,11 +87,27 @@ impl AggregateExpr for Avg {
9387
}
9488

9589
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
96-
Ok(Box::new(AvgAccumulator::try_new(
97-
// avg is f64 or decimal
98-
&self.input_data_type,
99-
&self.result_data_type,
100-
)?))
90+
use DataType::*;
91+
// instantiate specialized accumulator based for the type
92+
match (&self.input_data_type, &self.result_data_type) {
93+
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
94+
(
95+
Decimal128(sum_precision, sum_scale),
96+
Decimal128(target_precision, target_scale),
97+
) => Ok(Box::new(DecimalAvgAccumulator {
98+
sum: None,
99+
count: 0,
100+
sum_scale: *sum_scale,
101+
sum_precision: *sum_precision,
102+
target_precision: *target_precision,
103+
target_scale: *target_scale,
104+
})),
105+
_ => not_impl_err!(
106+
"AvgGroupsAccumulator for ({} --> {})",
107+
self.input_data_type,
108+
self.result_data_type
109+
),
110+
}
101111
}
102112

103113
fn state_fields(&self) -> Result<Vec<Field>> {
@@ -128,10 +138,7 @@ impl AggregateExpr for Avg {
128138
}
129139

130140
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
131-
Ok(Box::new(AvgAccumulator::try_new(
132-
&self.input_data_type,
133-
&self.result_data_type,
134-
)?))
141+
self.create_accumulator()
135142
}
136143

137144
fn groups_accumulator_supported(&self) -> bool {
@@ -195,91 +202,141 @@ impl PartialEq<dyn Any> for Avg {
195202
}
196203

197204
/// An accumulator to compute the average
198-
#[derive(Debug)]
205+
#[derive(Debug, Default)]
199206
pub struct AvgAccumulator {
200-
// sum is used for null
201-
sum: ScalarValue,
202-
return_data_type: DataType,
207+
sum: Option<f64>,
203208
count: u64,
204209
}
205210

206-
impl AvgAccumulator {
207-
/// Creates a new `AvgAccumulator`
208-
pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result<Self> {
209-
Ok(Self {
210-
sum: ScalarValue::try_from(datatype)?,
211-
return_data_type: return_data_type.clone(),
212-
count: 0,
213-
})
211+
impl Accumulator for AvgAccumulator {
212+
fn state(&self) -> Result<Vec<ScalarValue>> {
213+
Ok(vec![
214+
ScalarValue::from(self.count),
215+
ScalarValue::Float64(self.sum),
216+
])
217+
}
218+
219+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
220+
let values = values[0].as_primitive::<Float64Type>();
221+
self.count += (values.len() - values.null_count()) as u64;
222+
if let Some(x) = sum(values) {
223+
let v = self.sum.get_or_insert(0.);
224+
*v += x;
225+
}
226+
Ok(())
227+
}
228+
229+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
230+
let values = values[0].as_primitive::<Float64Type>();
231+
self.count -= (values.len() - values.null_count()) as u64;
232+
if let Some(x) = sum(values) {
233+
self.sum = Some(self.sum.unwrap() - x);
234+
}
235+
Ok(())
236+
}
237+
238+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
239+
// counts are summed
240+
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
241+
242+
// sums are summed
243+
if let Some(x) = sum(states[1].as_primitive::<Float64Type>()) {
244+
let v = self.sum.get_or_insert(0.);
245+
*v += x;
246+
}
247+
Ok(())
248+
}
249+
250+
fn evaluate(&self) -> Result<ScalarValue> {
251+
Ok(ScalarValue::Float64(
252+
self.sum.map(|f| f / self.count as f64),
253+
))
254+
}
255+
fn supports_retract_batch(&self) -> bool {
256+
true
257+
}
258+
259+
fn size(&self) -> usize {
260+
std::mem::size_of_val(self)
214261
}
215262
}
216263

217-
impl Accumulator for AvgAccumulator {
264+
/// An accumulator to compute the average for decimals
265+
#[derive(Debug)]
266+
struct DecimalAvgAccumulator {
267+
sum: Option<i128>,
268+
count: u64,
269+
sum_scale: i8,
270+
sum_precision: u8,
271+
target_precision: u8,
272+
target_scale: i8,
273+
}
274+
275+
impl Accumulator for DecimalAvgAccumulator {
218276
fn state(&self) -> Result<Vec<ScalarValue>> {
219-
Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
277+
Ok(vec![
278+
ScalarValue::from(self.count),
279+
ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
280+
])
220281
}
221282

222283
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
223-
let values = &values[0];
284+
let values = values[0].as_primitive::<Decimal128Type>();
224285

225286
self.count += (values.len() - values.null_count()) as u64;
226-
self.sum = self.sum.add(&sum::sum_batch(values)?)?;
287+
if let Some(x) = sum(values) {
288+
let v = self.sum.get_or_insert(0);
289+
*v += x;
290+
}
227291
Ok(())
228292
}
229293

230294
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
231-
let values = &values[0];
295+
let values = values[0].as_primitive::<Decimal128Type>();
232296
self.count -= (values.len() - values.null_count()) as u64;
233-
let delta = sum_batch(values)?;
234-
self.sum = self.sum.sub(&delta)?;
297+
if let Some(x) = sum(values) {
298+
self.sum = Some(self.sum.unwrap() - x);
299+
}
235300
Ok(())
236301
}
237302

238303
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
239-
let counts = downcast_value!(states[0], UInt64Array);
240304
// counts are summed
241-
self.count += compute::sum(counts).unwrap_or(0);
305+
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
242306

243307
// sums are summed
244-
self.sum = self.sum.add(&sum::sum_batch(&states[1])?)?;
308+
if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
309+
let v = self.sum.get_or_insert(0);
310+
*v += x;
311+
}
245312
Ok(())
246313
}
247314

248315
fn evaluate(&self) -> Result<ScalarValue> {
249-
match self.sum {
250-
ScalarValue::Float64(e) => {
251-
Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
252-
}
253-
ScalarValue::Decimal128(value, _, scale) => {
254-
match value {
255-
None => match &self.return_data_type {
256-
DataType::Decimal128(p, s) => {
257-
Ok(ScalarValue::Decimal128(None, *p, *s))
258-
}
259-
other => internal_err!(
260-
"Error returned data type in AvgAccumulator {other:?}"
261-
),
262-
},
263-
Some(value) => {
264-
// now the sum_type and return type is not the same, need to convert the sum type to return type
265-
calculate_result_decimal_for_avg(
266-
value,
267-
self.count as i128,
268-
scale,
269-
&self.return_data_type,
270-
)
271-
}
272-
}
273-
}
274-
_ => internal_err!("Sum should be f64 or decimal128 on average"),
275-
}
316+
let v = self
317+
.sum
318+
.map(|v| {
319+
Decimal128Averager::try_new(
320+
self.sum_scale,
321+
self.target_precision,
322+
self.target_scale,
323+
)?
324+
.avg(v, self.count as _)
325+
})
326+
.transpose()?;
327+
328+
Ok(ScalarValue::Decimal128(
329+
v,
330+
self.target_precision,
331+
self.target_scale,
332+
))
276333
}
277334
fn supports_retract_batch(&self) -> bool {
278335
true
279336
}
280337

281338
fn size(&self) -> usize {
282-
std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size()
339+
std::mem::size_of_val(self)
283340
}
284341
}
285342

0 commit comments

Comments
 (0)