Skip to content

Commit cd7d998

Browse files
author
Cheng-Yuan-Lai
committed
feat: enhance function argument diagnostics with detailed error handling
1 parent dff5d97 commit cd7d998

File tree

2 files changed

+201
-4
lines changed

2 files changed

+201
-4
lines changed

datafusion/expr/src/test/function_stub.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use arrow::datatypes::{
2525
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
2626
};
2727

28+
use datafusion_common::Spans;
2829
use datafusion_common::{
2930
exec_err, not_impl_err, utils::take_function_args_with_diag, Result,
3031
};
@@ -97,14 +98,36 @@ pub fn avg(expr: Expr) -> Expr {
9798
#[derive(Debug)]
9899
pub struct Sum {
99100
signature: Signature,
101+
/// Original source code location, if known
102+
pub spans: Spans,
100103
}
101104

102105
impl Sum {
103106
pub fn new() -> Self {
104107
Self {
105108
signature: Signature::user_defined(Immutable),
109+
spans: Spans::new(),
106110
}
107111
}
112+
113+
/// Returns a reference to the set of locations in the SQL query where this
114+
/// column appears, if known.
115+
pub fn spans(&self) -> &Spans {
116+
&self.spans
117+
}
118+
119+
/// Returns a mutable reference to the set of locations in the SQL query
120+
/// where this column appears, if known.
121+
pub fn spans_mut(&mut self) -> &mut Spans {
122+
&mut self.spans
123+
}
124+
125+
/// Replaces the set of locations in the SQL query where this column
126+
/// appears, if known.
127+
pub fn with_spans(mut self, spans: Spans) -> Self {
128+
self.spans = spans;
129+
self
130+
}
108131
}
109132

110133
impl Default for Sum {
@@ -127,7 +150,9 @@ impl AggregateUDFImpl for Sum {
127150
}
128151

129152
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
130-
let [array] = take_function_args_with_diag(self.name(), arg_types, None)?;
153+
dbg!(self.spans());
154+
let [array] =
155+
take_function_args_with_diag(self.name(), arg_types, self.spans().first())?;
131156

132157
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
133158
// smallint, int, bigint, real, double precision, decimal, or interval.

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ use arrow::{
2121
compute::can_cast_types,
2222
datatypes::{DataType, Field, TimeUnit},
2323
};
24+
use datafusion_common::{Diagnostic, Span};
2425
use datafusion_common::types::LogicalType;
2526
use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
2627
use datafusion_common::{
27-
exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28+
exec_err, internal_datafusion_err, internal_err, plan_err, plan_datafusion_err,types::NativeType,
2829
utils::list_ndims, Result,
2930
};
3031
use datafusion_expr_common::signature::ArrayFunctionArgument;
@@ -249,6 +250,177 @@ fn try_coerce_types(
249250
)
250251
}
251252

253+
pub fn check_function_length_with_diag(
254+
function_name: &str,
255+
signature: &TypeSignature,
256+
current_types: &[DataType],
257+
function_call_site: Option<Span>,
258+
) -> Result<()> {
259+
// Special handling for zero arguments
260+
if current_types.is_empty() {
261+
if signature.supports_zero_argument() {
262+
return Ok(());
263+
} else if signature.used_to_support_zero_arguments() {
264+
// Special error to help during upgrade
265+
let base_error = plan_datafusion_err!(
266+
"'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
267+
function_name
268+
);
269+
let mut diagnostic = Diagnostic::new_error(
270+
format!("Zero arguments not supported for {} function", function_name),
271+
function_call_site,
272+
);
273+
diagnostic.add_help(
274+
"Use TypeSignature::Nullary for functions that take no arguments",
275+
None,
276+
);
277+
return Err(base_error.with_diagnostic(diagnostic));
278+
} else {
279+
let base_error = plan_datafusion_err!(
280+
"'{}' does not support zero arguments",
281+
function_name
282+
);
283+
let mut diagnostic = Diagnostic::new_error(
284+
format!("Zero arguments not supported for {} function", function_name),
285+
function_call_site,
286+
);
287+
diagnostic.add_note(
288+
format!("Function {} requires at least one argument", function_name),
289+
None,
290+
);
291+
return Err(base_error.with_diagnostic(diagnostic));
292+
}
293+
}
294+
295+
// Helper closure to create and return an error with diagnostic information
296+
let create_error = |expected: &str, got: usize| {
297+
let base_error = plan_datafusion_err!(
298+
"Function '{}' {}, got {}",
299+
function_name,
300+
expected,
301+
got
302+
);
303+
304+
let mut diagnostic = Diagnostic::new_error(
305+
format!(
306+
"Wrong number of arguments for {} function call",
307+
function_name
308+
),
309+
function_call_site,
310+
);
311+
diagnostic.add_note(
312+
format!(
313+
"Function {} {}, but {} {} provided",
314+
function_name,
315+
expected,
316+
got,
317+
if got == 1 { "was" } else { "were" }
318+
),
319+
None,
320+
);
321+
322+
Err(base_error.with_diagnostic(diagnostic))
323+
};
324+
325+
match signature {
326+
TypeSignature::Uniform(num, _) |
327+
TypeSignature::Numeric(num) |
328+
TypeSignature::String(num) |
329+
TypeSignature::Comparable(num) |
330+
TypeSignature::Any(num) => {
331+
return create_error(&format!("expects {} arguments", num), current_types.len());
332+
},
333+
TypeSignature::Exact(types) => {
334+
if current_types.len() != types.len() {
335+
return create_error(&format!("expects {} arguments", types.len()), current_types.len());
336+
}
337+
},
338+
TypeSignature::Coercible(types) => {
339+
if current_types.len() != types.len() {
340+
return create_error(&format!("expects {} arguments", types.len()), current_types.len());
341+
}
342+
},
343+
TypeSignature::Nullary => {
344+
if !current_types.is_empty() {
345+
return create_error("expects zero arguments", current_types.len());
346+
}
347+
},
348+
TypeSignature::ArraySignature(array_signature) => {
349+
match array_signature {
350+
ArrayFunctionSignature::Array { arguments, .. } => {
351+
if current_types.len() != arguments.len() {
352+
return create_error(&format!("expects {} arguments", arguments.len()), current_types.len());
353+
}
354+
},
355+
ArrayFunctionSignature::RecursiveArray => {
356+
if current_types.len() != 1 {
357+
return create_error("expects exactly one array argument", current_types.len());
358+
}
359+
},
360+
ArrayFunctionSignature::MapArray => {
361+
if current_types.len() != 1 {
362+
return create_error("expects exactly one map argument", current_types.len());
363+
}
364+
},
365+
}
366+
},
367+
TypeSignature::OneOf(signatures) => {
368+
// For OneOf, we'll consider it valid if it matches ANY of the signatures
369+
// We'll collect all errors to provide better diagnostics if nothing matches
370+
let mut all_errors = Vec::new();
371+
372+
for sig in signatures {
373+
match check_function_length_with_diag(function_name, sig, current_types, function_call_site) {
374+
Ok(()) => return Ok(()), // If any signature matches, return immediately
375+
Err(e) => all_errors.push(e),
376+
}
377+
}
378+
379+
// If we're here, none of the signatures matched
380+
if !all_errors.is_empty() {
381+
let error_messages = all_errors
382+
.iter()
383+
.map(|e| e.to_string())
384+
.collect::<Vec<_>>()
385+
.join("; ");
386+
387+
let base_error = plan_datafusion_err!(
388+
"Function '{}' has no matching signature for {} arguments. Errors: {}",
389+
function_name,
390+
current_types.len(),
391+
error_messages
392+
);
393+
394+
let mut diagnostic = Diagnostic::new_error(
395+
format!(
396+
"No matching signature for {} function with {} arguments",
397+
function_name,
398+
current_types.len()
399+
),
400+
function_call_site,
401+
);
402+
403+
diagnostic.add_note(
404+
format!("The function {} has multiple possible signatures", function_name),
405+
None,
406+
);
407+
408+
return Err(base_error.with_diagnostic(diagnostic));
409+
}
410+
},
411+
// Signatures that accept variable numbers of arguments or are handled specially
412+
TypeSignature::Variadic(_) | TypeSignature::VariadicAny | TypeSignature::UserDefined => {
413+
// These cases are implicitly valid for any non-zero number of arguments:
414+
// - Variadic: accepts one or more arguments of specified types
415+
// - VariadicAny: accepts one or more arguments of any type
416+
// - UserDefined: custom validation handled by the UDF itself
417+
}
418+
419+
}
420+
421+
Ok(())
422+
}
423+
252424
fn get_valid_types_with_scalar_udf(
253425
signature: &TypeSignature,
254426
current_types: &[DataType],
@@ -267,6 +439,7 @@ fn get_valid_types_with_scalar_udf(
267439
let mut res = vec![];
268440
let mut errors = vec![];
269441
for sig in signatures {
442+
270443
match get_valid_types_with_scalar_udf(sig, current_types, func) {
271444
Ok(valid_types) => {
272445
res.extend(valid_types);
@@ -615,8 +788,7 @@ fn get_valid_types(
615788
}
616789
}
617790
TypeSignature::Coercible(param_types) => {
618-
function_length_check(function_name, current_types.len(), param_types.len())?;
619-
791+
620792
let mut new_types = Vec::with_capacity(current_types.len());
621793
for (current_type, param) in current_types.iter().zip(param_types.iter()) {
622794
let current_native_type: NativeType = current_type.into();

0 commit comments

Comments
 (0)