Skip to content

Commit 720ed82

Browse files
committed
cleanup return_type()
1 parent 0506a5c commit 720ed82

File tree

4 files changed

+42
-41
lines changed

4 files changed

+42
-41
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@ use crate::signature::TIMEZONE_WILDCARD;
2828
use crate::type_coercion::binary::get_wider_type;
2929
use crate::type_coercion::functions::data_types;
3030
use crate::{
31-
conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature,
31+
conditional_expressions, struct_expressions, FuncMonotonicity, Signature,
3232
TypeSignature, Volatility,
3333
};
3434

3535
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
36-
use datafusion_common::{
37-
internal_err, plan_datafusion_err, plan_err, DataFusionError, Result,
38-
};
36+
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
3937

4038
use strum::IntoEnumIterator;
4139
use strum_macros::EnumIter;
@@ -483,38 +481,20 @@ impl BuiltinScalarFunction {
483481
}
484482

485483
/// Returns the output [`DataType`] of this function
484+
///
485+
/// This method should be invoked only after `input_expr_types` have been validated
486+
/// against the function's `TypeSignature` using `type_coercion::functions::data_types()`.
487+
///
488+
/// This method will:
489+
/// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation.
490+
/// 2. Deduce the output `DataType` based on the provided `input_expr_types`.
486491
pub fn return_type(self, input_expr_types: &[DataType]) -> Result<DataType> {
487492
use DataType::*;
488493
use TimeUnit::*;
489494

490495
// Note that this function *must* return the same type that the respective physical expression returns
491496
// or the execution panics.
492497

493-
if input_expr_types.is_empty()
494-
&& !self.signature().type_signature.supports_zero_argument()
495-
{
496-
return plan_err!(
497-
"{}",
498-
utils::generate_signature_error_msg(
499-
&format!("{self}"),
500-
self.signature(),
501-
input_expr_types
502-
)
503-
);
504-
}
505-
506-
// verify that this is a valid set of data types for this function
507-
data_types(input_expr_types, &self.signature()).map_err(|_| {
508-
plan_datafusion_err!(
509-
"{}",
510-
utils::generate_signature_error_msg(
511-
&format!("{self}"),
512-
self.signature(),
513-
input_expr_types,
514-
)
515-
)
516-
})?;
517-
518498
// the return type of the built in function.
519499
// Some built-in functions' return type depends on the incoming type.
520500
match self {

datafusion/expr/src/expr_schema.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use crate::expr::{
2323
};
2424
use crate::field_util::GetFieldAccessSchema;
2525
use crate::type_coercion::binary::get_result_type;
26-
use crate::{LogicalPlan, Projection, Subquery};
26+
use crate::type_coercion::functions::data_types;
27+
use crate::{utils, LogicalPlan, Projection, Subquery};
2728
use arrow::compute::can_cast_types;
2829
use arrow::datatypes::{DataType, Field};
2930
use datafusion_common::{
@@ -89,12 +90,24 @@ impl ExprSchemable for Expr {
8990
Ok((fun.return_type)(&data_types)?.as_ref().clone())
9091
}
9192
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
92-
let data_types = args
93+
let arg_data_types = args
9394
.iter()
9495
.map(|e| e.get_type(schema))
9596
.collect::<Result<Vec<_>>>()?;
9697

97-
fun.return_type(&data_types)
98+
// verify that input data types is consistent with function's `TypeSignature`
99+
data_types(&arg_data_types, &fun.signature()).map_err(|_| {
100+
plan_datafusion_err!(
101+
"{}",
102+
utils::generate_signature_error_msg(
103+
&format!("{fun}"),
104+
fun.signature(),
105+
&arg_data_types,
106+
)
107+
)
108+
})?;
109+
110+
fun.return_type(&arg_data_types)
98111
}
99112
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
100113
let data_types = args

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,17 @@ pub fn data_types(
3535
signature: &Signature,
3636
) -> Result<Vec<DataType>> {
3737
if current_types.is_empty() {
38-
return Ok(vec![]);
38+
if signature.type_signature.supports_zero_argument() {
39+
return Ok(vec![]);
40+
} else {
41+
return plan_err!(
42+
"Coercion from {:?} to the signature {:?} failed.",
43+
current_types,
44+
&signature.type_signature
45+
);
46+
}
3947
}
48+
4049
let valid_types = get_valid_types(&signature.type_signature, current_types)?;
4150

4251
if valid_types

datafusion/physical-expr/src/functions.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ use arrow::{
4747
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
4848
pub use datafusion_expr::FuncMonotonicity;
4949
use datafusion_expr::{
50-
BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation,
50+
type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue,
51+
ScalarFunctionImplementation,
5152
};
5253
use std::ops::Neg;
5354
use std::sync::Arc;
@@ -65,6 +66,9 @@ pub fn create_physical_expr(
6566
.map(|e| e.data_type(input_schema))
6667
.collect::<Result<Vec<_>>>()?;
6768

69+
// verify that input data types is consistent with function's `TypeSignature`
70+
data_types(&input_expr_types, &fun.signature())?;
71+
6872
let data_type = fun.return_type(&input_expr_types)?;
6973

7074
let fun_expr: ScalarFunctionImplementation = match fun {
@@ -2952,13 +2956,8 @@ mod tests {
29522956
"Builtin scalar function {fun} does not support empty arguments"
29532957
);
29542958
}
2955-
Err(DataFusionError::Plan(err)) => {
2956-
if !err
2957-
.contains("No function matches the given name and argument types")
2958-
{
2959-
return plan_err!(
2960-
"Builtin scalar function {fun} didn't got the right error message with empty arguments");
2961-
}
2959+
Err(DataFusionError::Plan(_)) => {
2960+
// Continue the loop
29622961
}
29632962
Err(..) => {
29642963
return internal_err!(

0 commit comments

Comments
 (0)