Skip to content

Commit ea9144e

Browse files
korowametesynnada
andauthored
fix: inconsistent scalar types in DistinctArrayAggAccumulator state (#7385)
* fix: inconsistent types in array_agg_distinct merge_batch * Apply suggestions from code review Co-authored-by: Metehan Yıldırım <[email protected]> * filtering NULLs & validating sqllogictest output --------- Co-authored-by: Metehan Yıldırım <[email protected]>
1 parent 3f8c512 commit ea9144e

File tree

2 files changed

+178
-33
lines changed

2 files changed

+178
-33
lines changed

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 155 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ use std::collections::HashSet;
2828
use crate::aggregate::utils::down_cast_any_ref;
2929
use crate::expressions::format_state_name;
3030
use crate::{AggregateExpr, PhysicalExpr};
31-
use datafusion_common::Result;
32-
use datafusion_common::ScalarValue;
31+
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
3332
use datafusion_expr::Accumulator;
3433

3534
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
@@ -135,23 +134,36 @@ impl Accumulator for DistinctArrayAggAccumulator {
135134
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
136135
assert_eq!(values.len(), 1, "batch input should only include 1 column!");
137136

138-
let arr = &values[0];
139-
for i in 0..arr.len() {
140-
self.values.insert(ScalarValue::try_from_array(arr, i)?);
141-
}
142-
Ok(())
137+
let array = &values[0];
138+
(0..array.len()).try_for_each(|i| {
139+
if !array.is_null(i) {
140+
self.values.insert(ScalarValue::try_from_array(array, i)?);
141+
}
142+
Ok(())
143+
})
143144
}
144145

145146
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
146147
if states.is_empty() {
147148
return Ok(());
148149
}
149150

150-
for array in states {
151-
for j in 0..array.len() {
152-
self.values.insert(ScalarValue::try_from_array(array, j)?);
151+
assert_eq!(
152+
states.len(),
153+
1,
154+
"array_agg_distinct states must contain single array"
155+
);
156+
157+
let array = &states[0];
158+
(0..array.len()).try_for_each(|i| {
159+
let scalar = ScalarValue::try_from_array(array, i)?;
160+
if let ScalarValue::List(Some(values), _) = scalar {
161+
self.values.extend(values);
162+
Ok(())
163+
} else {
164+
internal_err!("array_agg_distinct state must be list")
153165
}
154-
}
166+
})?;
155167

156168
Ok(())
157169
}
@@ -174,12 +186,35 @@ impl Accumulator for DistinctArrayAggAccumulator {
174186
#[cfg(test)]
175187
mod tests {
176188
use super::*;
189+
use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
177190
use crate::expressions::col;
178191
use crate::expressions::tests::aggregate;
179192
use arrow::array::{ArrayRef, Int32Array};
180193
use arrow::datatypes::{DataType, Schema};
181194
use arrow::record_batch::RecordBatch;
182195

196+
fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> {
197+
match (expected, actual) {
198+
(ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => {
199+
// workaround lack of Ord of ScalarValue
200+
let cmp = |a: &ScalarValue, b: &ScalarValue| {
201+
a.partial_cmp(b).expect("Can compare ScalarValues")
202+
};
203+
204+
e.sort_by(cmp);
205+
a.sort_by(cmp);
206+
// Check that the inputs are the same
207+
assert_eq!(e, a);
208+
}
209+
_ => {
210+
return Err(DataFusionError::Internal(
211+
"Expected scalar lists as inputs".to_string(),
212+
));
213+
}
214+
}
215+
Ok(())
216+
}
217+
183218
fn check_distinct_array_agg(
184219
input: ArrayRef,
185220
expected: ScalarValue,
@@ -195,24 +230,34 @@ mod tests {
195230
));
196231
let actual = aggregate(&batch, agg)?;
197232

198-
match (expected, actual) {
199-
(ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => {
200-
// workaround lack of Ord of ScalarValue
201-
let cmp = |a: &ScalarValue, b: &ScalarValue| {
202-
a.partial_cmp(b).expect("Can compare ScalarValues")
203-
};
233+
compare_list_contents(expected, actual)
234+
}
204235

205-
e.sort_by(cmp);
206-
a.sort_by(cmp);
207-
// Check that the inputs are the same
208-
assert_eq!(e, a);
209-
}
210-
_ => {
211-
unreachable!()
212-
}
213-
}
236+
fn check_merge_distinct_array_agg(
237+
input1: ArrayRef,
238+
input2: ArrayRef,
239+
expected: ScalarValue,
240+
datatype: DataType,
241+
) -> Result<()> {
242+
let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]);
243+
let agg = Arc::new(DistinctArrayAgg::new(
244+
col("a", &schema)?,
245+
"bla".to_string(),
246+
datatype,
247+
));
214248

215-
Ok(())
249+
let mut accum1 = agg.create_accumulator()?;
250+
let mut accum2 = agg.create_accumulator()?;
251+
252+
accum1.update_batch(&[input1])?;
253+
accum2.update_batch(&[input2])?;
254+
255+
let state = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
256+
accum1.merge_batch(&state)?;
257+
258+
let actual = accum1.evaluate()?;
259+
260+
compare_list_contents(expected, actual)
216261
}
217262

218263
#[test]
@@ -233,6 +278,27 @@ mod tests {
233278
check_distinct_array_agg(col, out, DataType::Int32)
234279
}
235280

281+
#[test]
282+
fn merge_distinct_array_agg_i32() -> Result<()> {
283+
let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));
284+
let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4]));
285+
286+
let out = ScalarValue::new_list(
287+
Some(vec![
288+
ScalarValue::Int32(Some(1)),
289+
ScalarValue::Int32(Some(2)),
290+
ScalarValue::Int32(Some(3)),
291+
ScalarValue::Int32(Some(4)),
292+
ScalarValue::Int32(Some(5)),
293+
ScalarValue::Int32(Some(7)),
294+
ScalarValue::Int32(Some(8)),
295+
]),
296+
DataType::Int32,
297+
);
298+
299+
check_merge_distinct_array_agg(col1, col2, out, DataType::Int32)
300+
}
301+
236302
#[test]
237303
fn distinct_array_agg_nested() -> Result<()> {
238304
// [[1, 2, 3], [4, 5]]
@@ -296,4 +362,66 @@ mod tests {
296362
))),
297363
)
298364
}
365+
366+
#[test]
367+
fn merge_distinct_array_agg_nested() -> Result<()> {
368+
// [[1, 2], [3, 4]]
369+
let l1 = ScalarValue::new_list(
370+
Some(vec![
371+
ScalarValue::new_list(
372+
Some(vec![ScalarValue::from(1i32), ScalarValue::from(2i32)]),
373+
DataType::Int32,
374+
),
375+
ScalarValue::new_list(
376+
Some(vec![ScalarValue::from(3i32), ScalarValue::from(4i32)]),
377+
DataType::Int32,
378+
),
379+
]),
380+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
381+
);
382+
383+
// [[5]]
384+
let l2 = ScalarValue::new_list(
385+
Some(vec![ScalarValue::new_list(
386+
Some(vec![ScalarValue::from(5i32)]),
387+
DataType::Int32,
388+
)]),
389+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
390+
);
391+
392+
// [[6, 7], [8]]
393+
let l3 = ScalarValue::new_list(
394+
Some(vec![
395+
ScalarValue::new_list(
396+
Some(vec![ScalarValue::from(6i32), ScalarValue::from(7i32)]),
397+
DataType::Int32,
398+
),
399+
ScalarValue::new_list(
400+
Some(vec![ScalarValue::from(8i32)]),
401+
DataType::Int32,
402+
),
403+
]),
404+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
405+
);
406+
407+
let expected = ScalarValue::new_list(
408+
Some(vec![l1.clone(), l2.clone(), l3.clone()]),
409+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
410+
);
411+
412+
// Duplicate l1 in the input array and check that it is deduped in the output.
413+
let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap();
414+
let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap();
415+
416+
check_merge_distinct_array_agg(
417+
input1,
418+
input2,
419+
expected,
420+
DataType::List(Arc::new(Field::new_list(
421+
"item",
422+
Field::new("item", DataType::Int32, true),
423+
true,
424+
))),
425+
)
426+
}
299427
}

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,14 +1271,31 @@ NULL 4 29 1.260869565217 123 -117 23
12711271
NULL 5 -194 -13.857142857143 118 -101 14
12721272
NULL NULL 781 7.81 125 -117 100
12731273

1274-
# TODO this querys output is non determinisitic (the order of the elements
1275-
# differs run to run
1274+
# TODO: array_agg_distinct output is non-determinisitic -- rewrite with array_sort(list_sort)
1275+
# unnest is also not available, so manually unnesting via CROSS JOIN
1276+
# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data
12761277
#
12771278
# csv_query_array_agg_distinct
1278-
# query T
1279-
# SELECT array_agg(distinct c2) FROM aggregate_test_100
1280-
# ----
1281-
# [4, 2, 3, 5, 1]
1279+
query III
1280+
WITH indices AS (
1281+
SELECT 1 AS idx UNION ALL
1282+
SELECT 2 AS idx UNION ALL
1283+
SELECT 3 AS idx UNION ALL
1284+
SELECT 4 AS idx UNION ALL
1285+
SELECT 5 AS idx
1286+
)
1287+
SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy
1288+
FROM (
1289+
SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100
1290+
) data
1291+
CROSS JOIN indices
1292+
ORDER BY 1
1293+
----
1294+
1 5 100
1295+
2 5 100
1296+
3 5 100
1297+
4 5 100
1298+
5 5 100
12821299

12831300
# aggregate_time_min_and_max
12841301
query TT

0 commit comments

Comments
 (0)