diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index f82e4bfa1a89..bd313aa9ed48 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -19,15 +19,16 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; +use std::ops::Deref; use std::sync::Arc; make_udf_expr_and_func!( @@ -91,13 +92,22 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [map_type] = take_function_args(self.name(), arg_types)?; - let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new_list_field( - map_fields.last().unwrap().data_type().clone(), - true, - )))) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let [map_type] = take_function_args(self.name(), args.arg_fields)?; + + Ok(Field::new( + self.name(), + DataType::List(get_map_values_field_as_list_field(map_type.data_type())?), + // Nullable if the map is nullable + args.arg_fields.iter().any(|x| x.is_nullable()), + )) } fn invoke_with_args( @@ -121,9 +131,137 @@ fn map_values_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), + get_map_values_field_as_list_field(map_arg.data_type())?, map_array.offsets().clone(), Arc::clone(map_array.values()), map_array.nulls().cloned(), ))) } + +fn get_map_values_field_as_list_field(map_type: &DataType) -> Result { + let map_fields = get_map_entry_field(map_type)?; + + let values_field = map_fields + .last() + .unwrap() + .deref() + .clone() + .with_name(Field::LIST_FIELD_DEFAULT_NAME); + + Ok(Arc::new(values_field)) +} + +#[cfg(test)] +mod tests { + use crate::map_values::MapValuesFunc; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::ScalarValue; + use datafusion_expr::ScalarUDFImpl; + use std::sync::Arc; + + #[test] + fn return_type_field() { + fn get_map_field( + is_map_nullable: bool, + is_keys_nullable: bool, + is_values_nullable: bool, + ) -> Field { + Field::new_map( + "something", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)), + Arc::new(Field::new( + "values", + DataType::LargeUtf8, + is_values_nullable, + )), + false, + is_map_nullable, + ) + } + + fn get_list_field( + name: &str, + is_list_nullable: bool, + list_item_type: DataType, + is_list_items_nullable: bool, + ) -> Field { + Field::new_list( + name, + Arc::new(Field::new_list_field( + list_item_type, + is_list_items_nullable, + )), + is_list_nullable, + ) + } + + fn get_return_field(field: Field) -> Field { + let func = MapValuesFunc::new(); + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[field], + scalar_arguments: &[None::<&ScalarValue>], + }; + + func.return_field_from_args(args).unwrap() + } + + // Test cases: + // + // | Input Map || Expected Output | + // | ------------------------------------------------------ || ----------------------------------------------------- | + // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable | + // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- | + // | false | false | false || false | false | + // | false | false | true || false | true | + // | false | true | false || false | false | + // | false | true | true || false | true | + // | true | false | false || true | false | + // | true | false | true || true | true | + // | true | true | false || true | false | + // | true | true | true || true | true | + // + // --------------- + // We added the key nullability to show that it does not affect the nullability of the list or the list items. + + assert_eq!( + get_return_field(get_map_field(false, false, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, false, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + } +}