From 8ea5da59af8fa5a958bd94a9c2c9dd9a42f95776 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 19:41:27 +0200 Subject: [PATCH 01/14] start adding least fn --- datafusion/functions/src/core/least.rs | 272 +++++++++++++++++++++++++ datafusion/functions/src/core/mod.rs | 7 + 2 files changed, 279 insertions(+) create mode 100644 datafusion/functions/src/core/least.rs diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs new file mode 100644 index 000000000000..3ea2eadf22ee --- /dev/null +++ b/datafusion/functions/src/core/least.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray}; +use arrow::compute::kernels::cmp; +use arrow::compute::kernels::zip::zip; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use arrow_buffer::BooleanBuffer; +use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::binary::type_union_resolution; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +const SORT_OPTIONS: SortOptions = SortOptions { + // We want greatest first + descending: false, + + // NULL will be less than any other value + nulls_first: true, +}; + +#[derive(Debug)] +pub struct GreatestFunc { + signature: Signature, +} + +impl Default for GreatestFunc { + fn default() -> Self { + GreatestFunc::new() + } +} + +impl GreatestFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +fn get_logical_null_count(arr: &dyn Array) -> usize { + arr.logical_nulls() + .map(|n| n.null_count()) + .unwrap_or_default() +} + +/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array +/// Nulls are always considered smaller than any other value +fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && get_logical_null_count(lhs) == 0 + && get_logical_null_count(rhs) == 0 + { + return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return exec_err!( + "All arrays should have the same length for greatest comparison" + ); + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); + + // No nulls as we only want to keep the values that are larger, its either true or false + Ok(BooleanArray::new(values, None)) +} + +/// Return array where the largest value at each index is kept +fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result { + // True for values that we should keep from the left array + let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?; + + let larger = zip(&keep_lhs, &lhs, &rhs)?; + + Ok(larger) +} + +fn keep_larger_scalar<'a>( + lhs: &'a ScalarValue, + rhs: &'a ScalarValue, +) -> Result<&'a ScalarValue> { + if !lhs.data_type().is_nested() { + return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; + } + + // If complex type we can't compare directly as we want null values to be smaller + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_ge() { + Ok(lhs) + } else { + Ok(rhs) + } +} + +fn find_coerced_type(data_types: &[DataType]) -> Result { + if data_types.is_empty() { + plan_err!("greatest was called without any arguments. It requires at least 1.") + } else if let Some(coerced_type) = type_union_resolution(data_types) { + Ok(coerced_type) + } else { + plan_err!("Cannot find a common type for arguments") + } +} + +impl ScalarUDFImpl for GreatestFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "greatest" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return exec_err!( + "greatest was called with no arguments. It requires at least 1." + ); + } + + // Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop + if args.len() == 1 { + return Ok(args[0].clone()); + } + + // Split to scalars and arrays for later optimization + let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { + ColumnarValue::Scalar(_) => true, + ColumnarValue::Array(_) => false, + }); + + let mut arrays_iter = arrays.iter().map(|x| match x { + ColumnarValue::Array(a) => a, + _ => unreachable!(), + }); + + let first_array = arrays_iter.next(); + + let mut largest: ArrayRef; + + // Optimization: merge all scalars into one to avoid recomputing + if !scalars.is_empty() { + let mut scalars_iter = scalars.iter().map(|x| match x { + ColumnarValue::Scalar(s) => s, + _ => unreachable!(), + }); + + // We have at least one scalar + let mut largest_scalar = scalars_iter.next().unwrap(); + + for scalar in scalars_iter { + largest_scalar = keep_larger_scalar(largest_scalar, scalar)?; + } + + // If we only have scalars, return the largest one + if arrays.is_empty() { + return Ok(ColumnarValue::Scalar(largest_scalar.clone())); + } + + // We have at least one array + let first_array = first_array.unwrap(); + + // Start with the largest value + largest = keep_larger( + Arc::clone(first_array), + largest_scalar.to_array_of_size(first_array.len())?, + )?; + } else { + // If we only have arrays, start with the first array + // (We must have at least one array) + largest = Arc::clone(first_array.unwrap()); + } + + for array in arrays_iter { + largest = keep_larger(Arc::clone(array), largest)?; + } + + Ok(ColumnarValue::Array(largest)) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = find_coerced_type(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_greatest_doc()) + } +} +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_greatest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder( + DOC_SECTION_CONDITIONAL, + "Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.", + "greatest(expression1[, ..., expression_n])") + .with_sql_example(r#"```sql +> select greatest(4, 7, 5); ++---------------------------+ +| greatest(4,7,5) | ++---------------------------+ +| 7 | ++---------------------------+ +```"#, + ) + .with_argument( + "expression1, expression_n", + "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) + .build() + }) +} + +#[cfg(test)] +mod test { + use crate::core; + use arrow::datatypes::DataType; + use datafusion_expr::ScalarUDFImpl; + + #[test] + fn test_greatest_return_types_without_common_supertype_in_arg_type() { + let greatest = core::greatest::GreatestFunc::new(); + let return_type = greatest + .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) + .unwrap(); + assert_eq!( + return_type, + vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)] + ); + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index bd8305cd56d8..12a55590ce97 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -33,6 +33,7 @@ pub mod nvl2; pub mod planner; pub mod r#struct; pub mod version; +pub mod least; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); @@ -45,6 +46,7 @@ make_udf_function!(named_struct::NamedStructFunc, named_struct); make_udf_function!(getfield::GetFieldFunc, get_field); make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); +make_udf_function!(least::GreatestFunc, least); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -86,6 +88,10 @@ pub mod expr_fn { greatest, "Returns `greatest(args...)`, which evaluates to the greatest value in the list of expressions or NULL if all the expressions are NULL", args, + ),( + least, + "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", + args, )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -113,6 +119,7 @@ pub fn functions() -> Vec> { get_field(), coalesce(), greatest(), + least(), version(), r#struct(), ] From 3c6e133c35ad015f5e5c6cc1a7ef334b97b20e5d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 19:46:52 +0200 Subject: [PATCH 02/14] feat(function): add least function --- datafusion/functions/src/core/greatest.rs | 2 +- datafusion/functions/src/core/least.rs | 106 +++++++++++----------- datafusion/functions/src/core/mod.rs | 2 +- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 3ea2eadf22ee..54248c311149 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -177,7 +177,7 @@ impl ScalarUDFImpl for GreatestFunc { let mut largest: ArrayRef; - // Optimization: merge all scalars into one to avoid recomputing + // Optimization: merge all scalars into one to avoid recomputing (constant folding) if !scalars.is_empty() { let mut scalars_iter = scalars.iter().map(|x| match x { ColumnarValue::Scalar(s) => s, diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 3ea2eadf22ee..4cb8a92a5c76 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -31,25 +31,25 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; const SORT_OPTIONS: SortOptions = SortOptions { - // We want greatest first - descending: false, + // We want least first + descending: true, - // NULL will be less than any other value - nulls_first: true, + // NULL will be greater than any other value + nulls_first: false, }; #[derive(Debug)] -pub struct GreatestFunc { +pub struct LeastFunc { signature: Signature, } -impl Default for GreatestFunc { +impl Default for LeastFunc { fn default() -> Self { - GreatestFunc::new() + LeastFunc::new() } } -impl GreatestFunc { +impl LeastFunc { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), @@ -63,60 +63,60 @@ fn get_logical_null_count(arr: &dyn Array) -> usize { .unwrap_or_default() } -/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array -/// Nulls are always considered smaller than any other value -fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result { +/// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array +/// Nulls are always considered larger than any other value +fn get_smallest(lhs: &dyn Array, rhs: &dyn Array) -> Result { // Fast path: // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. - // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + // - both array does not have any nulls: cmp::lt_eq will return null if any of the input is null while we want to return false in that case if !lhs.data_type().is_nested() && get_logical_null_count(lhs) == 0 && get_logical_null_count(rhs) == 0 { - return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into()); } let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; if lhs.len() != rhs.len() { return exec_err!( - "All arrays should have the same length for greatest comparison" + "All arrays should have the same length for least comparison" ); } - let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); - // No nulls as we only want to keep the values that are larger, its either true or false + // No nulls as we only want to keep the values that are smaller, its either true or false Ok(BooleanArray::new(values, None)) } -/// Return array where the largest value at each index is kept -fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result { +/// Return array where the smallest value at each index is kept +fn keep_smallest(lhs: ArrayRef, rhs: ArrayRef) -> Result { // True for values that we should keep from the left array - let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?; + let keep_lhs = get_smallest(lhs.as_ref(), rhs.as_ref())?; - let larger = zip(&keep_lhs, &lhs, &rhs)?; + let smaller = zip(&keep_lhs, &lhs, &rhs)?; - Ok(larger) + Ok(smaller) } -fn keep_larger_scalar<'a>( +fn keep_smaller_scalar<'a>( lhs: &'a ScalarValue, rhs: &'a ScalarValue, ) -> Result<&'a ScalarValue> { if !lhs.data_type().is_nested() { - return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; + return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; } - // If complex type we can't compare directly as we want null values to be smaller + // If complex type we can't compare directly as we want null values to be larger let cmp = make_comparator( lhs.to_array()?.as_ref(), rhs.to_array()?.as_ref(), SORT_OPTIONS, )?; - if cmp(0, 0).is_ge() { + if cmp(0, 0).is_le() { Ok(lhs) } else { Ok(rhs) @@ -125,7 +125,7 @@ fn keep_larger_scalar<'a>( fn find_coerced_type(data_types: &[DataType]) -> Result { if data_types.is_empty() { - plan_err!("greatest was called without any arguments. It requires at least 1.") + plan_err!("least was called without any arguments. It requires at least 1.") } else if let Some(coerced_type) = type_union_resolution(data_types) { Ok(coerced_type) } else { @@ -133,7 +133,7 @@ fn find_coerced_type(data_types: &[DataType]) -> Result { } } -impl ScalarUDFImpl for GreatestFunc { +impl ScalarUDFImpl for LeastFunc { fn as_any(&self) -> &dyn Any { self } @@ -153,11 +153,11 @@ impl ScalarUDFImpl for GreatestFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { if args.is_empty() { return exec_err!( - "greatest was called with no arguments. It requires at least 1." + "least was called with no arguments. It requires at least 1." ); } - // Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop + // Some engines (e.g. SQL Server) allow least with single arg, it's a noop if args.len() == 1 { return Ok(args[0].clone()); } @@ -175,9 +175,9 @@ impl ScalarUDFImpl for GreatestFunc { let first_array = arrays_iter.next(); - let mut largest: ArrayRef; + let mut smallest: ArrayRef; - // Optimization: merge all scalars into one to avoid recomputing + // Optimization: merge all scalars into one to avoid recomputing (constant folding) if !scalars.is_empty() { let mut scalars_iter = scalars.iter().map(|x| match x { ColumnarValue::Scalar(s) => s, @@ -185,36 +185,36 @@ impl ScalarUDFImpl for GreatestFunc { }); // We have at least one scalar - let mut largest_scalar = scalars_iter.next().unwrap(); + let mut smallest_scalar = scalars_iter.next().unwrap(); for scalar in scalars_iter { - largest_scalar = keep_larger_scalar(largest_scalar, scalar)?; + smallest_scalar = keep_smaller_scalar(smallest_scalar, scalar)?; } - // If we only have scalars, return the largest one + // If we only have scalars, return the smaller one if arrays.is_empty() { - return Ok(ColumnarValue::Scalar(largest_scalar.clone())); + return Ok(ColumnarValue::Scalar(smallest_scalar.clone())); } // We have at least one array let first_array = first_array.unwrap(); - // Start with the largest value - largest = keep_larger( + // Start with the smaller value + smallest = keep_smallest( Arc::clone(first_array), - largest_scalar.to_array_of_size(first_array.len())?, + smallest_scalar.to_array_of_size(first_array.len())?, )?; } else { // If we only have arrays, start with the first array // (We must have at least one array) - largest = Arc::clone(first_array.unwrap()); + smallest = Arc::clone(first_array.unwrap()); } for array in arrays_iter { - largest = keep_larger(Arc::clone(array), largest)?; + smallest = keep_smallest(Arc::clone(array), smallest)?; } - Ok(ColumnarValue::Array(largest)) + Ok(ColumnarValue::Array(smallest)) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { @@ -224,29 +224,29 @@ impl ScalarUDFImpl for GreatestFunc { } fn documentation(&self) -> Option<&Documentation> { - Some(get_greatest_doc()) + Some(get_smallest_doc()) } } static DOCUMENTATION: OnceLock = OnceLock::new(); -fn get_greatest_doc() -> &'static Documentation { +fn get_smallest_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder( DOC_SECTION_CONDITIONAL, - "Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.", - "greatest(expression1[, ..., expression_n])") + "Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_.", + "least(expression1[, ..., expression_n])") .with_sql_example(r#"```sql -> select greatest(4, 7, 5); +> select least(4, 7, 5); +---------------------------+ -| greatest(4,7,5) | +| least(4,7,5) | +---------------------------+ -| 7 | +| 4 | +---------------------------+ ```"#, ) .with_argument( "expression1, expression_n", - "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + "Expressions to compare and return the smallest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." ) .build() }) @@ -254,14 +254,14 @@ fn get_greatest_doc() -> &'static Documentation { #[cfg(test)] mod test { - use crate::core; use arrow::datatypes::DataType; use datafusion_expr::ScalarUDFImpl; + use crate::core::least::LeastFunc; #[test] - fn test_greatest_return_types_without_common_supertype_in_arg_type() { - let greatest = core::greatest::GreatestFunc::new(); - let return_type = greatest + fn test_least_return_types_without_common_supertype_in_arg_type() { + let least = LeastFunc::new(); + let return_type = least .coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)]) .unwrap(); assert_eq!( diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 12a55590ce97..8c84a003ff33 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -46,7 +46,7 @@ make_udf_function!(named_struct::NamedStructFunc, named_struct); make_udf_function!(getfield::GetFieldFunc, get_field); make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); -make_udf_function!(least::GreatestFunc, least); +make_udf_function!(least::LeastFunc, least); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { From 991b17a64e62fdd61b2af3b57a77192c7cc8e990 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 19:59:46 +0200 Subject: [PATCH 03/14] update function name --- datafusion/functions/src/core/least.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 4cb8a92a5c76..3d1ae18f2c0c 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -139,7 +139,7 @@ impl ScalarUDFImpl for LeastFunc { } fn name(&self) -> &str { - "greatest" + "least" } fn signature(&self) -> &Signature { From e1bf7ae6fe2eb09d0ae0ef8c41dace855997d4a3 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:17:57 +0200 Subject: [PATCH 04/14] fix scalar smaller function --- datafusion/functions/src/core/least.rs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 3d1ae18f2c0c..ed1d719476a1 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -32,7 +32,7 @@ use std::sync::{Arc, OnceLock}; const SORT_OPTIONS: SortOptions = SortOptions { // We want least first - descending: true, + descending: false, // NULL will be greater than any other value nulls_first: false, @@ -105,10 +105,23 @@ fn keep_smaller_scalar<'a>( lhs: &'a ScalarValue, rhs: &'a ScalarValue, ) -> Result<&'a ScalarValue> { + // Manual checking for nulls as: + // 1. If we're going to use <=, in Rust None is smaller than Some(T), which we don't want + // 2. And we can't use make_comparator as it has no natural order (Arrow error) + if lhs.is_null() { + return Ok(rhs); + } + + if rhs.is_null() { + return Ok(lhs); + } + if !lhs.data_type().is_nested() { return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; } + // Not using <= as in Rust None is smaller than Some(T) + // If complex type we can't compare directly as we want null values to be larger let cmp = make_comparator( lhs.to_array()?.as_ref(), @@ -162,7 +175,7 @@ impl ScalarUDFImpl for LeastFunc { return Ok(args[0].clone()); } - // Split to scalars and arrays for later optimization + // Split to scalars and arrays for later optimization (constant folding) let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { ColumnarValue::Scalar(_) => true, ColumnarValue::Array(_) => false, From d96f5461e38e7d59b92b6c58dede12b67329dcdc Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:18:11 +0200 Subject: [PATCH 05/14] add tests --- .../sqllogictest/test_files/functions.slt | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 4b770a19fe20..4213de0235e4 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -955,3 +955,201 @@ Infinity statement ok drop table t1 + +# test for least +statement ok +CREATE TABLE t1 (a int, b int, c int) as VALUES +(4, NULL, NULL), +(1, 2, 3), +(3, 1, 2), +(1, NULL, -1), +(NULL, NULL, NULL), +(3, 0, -1); + +query I +SELECT least(a, b, c) FROM t1 +---- +4 +1 +1 +-1 +NULL +-1 + +statement ok +drop table t1 + +query I +SELECT least(1) +---- +1 + +query I +SELECT least(1, 2) +---- +1 + +query I +SELECT least(3, 1) +---- +1 + +query ? +SELECT least(NULL) +---- +NULL + +query I +SELECT least(1, NULL, -1) +---- +-1 + +query I +SELECT least((3), (0), (-1)); +---- +-1 + +query ? +SELECT least([4, 3], [4, 2], [4, 4]); +---- +[4, 2] + +query ? +SELECT least([2, 3], [1, 4], [5, 0]); +---- +[1, 4] + +query I +SELECT least(1::int, 2::text) +---- +1 + +query R +SELECT least(-1, 1, 2.3, 123456789, 3 + 5, -(-4)) +---- +-1 + +query R +SELECT least(-1.123, 1.21313, 2.3, 123456789.321, 3 + 5.3213, -(-4.3213), abs(-9)) +---- +-1.123 + +query R +SELECT least(-1, 1, 2.3, 123456789, 3 + 5, -(-4), abs(-9.0)) +---- +-1 + + +query error least does not support zero arguments +SELECT least() + +query I +SELECT least(4, 5, 7, 1, 2) +---- +1 + +query I +SELECT least(4, NULL, 7, 1, 2) +---- +1 + +query I +SELECT least(NULL, NULL, 7, NULL, 2) +---- +2 + +query I +SELECT least(NULL, NULL, NULL, NULL, 2) +---- +2 + +query I +SELECT least(2, NULL, NULL, NULL, NULL) +---- +2 + +query ? +SELECT least(NULL, NULL, NULL) +---- +NULL + +query I +SELECT least(2, '4') +---- +2 + +query T +SELECT least('foo', 'bar', 'foobar') +---- +bar + +query R +SELECT least(1, 1.2) +---- +1 + +statement ok +CREATE TABLE foo (a int) + +statement ok +INSERT INTO foo (a) VALUES (1) + +# Test homogenous functions that can't be constant folded. +query I +SELECT least(NULL, a, 5, NULL) FROM foo +---- +1 + +query I +SELECT least(NULL, NULL, NULL, a, -1) FROM foo +---- +-1 + +statement ok +drop table foo + +query R +select least(arrow_cast('NAN','Float64'), arrow_cast('NAN','Float64')) +---- +NaN + +query R +select least(arrow_cast('NAN','Float64'), arrow_cast('NAN','Float32')) +---- +NaN + +query R +select least(arrow_cast('NAN','Float64'), '+Inf'::Double) +---- +Infinity + +query R +select least(arrow_cast('NAN','Float64'), NULL) +---- +NaN + +query R +select least(NULL, '+Inf'::Double) +---- +Infinity + +query R +select least(NULL, '-Inf'::Double) +---- +-Infinity + +statement ok +CREATE TABLE t1 (a double, b double, c double) as VALUES +(1, arrow_cast('NAN', 'Float64'), '+Inf'::Double), +(NULL, arrow_cast('NAN','Float64'), '+Inf'::Double), +(1, '+Inf'::Double, NULL); + +query R +SELECT least(a, b, c) FROM t1 +---- +1 +Infinity +1 + +statement ok +drop table t1 From bfa1abc323717f1f55e39d839295ae629706838b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:19:38 +0200 Subject: [PATCH 06/14] run Clippy and Fmt --- datafusion/functions/src/core/least.rs | 6 ++---- datafusion/functions/src/core/mod.rs | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index ed1d719476a1..f17eb0462234 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -80,9 +80,7 @@ fn get_smallest(lhs: &dyn Array, rhs: &dyn Array) -> Result { let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; if lhs.len() != rhs.len() { - return exec_err!( - "All arrays should have the same length for least comparison" - ); + return exec_err!("All arrays should have the same length for least comparison"); } let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); @@ -267,9 +265,9 @@ fn get_smallest_doc() -> &'static Documentation { #[cfg(test)] mod test { + use crate::core::least::LeastFunc; use arrow::datatypes::DataType; use datafusion_expr::ScalarUDFImpl; - use crate::core::least::LeastFunc; #[test] fn test_least_return_types_without_common_supertype_in_arg_type() { diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 8c84a003ff33..ee4fc22ff48b 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -26,6 +26,7 @@ pub mod coalesce; pub mod expr_ext; pub mod getfield; pub mod greatest; +pub mod least; pub mod named_struct; pub mod nullif; pub mod nvl; @@ -33,7 +34,6 @@ pub mod nvl2; pub mod planner; pub mod r#struct; pub mod version; -pub mod least; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); From 25af65dfb314ea79c4677745e0fc1e9f306baeb8 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:28:59 +0200 Subject: [PATCH 07/14] Generated docs using `./dev/update_function_docs.sh` --- .../source/user-guide/sql/scalar_functions.md | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 4e74cfc54ae5..426835ed819d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -549,6 +549,7 @@ trunc(numeric_expression[, decimal_places]) - [coalesce](#coalesce) - [greatest](#greatest) - [ifnull](#ifnull) +- [least](#least) - [nullif](#nullif) - [nvl](#nvl) - [nvl2](#nvl2) @@ -603,6 +604,29 @@ greatest(expression1[, ..., expression_n]) _Alias of [nvl](#nvl)._ +### `least` + +Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_. + +``` +least(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expressions to compare and return the smallest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. + +#### Example + +```sql +> select least(4, 7, 5); ++---------------------------+ +| least(4,7,5) | ++---------------------------+ +| 4 | ++---------------------------+ +``` + ### `nullif` Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. From 3ad250c562d1ed1ee641b8e13be094cec3eafb8d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:36:49 +0200 Subject: [PATCH 08/14] add comment why `descending: false` --- datafusion/functions/src/core/least.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index f17eb0462234..aa207247d050 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; const SORT_OPTIONS: SortOptions = SortOptions { - // We want least first + // Decreasing here as we will use lower than or equal to find the smallest value descending: false, // NULL will be greater than any other value From 412dec6677029ceb5ae16d85a4cb63ba31a825c1 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 14 Dec 2024 20:40:10 +0200 Subject: [PATCH 09/14] update comment --- datafusion/functions/src/core/least.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index aa207247d050..6b6e99cc9dca 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; const SORT_OPTIONS: SortOptions = SortOptions { - // Decreasing here as we will use lower than or equal to find the smallest value + // Having the smallest result first descending: false, // NULL will be greater than any other value From 13862e420dfac124ccba62d6a6ab7e3424e70c43 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 15 Dec 2024 01:55:48 +0200 Subject: [PATCH 10/14] Update least.rs Co-authored-by: Bruce Ritchie --- datafusion/functions/src/core/least.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 6b6e99cc9dca..ebba6ee38c5e 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -257,7 +257,7 @@ fn get_smallest_doc() -> &'static Documentation { ) .with_argument( "expression1, expression_n", - "Expressions to compare and return the smallest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." ) .build() }) From 513348aed1869edce254236d1e524b36f2f20a1d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 15 Dec 2024 13:22:59 +0200 Subject: [PATCH 11/14] Update scalar_functions.md --- docs/source/user-guide/sql/scalar_functions.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 426835ed819d..9111b550d7b9 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -549,7 +549,7 @@ trunc(numeric_expression[, decimal_places]) - [coalesce](#coalesce) - [greatest](#greatest) - [ifnull](#ifnull) -- [least](#least) +- [](#) - [nullif](#nullif) - [nvl](#nvl) - [nvl2](#nvl2) @@ -604,7 +604,7 @@ greatest(expression1[, ..., expression_n]) _Alias of [nvl](#nvl)._ -### `least` +### `` Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_. @@ -614,7 +614,7 @@ least(expression1[, ..., expression_n]) #### Arguments -- **expression1, expression_n**: Expressions to compare and return the smallest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. +- **expression1, expression_n**: Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. #### Example From a28d425e7d4f6e89aec0cbcbcbea15ac9eba1b89 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 19 Dec 2024 07:38:41 -0500 Subject: [PATCH 12/14] run ./dev/update_function_docs.sh to update docs --- docs/source/user-guide/sql/scalar_functions.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9111b550d7b9..2e4147f96e0f 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -549,7 +549,7 @@ trunc(numeric_expression[, decimal_places]) - [coalesce](#coalesce) - [greatest](#greatest) - [ifnull](#ifnull) -- [](#) +- [least](#least) - [nullif](#nullif) - [nvl](#nvl) - [nvl2](#nvl2) @@ -604,7 +604,7 @@ greatest(expression1[, ..., expression_n]) _Alias of [nvl](#nvl)._ -### `` +### `least` Returns the smallest value in a list of expressions. Returns _null_ if all expressions are _null_. From b5e2ac0704429abdcc2a85ef905704625a9ac4c0 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:45:33 +0200 Subject: [PATCH 13/14] merge greatest and least implementation to one --- datafusion/functions/src/core/greatest.rs | 183 +++++----------- .../src/core/greatest_least_utils.rs | 116 ++++++++++ datafusion/functions/src/core/least.rs | 203 ++++++------------ datafusion/functions/src/core/mod.rs | 1 + 4 files changed, 226 insertions(+), 277 deletions(-) create mode 100644 datafusion/functions/src/core/greatest_least_utils.rs diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 54248c311149..e91ec2b0c4d8 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray}; +use crate::core::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; use arrow::compute::kernels::cmp; -use arrow::compute::kernels::zip::zip; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use arrow_buffer::BooleanBuffer; -use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; -use datafusion_expr::binary::type_union_resolution; use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::OnceLock; const SORT_OPTIONS: SortOptions = SortOptions { // We want greatest first @@ -57,79 +56,57 @@ impl GreatestFunc { } } -fn get_logical_null_count(arr: &dyn Array) -> usize { - arr.logical_nulls() - .map(|n| n.null_count()) - .unwrap_or_default() -} +impl GreatestLeastOperator for GreatestFunc { + const NAME: &'static str = "greatest"; -/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array -/// Nulls are always considered smaller than any other value -fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result { - // Fast path: - // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel - // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. - // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case - if !lhs.data_type().is_nested() - && get_logical_null_count(lhs) == 0 - && get_logical_null_count(rhs) == 0 - { - return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); - } + fn keep_scalar<'a>( + lhs: &'a ScalarValue, + rhs: &'a ScalarValue, + ) -> Result<&'a ScalarValue> { + if !lhs.data_type().is_nested() { + return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; + } - let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + // If complex type we can't compare directly as we want null values to be smaller + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; - if lhs.len() != rhs.len() { - return exec_err!( - "All arrays should have the same length for greatest comparison" - ); + if cmp(0, 0).is_ge() { + Ok(lhs) + } else { + Ok(rhs) + } } - let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); - - // No nulls as we only want to keep the values that are larger, its either true or false - Ok(BooleanArray::new(values, None)) -} - -/// Return array where the largest value at each index is kept -fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result { - // True for values that we should keep from the left array - let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?; - - let larger = zip(&keep_lhs, &lhs, &rhs)?; + /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered smaller than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + } - Ok(larger) -} + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; -fn keep_larger_scalar<'a>( - lhs: &'a ScalarValue, - rhs: &'a ScalarValue, -) -> Result<&'a ScalarValue> { - if !lhs.data_type().is_nested() { - return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; - } - - // If complex type we can't compare directly as we want null values to be smaller - let cmp = make_comparator( - lhs.to_array()?.as_ref(), - rhs.to_array()?.as_ref(), - SORT_OPTIONS, - )?; + if lhs.len() != rhs.len() { + return internal_err!( + "All arrays should have the same length for greatest comparison" + ); + } - if cmp(0, 0).is_ge() { - Ok(lhs) - } else { - Ok(rhs) - } -} + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); -fn find_coerced_type(data_types: &[DataType]) -> Result { - if data_types.is_empty() { - plan_err!("greatest was called without any arguments. It requires at least 1.") - } else if let Some(coerced_type) = type_union_resolution(data_types) { - Ok(coerced_type) - } else { - plan_err!("Cannot find a common type for arguments") + // No nulls as we only want to keep the values that are larger, its either true or false + Ok(BooleanArray::new(values, None)) } } @@ -151,74 +128,12 @@ impl ScalarUDFImpl for GreatestFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.is_empty() { - return exec_err!( - "greatest was called with no arguments. It requires at least 1." - ); - } - - // Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop - if args.len() == 1 { - return Ok(args[0].clone()); - } - - // Split to scalars and arrays for later optimization - let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { - ColumnarValue::Scalar(_) => true, - ColumnarValue::Array(_) => false, - }); - - let mut arrays_iter = arrays.iter().map(|x| match x { - ColumnarValue::Array(a) => a, - _ => unreachable!(), - }); - - let first_array = arrays_iter.next(); - - let mut largest: ArrayRef; - - // Optimization: merge all scalars into one to avoid recomputing (constant folding) - if !scalars.is_empty() { - let mut scalars_iter = scalars.iter().map(|x| match x { - ColumnarValue::Scalar(s) => s, - _ => unreachable!(), - }); - - // We have at least one scalar - let mut largest_scalar = scalars_iter.next().unwrap(); - - for scalar in scalars_iter { - largest_scalar = keep_larger_scalar(largest_scalar, scalar)?; - } - - // If we only have scalars, return the largest one - if arrays.is_empty() { - return Ok(ColumnarValue::Scalar(largest_scalar.clone())); - } - - // We have at least one array - let first_array = first_array.unwrap(); - - // Start with the largest value - largest = keep_larger( - Arc::clone(first_array), - largest_scalar.to_array_of_size(first_array.len())?, - )?; - } else { - // If we only have arrays, start with the first array - // (We must have at least one array) - largest = Arc::clone(first_array.unwrap()); - } - - for array in arrays_iter { - largest = keep_larger(Arc::clone(array), largest)?; - } - - Ok(ColumnarValue::Array(largest)) + super::greatest_least_utils::execute_conditional::(args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let coerced_type = find_coerced_type(arg_types)?; + let coerced_type = + super::greatest_least_utils::find_coerced_type::(arg_types)?; Ok(vec![coerced_type; arg_types.len()]) } diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs new file mode 100644 index 000000000000..c051f425adbd --- /dev/null +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -0,0 +1,116 @@ +use arrow::array::{Array, ArrayRef, BooleanArray}; +use arrow::compute::kernels::zip::zip; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use std::sync::Arc; + +pub(super) trait GreatestLeastOperator { + const NAME: &'static str; + + fn keep_scalar<'a>( + lhs: &'a ScalarValue, + rhs: &'a ScalarValue, + ) -> Result<&'a ScalarValue>; + + /// Return array with true for values that we should keep from the lhs array + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result; +} + +fn keep_array( + lhs: ArrayRef, + rhs: ArrayRef, +) -> Result { + // True for values that we should keep from the left array + let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?; + + let result = zip(&keep_lhs, &lhs, &rhs)?; + + Ok(result) +} + +pub(super) fn execute_conditional( + args: &[ColumnarValue], +) -> Result { + if args.is_empty() { + return internal_err!( + "{} was called with no arguments. It requires at least 1.", + Op::NAME + ); + } + + // Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop + if args.len() == 1 { + return Ok(args[0].clone()); + } + + // Split to scalars and arrays for later optimization + let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { + ColumnarValue::Scalar(_) => true, + ColumnarValue::Array(_) => false, + }); + + let mut arrays_iter = arrays.iter().map(|x| match x { + ColumnarValue::Array(a) => a, + _ => unreachable!(), + }); + + let first_array = arrays_iter.next(); + + let mut result: ArrayRef; + + // Optimization: merge all scalars into one to avoid recomputing (constant folding) + if !scalars.is_empty() { + let mut scalars_iter = scalars.iter().map(|x| match x { + ColumnarValue::Scalar(s) => s, + _ => unreachable!(), + }); + + // We have at least one scalar + let mut result_scalar = scalars_iter.next().unwrap(); + + for scalar in scalars_iter { + result_scalar = Op::keep_scalar(result_scalar, scalar)?; + } + + // If we only have scalars, return the one that we should keep (largest/least) + if arrays.is_empty() { + return Ok(ColumnarValue::Scalar(result_scalar.clone())); + } + + // We have at least one array + let first_array = first_array.unwrap(); + + // Start with the result value + result = keep_array::( + Arc::clone(first_array), + result_scalar.to_array_of_size(first_array.len())?, + )?; + } else { + // If we only have arrays, start with the first array + // (We must have at least one array) + result = Arc::clone(first_array.unwrap()); + } + + for array in arrays_iter { + result = keep_array::(Arc::clone(array), result)?; + } + + Ok(ColumnarValue::Array(result)) +} + +pub(super) fn find_coerced_type( + data_types: &[DataType], +) -> Result { + if data_types.is_empty() { + plan_err!( + "{} was called without any arguments. It requires at least 1.", + Op::NAME + ) + } else if let Some(coerced_type) = type_union_resolution(data_types) { + Ok(coerced_type) + } else { + plan_err!("Cannot find a common type for arguments") + } +} diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index ebba6ee38c5e..b9ea65cdb732 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray}; +use crate::core::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; use arrow::compute::kernels::cmp; -use arrow::compute::kernels::zip::zip; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use arrow_buffer::BooleanBuffer; -use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; -use datafusion_expr::binary::type_union_resolution; use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::OnceLock; const SORT_OPTIONS: SortOptions = SortOptions { // Having the smallest result first @@ -57,90 +56,70 @@ impl LeastFunc { } } -fn get_logical_null_count(arr: &dyn Array) -> usize { - arr.logical_nulls() - .map(|n| n.null_count()) - .unwrap_or_default() -} - -/// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array -/// Nulls are always considered larger than any other value -fn get_smallest(lhs: &dyn Array, rhs: &dyn Array) -> Result { - // Fast path: - // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel - // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. - // - both array does not have any nulls: cmp::lt_eq will return null if any of the input is null while we want to return false in that case - if !lhs.data_type().is_nested() - && get_logical_null_count(lhs) == 0 - && get_logical_null_count(rhs) == 0 - { - return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into()); - } - - let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; - - if lhs.len() != rhs.len() { - return exec_err!("All arrays should have the same length for least comparison"); - } +impl GreatestLeastOperator for LeastFunc { + const NAME: &'static str = "least"; + + fn keep_scalar<'a>( + lhs: &'a ScalarValue, + rhs: &'a ScalarValue, + ) -> Result<&'a ScalarValue> { + // Manual checking for nulls as: + // 1. If we're going to use <=, in Rust None is smaller than Some(T), which we don't want + // 2. And we can't use make_comparator as it has no natural order (Arrow error) + if lhs.is_null() { + return Ok(rhs); + } - let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); + if rhs.is_null() { + return Ok(lhs); + } - // No nulls as we only want to keep the values that are smaller, its either true or false - Ok(BooleanArray::new(values, None)) -} + if !lhs.data_type().is_nested() { + return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; + } -/// Return array where the smallest value at each index is kept -fn keep_smallest(lhs: ArrayRef, rhs: ArrayRef) -> Result { - // True for values that we should keep from the left array - let keep_lhs = get_smallest(lhs.as_ref(), rhs.as_ref())?; + // Not using <= as in Rust None is smaller than Some(T) - let smaller = zip(&keep_lhs, &lhs, &rhs)?; + // If complex type we can't compare directly as we want null values to be larger + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; - Ok(smaller) -} - -fn keep_smaller_scalar<'a>( - lhs: &'a ScalarValue, - rhs: &'a ScalarValue, -) -> Result<&'a ScalarValue> { - // Manual checking for nulls as: - // 1. If we're going to use <=, in Rust None is smaller than Some(T), which we don't want - // 2. And we can't use make_comparator as it has no natural order (Arrow error) - if lhs.is_null() { - return Ok(rhs); + if cmp(0, 0).is_le() { + Ok(lhs) + } else { + Ok(rhs) + } } - if rhs.is_null() { - return Ok(lhs); - } + /// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered larger than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::lt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into()); + } - if !lhs.data_type().is_nested() { - return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; - } + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; - // Not using <= as in Rust None is smaller than Some(T) + if lhs.len() != rhs.len() { + return internal_err!( + "All arrays should have the same length for least comparison" + ); + } - // If complex type we can't compare directly as we want null values to be larger - let cmp = make_comparator( - lhs.to_array()?.as_ref(), - rhs.to_array()?.as_ref(), - SORT_OPTIONS, - )?; + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); - if cmp(0, 0).is_le() { - Ok(lhs) - } else { - Ok(rhs) - } -} - -fn find_coerced_type(data_types: &[DataType]) -> Result { - if data_types.is_empty() { - plan_err!("least was called without any arguments. It requires at least 1.") - } else if let Some(coerced_type) = type_union_resolution(data_types) { - Ok(coerced_type) - } else { - plan_err!("Cannot find a common type for arguments") + // No nulls as we only want to keep the values that are smaller, its either true or false + Ok(BooleanArray::new(values, None)) } } @@ -162,74 +141,12 @@ impl ScalarUDFImpl for LeastFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.is_empty() { - return exec_err!( - "least was called with no arguments. It requires at least 1." - ); - } - - // Some engines (e.g. SQL Server) allow least with single arg, it's a noop - if args.len() == 1 { - return Ok(args[0].clone()); - } - - // Split to scalars and arrays for later optimization (constant folding) - let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { - ColumnarValue::Scalar(_) => true, - ColumnarValue::Array(_) => false, - }); - - let mut arrays_iter = arrays.iter().map(|x| match x { - ColumnarValue::Array(a) => a, - _ => unreachable!(), - }); - - let first_array = arrays_iter.next(); - - let mut smallest: ArrayRef; - - // Optimization: merge all scalars into one to avoid recomputing (constant folding) - if !scalars.is_empty() { - let mut scalars_iter = scalars.iter().map(|x| match x { - ColumnarValue::Scalar(s) => s, - _ => unreachable!(), - }); - - // We have at least one scalar - let mut smallest_scalar = scalars_iter.next().unwrap(); - - for scalar in scalars_iter { - smallest_scalar = keep_smaller_scalar(smallest_scalar, scalar)?; - } - - // If we only have scalars, return the smaller one - if arrays.is_empty() { - return Ok(ColumnarValue::Scalar(smallest_scalar.clone())); - } - - // We have at least one array - let first_array = first_array.unwrap(); - - // Start with the smaller value - smallest = keep_smallest( - Arc::clone(first_array), - smallest_scalar.to_array_of_size(first_array.len())?, - )?; - } else { - // If we only have arrays, start with the first array - // (We must have at least one array) - smallest = Arc::clone(first_array.unwrap()); - } - - for array in arrays_iter { - smallest = keep_smallest(Arc::clone(array), smallest)?; - } - - Ok(ColumnarValue::Array(smallest)) + super::greatest_least_utils::execute_conditional::(args) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let coerced_type = find_coerced_type(arg_types)?; + let coerced_type = + super::greatest_least_utils::find_coerced_type::(arg_types)?; Ok(vec![coerced_type; arg_types.len()]) } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index ee4fc22ff48b..ba8255d2e472 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -26,6 +26,7 @@ pub mod coalesce; pub mod expr_ext; pub mod getfield; pub mod greatest; +mod greatest_least_utils; pub mod least; pub mod named_struct; pub mod nullif; From 881c54ed3388616b6866b6698fde94e206e01a40 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 20 Dec 2024 09:42:36 -0500 Subject: [PATCH 14/14] add header --- .../functions/src/core/greatest_least_utils.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs index c051f425adbd..46b3645e703a 100644 --- a/datafusion/functions/src/core/greatest_least_utils.rs +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType;