@@ -28,8 +28,7 @@ use std::collections::HashSet;
28
28
use crate :: aggregate:: utils:: down_cast_any_ref;
29
29
use crate :: expressions:: format_state_name;
30
30
use crate :: { AggregateExpr , PhysicalExpr } ;
31
- use datafusion_common:: Result ;
32
- use datafusion_common:: ScalarValue ;
31
+ use datafusion_common:: { internal_err, DataFusionError , Result , ScalarValue } ;
33
32
use datafusion_expr:: Accumulator ;
34
33
35
34
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
@@ -135,23 +134,36 @@ impl Accumulator for DistinctArrayAggAccumulator {
135
134
fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
136
135
assert_eq ! ( values. len( ) , 1 , "batch input should only include 1 column!" ) ;
137
136
138
- let arr = & values[ 0 ] ;
139
- for i in 0 ..arr. len ( ) {
140
- self . values . insert ( ScalarValue :: try_from_array ( arr, i) ?) ;
141
- }
142
- Ok ( ( ) )
137
+ let array = & values[ 0 ] ;
138
+ ( 0 ..array. len ( ) ) . try_for_each ( |i| {
139
+ if !array. is_null ( i) {
140
+ self . values . insert ( ScalarValue :: try_from_array ( array, i) ?) ;
141
+ }
142
+ Ok ( ( ) )
143
+ } )
143
144
}
144
145
145
146
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
146
147
if states. is_empty ( ) {
147
148
return Ok ( ( ) ) ;
148
149
}
149
150
150
- for array in states {
151
- for j in 0 ..array. len ( ) {
152
- self . values . insert ( ScalarValue :: try_from_array ( array, j) ?) ;
151
+ assert_eq ! (
152
+ states. len( ) ,
153
+ 1 ,
154
+ "array_agg_distinct states must contain single array"
155
+ ) ;
156
+
157
+ let array = & states[ 0 ] ;
158
+ ( 0 ..array. len ( ) ) . try_for_each ( |i| {
159
+ let scalar = ScalarValue :: try_from_array ( array, i) ?;
160
+ if let ScalarValue :: List ( Some ( values) , _) = scalar {
161
+ self . values . extend ( values) ;
162
+ Ok ( ( ) )
163
+ } else {
164
+ internal_err ! ( "array_agg_distinct state must be list" )
153
165
}
154
- }
166
+ } ) ? ;
155
167
156
168
Ok ( ( ) )
157
169
}
@@ -174,12 +186,35 @@ impl Accumulator for DistinctArrayAggAccumulator {
174
186
#[ cfg( test) ]
175
187
mod tests {
176
188
use super :: * ;
189
+ use crate :: aggregate:: utils:: get_accum_scalar_values_as_arrays;
177
190
use crate :: expressions:: col;
178
191
use crate :: expressions:: tests:: aggregate;
179
192
use arrow:: array:: { ArrayRef , Int32Array } ;
180
193
use arrow:: datatypes:: { DataType , Schema } ;
181
194
use arrow:: record_batch:: RecordBatch ;
182
195
196
+ fn compare_list_contents ( expected : ScalarValue , actual : ScalarValue ) -> Result < ( ) > {
197
+ match ( expected, actual) {
198
+ ( ScalarValue :: List ( Some ( mut e) , _) , ScalarValue :: List ( Some ( mut a) , _) ) => {
199
+ // workaround lack of Ord of ScalarValue
200
+ let cmp = |a : & ScalarValue , b : & ScalarValue | {
201
+ a. partial_cmp ( b) . expect ( "Can compare ScalarValues" )
202
+ } ;
203
+
204
+ e. sort_by ( cmp) ;
205
+ a. sort_by ( cmp) ;
206
+ // Check that the inputs are the same
207
+ assert_eq ! ( e, a) ;
208
+ }
209
+ _ => {
210
+ return Err ( DataFusionError :: Internal (
211
+ "Expected scalar lists as inputs" . to_string ( ) ,
212
+ ) ) ;
213
+ }
214
+ }
215
+ Ok ( ( ) )
216
+ }
217
+
183
218
fn check_distinct_array_agg (
184
219
input : ArrayRef ,
185
220
expected : ScalarValue ,
@@ -195,24 +230,34 @@ mod tests {
195
230
) ) ;
196
231
let actual = aggregate ( & batch, agg) ?;
197
232
198
- match ( expected, actual) {
199
- ( ScalarValue :: List ( Some ( mut e) , _) , ScalarValue :: List ( Some ( mut a) , _) ) => {
200
- // workaround lack of Ord of ScalarValue
201
- let cmp = |a : & ScalarValue , b : & ScalarValue | {
202
- a. partial_cmp ( b) . expect ( "Can compare ScalarValues" )
203
- } ;
233
+ compare_list_contents ( expected, actual)
234
+ }
204
235
205
- e. sort_by ( cmp) ;
206
- a. sort_by ( cmp) ;
207
- // Check that the inputs are the same
208
- assert_eq ! ( e, a) ;
209
- }
210
- _ => {
211
- unreachable ! ( )
212
- }
213
- }
236
+ fn check_merge_distinct_array_agg (
237
+ input1 : ArrayRef ,
238
+ input2 : ArrayRef ,
239
+ expected : ScalarValue ,
240
+ datatype : DataType ,
241
+ ) -> Result < ( ) > {
242
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , datatype. clone( ) , false ) ] ) ;
243
+ let agg = Arc :: new ( DistinctArrayAgg :: new (
244
+ col ( "a" , & schema) ?,
245
+ "bla" . to_string ( ) ,
246
+ datatype,
247
+ ) ) ;
214
248
215
- Ok ( ( ) )
249
+ let mut accum1 = agg. create_accumulator ( ) ?;
250
+ let mut accum2 = agg. create_accumulator ( ) ?;
251
+
252
+ accum1. update_batch ( & [ input1] ) ?;
253
+ accum2. update_batch ( & [ input2] ) ?;
254
+
255
+ let state = get_accum_scalar_values_as_arrays ( accum2. as_ref ( ) ) ?;
256
+ accum1. merge_batch ( & state) ?;
257
+
258
+ let actual = accum1. evaluate ( ) ?;
259
+
260
+ compare_list_contents ( expected, actual)
216
261
}
217
262
218
263
#[ test]
@@ -233,6 +278,27 @@ mod tests {
233
278
check_distinct_array_agg ( col, out, DataType :: Int32 )
234
279
}
235
280
281
+ #[ test]
282
+ fn merge_distinct_array_agg_i32 ( ) -> Result < ( ) > {
283
+ let col1: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 7 , 4 , 5 , 2 ] ) ) ;
284
+ let col2: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 3 , 7 , 8 , 4 ] ) ) ;
285
+
286
+ let out = ScalarValue :: new_list (
287
+ Some ( vec ! [
288
+ ScalarValue :: Int32 ( Some ( 1 ) ) ,
289
+ ScalarValue :: Int32 ( Some ( 2 ) ) ,
290
+ ScalarValue :: Int32 ( Some ( 3 ) ) ,
291
+ ScalarValue :: Int32 ( Some ( 4 ) ) ,
292
+ ScalarValue :: Int32 ( Some ( 5 ) ) ,
293
+ ScalarValue :: Int32 ( Some ( 7 ) ) ,
294
+ ScalarValue :: Int32 ( Some ( 8 ) ) ,
295
+ ] ) ,
296
+ DataType :: Int32 ,
297
+ ) ;
298
+
299
+ check_merge_distinct_array_agg ( col1, col2, out, DataType :: Int32 )
300
+ }
301
+
236
302
#[ test]
237
303
fn distinct_array_agg_nested ( ) -> Result < ( ) > {
238
304
// [[1, 2, 3], [4, 5]]
@@ -296,4 +362,66 @@ mod tests {
296
362
) ) ) ,
297
363
)
298
364
}
365
+
366
+ #[ test]
367
+ fn merge_distinct_array_agg_nested ( ) -> Result < ( ) > {
368
+ // [[1, 2], [3, 4]]
369
+ let l1 = ScalarValue :: new_list (
370
+ Some ( vec ! [
371
+ ScalarValue :: new_list(
372
+ Some ( vec![ ScalarValue :: from( 1i32 ) , ScalarValue :: from( 2i32 ) ] ) ,
373
+ DataType :: Int32 ,
374
+ ) ,
375
+ ScalarValue :: new_list(
376
+ Some ( vec![ ScalarValue :: from( 3i32 ) , ScalarValue :: from( 4i32 ) ] ) ,
377
+ DataType :: Int32 ,
378
+ ) ,
379
+ ] ) ,
380
+ DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
381
+ ) ;
382
+
383
+ // [[5]]
384
+ let l2 = ScalarValue :: new_list (
385
+ Some ( vec ! [ ScalarValue :: new_list(
386
+ Some ( vec![ ScalarValue :: from( 5i32 ) ] ) ,
387
+ DataType :: Int32 ,
388
+ ) ] ) ,
389
+ DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
390
+ ) ;
391
+
392
+ // [[6, 7], [8]]
393
+ let l3 = ScalarValue :: new_list (
394
+ Some ( vec ! [
395
+ ScalarValue :: new_list(
396
+ Some ( vec![ ScalarValue :: from( 6i32 ) , ScalarValue :: from( 7i32 ) ] ) ,
397
+ DataType :: Int32 ,
398
+ ) ,
399
+ ScalarValue :: new_list(
400
+ Some ( vec![ ScalarValue :: from( 8i32 ) ] ) ,
401
+ DataType :: Int32 ,
402
+ ) ,
403
+ ] ) ,
404
+ DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
405
+ ) ;
406
+
407
+ let expected = ScalarValue :: new_list (
408
+ Some ( vec ! [ l1. clone( ) , l2. clone( ) , l3. clone( ) ] ) ,
409
+ DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
410
+ ) ;
411
+
412
+ // Duplicate l1 in the input array and check that it is deduped in the output.
413
+ let input1 = ScalarValue :: iter_to_array ( vec ! [ l1. clone( ) , l2] ) . unwrap ( ) ;
414
+ let input2 = ScalarValue :: iter_to_array ( vec ! [ l1, l3] ) . unwrap ( ) ;
415
+
416
+ check_merge_distinct_array_agg (
417
+ input1,
418
+ input2,
419
+ expected,
420
+ DataType :: List ( Arc :: new ( Field :: new_list (
421
+ "item" ,
422
+ Field :: new ( "item" , DataType :: Int32 , true ) ,
423
+ true ,
424
+ ) ) ) ,
425
+ )
426
+ }
299
427
}
0 commit comments