Skip to content

Commit 854ed60

Browse files
committed
fix: coalesce schema issues
closes #12307
1 parent 26c8004 commit 854ed60

File tree

14 files changed

+335
-128
lines changed

14 files changed

+335
-128
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,43 @@ mod tests {
20292029
Ok(())
20302030
}
20312031

2032+
#[tokio::test]
2033+
async fn test_coalesce_schema() -> Result<()> {
2034+
let ctx = SessionContext::new();
2035+
2036+
let query = r#"SELECT COALESCE(null, 5)"#;
2037+
2038+
let result = ctx.sql(query).await?;
2039+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
2040+
Ok(())
2041+
}
2042+
2043+
#[tokio::test]
2044+
async fn test_coalesce_from_values_schema() -> Result<()> {
2045+
let ctx = SessionContext::new();
2046+
2047+
let query = r#"SELECT COALESCE(column1, column2) FROM VALUES (null, 1.2)"#;
2048+
2049+
let result = ctx.sql(query).await?;
2050+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
2051+
Ok(())
2052+
}
2053+
2054+
#[tokio::test]
2055+
async fn test_coalesce_from_values_schema_multiple_rows() -> Result<()> {
2056+
let ctx = SessionContext::new();
2057+
2058+
let query = r#"SELECT COALESCE(column1, column2)
2059+
FROM VALUES
2060+
(null, 1.2),
2061+
(1.1, null),
2062+
(2, 5);"#;
2063+
2064+
let result = ctx.sql(query).await?;
2065+
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
2066+
Ok(())
2067+
}
2068+
20322069
#[tokio::test]
20332070
async fn test_array_agg_schema() -> Result<()> {
20342071
let ctx = SessionContext::new();

datafusion/expr/src/expr_schema.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,22 @@ impl ExprSchemable for Expr {
151151
.collect::<Result<Vec<_>>>()?;
152152

153153
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
154-
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
155-
plan_datafusion_err!(
156-
"{} {}",
157-
err,
158-
utils::generate_signature_error_msg(
159-
func.name(),
160-
func.signature().clone(),
161-
&arg_data_types,
154+
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
155+
.map_err(|err| {
156+
plan_datafusion_err!(
157+
"{} {}",
158+
err,
159+
utils::generate_signature_error_msg(
160+
func.name(),
161+
func.signature().clone(),
162+
&arg_data_types,
163+
)
162164
)
163-
)
164-
})?;
165+
})?;
165166

166167
// perform additional function arguments validation (due to limited
167168
// expressiveness of `TypeSignature`), then infer return type
168-
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
169+
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
169170
}
170171
Expr::WindowFunction(window_function) => self
171172
.data_type_and_nullable_with_window_function(schema, window_function)

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ impl LogicalPlanBuilder {
216216
common_type = Some(data_type);
217217
}
218218
}
219-
field_types.push(common_type.unwrap_or(DataType::Utf8));
219+
// assuming common_type was not set, and no error, therefore the type should be NULL
220+
// since the code loop skips NULL
221+
field_types.push(common_type.unwrap_or(DataType::Null));
220222
}
221223
// wrap cast if data type is not same as common type.
222224
for row in &mut values {

datafusion/functions/src/core/coalesce.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ use arrow::array::{new_null_array, BooleanArray};
2121
use arrow::compute::kernels::zip::zip;
2222
use arrow::compute::{and, is_not_null, is_null};
2323
use arrow::datatypes::DataType;
24-
2524
use datafusion_common::{exec_err, ExprSchema, Result};
2625
use datafusion_expr::type_coercion::binary::type_union_resolution;
2726
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
2827
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28+
use itertools::Itertools;
2929

3030
#[derive(Debug)]
3131
pub struct CoalesceFunc {
@@ -60,12 +60,16 @@ impl ScalarUDFImpl for CoalesceFunc {
6060
}
6161

6262
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
63-
Ok(arg_types[0].clone())
63+
Ok(arg_types
64+
.iter()
65+
.find_or_first(|d| !d.is_null())
66+
.unwrap()
67+
.clone())
6468
}
6569

66-
// If all the element in coalesce is non-null, the result is non-null
70+
// If any the arguments in coalesce is non-null, the result is non-null
6771
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
68-
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
72+
args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true))
6973
}
7074

