@@ -22,11 +22,16 @@ use arrow_array::{
22
22
} ;
23
23
use arrow_buffer:: BooleanBufferBuilder ;
24
24
use arrow_schema:: { DataType , Schema } ;
25
- use datafusion_common:: ScalarValue ;
25
+ use datafusion_common:: { Result as DFResult , ScalarValue } ;
26
26
use datafusion_expr:: ColumnarValue ;
27
- use datafusion_physical_expr_common:: physical_expr:: { down_cast_any_ref , PhysicalExpr } ;
27
+ use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
28
28
use regex:: Regex ;
29
- use std:: { any:: Any , hash:: Hash , sync:: Arc } ;
29
+ use std:: {
30
+ any:: Any ,
31
+ fmt:: { Debug , Display , Formatter , Result as FmtResult } ,
32
+ hash:: Hash ,
33
+ sync:: Arc ,
34
+ } ;
30
35
31
36
/// ScalarRegexMatchExpr
32
37
/// Only used when evaluating regexp matching with literal pattern.
@@ -133,9 +138,7 @@ impl ScalarRegexMatchExpr {
133
138
( true , true ) => "NOT IMATCH" ,
134
139
}
135
140
}
136
- }
137
141
138
- impl ScalarRegexMatchExpr {
139
142
/// Evaluate the scalar regex match expression match array value
140
143
fn evaluate_array (
141
144
& self ,
@@ -200,16 +203,9 @@ impl ScalarRegexMatchExpr {
200
203
}
201
204
}
202
205
203
- impl std:: hash:: Hash for ScalarRegexMatchExpr {
204
- fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
205
- self . negated . hash ( state) ;
206
- self . case_insensitive . hash ( state) ;
207
- self . expr . hash ( state) ;
208
- self . pattern . hash ( state) ;
209
- }
210
- }
206
+ impl Eq for ScalarRegexMatchExpr { }
211
207
212
- impl std :: cmp :: PartialEq for ScalarRegexMatchExpr {
208
+ impl PartialEq for ScalarRegexMatchExpr {
213
209
fn eq ( & self , other : & Self ) -> bool {
214
210
self . negated . eq ( & other. negated )
215
211
&& self . case_insensitive . eq ( & self . case_insensitive )
@@ -218,8 +214,17 @@ impl std::cmp::PartialEq for ScalarRegexMatchExpr {
218
214
}
219
215
}
220
216
221
- impl std:: fmt:: Debug for ScalarRegexMatchExpr {
222
- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
217
+ impl Hash for ScalarRegexMatchExpr {
218
+ fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
219
+ self . negated . hash ( state) ;
220
+ self . case_insensitive . hash ( state) ;
221
+ self . expr . hash ( state) ;
222
+ self . pattern . hash ( state) ;
223
+ }
224
+ }
225
+
226
+ impl Debug for ScalarRegexMatchExpr {
227
+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> FmtResult {
223
228
f. debug_struct ( "ScalarRegexMatchExpr" )
224
229
. field ( "negated" , & self . negated )
225
230
. field ( "case_insensitive" , & self . case_insensitive )
@@ -229,35 +234,26 @@ impl std::fmt::Debug for ScalarRegexMatchExpr {
229
234
}
230
235
}
231
236
232
- impl std :: fmt :: Display for ScalarRegexMatchExpr {
233
- fn fmt ( & self , f : & mut std :: fmt :: Formatter ) -> std :: fmt :: Result {
237
+ impl Display for ScalarRegexMatchExpr {
238
+ fn fmt ( & self , f : & mut Formatter ) -> FmtResult {
234
239
write ! ( f, "{} {} {}" , self . expr, self . op_name( ) , self . pattern)
235
240
}
236
241
}
237
242
238
243
impl PhysicalExpr for ScalarRegexMatchExpr {
239
- fn as_any ( & self ) -> & dyn std :: any :: Any {
244
+ fn as_any ( & self ) -> & dyn Any {
240
245
self
241
246
}
242
247
243
- fn data_type (
244
- & self ,
245
- _: & arrow_schema:: Schema ,
246
- ) -> datafusion_common:: Result < arrow_schema:: DataType > {
248
+ fn data_type ( & self , _: & Schema ) -> DFResult < DataType > {
247
249
Ok ( DataType :: Boolean )
248
250
}
249
251
250
- fn nullable (
251
- & self ,
252
- input_schema : & arrow_schema:: Schema ,
253
- ) -> datafusion_common:: Result < bool > {
252
+ fn nullable ( & self , input_schema : & Schema ) -> DFResult < bool > {
254
253
Ok ( self . expr . nullable ( input_schema) ? || self . pattern . nullable ( input_schema) ?)
255
254
}
256
255
257
- fn evaluate (
258
- & self ,
259
- batch : & arrow_array:: RecordBatch ,
260
- ) -> datafusion_common:: Result < ColumnarValue > {
256
+ fn evaluate ( & self , batch : & arrow_array:: RecordBatch ) -> DFResult < ColumnarValue > {
261
257
self . expr
262
258
. evaluate ( batch)
263
259
. and_then ( |lhs| {
@@ -274,14 +270,14 @@ impl PhysicalExpr for ScalarRegexMatchExpr {
274
270
. map ( ColumnarValue :: Array )
275
271
}
276
272
277
- fn children ( & self ) -> Vec < & std :: sync :: Arc < dyn PhysicalExpr > > {
273
+ fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
278
274
vec ! [ & self . expr, & self . pattern]
279
275
}
280
276
281
277
fn with_new_children (
282
- self : std :: sync :: Arc < Self > ,
283
- children : Vec < std :: sync :: Arc < dyn PhysicalExpr > > ,
284
- ) -> datafusion_common :: Result < std :: sync :: Arc < dyn PhysicalExpr > > {
278
+ self : Arc < Self > ,
279
+ children : Vec < Arc < dyn PhysicalExpr > > ,
280
+ ) -> DFResult < Arc < dyn PhysicalExpr > > {
285
281
Ok ( Arc :: new ( ScalarRegexMatchExpr :: new (
286
282
self . negated ,
287
283
self . case_insensitive ,
@@ -290,18 +286,24 @@ impl PhysicalExpr for ScalarRegexMatchExpr {
290
286
) ) )
291
287
}
292
288
293
- fn dyn_hash ( & self , state : & mut dyn std:: hash:: Hasher ) {
294
- let mut s = state;
295
- self . hash ( & mut s) ;
296
- }
297
- }
298
-
299
- impl PartialEq < dyn Any > for ScalarRegexMatchExpr {
300
- fn eq ( & self , other : & dyn Any ) -> bool {
301
- down_cast_any_ref ( other)
302
- . downcast_ref :: < Self > ( )
303
- . map ( |x| self == x)
304
- . unwrap_or ( false )
289
+ fn evaluate_selection (
290
+ & self ,
291
+ batch : & arrow_array:: RecordBatch ,
292
+ selection : & BooleanArray ,
293
+ ) -> DFResult < ColumnarValue > {
294
+ let tmp_batch = arrow:: compute:: filter_record_batch ( batch, selection) ?;
295
+
296
+ let tmp_result = self . evaluate ( & tmp_batch) ?;
297
+
298
+ if batch. num_rows ( ) == tmp_batch. num_rows ( ) {
299
+ // All values from the `selection` filter are true.
300
+ Ok ( tmp_result)
301
+ } else if let ColumnarValue :: Array ( a) = tmp_result {
302
+ datafusion_physical_expr_common:: utils:: scatter ( selection, a. as_ref ( ) )
303
+ . map ( ColumnarValue :: Array )
304
+ } else {
305
+ Ok ( tmp_result)
306
+ }
305
307
}
306
308
}
307
309
@@ -310,7 +312,7 @@ fn array_regexp_match(
310
312
array : & dyn ArrayAccessor < Item = & str > ,
311
313
regex : & Regex ,
312
314
negated : bool ,
313
- ) -> datafusion_common :: Result < ColumnarValue > {
315
+ ) -> DFResult < ColumnarValue > {
314
316
let null_bit_buffer = array. nulls ( ) . map ( |x| x. inner ( ) . sliced ( ) ) ;
315
317
let mut buffer_builder = BooleanBufferBuilder :: new ( array. len ( ) ) ;
316
318
@@ -359,7 +361,7 @@ pub fn scalar_regex_match(
359
361
expr : Arc < dyn PhysicalExpr > ,
360
362
pattern : Arc < dyn PhysicalExpr > ,
361
363
input_schema : & Schema ,
362
- ) -> datafusion_common :: Result < Arc < dyn PhysicalExpr > > {
364
+ ) -> DFResult < Arc < dyn PhysicalExpr > > {
363
365
let valid_data_type = |data_type : & DataType | {
364
366
if !matches ! (
365
367
data_type,
0 commit comments