19
19
20
20
use crate :: utils:: { get_map_entry_field, make_scalar_function} ;
21
21
use arrow:: array:: { Array , ArrayRef , ListArray } ;
22
- use arrow:: datatypes:: { DataType , Field } ;
22
+ use arrow:: datatypes:: { DataType , Field , FieldRef } ;
23
23
use datafusion_common:: utils:: take_function_args;
24
- use datafusion_common:: { cast:: as_map_array, exec_err, Result } ;
24
+ use datafusion_common:: { cast:: as_map_array, exec_err, internal_err , Result } ;
25
25
use datafusion_expr:: {
26
26
ArrayFunctionSignature , ColumnarValue , Documentation , ScalarUDFImpl , Signature ,
27
27
TypeSignature , Volatility ,
28
28
} ;
29
29
use datafusion_macros:: user_doc;
30
30
use std:: any:: Any ;
31
+ use std:: ops:: Deref ;
31
32
use std:: sync:: Arc ;
32
33
33
34
make_udf_expr_and_func ! (
@@ -91,13 +92,22 @@ impl ScalarUDFImpl for MapValuesFunc {
91
92
& self . signature
92
93
}
93
94
94
- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
95
- let [ map_type] = take_function_args ( self . name ( ) , arg_types) ?;
96
- let map_fields = get_map_entry_field ( map_type) ?;
97
- Ok ( DataType :: List ( Arc :: new ( Field :: new_list_field (
98
- map_fields. last ( ) . unwrap ( ) . data_type ( ) . clone ( ) ,
99
- true ,
100
- ) ) ) )
95
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
96
+ internal_err ! ( "return_field_from_args should be used instead" )
97
+ }
98
+
99
+ fn return_field_from_args (
100
+ & self ,
101
+ args : datafusion_expr:: ReturnFieldArgs ,
102
+ ) -> Result < Field > {
103
+ let [ map_type] = take_function_args ( self . name ( ) , args. arg_fields ) ?;
104
+
105
+ Ok ( Field :: new (
106
+ self . name ( ) ,
107
+ DataType :: List ( get_map_values_field_as_list_field ( map_type. data_type ( ) ) ?) ,
108
+ // Nullable if the map is nullable
109
+ args. arg_fields . iter ( ) . any ( |x| x. is_nullable ( ) ) ,
110
+ ) )
101
111
}
102
112
103
113
fn invoke_with_args (
@@ -121,9 +131,137 @@ fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
121
131
} ;
122
132
123
133
Ok ( Arc :: new ( ListArray :: new (
124
- Arc :: new ( Field :: new_list_field ( map_array . value_type ( ) . clone ( ) , true ) ) ,
134
+ get_map_values_field_as_list_field ( map_arg . data_type ( ) ) ? ,
125
135
map_array. offsets ( ) . clone ( ) ,
126
136
Arc :: clone ( map_array. values ( ) ) ,
127
137
map_array. nulls ( ) . cloned ( ) ,
128
138
) ) )
129
139
}
140
+
141
+ fn get_map_values_field_as_list_field ( map_type : & DataType ) -> Result < FieldRef > {
142
+ let map_fields = get_map_entry_field ( map_type) ?;
143
+
144
+ let values_field = map_fields
145
+ . last ( )
146
+ . unwrap ( )
147
+ . deref ( )
148
+ . clone ( )
149
+ . with_name ( Field :: LIST_FIELD_DEFAULT_NAME ) ;
150
+
151
+ Ok ( Arc :: new ( values_field) )
152
+ }
153
+
154
+ #[ cfg( test) ]
155
+ mod tests {
156
+ use crate :: map_values:: MapValuesFunc ;
157
+ use arrow:: datatypes:: { DataType , Field } ;
158
+ use datafusion_common:: ScalarValue ;
159
+ use datafusion_expr:: ScalarUDFImpl ;
160
+ use std:: sync:: Arc ;
161
+
162
+ #[ test]
163
+ fn return_type_field ( ) {
164
+ fn get_map_field (
165
+ is_map_nullable : bool ,
166
+ is_keys_nullable : bool ,
167
+ is_values_nullable : bool ,
168
+ ) -> Field {
169
+ Field :: new_map (
170
+ "something" ,
171
+ "entries" ,
172
+ Arc :: new ( Field :: new ( "keys" , DataType :: Utf8 , is_keys_nullable) ) ,
173
+ Arc :: new ( Field :: new (
174
+ "values" ,
175
+ DataType :: LargeUtf8 ,
176
+ is_values_nullable,
177
+ ) ) ,
178
+ false ,
179
+ is_map_nullable,
180
+ )
181
+ }
182
+
183
+ fn get_list_field (
184
+ name : & str ,
185
+ is_list_nullable : bool ,
186
+ list_item_type : DataType ,
187
+ is_list_items_nullable : bool ,
188
+ ) -> Field {
189
+ Field :: new_list (
190
+ name,
191
+ Arc :: new ( Field :: new_list_field (
192
+ list_item_type,
193
+ is_list_items_nullable,
194
+ ) ) ,
195
+ is_list_nullable,
196
+ )
197
+ }
198
+
199
+ fn get_return_field ( field : Field ) -> Field {
200
+ let func = MapValuesFunc :: new ( ) ;
201
+ let args = datafusion_expr:: ReturnFieldArgs {
202
+ arg_fields : & [ field] ,
203
+ scalar_arguments : & [ None :: < & ScalarValue > ] ,
204
+ } ;
205
+
206
+ func. return_field_from_args ( args) . unwrap ( )
207
+ }
208
+
209
+ // Test cases:
210
+ //
211
+ // | Input Map || Expected Output |
212
+ // | ------------------------------------------------------ || ----------------------------------------------------- |
213
+ // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
214
+ // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
215
+ // | false | false | false || false | false |
216
+ // | false | false | true || false | true |
217
+ // | false | true | false || false | false |
218
+ // | false | true | true || false | true |
219
+ // | true | false | false || true | false |
220
+ // | true | false | true || true | true |
221
+ // | true | true | false || true | false |
222
+ // | true | true | true || true | true |
223
+ //
224
+ // ---------------
225
+ // We added the key nullability to show that it does not affect the nullability of the list or the list items.
226
+
227
+ assert_eq ! (
228
+ get_return_field( get_map_field( false , false , false ) ) ,
229
+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , false )
230
+ ) ;
231
+
232
+ assert_eq ! (
233
+ get_return_field( get_map_field( false , false , true ) ) ,
234
+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , true )
235
+ ) ;
236
+
237
+ assert_eq ! (
238
+ get_return_field( get_map_field( false , true , false ) ) ,
239
+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , false )
240
+ ) ;
241
+
242
+ assert_eq ! (
243
+ get_return_field( get_map_field( false , true , true ) ) ,
244
+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , true )
245
+ ) ;
246
+
247
+ assert_eq ! (
248
+ get_return_field( get_map_field( true , false , false ) ) ,
249
+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , false )
250
+ ) ;
251
+
252
+ assert_eq ! (
253
+ get_return_field( get_map_field( true , false , true ) ) ,
254
+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , true )
255
+ ) ;
256
+
257
+ assert_eq ! (
258
+ get_return_field( get_map_field( true , true , false ) ) ,
259
+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , false )
260
+ ) ;
261
+
262
+ assert_eq ! (
263
+ get_return_field( get_map_field( true , true , true ) ) ,
264
+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , true )
265
+ ) ;
266
+ }
267
+ }
0 commit comments