7175
/// coalesce evaluates to the first value which is not NULL
@@ -154,4 +158,22 @@ mod test {
154158
.unwrap();
155159
assert_eq!(return_type, DataType::Date32);
156160
}
161+
162+
#[test]
163+
fn test_coalesce_return_types_with_nulls_first() {
164+
let coalesce = core::coalesce::CoalesceFunc::new();
165+
let return_type = coalesce
166+
.return_type(&[DataType::Null, DataType::Date32])
167+
.unwrap();
168+
assert_eq!(return_type, DataType::Date32);
169+
}
170+
171+
#[test]
172+
fn test_coalesce_return_types_with_nulls_last() {
173+
let coalesce = core::coalesce::CoalesceFunc::new();
174+
let return_type = coalesce
175+
.return_type(&[DataType::Int64, DataType::Null])
176+
.unwrap();
177+
assert_eq!(return_type, DataType::Int64);
178+
}
157179
}

datafusion/functions/src/datetime/to_local_time.rs

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,16 @@ use std::sync::Arc;
2222
use arrow::array::timezone::Tz;
2323
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
2424
use arrow::datatypes::DataType::Timestamp;
25+
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
2526
use arrow::datatypes::{
2627
ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType,
2728
TimestampNanosecondType, TimestampSecondType,
2829
};
29-
use arrow::datatypes::{
30-
TimeUnit,
31-
TimeUnit::{Microsecond, Millisecond, Nanosecond, Second},
32-
};
3330

3431
use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc};
3532
use datafusion_common::cast::as_primitive_array;
36-
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
37-
use datafusion_expr::TypeSignature::Exact;
38-
use datafusion_expr::{
39-
ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
40-
};
33+
use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue};
34+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
4135

4236
/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or
4337
/// timezone information). In other words, this function strips off the timezone from the timestamp,
@@ -55,20 +49,8 @@ impl Default for ToLocalTimeFunc {
5549

5650
impl ToLocalTimeFunc {
5751
pub fn new() -> Self {
58-
let base_sig = |array_type: TimeUnit| {
59-
[
60-
Exact(vec![Timestamp(array_type, None)]),
61-
Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]),
62-
]
63-
};
64-
65-
let full_sig = [Nanosecond, Microsecond, Millisecond, Second]
66-
.into_iter()
67-
.flat_map(base_sig)
68-
.collect::<Vec<_>>();
69-
7052
Self {
71-
signature: Signature::one_of(full_sig, Volatility::Immutable),
53+
signature: Signature::user_defined(Volatility::Immutable),
7254
}
7355
}
7456

@@ -328,13 +310,10 @@ impl ScalarUDFImpl for ToLocalTimeFunc {
328310
}
329311

330312
match &arg_types[0] {
331-
Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)),
332-
Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
333-
Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
334-
Timestamp(Second, _) => Ok(Timestamp(Second, None)),
313+
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
335314
_ => exec_err!(
336315
"The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0]
337-
),
316+
)
338317
}
339318
}
340319

@@ -348,6 +327,30 @@ impl ScalarUDFImpl for ToLocalTimeFunc {
348327

349328
self.to_local_time(args)
350329
}
330+
331+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
332+
if arg_types.len() != 1 {
333+
return plan_err!(
334+
"to_local_time function requires 1 argument, got {:?}",
335+
arg_types.len()
336+
);
337+
}
338+
339+
let first_arg = arg_types[0].clone();
340+
match &first_arg {
341+
Timestamp(Nanosecond, timezone) => {
342+
Ok(vec![Timestamp(Nanosecond, timezone.clone())])
343+
}
344+
Timestamp(Microsecond, timezone) => {
345+
Ok(vec![Timestamp(Microsecond, timezone.clone())])
346+
}
347+
Timestamp(Millisecond, timezone) => {
348+
Ok(vec![Timestamp(Millisecond, timezone.clone())])
349+
}
350+
Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]),
351+
_ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"),
352+
}
353+
}
351354
}
352355

353356
#[cfg(test)]

datafusion/functions/src/encoding/inner.rs

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ use datafusion_expr::ColumnarValue;
3232
use std::sync::Arc;
3333
use std::{fmt, str::FromStr};
3434

35-
use datafusion_expr::TypeSignature::*;
3635
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3736
use std::any::Any;
3837

