From 8d07d08b88f83074743adcafe4e67fd749de28b4 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 15 Feb 2025 21:39:36 -0300 Subject: [PATCH 1/3] feat: add union_tag scalar function --- datafusion/functions/src/core/mod.rs | 7 + datafusion/functions/src/core/union_tag.rs | 223 ++++++++++++++++++ .../test_files/union_function.slt | 15 ++ .../source/user-guide/sql/scalar_functions.md | 28 +++ 4 files changed, 273 insertions(+) create mode 100644 datafusion/functions/src/core/union_tag.rs diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 425ce78decbe..45d848751e94 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -35,6 +35,7 @@ pub mod nvl2; pub mod planner; pub mod r#struct; pub mod union_extract; +pub mod union_tag; pub mod version; // create UDFs @@ -50,6 +51,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); +make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -95,6 +97,10 @@ pub mod expr_fn { least, "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", args, + ),( + union_tag, + "Returns the name of the currently selected field in the union", + arg1 )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -129,6 +135,7 @@ pub fn functions() -> Vec> { greatest(), least(), union_extract(), + union_tag(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs new file mode 100644 index 000000000000..4e605b6031c7 --- /dev/null +++ b/datafusion/functions/src/core/union_tag.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the name of the currently selected field in the union", + syntax_example = "union_tag(union_expression)", + sql_example = r#"```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union") +)] +#[derive(Debug)] +pub struct UnionTagFunc { + signature: Signature, +} + +impl Default for UnionTagFunc { + fn default() -> Self { + Self::new() + } +} + +impl UnionTagFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionTagFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_tag" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + )) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [union_] = take_function_args("union_tag", args.args)?; + + match union_ { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Union(_, _)) => + { + let union_array = array.as_union(); + + let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?; + + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + // Union fields type IDs only constraints are being unique and in the 0..128 range: + // They may not start at 0, be sequential, or even contiguous. + // Therefore, we allocate a values vector with a length equal to the highest type ID plus one, + // ensuring that each field's name can be placed at the index corresponding to its type ID. + let values_len = fields + .iter() + .map(|(type_id, _)| type_id + 1) + .max() + .unwrap_or_default() as usize; + + let mut values = vec![""; values_len]; + + for (type_id, field) in fields.iter() { + values[type_id as usize] = field.name().as_str() + } + + let values = Arc::new(StringArray::from(values)); + + // SAFETY: union type_ids are validated to not be smaller than zero. + // values len is the union biggest type id plus one. + // keys is built from the union type_ids, which contains only valid type ids + // therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs + let dict = unsafe { DictionaryArray::new_unchecked(keys, values) }; + + Ok(ColumnarValue::Array(Arc::new(dict))) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value { + Some((value_type_id, _)) => fields + .iter() + .find(|(type_id, _)| value_type_id == *type_id) + .map(|(_, field)| { + ColumnarValue::Scalar(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(field.name().as_str().into()), + )) + }) + .ok_or_else(|| { + exec_datafusion_err!( + "union_tag: union scalar with unknow type_id {value_type_id}" + ) + }), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + args.return_type, + )?)), + }, + v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use super::UnionTagFunc; + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn union_scalar() { + let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))] + .into_iter() + .collect(); + + let scalar = ScalarValue::Union( + Some((0, Box::new(ScalarValue::UInt32(Some(0))))), + fields, + UnionMode::Dense, + ); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_type: &DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())), + ); + } + + #[test] + fn union_scalar_empty() { + let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_type: &DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(None)), + ), + ); + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 9c70b1011f58..a67992a99683 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -45,3 +45,18 @@ select union_extract(union_column, 1) from union_table; query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; + +query ?T +select union_column, union_tag(union_column) from union_table; +---- +{int=1} int +{int=2} int + +query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments +select union_tag() from union_table; + +query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2 +select union_tag(union_column, 'int') from union_table; + +query error DataFusion error: Execution error: union_tag only support unions, got Utf8 +select union_tag('int') from union_table; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index fb4043c33efc..9becd58ba376 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4344,6 +4344,7 @@ sha512(expression) Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator - [union_extract](#union_extract) +- [union_tag](#union_tag) ### `union_extract` @@ -4373,6 +4374,33 @@ union_extract(union, field_name) +--------------+----------------------------------+----------------------------------+ ``` +### `union_tag` + +Returns the name of the currently selected field in the union + +```sql +union_tag(union_expression) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast) From 47e785a2ef8cf3a57fe19bcf34e6d13f34981bf0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 29 Apr 2025 18:34:22 -0400 Subject: [PATCH 2/3] update for new api --- datafusion/functions/src/core/union_tag.rs | 20 ++++++++++--------- datafusion/sqllogictest/src/test_context.rs | 20 ++++++++++++++++--- .../test_files/union_function.slt | 3 +++ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 4e605b6031c7..2997313f9efe 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -140,7 +140,7 @@ impl ScalarUDFImpl for UnionTagFunc { ) }), None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( - args.return_type, + args.return_field.data_type(), )?)), }, v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), @@ -173,14 +173,15 @@ mod tests { UnionMode::Dense, ); + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], number_rows: 1, - return_type: &DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), + return_field: &Field::new("res", return_type, true), + arg_fields: vec![], }) .unwrap(); @@ -194,14 +195,15 @@ mod tests { fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], number_rows: 1, - return_type: &DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), + return_field: &Field::new("res", return_type, true), + arg_fields: vec![], }) .unwrap(); diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index ce819f186454..6261e6e47fcb 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -410,10 +410,24 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), - ScalarBuffer::from(vec![3, 3]), + UnionFields::new( + // typeids: 3 for int, 1 for string + vec![3, 1], + vec![ + Field::new("int", DataType::Int32, false), + Field::new("string", DataType::Utf8, false), + ], + ), + ScalarBuffer::from(vec![3, 1, 3]), None, - vec![Arc::new(Int32Array::from(vec![1, 2]))], + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("baz"), + ])), + ], ) .unwrap(); diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index a67992a99683..8906b6093fbb 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +# Note: union_table is registered via Rust code in the sqllogictest test harness +# because there is no way to create a union type in SQL today + ########## ## UNION DataType Tests ########## From 680e7fa42edc6a959dd1e8d057cc858ebdf750fc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 29 Apr 2025 18:37:10 -0400 Subject: [PATCH 3/3] Add test for second field type --- datafusion/sqllogictest/test_files/union_function.slt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 8906b6093fbb..74616490ab70 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -26,7 +26,8 @@ query ?I select union_column, union_extract(union_column, 'int') from union_table; ---- {int=1} 1 -{int=2} 2 +{string=bar} NULL +{int=3} 3 query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; @@ -53,7 +54,8 @@ query ?T select union_column, union_tag(union_column) from union_table; ---- {int=1} int -{int=2} int +{string=bar} string +{int=3} int query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments select union_tag() from union_table;