diff --git a/crates/df-builtins/src/insert.rs b/crates/df-builtins/src/insert.rs new file mode 100644 index 00000000..1bdaff68 --- /dev/null +++ b/crates/df-builtins/src/insert.rs @@ -0,0 +1,473 @@ +use datafusion::arrow::array::{BinaryBuilder, StringBuilder}; +use datafusion::arrow::datatypes::DataType; +use datafusion::error::Result as DFResult; +use datafusion_common::cast::{as_binary_array, as_int64_array, as_string_array}; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +// insert SQL function +// Replaces a substring of the specified length, starting at the specified position, with a +// new string or binary value. +// This function should not be confused with the INSERT DML command. +// Syntax: INSERT( , , , ) +// Arguments: +// +// The string or BINARY expression for which you want to insert/replace characters. +// +// The offset at which to start inserting characters. This is 1-based, not 0-based. +// In other words, the first character in the string is considered to be at position 1, +// not position 0. For example, to insert at the beginning of the string, set to 1. +// Valid values are between 1 and one more than the length of the string (inclusive). +// Setting to one more than the length of the string makes the operation equivalent +// to an append. (This also requires that the parameter be 0 because you should not +// try to delete any characters past the last character.) +// +// The number of characters (starting at ) that you want to replace. Valid values range +// from 0 to the number of characters between and the end of the string. If this is 0, +// it means add the new characters without deleting any existing characters. +// +// The string to insert into the base_expr. If this string is empty, and if len is greater than +// zero, then effectively the operation becomes a delete (some characters are deleted, and none are added). +// Usage notes: +// The and should be the same data type; either both should be string +// (e.g. VARCHAR) or both should be binary. +// If any of the arguments are NULL, the returned value is NULL. +// Note: `insert` returns +// Returns a string or BINARY that is equivalent to making a copy of , deleting characters +// starting at , and then inserting at . +// Note that the original input base_expr is not changed; the function returns a separate (modified) copy. +#[derive(Debug)] +pub struct Insert { + signature: Signature, +} + +impl Default for Insert { + fn default() -> Self { + Self::new() + } +} + +impl Insert { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::Int64, + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::Binary, + DataType::Int64, + DataType::Int64, + DataType::Binary, + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for Insert { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + "insert" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DFResult { + match (arg_types[0].clone(), arg_types[3].clone()) { + (DataType::Utf8, DataType::Utf8) => Ok(DataType::Utf8), + (DataType::Binary, DataType::Binary) => Ok(DataType::Binary), + _ => plan_err!("Base expression and insert expression should be the same data type"), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + + let base_arr = match &args[0] { + ColumnarValue::Array(arr) => arr, + ColumnarValue::Scalar(v) => &v.to_array()?, + }; + + let pos_arr = match &args[1] { + ColumnarValue::Array(arr) => arr, + ColumnarValue::Scalar(v) => &v.to_array()?, + }; + + let len_arr = match &args[2] { + ColumnarValue::Array(arr) => arr, + ColumnarValue::Scalar(v) => &v.to_array()?, + }; + + let insert_arr = match &args[3] { + ColumnarValue::Array(arr) => arr, + ColumnarValue::Scalar(v) => &v.to_array()?, + }; + + let res_arr: Arc = match base_arr.data_type() { + DataType::Utf8 => { + let base_strs = as_string_array(&base_arr)?; + let pos_i64s = as_int64_array(&pos_arr)?; + let len_i64s = as_int64_array(&len_arr)?; + let insert_strs = as_string_array(&insert_arr)?; + + let mut res = StringBuilder::new(); + + let zipped = super::macros::izip!( + base_strs.iter(), + pos_i64s.iter(), + len_i64s.iter(), + insert_strs.iter() + ); + + for (base, pos, len, insert) in zipped { + if base.is_none() || pos.is_none() || len.is_none() || insert.is_none() { + res.append_null(); + continue; + } + + let inserted_string = insert_string(base, pos, len, insert); + match inserted_string { + Ok(v) => res.append_value(v), + Err(e) => return Err(e), + }; + } + Arc::new(res.finish()) + } + DataType::Binary => { + let base_bin = as_binary_array(&base_arr)?; + let pos_i64s = as_int64_array(&pos_arr)?; + let len_i64s = as_int64_array(&len_arr)?; + let insert_bin = as_binary_array(&insert_arr)?; + + let mut res = BinaryBuilder::new(); + + let zipped = super::macros::izip!( + base_bin.iter(), + pos_i64s.iter(), + len_i64s.iter(), + insert_bin.iter() + ); + + for (base, pos, len, insert) in zipped { + if base.is_none() || pos.is_none() || len.is_none() || insert.is_none() { + res.append_null(); + continue; + } + + let inserted_binary = insert_binary(base, pos, len, insert); + match inserted_binary { + Ok(v) => res.append_value(v), + Err(e) => return Err(e), + }; + } + + Arc::new(res.finish()) + } + _ => { + return exec_err!("Invalid datatype"); + } + }; + + Ok(ColumnarValue::Array(res_arr)) + } +} + +#[allow(clippy::unwrap_used)] +fn insert_string( + base: Option<&str>, + pos: Option, + len: Option, + insert: Option<&str>, +) -> Result { + let base = base.unwrap(); + let base_len: i64 = base.len().try_into().unwrap(); + let pos = pos.unwrap(); + let len = len.unwrap(); + let insert = insert.unwrap(); + + if pos < 1 || pos > base_len + 1 { + return exec_err!( + "Valid values for position are between 1 and one more than length of string" + ); + } + + if len < 0 || len > (base_len - pos + 1) { + return exec_err!( + "Valid values for length range from 0 to the number of characters from position to end of string" + ); + } + + let mut chs = base.chars(); + let left_part = chs + .by_ref() + .take((pos - 1).try_into().unwrap()) + .collect::(); + let right_part = chs.skip(len.try_into().unwrap()).collect::(); + + Ok(format!("{left_part}{insert}{right_part}")) +} + +#[allow(clippy::unwrap_used)] +fn insert_binary( + base: Option<&[u8]>, + pos: Option, + len: Option, + insert: Option<&[u8]>, +) -> Result, datafusion_common::DataFusionError> { + let base = base.unwrap(); + let pos = pos.unwrap(); + let len = len.unwrap(); + let insert = insert.unwrap(); + + let base_len: i64 = base.len().try_into().unwrap(); + + if pos < 1 || pos > base_len + 1 { + return exec_err!( + "Valid values for position are between 1 and one more than length of binary" + ); + } + + if len < 0 || len > (base_len - pos + 1) { + return exec_err!( + "Valid values for length range from 0 to the number of bytes from position to end of binary" + ); + } + + let pos: usize = (pos - 1).try_into().unwrap(); + let len: usize = len.try_into().unwrap(); + + let mut result = vec![]; + result.extend_from_slice(&base[..pos]); + result.extend_from_slice(insert); + result.extend_from_slice(&base[pos + len..]); + + Ok(result) +} + +super::macros::make_udf_function!(Insert); + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::prelude::SessionContext; + use datafusion_common::{DataFusionError, assert_batches_eq}; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_it_works() -> DFResult<()> { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(Insert::new())); + + let q = "SELECT INSERT('abc', 1, 2, 'Z') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| Zc |", "+-----+",], + &result + ); + + let q = "SELECT INSERT('abcdef', 3, 2, 'zzz') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &[ + "+---------+", + "| str |", + "+---------+", + "| abzzzef |", + "+---------+", + ], + &result + ); + + let q = "SELECT INSERT('abc', 2, 1, '') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| ac |", "+-----+",], + &result + ); + + let q = "SELECT INSERT('abc', 4, 0, 'Z') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+------+", "| str |", "+------+", "| abcZ |", "+------+",], + &result + ); + + let q = "SELECT INSERT(NULL, 1, 2, 'Z') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| |", "+-----+",], + &result + ); + + let q = "SELECT INSERT('abc', NULL, 2, 'Z') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| |", "+-----+",], + &result + ); + + let q = "SELECT INSERT('abc', 1, NULL, 'Z') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| |", "+-----+",], + &result + ); + + let q = "SELECT INSERT('abc', 1, 2, NULL) as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &["+-----+", "| str |", "+-----+", "| |", "+-----+",], + &result + ); + + // Binary data test + let q = "SELECT INSERT(X'1234567890ABCDEF', 4, 2, X'FFFF') as STR;"; + let result = ctx.sql(q).await?.collect().await?; + + assert_batches_eq!( + &[ + "+------------------+", + "| str |", + "+------------------+", + "| 123456ffffabcdef |", + "+------------------+", + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_invalid_pos_fails() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(Insert::new())); + + // pos less than 1 + let q = "SELECT INSERT('abc', 0, 2, 'Z') as STR;"; + let result = ctx.sql(q).await; + + if let Ok(df) = result { + let result = df.collect().await; + + match result { + Err(e) => { + assert!( + matches!(e, DataFusionError::Execution(_)), + "Expected Execution error, got: {e}", + ); + } + Ok(_) => panic!("Expected error but query succeeded"), + } + } + + // pos beyond one more than the length of the string + let q = "SELECT INSERT('abc', 5, 2, 'Z') as STR;"; + let result = ctx.sql(q).await; + + if let Ok(df) = result { + let result = df.collect().await; + + match result { + Err(e) => { + assert!( + matches!(e, DataFusionError::Execution(_)), + "Expected Execution error, got: {e}", + ); + } + Ok(_) => panic!("Expected error but query succeeded"), + } + } + + // Setting pos to one more than the length of the string + // means that len must be 0 + let q = "SELECT INSERT('abc', 4, 1, 'Z') as STR;"; + let result = ctx.sql(q).await; + + if let Ok(df) = result { + let result = df.collect().await; + + match result { + Err(e) => { + assert!( + matches!(e, DataFusionError::Execution(_)), + "Expected Execution error, got: {e}", + ); + } + Ok(_) => panic!("Expected error but query succeeded"), + } + } + } + + #[tokio::test] + async fn test_invalid_len_fails() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(Insert::new())); + + // Length is outside valid range from 0 to number of characters + // between pos and end of the string + let q = "SELECT INSERT('abc', 1, 50, 'Z') as STR;"; + let result = ctx.sql(q).await; + + if let Ok(df) = result { + let result = df.collect().await; + + match result { + Err(e) => { + assert!( + matches!(e, DataFusionError::Execution(_)), + "Expected Execution error, got: {e}", + ); + } + Ok(_) => panic!("Expected error but query succeeded"), + } + } + } + + #[tokio::test] + async fn test_mixed_arg_types_fails() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(Insert::new())); + + let q = "SELECT INSERT(X'1234567890ABCDEF', 4, 2, 'abc') as STR;"; + let result = ctx.sql(q).await; + + if let Ok(df) = result { + let result = df.collect().await; + + match result { + Err(e) => { + assert!( + matches!(e, DataFusionError::ArrowError(..)), + "Expected Arrow error, got: {e}", + ); + } + Ok(_) => panic!("Expected error but query succeeded"), + } + } + } +} diff --git a/crates/df-builtins/src/lib.rs b/crates/df-builtins/src/lib.rs index 0720589d..705bea17 100644 --- a/crates/df-builtins/src/lib.rs +++ b/crates/df-builtins/src/lib.rs @@ -9,6 +9,8 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::DataType; use datafusion::{common::Result, execution::FunctionRegistry, logical_expr::ScalarUDF}; use datafusion_common::DataFusionError; +#[doc(hidden)] +pub use std::iter as __std_iter; use std::sync::Arc; pub(crate) mod aggregate; @@ -24,6 +26,7 @@ mod boolor; mod boolxor; mod equal_null; mod iff; +mod insert; mod is_array; mod is_object; mod nullifzero; @@ -53,6 +56,7 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { is_array::get_udf(), array_flatten::get_udf(), array_to_string::get_udf(), + insert::get_udf(), Arc::new(ScalarUDF::from(ToBooleanFunc::new(false))), Arc::new(ScalarUDF::from(ToBooleanFunc::new(true))), Arc::new(ScalarUDF::from(ToTimeFunc::new(false))), @@ -67,6 +71,48 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { } mod macros { + // Adopted from itertools: https://docs.rs/itertools/latest/src/itertools/lib.rs.html#321-360 + macro_rules! izip { + // @closure creates a tuple-flattening closure for .map() call. usage: + // @closure partial_pattern => partial_tuple , rest , of , iterators + // eg. izip!( @closure ((a, b), c) => (a, b, c) , dd , ee ) + ( @closure $p:pat => $tup:expr ) => { + |$p| $tup + }; + + // The "b" identifier is a different identifier on each recursion level thanks to hygiene. + ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => { + $crate::macros::izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*) + }; + + // unary + ($first:expr $(,)*) => { + $crate::__std_iter::IntoIterator::into_iter($first) + }; + + // binary + ($first:expr, $second:expr $(,)*) => { + $crate::__std_iter::Iterator::zip( + $crate::__std_iter::IntoIterator::into_iter($first), + $second, + ) + }; + + // n-ary where n > 2 + ( $first:expr $( , $rest:expr )* $(,)* ) => { + { + let iter = $crate::__std_iter::IntoIterator::into_iter($first); + $( + let iter = $crate::__std_iter::Iterator::zip(iter, $rest); + )* + $crate::__std_iter::Iterator::map( + iter, + $crate::macros::izip!(@closure a => (a) $( , $rest )*) + ) + } + }; + } + macro_rules! make_udf_function { ($udf_type:ty) => { paste::paste! { @@ -86,6 +132,7 @@ mod macros { } } + pub(crate) use izip; pub(crate) use make_udf_function; }