@@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder};
21
21
use log:: debug;
22
22
23
23
use std:: any:: Any ;
24
- use std:: convert:: TryFrom ;
25
24
use std:: sync:: Arc ;
26
25
27
26
use crate :: aggregate:: groups_accumulator:: accumulate:: NullState ;
28
- use crate :: aggregate:: sum;
29
- use crate :: aggregate:: sum:: sum_batch;
30
- use crate :: aggregate:: utils:: calculate_result_decimal_for_avg;
31
27
use crate :: aggregate:: utils:: down_cast_any_ref;
32
28
use crate :: expressions:: format_state_name;
33
29
use crate :: { AggregateExpr , GroupsAccumulator , PhysicalExpr } ;
34
- use arrow:: compute;
30
+ use arrow:: compute:: sum ;
35
31
use arrow:: datatypes:: { DataType , Decimal128Type , Float64Type , UInt64Type } ;
36
32
use arrow:: {
37
33
array:: { ArrayRef , UInt64Array } ,
@@ -40,9 +36,7 @@ use arrow::{
40
36
use arrow_array:: {
41
37
Array , ArrowNativeTypeOp , ArrowNumericType , ArrowPrimitiveType , PrimitiveArray ,
42
38
} ;
43
- use datafusion_common:: {
44
- downcast_value, internal_err, not_impl_err, DataFusionError , Result , ScalarValue ,
45
- } ;
39
+ use datafusion_common:: { not_impl_err, DataFusionError , Result , ScalarValue } ;
46
40
use datafusion_expr:: type_coercion:: aggregates:: avg_return_type;
47
41
use datafusion_expr:: Accumulator ;
48
42
@@ -93,11 +87,27 @@ impl AggregateExpr for Avg {
93
87
}
94
88
95
89
fn create_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
96
- Ok ( Box :: new ( AvgAccumulator :: try_new (
97
- // avg is f64 or decimal
98
- & self . input_data_type ,
99
- & self . result_data_type ,
100
- ) ?) )
90
+ use DataType :: * ;
91
+ // instantiate specialized accumulator based for the type
92
+ match ( & self . input_data_type , & self . result_data_type ) {
93
+ ( Float64 , Float64 ) => Ok ( Box :: < AvgAccumulator > :: default ( ) ) ,
94
+ (
95
+ Decimal128 ( sum_precision, sum_scale) ,
96
+ Decimal128 ( target_precision, target_scale) ,
97
+ ) => Ok ( Box :: new ( DecimalAvgAccumulator {
98
+ sum : None ,
99
+ count : 0 ,
100
+ sum_scale : * sum_scale,
101
+ sum_precision : * sum_precision,
102
+ target_precision : * target_precision,
103
+ target_scale : * target_scale,
104
+ } ) ) ,
105
+ _ => not_impl_err ! (
106
+ "AvgGroupsAccumulator for ({} --> {})" ,
107
+ self . input_data_type,
108
+ self . result_data_type
109
+ ) ,
110
+ }
101
111
}
102
112
103
113
fn state_fields ( & self ) -> Result < Vec < Field > > {
@@ -128,10 +138,7 @@ impl AggregateExpr for Avg {
128
138
}
129
139
130
140
fn create_sliding_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
131
- Ok ( Box :: new ( AvgAccumulator :: try_new (
132
- & self . input_data_type ,
133
- & self . result_data_type ,
134
- ) ?) )
141
+ self . create_accumulator ( )
135
142
}
136
143
137
144
fn groups_accumulator_supported ( & self ) -> bool {
@@ -195,91 +202,141 @@ impl PartialEq<dyn Any> for Avg {
195
202
}
196
203
197
204
/// An accumulator to compute the average
198
- #[ derive( Debug ) ]
205
+ #[ derive( Debug , Default ) ]
199
206
pub struct AvgAccumulator {
200
- // sum is used for null
201
- sum : ScalarValue ,
202
- return_data_type : DataType ,
207
+ sum : Option < f64 > ,
203
208
count : u64 ,
204
209
}
205
210
206
- impl AvgAccumulator {
207
- /// Creates a new `AvgAccumulator`
208
- pub fn try_new ( datatype : & DataType , return_data_type : & DataType ) -> Result < Self > {
209
- Ok ( Self {
210
- sum : ScalarValue :: try_from ( datatype) ?,
211
- return_data_type : return_data_type. clone ( ) ,
212
- count : 0 ,
213
- } )
211
+ impl Accumulator for AvgAccumulator {
212
+ fn state ( & self ) -> Result < Vec < ScalarValue > > {
213
+ Ok ( vec ! [
214
+ ScalarValue :: from( self . count) ,
215
+ ScalarValue :: Float64 ( self . sum) ,
216
+ ] )
217
+ }
218
+
219
+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
220
+ let values = values[ 0 ] . as_primitive :: < Float64Type > ( ) ;
221
+ self . count += ( values. len ( ) - values. null_count ( ) ) as u64 ;
222
+ if let Some ( x) = sum ( values) {
223
+ let v = self . sum . get_or_insert ( 0. ) ;
224
+ * v += x;
225
+ }
226
+ Ok ( ( ) )
227
+ }
228
+
229
+ fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
230
+ let values = values[ 0 ] . as_primitive :: < Float64Type > ( ) ;
231
+ self . count -= ( values. len ( ) - values. null_count ( ) ) as u64 ;
232
+ if let Some ( x) = sum ( values) {
233
+ self . sum = Some ( self . sum . unwrap ( ) - x) ;
234
+ }
235
+ Ok ( ( ) )
236
+ }
237
+
238
+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
239
+ // counts are summed
240
+ self . count += sum ( states[ 0 ] . as_primitive :: < UInt64Type > ( ) ) . unwrap_or_default ( ) ;
241
+
242
+ // sums are summed
243
+ if let Some ( x) = sum ( states[ 1 ] . as_primitive :: < Float64Type > ( ) ) {
244
+ let v = self . sum . get_or_insert ( 0. ) ;
245
+ * v += x;
246
+ }
247
+ Ok ( ( ) )
248
+ }
249
+
250
+ fn evaluate ( & self ) -> Result < ScalarValue > {
251
+ Ok ( ScalarValue :: Float64 (
252
+ self . sum . map ( |f| f / self . count as f64 ) ,
253
+ ) )
254
+ }
255
+ fn supports_retract_batch ( & self ) -> bool {
256
+ true
257
+ }
258
+
259
+ fn size ( & self ) -> usize {
260
+ std:: mem:: size_of_val ( self )
214
261
}
215
262
}
216
263
217
- impl Accumulator for AvgAccumulator {
264
+ /// An accumulator to compute the average for decimals
265
+ #[ derive( Debug ) ]
266
+ struct DecimalAvgAccumulator {
267
+ sum : Option < i128 > ,
268
+ count : u64 ,
269
+ sum_scale : i8 ,
270
+ sum_precision : u8 ,
271
+ target_precision : u8 ,
272
+ target_scale : i8 ,
273
+ }
274
+
275
+ impl Accumulator for DecimalAvgAccumulator {
218
276
fn state ( & self ) -> Result < Vec < ScalarValue > > {
219
- Ok ( vec ! [ ScalarValue :: from( self . count) , self . sum. clone( ) ] )
277
+ Ok ( vec ! [
278
+ ScalarValue :: from( self . count) ,
279
+ ScalarValue :: Decimal128 ( self . sum, self . sum_precision, self . sum_scale) ,
280
+ ] )
220
281
}
221
282
222
283
fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
223
- let values = & values[ 0 ] ;
284
+ let values = values[ 0 ] . as_primitive :: < Decimal128Type > ( ) ;
224
285
225
286
self . count += ( values. len ( ) - values. null_count ( ) ) as u64 ;
226
- self . sum = self . sum . add ( & sum:: sum_batch ( values) ?) ?;
287
+ if let Some ( x) = sum ( values) {
288
+ let v = self . sum . get_or_insert ( 0 ) ;
289
+ * v += x;
290
+ }
227
291
Ok ( ( ) )
228
292
}
229
293
230
294
fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
231
- let values = & values[ 0 ] ;
295
+ let values = values[ 0 ] . as_primitive :: < Decimal128Type > ( ) ;
232
296
self . count -= ( values. len ( ) - values. null_count ( ) ) as u64 ;
233
- let delta = sum_batch ( values) ?;
234
- self . sum = self . sum . sub ( & delta) ?;
297
+ if let Some ( x) = sum ( values) {
298
+ self . sum = Some ( self . sum . unwrap ( ) - x) ;
299
+ }
235
300
Ok ( ( ) )
236
301
}
237
302
238
303
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
239
- let counts = downcast_value ! ( states[ 0 ] , UInt64Array ) ;
240
304
// counts are summed
241
- self . count += compute :: sum ( counts ) . unwrap_or ( 0 ) ;
305
+ self . count += sum ( states [ 0 ] . as_primitive :: < UInt64Type > ( ) ) . unwrap_or_default ( ) ;
242
306
243
307
// sums are summed
244
- self . sum = self . sum . add ( & sum:: sum_batch ( & states[ 1 ] ) ?) ?;
308
+ if let Some ( x) = sum ( states[ 1 ] . as_primitive :: < Decimal128Type > ( ) ) {
309
+ let v = self . sum . get_or_insert ( 0 ) ;
310
+ * v += x;
311
+ }
245
312
Ok ( ( ) )
246
313
}
247
314
248
315
fn evaluate ( & self ) -> Result < ScalarValue > {
249
- match self . sum {
250
- ScalarValue :: Float64 ( e) => {
251
- Ok ( ScalarValue :: Float64 ( e. map ( |f| f / self . count as f64 ) ) )
252
- }
253
- ScalarValue :: Decimal128 ( value, _, scale) => {
254
- match value {
255
- None => match & self . return_data_type {
256
- DataType :: Decimal128 ( p, s) => {
257
- Ok ( ScalarValue :: Decimal128 ( None , * p, * s) )
258
- }
259
- other => internal_err ! (
260
- "Error returned data type in AvgAccumulator {other:?}"
261
- ) ,
262
- } ,
263
- Some ( value) => {
264
- // now the sum_type and return type is not the same, need to convert the sum type to return type
265
- calculate_result_decimal_for_avg (
266
- value,
267
- self . count as i128 ,
268
- scale,
269
- & self . return_data_type ,
270
- )
271
- }
272
- }
273
- }
274
- _ => internal_err ! ( "Sum should be f64 or decimal128 on average" ) ,
275
- }
316
+ let v = self
317
+ . sum
318
+ . map ( |v| {
319
+ Decimal128Averager :: try_new (
320
+ self . sum_scale ,
321
+ self . target_precision ,
322
+ self . target_scale ,
323
+ ) ?
324
+ . avg ( v, self . count as _ )
325
+ } )
326
+ . transpose ( ) ?;
327
+
328
+ Ok ( ScalarValue :: Decimal128 (
329
+ v,
330
+ self . target_precision ,
331
+ self . target_scale ,
332
+ ) )
276
333
}
277
334
fn supports_retract_batch ( & self ) -> bool {
278
335
true
279
336
}
280
337
281
338
fn size ( & self ) -> usize {
282
- std:: mem:: size_of_val ( self ) - std :: mem :: size_of_val ( & self . sum ) + self . sum . size ( )
339
+ std:: mem:: size_of_val ( self )
283
340
}
284
341
}
285
342
0 commit comments