@@ -49,17 +48,8 @@ impl Default for EncodeFunc {
4948

5049
impl EncodeFunc {
5150
pub fn new() -> Self {
52-
use DataType::*;
5351
Self {
54-
signature: Signature::one_of(
55-
vec![
56-
Exact(vec![Utf8, Utf8]),
57-
Exact(vec![LargeUtf8, Utf8]),
58-
Exact(vec![Binary, Utf8]),
59-
Exact(vec![LargeBinary, Utf8]),
60-
],
61-
Volatility::Immutable,
62-
),
52+
signature: Signature::user_defined(Volatility::Immutable),
6353
}
6454
}
6555
}
@@ -77,23 +67,39 @@ impl ScalarUDFImpl for EncodeFunc {
7767
}
7868

7969
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
80-
use DataType::*;
81-
82-
Ok(match arg_types[0] {
83-
Utf8 => Utf8,
84-
LargeUtf8 => LargeUtf8,
85-
Binary => Utf8,
86-
LargeBinary => LargeUtf8,
87-
Null => Null,
88-
_ => {
89-
return plan_err!("The encode function can only accept utf8 or binary.");
90-
}
91-
})
70+
Ok(arg_types[0].to_owned())
9271
}
9372

9473
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
9574
encode(args)
9675
}
76+
77+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78+
if arg_types.len() != 2 {
79+
return plan_err!(
80+
"{} expects to get 2 arguments, but got {}",
81+
self.name(),
82+
arg_types.len()
83+
);
84+
}
85+
86+
if arg_types[1] != DataType::Utf8 {
87+
return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
88+
}
89+
90+
match arg_types[0] {
91+
DataType::Utf8 | DataType::Binary | DataType::Null => {
92+
Ok(vec![DataType::Utf8; 2])
93+
}
94+
DataType::LargeUtf8 | DataType::LargeBinary => {
95+
Ok(vec![DataType::LargeUtf8, DataType::Utf8])
96+
}
97+
_ => plan_err!(
98+
"1st argument should be Utf8 or Binary or Null, got {:?}",
99+
arg_types[0]
100+
),
101+
}
102+
}
97103
}
98104

99105
#[derive(Debug)]
@@ -109,17 +115,8 @@ impl Default for DecodeFunc {
109115

110116
impl DecodeFunc {
111117
pub fn new() -> Self {
112-
use DataType::*;
113118
Self {
114-
signature: Signature::one_of(
115-
vec![
116-
Exact(vec![Utf8, Utf8]),
117-
Exact(vec![LargeUtf8, Utf8]),
118-
Exact(vec![Binary, Utf8]),
119-
Exact(vec![LargeBinary, Utf8]),
120-
],
121-
Volatility::Immutable,
122-
),
119+
signature: Signature::user_defined(Volatility::Immutable),
123120
}
124121
}
125122
}
@@ -137,23 +134,39 @@ impl ScalarUDFImpl for DecodeFunc {
137134
}
138135

139136
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
140-
use DataType::*;
141-
142-
Ok(match arg_types[0] {
143-
Utf8 => Binary,
144-
LargeUtf8 => LargeBinary,
145-
Binary => Binary,
146-
LargeBinary => LargeBinary,
147-
Null => Null,
148-
_ => {
149-
return plan_err!("The decode function can only accept utf8 or binary.");
150-
}
151-
})
137+
Ok(arg_types[0].to_owned())
152138
}
153139

154140
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
155141
decode(args)
156142
}
143+
144+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
145+
if arg_types.len() != 2 {
146+
return plan_err!(
147+
"{} expects to get 2 arguments, but got {}",
148+
self.name(),
149+
arg_types.len()
150+
);
151+
}
152+
153+
if arg_types[1] != DataType::Utf8 {
154+
return plan_err!("2nd argument should be Utf8");
155+
}
156+
157+
match arg_types[0] {
158+
DataType::Utf8 | DataType::Binary | DataType::Null => {
159+
Ok(vec![DataType::Binary, DataType::Utf8])
160+
}
161+
DataType::LargeUtf8 | DataType::LargeBinary => {
162+
Ok(vec![DataType::LargeBinary, DataType::Utf8])
163+
}
164+
_ => plan_err!(
165+
"1st argument should be Utf8 or Binary or Null, got {:?}",
166+
arg_types[0]
167+
),
168+
}
169+
}
157170
}
158171

159172
#[derive(Debug, Copy, Clone)]

0 commit comments

Comments
 (0)