diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9549cfeeb3b8..16fc1e979e28 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1570,6 +1570,7 @@ dependencies = [ "log", "paste", "petgraph", + "regex", ] [[package]] diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d1de63a1e8fc..3f8c293653c5 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -53,6 +53,7 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" +regex = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -72,3 +73,7 @@ name = "case_when" [[bench]] harness = false name = "is_null" + +[[bench]] +harness = false +name = "scalar_regex_match" diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs new file mode 100644 index 000000000000..0526d48daf46 --- /dev/null +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::time::Duration; + +use arrow_array::{RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::{ + expressions::{binary, col, lit, scalar_regex_match}, + PhysicalExpr, +}; +use rand::{ + distributions::{Alphanumeric, DistString}, + rngs::StdRng, + SeedableRng, +}; + +/// make a record batch with one column and n rows +/// this record batch is single string column is used for +/// scalar regex match benchmarks +fn make_record_batch( + batch_iter: usize, + batch_size: usize, + string_len: usize, + matched_str: &[&str], + schema: &Schema, +) -> Vec { + let mut rng = StdRng::seed_from_u64(12345); + let mut batches = vec![]; + for _ in 0..batch_iter { + let mut array = (0..batch_size) + .map(|_| Some(Alphanumeric.sample_string(&mut rng, string_len))) + .collect::>(); + for v in matched_str { + array.push(Some(v.to_string())); + } + let array = StringArray::from(array); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) + .unwrap(); + batches.push(batch); + } + batches +} + +/// initialize benchmark data and pattern literals +#[allow(clippy::type_complexity)] +fn init_benchmark() -> ( + Vec<(usize, usize, Vec)>, + Schema, + Arc, + Vec<(String, Arc)>, +) { + // make common schema + let column = "s"; + let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); + + // make test record batch + let batch_data = vec![ + // (20, 10_usize, make_record_batch(20, 10, 100, schema.clone())), + // (20, 100_usize, make_record_batch(20, 100, 100, schema.clone())), + // (20, 1000_usize, make_record_batch(20, 1000, 100, schema.clone())), + ( + 128_usize, + 4096_usize, + make_record_batch( + 128, + 4096, + 100, + &[ + "example@email.com", + "http://example.com", + "123.4.5.6", + "1236787788", + "55555", + ], + &schema, + ), + ), + ]; + + // string column + let string_col = col(column, &schema).unwrap(); + + // some pattern literal + let pattern_lit = vec![ + ( + "email".to_string(), + lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"), + ), + ( + "url".to_string(), + lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$"), + ), + ( + "ip".to_string(), + lit( + r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$", + ), + ), + ( + "phone".to_string(), + lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$"), + ), + ("zip_code".to_string(), lit(r"^\d{5}(?:[-\s]\d{4})?$")), + ]; + (batch_data, schema, string_col, pattern_lit) +} + +fn regex_match_benchmark(c: &mut Criterion) { + let (batch_data, schema, string_col, pattern_lit) = init_benchmark(); + for (name, regexp_lit) in pattern_lit.iter() { + for (batch_iter, batch_size, batches) in batch_data.iter() { + let group_name = format!( + "regex_{}_batch_iter_{}_batch_size_{}", + name, batch_iter, batch_size + ); + let mut group = c.benchmark_group(group_name.as_str()); + + group.sample_size(50).measurement_time(Duration::new(30, 0)); + + // binary expr match benchmarks + group.bench_function("binary_expr_match", |b| { + b.iter(|| { + let expr = binary( + string_col.clone(), + Operator::RegexMatch, + regexp_lit.clone(), + &schema, + ) + .unwrap(); + for batch in batches.iter() { + expr.evaluate(batch).unwrap(); + } + }); + }); + // scalar regex match benchmarks + group.bench_function("scalar_regex_match", |b| { + b.iter(|| { + let expr = scalar_regex_match( + false, + false, + string_col.clone(), + regexp_lit.clone(), + &schema, + ) + .unwrap(); + for batch in batches.iter() { + expr.evaluate(batch).unwrap(); + } + }); + }); + group.finish(); + } + } +} + +criterion_group!(benches, regex_match_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f00b49f50314..462236737074 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -30,6 +30,7 @@ mod literal; mod negative; mod no_op; mod not; +mod scalar_regex_match; mod try_cast; mod unknown_column; @@ -50,5 +51,6 @@ pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; +pub use scalar_regex_match::{scalar_regex_match, ScalarRegexMatchExpr}; pub use try_cast::{try_cast, TryCastExpr}; pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs new file mode 100644 index 000000000000..1f7be76a95b2 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -0,0 +1,646 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Literal; +use arrow_array::{ + Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, +}; +use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use regex::Regex; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter, Result as FmtResult}, + hash::Hash, + sync::Arc, +}; + +/// ScalarRegexMatchExpr +/// Only used when evaluating regexp matching with literal pattern. +/// Example regex expression: c1 ~ '^a' / c1 !~ '^a' / c1 ~* '^a' / c1 !~* '^a'. +/// Literal regexp pattern will be compiled once and cached to be reused in execution. +/// It's will save compile time of pre execution and speed up execution. +#[derive(Clone)] +pub struct ScalarRegexMatchExpr { + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + compiled: Option, +} + +impl ScalarRegexMatchExpr { + pub fn new( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + ) -> Self { + let mut res = Self { + negated, + case_insensitive, + expr, + pattern, + compiled: None, + }; + res.compile().unwrap(); + res + } + + /// Is negated + pub fn negated(&self) -> bool { + self.negated + } + + /// Is case insensitive + pub fn case_insensitive(&self) -> bool { + self.case_insensitive + } + + /// Input expression + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Pattern expression + pub fn pattern(&self) -> &Arc { + &self.pattern + } + + /// Compile regex pattern + fn compile(&mut self) -> datafusion_common::Result<()> { + let scalar_pattern = + self.pattern + .as_any() + .downcast_ref::() + .and_then(|pattern| match pattern.value() { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Some(None), + ScalarValue::Utf8(Some(pattern)) + | ScalarValue::Utf8View(Some(pattern)) + | ScalarValue::LargeUtf8(Some(pattern)) => { + let mut pattern = pattern.to_string(); + if self.case_insensitive { + pattern = format!("(?i){}", pattern); + } + Some(Some(pattern)) + } + _ => None, + }); + match scalar_pattern { + Some(Some(scalar_pattern)) => Regex::new(scalar_pattern.as_str()) + .map(|compiled| { + self.compiled = Some(compiled); + }) + .map_err(|err| { + datafusion_common::DataFusionError::Internal(format!( + "Failed to compile regex: {}", + err + )) + }), + Some(None) => { + self.compiled = None; + Ok(()) + } + None => Err(datafusion_common::DataFusionError::Internal(format!( + "Regex pattern({}) isn't literal string", + self.pattern + ))), + } + } + + /// Operator name + fn op_name(&self) -> &str { + match (self.negated, self.case_insensitive) { + (false, false) => "MATCH", + (true, false) => "NOT MATCH", + (false, true) => "IMATCH", + (true, true) => "NOT IMATCH", + } + } + + /// Evaluate the scalar regex match expression match array value + fn evaluate_array( + &self, + array: &Arc, + ) -> datafusion_common::Result { + /// downcast_string_array downcast a [`ArrayRef`] to specific array type + /// example: [`StringArray`], [`LargeStringArray`], [`StringViewArray`] + macro_rules! downcast_string_array { + ($ARRAY:expr, $ARRAY_TYPE:ident, $ERR_MSG:expr) => { + &($ARRAY + .as_any() + .downcast_ref::<$ARRAY_TYPE>() + .expect($ERR_MSG)) + }; + } + match array.data_type() { + DataType::Null => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + }, + DataType::Utf8 => array_regexp_match( + downcast_string_array!(array, StringArray, "Failed to downcast StringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::Utf8View => array_regexp_match( + downcast_string_array!(array, StringViewArray, "Failed to downcast StringViewArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::LargeUtf8 => array_regexp_match( + downcast_string_array!(array, LargeStringArray, "Failed to downcast LargeStringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } + + /// Evaluate the scalar regex match expression match scalar value + fn evaluate_scalar( + &self, + scalar: &ScalarValue, + ) -> datafusion_common::Result { + match scalar { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))), + ScalarValue::Utf8(Some(scalar)) + | ScalarValue::Utf8View(Some(scalar)) + | ScalarValue::LargeUtf8(Some(scalar)) => { + let mut result = self.compiled.as_ref().unwrap().is_match(scalar); + if self.negated { + result = !result; + } + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) + }, + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } +} + +impl Eq for ScalarRegexMatchExpr {} + +impl PartialEq for ScalarRegexMatchExpr { + fn eq(&self, other: &Self) -> bool { + self.negated.eq(&other.negated) + && self.case_insensitive.eq(&self.case_insensitive) + && self.expr.eq(&other.expr) + && self.pattern.eq(&other.pattern) + } +} + +impl Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + f.debug_struct("ScalarRegexMatchExpr") + .field("negated", &self.negated) + .field("case_insensitive", &self.case_insensitive) + .field("expr", &self.expr) + .field("pattern", &self.pattern) + .finish() + } +} + +impl Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter) -> FmtResult { + write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) + } +} + +impl PhysicalExpr for ScalarRegexMatchExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> DFResult { + Ok(DataType::Boolean) + } + + fn nullable(&self, input_schema: &Schema) -> DFResult { + Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) + } + + fn evaluate(&self, batch: &arrow_array::RecordBatch) -> DFResult { + self.expr + .evaluate(batch) + .and_then(|lhs| { + if self.compiled.is_some() { + match &lhs { + ColumnarValue::Array(array) => self.evaluate_array(array), + ColumnarValue::Scalar(scalar) => self.evaluate_scalar(scalar), + } + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + } + }) + .and_then(|result| result.into_array(batch.num_rows())) + .map(ColumnarValue::Array) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr, &self.pattern] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(ScalarRegexMatchExpr::new( + self.negated, + self.case_insensitive, + Arc::clone(&children[0]), + Arc::clone(&children[1]), + ))) + } +} + +/// It is used for scalar regexp matching and copy from arrow-rs +fn array_regexp_match( + array: &dyn ArrayAccessor, + regex: &Regex, + negated: bool, +) -> DFResult { + let null_buffer = array.logical_nulls(); + let bool_buffer = if regex.as_str().is_empty() { + BooleanBuffer::new_set(array.len()) + } else { + let mut mutable_buffer = MutableBuffer::new(0); + for i in 0..array.len() { + let value = unsafe { array.value_unchecked(i) }; + if i % 8 == 0 { + mutable_buffer.push(0u8); + } + if regex.is_match(value) { + unsafe { bit_util::set_bit_raw(mutable_buffer.as_mut_ptr(), i) }; + } + } + BooleanBuffer::new(mutable_buffer.into(), 0, array.len()) + }; + + let bool_array = BooleanArray::new(bool_buffer, null_buffer); + let bool_array = if negated { + arrow::compute::kernels::boolean::not(&bool_array) + } else { + Ok(bool_array) + }; + + bool_array + .map_err(|err| { + datafusion_common::DataFusionError::Execution(format!( + "Failed to evaluate regex: {}", + err + )) + }) + .map(|bool_array| ColumnarValue::Array(Arc::new(bool_array))) +} + +/// Create a scalar regex match expression, erroring if the argument types are not compatible. +pub fn scalar_regex_match( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + input_schema: &Schema, +) -> DFResult> { + let valid_data_type = |data_type: &DataType| { + if !matches!( + data_type, + DataType::Null | DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) { + return datafusion_common::internal_err!( + "The type {data_type} not supported for scalar_regex_match, expect Null|Utf8|Utf8View|LargeUtf8" + ); + } + Ok(()) + }; + + for arg_expr in [&expr, &pattern] { + arg_expr + .data_type(input_schema) + .and_then(|data_type| valid_data_type(&data_type))?; + } + + Ok(Arc::new(ScalarRegexMatchExpr::new( + negated, + case_insensitive, + expr, + pattern, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit}; + use arrow::record_batch::RecordBatch; + use arrow_array::ArrayRef; + use arrow_array::NullArray; + use arrow_schema::Field; + use arrow_schema::Schema; + use rstest::rstest; + use std::sync::Arc; + + fn test_schema(typ: DataType) -> Schema { + Schema::new(vec![Field::new("c1", typ, false)]) + } + + #[rstest( + negated, case_insensitive, typ, a_vec, b_lit, c_vec, + case( + false, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), + ), + case( + false, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), + ), + case( + false, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), + ), + case( + true, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), + ), + case( + false, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false, true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false, true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true, false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true, false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba", "abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), + ), + case( + true, true, DataType::Null, + Arc::new(NullArray::new(10)), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![None, None, None, None, None, None, None, None, None, None])), + ), + )] + fn test_scalar_regex_match_array( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_vec: ArrayRef, + b_lit: impl datafusion_expr::Literal, + c_vec: ArrayRef, + ) { + let schema = test_schema(typ); + let left = col("c1", &schema).unwrap(); + let right = lit(b_lit); + + // verify that we can construct the expression + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a_vec]).unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + let c_vec = c_vec + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + // verify that the result is correct + assert_eq!(array, c_vec); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + negated, case_insensitive, typ, a_lit, b_lit, flag, + case( + false, false, DataType::Utf8, "abc", "^a", Some(true), + ), + case( + false, true, DataType::Utf8, "Abc", "^a", Some(true), + ), + case( + true, false, DataType::Utf8, "abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, "Abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::Utf8(None), + None, + ), + case( + false, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + false, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + true, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(None), + None, + ), + )] + fn test_scalar_regex_match_scalar( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_lit: impl datafusion_expr::Literal, + b_lit: impl datafusion_expr::Literal, + flag: Option, + ) { + let left = lit(a_lit); + let right = lit(b_lit); + let schema = test_schema(typ); + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let num_rows: usize = 3; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from([""].repeat(num_rows)))], + ) + .unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + // verify that the result is correct + let c_vec = [flag].repeat(batch.num_rows()); + assert_eq!(array, &BooleanArray::from(c_vec)); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + expr, pattern, + case( + col("c1", &test_schema(DataType::Utf8)).unwrap(), + lit(1), + ), + case( + lit(1), + col("c1", &test_schema(DataType::Utf8)).unwrap(), + ), + )] + #[should_panic] + fn test_scalar_regex_match_panic( + expr: Arc, + pattern: Arc, + ) { + let _ = + scalar_regex_match(false, false, expr, pattern, &test_schema(DataType::Utf8)) + .unwrap(); + } + + #[rstest( + pattern, + case(col("c1", &test_schema(DataType::Utf8)).unwrap()), // not literal + case(lit(1)), // not literal string + case(lit("\\x{202e")), // wrong regex pattern + )] + #[should_panic] + fn test_scalar_regex_match_compile_error(pattern: Arc) { + let _ = ScalarRegexMatchExpr::new(false, false, lit("a"), pattern); + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 906ca9fd1093..9afb3cc36bb1 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::scalar_regex_match; use crate::scalar_function; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -191,7 +192,32 @@ pub fn create_physical_expr( // // There should be no coercion during physical // planning. - binary(lhs, *op, rhs, input_schema) + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = right.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + match *op { + Operator::RegexMatch => { + scalar_regex_match(false, false, lhs, rhs, input_schema) + } + Operator::RegexNotMatch => { + scalar_regex_match(true, false, lhs, rhs, input_schema) + } + Operator::RegexIMatch => { + scalar_regex_match(false, true, lhs, rhs, input_schema) + } + Operator::RegexNotIMatch => { + scalar_regex_match(true, true, lhs, rhs, input_schema) + } + _ => binary(lhs, *op, rhs, input_schema), + } + } else { + binary(lhs, *op, rhs, input_schema) + } } Expr::Like(Like { negated, @@ -232,6 +258,23 @@ pub fn create_physical_expr( create_physical_expr(expr, input_dfschema, execution_props)?; let physical_pattern = create_physical_expr(pattern, input_dfschema, execution_props)?; + + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = pattern.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + return scalar_regex_match( + *negated, + *case_insensitive, + physical_expr, + physical_pattern, + input_schema, + ); + } similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) } Expr::Case(case) => { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 416227f70de9..509dadf3335b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -830,6 +830,8 @@ message PhysicalExprNode { PhysicalExtensionExprNode extension = 19; UnknownColumn unknown_column = 20; + + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 21; } } @@ -944,6 +946,13 @@ message PhysicalExtensionExprNode { repeated PhysicalExprNode inputs = 2; } +message PhysicalScalarRegexMatchExprNode { + bool negated = 1; + bool case_insensitive = 2; + PhysicalExprNode expr = 3; + PhysicalExprNode pattern = 4; +} + message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cffb63018676..3b4813e2cb97 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -13994,6 +13994,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::UnknownColumn(v) => { struct_ser.serialize_field("unknownColumn", v)?; } + physical_expr_node::ExprType::ScalarRegexMatchExpr(v) => { + struct_ser.serialize_field("scalarRegexMatchExpr", v)?; + } } } struct_ser.end() @@ -14036,6 +14039,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension", "unknown_column", "unknownColumn", + "scalar_regex_match_expr", + "scalarRegexMatchExpr", ]; #[allow(clippy::enum_variant_names)] @@ -14058,6 +14063,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { LikeExpr, Extension, UnknownColumn, + ScalarRegexMatchExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14097,6 +14103,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), + "scalarRegexMatchExpr" | "scalar_regex_match_expr" => Ok(GeneratedField::ScalarRegexMatchExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14243,6 +14250,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("unknownColumn")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::UnknownColumn) +; + } + GeneratedField::ScalarRegexMatchExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarRegexMatchExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarRegexMatchExpr) ; } } @@ -15700,6 +15714,149 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.negated { + len += 1; + } + if self.case_insensitive { + len += 1; + } + if self.expr.is_some() { + len += 1; + } + if self.pattern.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if self.case_insensitive { + struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "negated", + "case_insensitive", + "caseInsensitive", + "expr", + "pattern", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Negated, + CaseInsensitive, + Expr, + Pattern, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "negated" => Ok(GeneratedField::Negated), + "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarRegexMatchExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarRegexMatchExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut negated__ = None; + let mut case_insensitive__ = None; + let mut expr__ = None; + let mut pattern__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::CaseInsensitive => { + if case_insensitive__.is_some() { + return Err(serde::de::Error::duplicate_field("caseInsensitive")); + } + case_insensitive__ = Some(map_.next_value()?); + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); + } + pattern__ = map_.next_value()?; + } + } + } + Ok(PhysicalScalarRegexMatchExprNode { + negated: negated__.unwrap_or_default(), + case_insensitive: case_insensitive__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d2fda5dc8892..894c35186b22 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1165,7 +1165,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" )] pub expr_type: ::core::option::Option, } @@ -1216,6 +1216,10 @@ pub mod physical_expr_node { Extension(super::PhysicalExtensionExprNode), #[prost(message, tag = "20")] UnknownColumn(super::UnknownColumn), + #[prost(message, tag = "21")] + ScalarRegexMatchExpr( + ::prost::alloc::boxed::Box, + ), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1398,6 +1402,17 @@ pub struct PhysicalExtensionExprNode { pub inputs: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarRegexMatchExprNode { + #[prost(bool, tag = "1")] + pub negated: bool, + #[prost(bool, tag = "2")] + pub case_insensitive: bool, + #[prost(message, optional, boxed, tag = "3")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "4")] + pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index d1fe48cfec74..d383dd3f6dea 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -38,7 +38,7 @@ use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + Literal, NegativeExpr, NotExpr, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; @@ -394,6 +394,26 @@ pub fn parse_physical_expr( .collect::>()?; (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ } + ExprType::ScalarRegexMatchExpr(scalar_match_expr) => { + Arc::new(ScalarRegexMatchExpr::new( + scalar_match_expr.negated, + scalar_match_expr.case_insensitive, + parse_required_physical_expr( + scalar_match_expr.expr.as_deref(), + registry, + "expr", + input_schema, + codec, + )?, + parse_required_physical_expr( + scalar_match_expr.pattern.as_deref(), + registry, + "pattern", + input_schema, + codec, + )?, + )) + } }; Ok(pexpr) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 000d0521def3..140321b8b74d 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,7 +23,7 @@ use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWind use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + Literal, NegativeExpr, NotExpr, ScalarRegexMatchExpr, TryCastExpr, UnKnownColumn, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; @@ -365,6 +365,25 @@ pub fn serialize_physical_expr( }, ))), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::ScalarRegexMatchExpr(Box::new( + protobuf::PhysicalScalarRegexMatchExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern(), + codec, + )?)), + }, + )), + ), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) {