Skip to content

Commit 600b815

Browse files
authored
Minor: Signature check for UDAF (#10147)
* add sig for udaf Signed-off-by: jayzhan211 <[email protected]> * fix test Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 0e21281 commit 600b815

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

datafusion/expr/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ pub use logical_plan::*;
7979
pub use operator::Operator;
8080
pub use partition_evaluator::PartitionEvaluator;
8181
pub use signature::{
82-
FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
82+
ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility,
83+
TIMEZONE_WILDCARD,
8384
};
8485
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
8586
pub use udaf::{AggregateUDF, AggregateUDFImpl};

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pub fn coerce_types(
9393
) -> Result<Vec<DataType>> {
9494
use DataType::*;
9595
// Validate input_types matches (at least one of) the func signature.
96-
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
96+
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;
9797

9898
match agg_fun {
9999
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
@@ -323,17 +323,16 @@ pub fn coerce_types(
323323
/// This method DOES NOT validate the argument types - only that (at least one,
324324
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
325325
/// number of input types.
326-
fn check_arg_count(
327-
agg_fun: &AggregateFunction,
326+
pub fn check_arg_count(
327+
func_name: &str,
328328
input_types: &[DataType],
329329
signature: &TypeSignature,
330330
) -> Result<()> {
331331
match signature {
332332
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
333333
if input_types.len() != *agg_count {
334334
return plan_err!(
335-
"The function {:?} expects {:?} arguments, but {:?} were provided",
336-
agg_fun,
335+
"The function {func_name} expects {:?} arguments, but {:?} were provided",
337336
agg_count,
338337
input_types.len()
339338
);
@@ -342,8 +341,7 @@ fn check_arg_count(
342341
TypeSignature::Exact(types) => {
343342
if types.len() != input_types.len() {
344343
return plan_err!(
345-
"The function {:?} expects {:?} arguments, but {:?} were provided",
346-
agg_fun,
344+
"The function {func_name} expects {:?} arguments, but {:?} were provided",
347345
types.len(),
348346
input_types.len()
349347
);
@@ -352,19 +350,18 @@ fn check_arg_count(
352350
TypeSignature::OneOf(variants) => {
353351
let ok = variants
354352
.iter()
355-
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
353+
.any(|v| check_arg_count(func_name, input_types, v).is_ok());
356354
if !ok {
357355
return plan_err!(
358-
"The function {:?} does not accept {:?} function arguments.",
359-
agg_fun,
356+
"The function {func_name} does not accept {:?} function arguments.",
360357
input_types.len()
361358
);
362359
}
363360
}
364361
TypeSignature::VariadicAny => {
365362
if input_types.is_empty() {
366363
return plan_err!(
367-
"The function {agg_fun:?} expects at least one argument"
364+
"The function {func_name} expects at least one argument"
368365
);
369366
}
370367
}
@@ -594,7 +591,7 @@ mod tests {
594591
let input_types = vec![DataType::Int64, DataType::Int32];
595592
let signature = fun.signature();
596593
let result = coerce_types(&fun, &input_types, &signature);
597-
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());
594+
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());
598595

599596
// test input args is invalid data type for sum or avg
600597
let fun = AggregateFunction::Sum;

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ use datafusion_common::{
2727
use datafusion_expr::function::AccumulatorArgs;
2828
use datafusion_expr::type_coercion::aggregates::NUMERICS;
2929
use datafusion_expr::utils::format_state_name;
30-
use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility};
30+
use datafusion_expr::{
31+
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature,
32+
TypeSignature, Volatility,
33+
};
3134
use datafusion_physical_expr_common::aggregate::utils::{
3235
down_cast_any_ref, get_sort_options, ordering_fields,
3336
};
@@ -73,7 +76,14 @@ impl FirstValue {
7376
pub fn new() -> Self {
7477
Self {
7578
aliases: vec![String::from("FIRST_VALUE")],
76-
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
79+
signature: Signature::one_of(
80+
vec![
81+
// TODO: we can introduce more strict signature that only numeric of array types are allowed
82+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
83+
TypeSignature::Uniform(1, NUMERICS.to_vec()),
84+
],
85+
Volatility::Immutable,
86+
),
7787
}
7888
}
7989
}

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod utils;
1919

2020
use arrow::datatypes::{DataType, Field, Schema};
2121
use datafusion_common::{not_impl_err, Result};
22+
use datafusion_expr::type_coercion::aggregates::check_arg_count;
2223
use datafusion_expr::{
2324
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
2425
};
@@ -46,6 +47,12 @@ pub fn create_aggregate_expr(
4647
.map(|arg| arg.data_type(schema))
4748
.collect::<Result<Vec<_>>>()?;
4849

50+
check_arg_count(
51+
fun.name(),
52+
&input_exprs_types,
53+
&fun.signature().type_signature,
54+
)?;
55+
4956
let ordering_types = ordering_req
5057
.iter()
5158
.map(|e| e.expr.data_type(schema))

0 commit comments

Comments
 (0)