Skip to content

Commit 3d10036

Browse files
lewiszlwfindepi
authored andcommitted
Convert StringAgg to UDAF (apache#10945)
* Convert StringAgg to UDAF * generate proto code * Fix bug * Fix * Add license * Add doc * Fix clippy * Remove aliases field * Add StringAgg proto test * Add roundtrip_expr_api test
1 parent 85c8f58 commit 3d10036

File tree

17 files changed

+192
-321
lines changed

17 files changed

+192
-321
lines changed

datafusion/expr/src/aggregate_function.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ pub enum AggregateFunction {
5151
BoolAnd,
5252
/// Bool Or
5353
BoolOr,
54-
/// String aggregation
55-
StringAgg,
5654
}
5755

5856
impl AggregateFunction {
@@ -68,7 +66,6 @@ impl AggregateFunction {
6866
Grouping => "GROUPING",
6967
BoolAnd => "BOOL_AND",
7068
BoolOr => "BOOL_OR",
71-
StringAgg => "STRING_AGG",
7269
}
7370
}
7471
}
@@ -92,7 +89,6 @@ impl FromStr for AggregateFunction {
9289
"min" => AggregateFunction::Min,
9390
"array_agg" => AggregateFunction::ArrayAgg,
9491
"nth_value" => AggregateFunction::NthValue,
95-
"string_agg" => AggregateFunction::StringAgg,
9692
// statistical
9793
"corr" => AggregateFunction::Correlation,
9894
// other
@@ -146,7 +142,6 @@ impl AggregateFunction {
146142
)))),
147143
AggregateFunction::Grouping => Ok(DataType::Int32),
148144
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
149-
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
150145
}
151146
}
152147
}
@@ -195,9 +190,6 @@ impl AggregateFunction {
195190
AggregateFunction::Correlation => {
196191
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
197192
}
198-
AggregateFunction::StringAgg => {
199-
Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
200-
}
201193
}
202194
}
203195
}

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -145,23 +145,6 @@ pub fn coerce_types(
145145
}
146146
AggregateFunction::NthValue => Ok(input_types.to_vec()),
147147
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
148-
AggregateFunction::StringAgg => {
149-
if !is_string_agg_supported_arg_type(&input_types[0]) {
150-
return plan_err!(
151-
"The function {:?} does not support inputs of type {:?}",
152-
agg_fun,
153-
input_types[0]
154-
);
155-
}
156-
if !is_string_agg_supported_arg_type(&input_types[1]) {
157-
return plan_err!(
158-
"The function {:?} does not support inputs of type {:?}",
159-
agg_fun,
160-
input_types[1]
161-
);
162-
}
163-
Ok(vec![LargeUtf8, input_types[1].clone()])
164-
}
165148
}
166149
}
167150

@@ -391,15 +374,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
391374
arg_type.is_integer()
392375
}
393376

394-
/// Return `true` if `arg_type` is of a [`DataType`] that the
395-
/// [`AggregateFunction::StringAgg`] aggregation can operate on.
396-
pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
397-
matches!(
398-
arg_type,
399-
DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
400-
)
401-
}
402-
403377
#[cfg(test)]
404378
mod tests {
405379
use super::*;

datafusion/functions-aggregate/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pub mod approx_median;
7070
pub mod approx_percentile_cont;
7171
pub mod approx_percentile_cont_with_weight;
7272
pub mod bit_and_or_xor;
73+
pub mod string_agg;
7374

7475
use crate::approx_percentile_cont::approx_percentile_cont_udaf;
7576
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
@@ -138,6 +139,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
138139
approx_distinct::approx_distinct_udaf(),
139140
approx_percentile_cont_udaf(),
140141
approx_percentile_cont_with_weight_udaf(),
142+
string_agg::string_agg_udaf(),
141143
bit_and_or_xor::bit_and_udaf(),
142144
bit_and_or_xor::bit_or_udaf(),
143145
bit_and_or_xor::bit_xor_udaf(),
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function
19+
20+
use arrow::array::ArrayRef;
21+
use arrow_schema::DataType;
22+
use datafusion_common::cast::as_generic_string_array;
23+
use datafusion_common::Result;
24+
use datafusion_common::{not_impl_err, ScalarValue};
25+
use datafusion_expr::function::AccumulatorArgs;
26+
use datafusion_expr::{
27+
Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility,
28+
};
29+
use std::any::Any;
30+
31+
make_udaf_expr_and_func!(
32+
StringAgg,
33+
string_agg,
34+
expr delimiter,
35+
"Concatenates the values of string expressions and places separator values between them",
36+
string_agg_udaf
37+
);
38+
39+
/// STRING_AGG aggregate expression
40+
#[derive(Debug)]
41+
pub struct StringAgg {
42+
signature: Signature,
43+
}
44+
45+
impl StringAgg {
46+
/// Create a new StringAgg aggregate function
47+
pub fn new() -> Self {
48+
Self {
49+
signature: Signature::one_of(
50+
vec![
51+
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
52+
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
53+
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
54+
],
55+
Volatility::Immutable,
56+
),
57+
}
58+
}
59+
}
60+
61+
impl Default for StringAgg {
62+
fn default() -> Self {
63+
Self::new()
64+
}
65+
}
66+
67+
impl AggregateUDFImpl for StringAgg {
68+
fn as_any(&self) -> &dyn Any {
69+
self
70+
}
71+
72+
fn name(&self) -> &str {
73+
"string_agg"
74+
}
75+
76+
fn signature(&self) -> &Signature {
77+
&self.signature
78+
}
79+
80+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
81+
Ok(DataType::LargeUtf8)
82+
}
83+
84+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
85+
match &acc_args.input_exprs[1] {
86+
Expr::Literal(ScalarValue::Utf8(Some(delimiter)))
87+
| Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => {
88+
Ok(Box::new(StringAggAccumulator::new(delimiter)))
89+
}
90+
Expr::Literal(ScalarValue::Utf8(None))
91+
| Expr::Literal(ScalarValue::LargeUtf8(None))
92+
| Expr::Literal(ScalarValue::Null) => {
93+
Ok(Box::new(StringAggAccumulator::new("")))
94+
}
95+
_ => not_impl_err!(
96+
"StringAgg not supported for delimiter {}",
97+
&acc_args.input_exprs[1]
98+
),
99+
}
100+
}
101+
}
102+
103+
#[derive(Debug)]
104+
pub(crate) struct StringAggAccumulator {
105+
values: Option<String>,
106+
delimiter: String,
107+
}
108+
109+
impl StringAggAccumulator {
110+
pub fn new(delimiter: &str) -> Self {
111+
Self {
112+
values: None,
113+
delimiter: delimiter.to_string(),
114+
}
115+
}
116+
}
117+
118+
impl Accumulator for StringAggAccumulator {
119+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
120+
let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
121+
.iter()
122+
.filter_map(|v| v.as_ref().map(ToString::to_string))
123+
.collect();
124+
if !string_array.is_empty() {
125+
let s = string_array.join(self.delimiter.as_str());
126+
let v = self.values.get_or_insert("".to_string());
127+
if !v.is_empty() {
128+
v.push_str(self.delimiter.as_str());
129+
}
130+
v.push_str(s.as_str());
131+
}
132+
Ok(())
133+
}
134+
135+
fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
136+
self.update_batch(values)?;
137+
Ok(())
138+
}
139+
140+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
141+
Ok(vec![self.evaluate()?])
142+
}
143+
144+
fn evaluate(&mut self) -> Result<ScalarValue> {
145+
Ok(ScalarValue::LargeUtf8(self.values.clone()))
146+
}
147+
148+
fn size(&self) -> usize {
149+
std::mem::size_of_val(self)
150+
+ self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
151+
+ self.delimiter.capacity()
152+
}
153+
}

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,6 @@ pub fn create_aggregate_expr(
155155
ordering_req.to_vec(),
156156
))
157157
}
158-
(AggregateFunction::StringAgg, false) => {
159-
if !ordering_req.is_empty() {
160-
return not_impl_err!(
161-
"STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available"
162-
);
163-
}
164-
Arc::new(expressions::StringAgg::new(
165-
input_phy_exprs[0].clone(),
166-
input_phy_exprs[1].clone(),
167-
name,
168-
data_type,
169-
))
170-
}
171-
(AggregateFunction::StringAgg, true) => {
172-
return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available");
173-
}
174158
})
175159
}
176160

datafusion/physical-expr/src/aggregate/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ pub(crate) mod correlation;
2626
pub(crate) mod covariance;
2727
pub(crate) mod grouping;
2828
pub(crate) mod nth_value;
29-
pub(crate) mod string_agg;
3029
#[macro_use]
3130
pub(crate) mod min_max;
3231
pub(crate) mod groups_accumulator;

0 commit comments

Comments
 (0)