Skip to content

Commit 026a2b1

Browse files
rluvatonalamb
andauthored
fix: correctly specify the nullability of map_values return type (#15901)
Co-authored-by: Andrew Lamb <[email protected]>
1 parent e1cc80c commit 026a2b1

File tree

1 file changed

+148
-10
lines changed

1 file changed

+148
-10
lines changed

datafusion/functions-nested/src/map_values.rs

Lines changed: 148 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
2020
use crate::utils::{get_map_entry_field, make_scalar_function};
2121
use arrow::array::{Array, ArrayRef, ListArray};
22-
use arrow::datatypes::{DataType, Field};
22+
use arrow::datatypes::{DataType, Field, FieldRef};
2323
use datafusion_common::utils::take_function_args;
24-
use datafusion_common::{cast::as_map_array, exec_err, Result};
24+
use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result};
2525
use datafusion_expr::{
2626
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
2727
TypeSignature, Volatility,
2828
};
2929
use datafusion_macros::user_doc;
3030
use std::any::Any;
31+
use std::ops::Deref;
3132
use std::sync::Arc;
3233

3334
make_udf_expr_and_func!(
@@ -91,13 +92,22 @@ impl ScalarUDFImpl for MapValuesFunc {
9192
&self.signature
9293
}
9394

94-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95-
let [map_type] = take_function_args(self.name(), arg_types)?;
96-
let map_fields = get_map_entry_field(map_type)?;
97-
Ok(DataType::List(Arc::new(Field::new_list_field(
98-
map_fields.last().unwrap().data_type().clone(),
99-
true,
100-
))))
95+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96+
internal_err!("return_field_from_args should be used instead")
97+
}
98+
99+
fn return_field_from_args(
100+
&self,
101+
args: datafusion_expr::ReturnFieldArgs,
102+
) -> Result<Field> {
103+
let [map_type] = take_function_args(self.name(), args.arg_fields)?;
104+
105+
Ok(Field::new(
106+
self.name(),
107+
DataType::List(get_map_values_field_as_list_field(map_type.data_type())?),
108+
// Nullable if the map is nullable
109+
args.arg_fields.iter().any(|x| x.is_nullable()),
110+
))
101111
}
102112

103113
fn invoke_with_args(
@@ -121,9 +131,137 @@ fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
121131
};
122132

123133
Ok(Arc::new(ListArray::new(
124-
Arc::new(Field::new_list_field(map_array.value_type().clone(), true)),
134+
get_map_values_field_as_list_field(map_arg.data_type())?,
125135
map_array.offsets().clone(),
126136
Arc::clone(map_array.values()),
127137
map_array.nulls().cloned(),
128138
)))
129139
}
140+
141+
fn get_map_values_field_as_list_field(map_type: &DataType) -> Result<FieldRef> {
142+
let map_fields = get_map_entry_field(map_type)?;
143+
144+
let values_field = map_fields
145+
.last()
146+
.unwrap()
147+
.deref()
148+
.clone()
149+
.with_name(Field::LIST_FIELD_DEFAULT_NAME);
150+
151+
Ok(Arc::new(values_field))
152+
}
153+
154+
#[cfg(test)]
155+
mod tests {
156+
use crate::map_values::MapValuesFunc;
157+
use arrow::datatypes::{DataType, Field};
158+
use datafusion_common::ScalarValue;
159+
use datafusion_expr::ScalarUDFImpl;
160+
use std::sync::Arc;
161+
162+
#[test]
163+
fn return_type_field() {
164+
fn get_map_field(
165+
is_map_nullable: bool,
166+
is_keys_nullable: bool,
167+
is_values_nullable: bool,
168+
) -> Field {
169+
Field::new_map(
170+
"something",
171+
"entries",
172+
Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)),
173+
Arc::new(Field::new(
174+
"values",
175+
DataType::LargeUtf8,
176+
is_values_nullable,
177+
)),
178+
false,
179+
is_map_nullable,
180+
)
181+
}
182+
183+
fn get_list_field(
184+
name: &str,
185+
is_list_nullable: bool,
186+
list_item_type: DataType,
187+
is_list_items_nullable: bool,
188+
) -> Field {
189+
Field::new_list(
190+
name,
191+
Arc::new(Field::new_list_field(
192+
list_item_type,
193+
is_list_items_nullable,
194+
)),
195+
is_list_nullable,
196+
)
197+
}
198+
199+
fn get_return_field(field: Field) -> Field {
200+
let func = MapValuesFunc::new();
201+
let args = datafusion_expr::ReturnFieldArgs {
202+
arg_fields: &[field],
203+
scalar_arguments: &[None::<&ScalarValue>],
204+
};
205+
206+
func.return_field_from_args(args).unwrap()
207+
}
208+
209+
// Test cases:
210+
//
211+
// | Input Map || Expected Output |
212+
// | ------------------------------------------------------ || ----------------------------------------------------- |
213+
// | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
214+
// | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
215+
// | false | false | false || false | false |
216+
// | false | false | true || false | true |
217+
// | false | true | false || false | false |
218+
// | false | true | true || false | true |
219+
// | true | false | false || true | false |
220+
// | true | false | true || true | true |
221+
// | true | true | false || true | false |
222+
// | true | true | true || true | true |
223+
//
224+
// ---------------
225+
// We added the key nullability to show that it does not affect the nullability of the list or the list items.
226+
227+
assert_eq!(
228+
get_return_field(get_map_field(false, false, false)),
229+
get_list_field("map_values", false, DataType::LargeUtf8, false)
230+
);
231+
232+
assert_eq!(
233+
get_return_field(get_map_field(false, false, true)),
234+
get_list_field("map_values", false, DataType::LargeUtf8, true)
235+
);
236+
237+
assert_eq!(
238+
get_return_field(get_map_field(false, true, false)),
239+
get_list_field("map_values", false, DataType::LargeUtf8, false)
240+
);
241+
242+
assert_eq!(
243+
get_return_field(get_map_field(false, true, true)),
244+
get_list_field("map_values", false, DataType::LargeUtf8, true)
245+
);
246+
247+
assert_eq!(
248+
get_return_field(get_map_field(true, false, false)),
249+
get_list_field("map_values", true, DataType::LargeUtf8, false)
250+
);
251+
252+
assert_eq!(
253+
get_return_field(get_map_field(true, false, true)),
254+
get_list_field("map_values", true, DataType::LargeUtf8, true)
255+
);
256+
257+
assert_eq!(
258+
get_return_field(get_map_field(true, true, false)),
259+
get_list_field("map_values", true, DataType::LargeUtf8, false)
260+
);
261+
262+
assert_eq!(
263+
get_return_field(get_map_field(true, true, true)),
264+
get_list_field("map_values", true, DataType::LargeUtf8, true)
265+
);
266+
}
267+
}

0 commit comments

Comments
 (0)