From 18ab0d4c4e2b3087c34d0df957e9eed6e5a762c5 Mon Sep 17 00:00:00 2001 From: Maxim Bogdanov Date: Fri, 16 May 2025 17:35:00 +0200 Subject: [PATCH 1/2] wip --- crates/core-executor/src/session.rs | 2 + crates/df-builtins/Cargo.toml | 1 + crates/df-builtins/src/aggregate/mod.rs | 1 + .../df-builtins/src/aggregate/object_agg.rs | 539 +++++++ crates/df-builtins/src/lib.rs | 49 +- crates/df-builtins/src/table/mod.rs | 2 +- .../df-builtins/src/variant/array_append.rs | 252 ++++ crates/df-builtins/src/variant/array_cat.rs | 248 ++++ .../df-builtins/src/variant/array_compact.rs | 190 +++ .../src/variant/array_construct.rs | 207 +++ .../df-builtins/src/variant/array_contains.rs | 262 ++++ .../df-builtins/src/variant/array_distinct.rs | 194 +++ .../df-builtins/src/variant/array_except.rs | 269 ++++ .../src/{ => variant}/array_flatten.rs | 3 +- .../src/variant/array_generate_range.rs | 210 +++ .../df-builtins/src/variant/array_insert.rs | 283 ++++ .../src/variant/array_intersection.rs | 246 ++++ crates/df-builtins/src/variant/array_max.rs | 249 ++++ crates/df-builtins/src/variant/array_min.rs | 249 ++++ .../df-builtins/src/variant/array_position.rs | 216 +++ .../df-builtins/src/variant/array_prepend.rs | 207 +++ .../df-builtins/src/variant/array_remove.rs | 275 ++++ .../src/variant/array_remove_at.rs | 292 ++++ .../df-builtins/src/variant/array_reverse.rs | 225 +++ crates/df-builtins/src/variant/array_size.rs | 189 +++ crates/df-builtins/src/variant/array_slice.rs | 275 ++++ crates/df-builtins/src/variant/array_sort.rs | 291 ++++ .../src/{ => variant}/array_to_string.rs | 3 +- .../df-builtins/src/variant/arrays_overlap.rs | 226 +++ .../src/variant/arrays_to_object.rs | 244 ++++ crates/df-builtins/src/variant/arrays_zip.rs | 253 ++++ crates/df-builtins/src/variant/json.rs | 1281 +++++++++++++++++ crates/df-builtins/src/variant/mod.rs | 75 + .../src/variant/object_construct.rs | 292 ++++ .../df-builtins/src/variant/object_delete.rs | 235 +++ .../df-builtins/src/variant/object_insert.rs | 333 +++++ crates/df-builtins/src/variant/object_pick.rs | 234 +++ .../src/variant/variant_element.rs | 377 +++++ crates/df-builtins/src/visitors/mod.rs | 1 + .../src/visitors/variant/array_agg.rs | 122 ++ .../src/visitors/variant/array_construct.rs | 131 ++ .../variant/array_construct_compact.rs | 111 ++ .../df-builtins/src/visitors/variant/mod.rs | 15 + .../src/visitors/variant/type_rewrite.rs | 108 ++ .../src/visitors/variant/variant_element.rs | 69 + 45 files changed, 9485 insertions(+), 51 deletions(-) create mode 100644 crates/df-builtins/src/aggregate/object_agg.rs create mode 100644 crates/df-builtins/src/variant/array_append.rs create mode 100644 crates/df-builtins/src/variant/array_cat.rs create mode 100644 crates/df-builtins/src/variant/array_compact.rs create mode 100644 crates/df-builtins/src/variant/array_construct.rs create mode 100644 crates/df-builtins/src/variant/array_contains.rs create mode 100644 crates/df-builtins/src/variant/array_distinct.rs create mode 100644 crates/df-builtins/src/variant/array_except.rs rename crates/df-builtins/src/{ => variant}/array_flatten.rs (99%) create mode 100644 crates/df-builtins/src/variant/array_generate_range.rs create mode 100644 crates/df-builtins/src/variant/array_insert.rs create mode 100644 crates/df-builtins/src/variant/array_intersection.rs create mode 100644 crates/df-builtins/src/variant/array_max.rs create mode 100644 crates/df-builtins/src/variant/array_min.rs create mode 100644 crates/df-builtins/src/variant/array_position.rs create mode 100644 crates/df-builtins/src/variant/array_prepend.rs create mode 100644 crates/df-builtins/src/variant/array_remove.rs create mode 100644 crates/df-builtins/src/variant/array_remove_at.rs create mode 100644 crates/df-builtins/src/variant/array_reverse.rs create mode 100644 crates/df-builtins/src/variant/array_size.rs create mode 100644 crates/df-builtins/src/variant/array_slice.rs create mode 100644 crates/df-builtins/src/variant/array_sort.rs rename crates/df-builtins/src/{ => variant}/array_to_string.rs (99%) create mode 100644 crates/df-builtins/src/variant/arrays_overlap.rs create mode 100644 crates/df-builtins/src/variant/arrays_to_object.rs create mode 100644 crates/df-builtins/src/variant/arrays_zip.rs create mode 100644 crates/df-builtins/src/variant/json.rs create mode 100644 crates/df-builtins/src/variant/mod.rs create mode 100644 crates/df-builtins/src/variant/object_construct.rs create mode 100644 crates/df-builtins/src/variant/object_delete.rs create mode 100644 crates/df-builtins/src/variant/object_insert.rs create mode 100644 crates/df-builtins/src/variant/object_pick.rs create mode 100644 crates/df-builtins/src/variant/variant_element.rs create mode 100644 crates/df-builtins/src/visitors/mod.rs create mode 100644 crates/df-builtins/src/visitors/variant/array_agg.rs create mode 100644 crates/df-builtins/src/visitors/variant/array_construct.rs create mode 100644 crates/df-builtins/src/visitors/variant/array_construct_compact.rs create mode 100644 crates/df-builtins/src/visitors/variant/mod.rs create mode 100644 crates/df-builtins/src/visitors/variant/type_rewrite.rs create mode 100644 crates/df-builtins/src/visitors/variant/variant_element.rs diff --git a/crates/core-executor/src/session.rs b/crates/core-executor/src/session.rs index a17b3e66..29e9bdda 100644 --- a/crates/core-executor/src/session.rs +++ b/crates/core-executor/src/session.rs @@ -36,6 +36,7 @@ use std::any::Any; use std::collections::HashMap; use std::env; use std::sync::Arc; +use df_builtins::table::register_udtfs; pub struct UserSession { pub metastore: Arc, @@ -77,6 +78,7 @@ impl UserSession { let mut ctx = SessionContext::new_with_state(state); register_udfs(&mut ctx).context(ex_error::RegisterUDFSnafu)?; register_udafs(&mut ctx).context(ex_error::RegisterUDAFSnafu)?; + register_udtfs(&mut ctx); register_json_udfs(&mut ctx).context(ex_error::RegisterUDFSnafu)?; //register_geo_native(&ctx); //register_geo_udfs(&ctx); diff --git a/crates/df-builtins/Cargo.toml b/crates/df-builtins/Cargo.toml index b2093695..c4dc7b65 100644 --- a/crates/df-builtins/Cargo.toml +++ b/crates/df-builtins/Cargo.toml @@ -15,6 +15,7 @@ datafusion-physical-plan = { workspace = true } paste = "1" serde = { workspace = true } serde_json = { workspace = true } +jsonpath_lib = "0.3.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } diff --git a/crates/df-builtins/src/aggregate/mod.rs b/crates/df-builtins/src/aggregate/mod.rs index a36280ad..a8322759 100644 --- a/crates/df-builtins/src/aggregate/mod.rs +++ b/crates/df-builtins/src/aggregate/mod.rs @@ -24,6 +24,7 @@ pub mod array_unique_agg; pub mod booland_agg; pub mod boolor_agg; pub mod boolxor_agg; +pub mod object_agg; pub mod percentile_cont; pub fn register_udafs(registry: &mut dyn FunctionRegistry) -> datafusion_common::Result<()> { diff --git a/crates/df-builtins/src/aggregate/object_agg.rs b/crates/df-builtins/src/aggregate/object_agg.rs new file mode 100644 index 00000000..2458c163 --- /dev/null +++ b/crates/df-builtins/src/aggregate/object_agg.rs @@ -0,0 +1,539 @@ +use super::macros::make_udaf_function; +use serde_json::Value as JsonValue; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; + +use datafusion::arrow::array::as_list_array; +use datafusion::arrow::datatypes::{DataType, Field, Fields}; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::array::{new_empty_array, Array}; +use datafusion::arrow::array::{ArrayRef, StructArray}; +use datafusion::common::ScalarValue; + +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::Volatility; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature}; + +#[derive(Debug, Clone)] +pub struct ObjectAggUDAF { + signature: Signature, +} + +impl Default for ObjectAggUDAF { + fn default() -> Self { + Self::new() + } +} + +impl ObjectAggUDAF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ObjectAggUDAF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + "object_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + // one field: a List> + let entry_struct = Field::new_list( + format_state_name(args.name, "object_agg"), + Field::new_list_field( + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", args.input_types[1].clone(), true), + ])), + true, + ), + true, + ); + Ok(vec![entry_struct]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return exec_err!("OBJECT_AGG does not yet support DISTINCT"); + } + + if !acc_args.ordering_req.is_empty() { + return exec_err!("OBJECT_AGG does not yet support ORDER BY"); + } + + let key_type = acc_args.exprs[0].data_type(acc_args.schema)?; + match key_type { + DataType::Utf8 | DataType::LargeUtf8 => {} + _ => { + return exec_err!("OBJECT_AGG key must be Utf8 or LargeUtf8, got {key_type:?}"); + } + } + + let value_type = acc_args.exprs[1].data_type(acc_args.schema)?; + + Ok(Box::new(ObjectAggAccumulator::try_new(&value_type))) + } +} + +#[derive(Debug)] +struct ObjectAggAccumulator { + keys: Vec, + values: Vec, + value_type: DataType, + keys_seen: HashSet, +} + +impl ObjectAggAccumulator { + pub fn try_new(value_type: &DataType) -> Self { + Self { + keys: Vec::new(), + values: Vec::new(), + value_type: value_type.clone(), + keys_seen: HashSet::new(), + } + } +} + +impl Accumulator for ObjectAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let Some(key_array) = values[0].as_any().downcast_ref::() else { + return internal_err!("Key column must be Utf8 or LargeUtf8"); + }; + + let val_array = &values[1]; + + for row in 0..key_array.len() { + if key_array.is_null(row) || val_array.is_null(row) { + continue; + } + + let key_str = ScalarValue::from(key_array.value(row)); + if !self.keys_seen.insert(key_str.clone()) { + return internal_err!("Duplicate keys are not allowed, key found: {}", key_str); + } + + let k = ScalarValue::Utf8(Some(key_array.value(row).to_string())); + let v = ScalarValue::try_from_array(val_array, row)?; + self.keys.push(k); + self.values.push(v); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + // states[0] is the ListArray> from your state() + let list = as_list_array(&*states[0]); + + for maybe_struct in list.iter().flatten() { + let Some(array) = maybe_struct.as_any().downcast_ref::() else { + return internal_err!("OBJECT_AGG state is not a StructArray"); + }; + + let key_arr = array + .column_by_name("key") + .ok_or(DataFusionError::Internal("Missing key column".to_string()))? + .as_any() + .downcast_ref::() + .ok_or(DataFusionError::Internal( + "Key column is not a StringArray".to_string(), + ))?; + + let val_arr = array + .column_by_name("value") + .ok_or(DataFusionError::Internal( + "Missing value column".to_string(), + ))?; + + for i in 0..array.len() { + if array.is_null(i) { + continue; + } + let key = ScalarValue::Utf8(Some(key_arr.value(i).to_string())); + if !self.keys_seen.insert(key.clone()) { + return internal_err!( + "Duplicate keys are not allowed, key found: {}", + key_arr.value(i) + ); + } + let val = ScalarValue::try_from_array(val_arr, i)?; + self.keys.push(key); + self.values.push(val); + } + } + + Ok(()) + } + + fn state(&mut self) -> Result> { + let key_array: ArrayRef = if self.keys.is_empty() { + new_empty_array(&DataType::Utf8) + } else { + ScalarValue::iter_to_array(self.keys.clone())? + }; + + let val_array: ArrayRef = if self.values.is_empty() { + new_empty_array(&self.value_type) + } else { + ScalarValue::iter_to_array(self.values.clone())? + }; + + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let val_field = Arc::new(Field::new("value", self.value_type.clone(), true)); + + let entries = StructArray::from(vec![(key_field, key_array), (val_field, val_array)]); + + let map_scalar = SingleRowListArrayBuilder::new(Arc::new(entries)).build_list_scalar(); + + Ok(vec![map_scalar]) + } + + fn evaluate(&mut self) -> Result { + let mut obj = serde_json::Map::with_capacity(self.keys.len()); + for (k_sv, v_sv) in self.keys.iter().zip(self.values.iter()) { + let key = match k_sv { + ScalarValue::Utf8(Some(s)) => s.clone(), + _ => continue, + }; + + let value: JsonValue = match v_sv { + ScalarValue::Utf8(Some(s)) => { + serde_json::from_str(s).map_err(|e| DataFusionError::Internal(e.to_string()))? + } + _ => continue, + }; + obj.insert(key, value); + } + + let json_text = JsonValue::Object(obj).to_string(); + + Ok(ScalarValue::Utf8(Some(json_text))) + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec(&self.keys) - size_of_val(&self.keys) + + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values) + + self.value_type.size() + - size_of_val(&self.value_type) + + ScalarValue::size_of_hashset(&self.keys_seen) + - size_of_val(&self.keys_seen) + } +} + +make_udaf_function!(ObjectAggUDAF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::physical_expr::LexOrdering; + use datafusion_common::{internal_err, Result}; + use datafusion_physical_plan::expressions::Column; + use datafusion_physical_plan::Accumulator; + use serde_json::json; + use serde_json::{Map as JsonMap, Value as JsonValue}; + use std::sync::Arc; + + struct ObjectAggAccumulatorBuilder { + data_type: DataType, + distinct: bool, + ordering: LexOrdering, + schema: Schema, + } + + impl ObjectAggAccumulatorBuilder { + fn string() -> Self { + Self::new(DataType::Utf8) + } + + fn new(value_type: DataType) -> Self { + let schema = Schema::new(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", value_type.clone(), true), + ]); + + Self { + data_type: value_type, + distinct: Default::default(), + ordering: LexOrdering::default(), + schema, + } + } + + fn build(&self) -> Result> { + ObjectAggUDAF::default().accumulator(AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: false, + ordering_req: &self.ordering, + is_reversed: false, + name: "", + is_distinct: self.distinct, + exprs: &[ + Arc::new(Column::new("key", 0)), + Arc::new(Column::new("value", 1)), + ], + }) + } + + fn build_two(&self) -> Result<(Box, Box)> { + Ok((self.build()?, self.build()?)) + } + } + + fn data(list: [T; N]) -> ArrayRef + where + ScalarValue: From, + { + let values: Vec<_> = list.into_iter().map(ScalarValue::from).collect(); + let array: ArrayRef = if values.is_empty() { + new_empty_array(&DataType::Utf8) + } else { + ScalarValue::iter_to_array(values).unwrap() + }; + array + } + + fn merge( + mut acc1: Box, + mut acc2: Box, + ) -> Result> { + let intermediate_state = acc2.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() + })?; + acc1.merge_batch(&intermediate_state)?; + Ok(acc1) + } + + #[test] + fn basic_object_agg() -> Result<()> { + let (mut acc1, mut acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data(["1", "2", "3"])])?; + + acc2.update_batch(&[data(["d", "e"]), data(["4", "5"])])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected Utf8 JSON"); + }; + + // parse and verify the JSON object + let map: JsonMap = serde_json::from_str(&json).unwrap(); + assert_eq!(map.len(), 5); + assert_eq!(map.get("a").unwrap(), &JsonValue::Number(1.into())); + assert_eq!(map.get("b").unwrap(), &JsonValue::Number(2.into())); + assert_eq!(map.get("c").unwrap(), &JsonValue::Number(3.into())); + assert_eq!(map.get("d").unwrap(), &JsonValue::Number(4.into())); + assert_eq!(map.get("e").unwrap(), &JsonValue::Number(5.into())); + Ok(()) + } + + #[test] + fn duplicate_key_error() -> Result<()> { + let (mut acc1, mut acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[data(["x", "y"]), data(["1", "2"])])?; + acc2.update_batch(&[data(["x", "z"]), data(["3", "4"])])?; + + let intermediate = acc2 + .state()? + .iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + let err = acc1.merge_batch(&intermediate).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("Duplicate keys"), "got: {msg}"); + Ok(()) + } + + /// Null keys are ignored entirely, non-null keys still aggregate + #[test] + fn null_key_ignored() -> Result<()> { + let (mut acc1, mut acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[data([None, Some("a")]), data([Some("10"), Some("1")])])?; + acc2.update_batch(&[data(["b"]), data(["2"])])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected JSON"); + }; + let map: JsonMap<_, _> = serde_json::from_str(&json).unwrap(); + assert_eq!(map.len(), 2); + assert_eq!(map.get("a").unwrap(), &JsonValue::Number(1.into())); + assert_eq!(map.get("b").unwrap(), &JsonValue::Number(2.into())); + Ok(()) + } + + #[test] + fn null_pairs_are_ignored() -> Result<()> { + let (mut acc1, acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[ + data([Some("a"), None, Some("b"), Some("c")]), + data([Some("1"), Some("2"), None, Some("3")]), + ])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected JSON"); + }; + let map: JsonMap = serde_json::from_str(&json).unwrap(); + + assert_eq!(map.len(), 2); + assert_eq!(map.get("a").unwrap(), &JsonValue::Number(1.into())); + assert_eq!(map.get("c").unwrap(), &JsonValue::Number(3.into())); + Ok(()) + } + + /// Completely empty input should produce an empty object + #[test] + fn empty_input_produces_empty_object() -> Result<()> { + let (mut acc1, mut acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + // feed no rows into either accumulator + acc1.update_batch(&[data::, 0>([]), data::, 0>([])])?; + acc2.update_batch(&[data::, 0>([]), data::, 0>([])])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected JSON"); + }; + assert_eq!(json, "{}"); + Ok(()) + } + + /// chain three partitions + #[test] + fn merge_three_partitions() -> Result<()> { + let builder = ObjectAggAccumulatorBuilder::string(); + let (mut a, mut b) = builder.build_two()?; + let mut c = builder.build()?; + + a.update_batch(&[data(["k1"]), data(["1"])])?; + b.update_batch(&[data(["k2"]), data(["2"])])?; + c.update_batch(&[data(["k3"]), data(["3"])])?; + + let mut ab = merge(a, b)?; + ab = merge(ab, c)?; + + let ScalarValue::Utf8(Some(json)) = ab.evaluate()? else { + return internal_err!("expected JSON"); + }; + let map: JsonMap<_, _> = serde_json::from_str(&json).unwrap(); + assert_eq!(map.len(), 3); + assert_eq!(map.get("k1").unwrap(), &JsonValue::Number(1.into())); + assert_eq!(map.get("k2").unwrap(), &JsonValue::Number(2.into())); + assert_eq!(map.get("k3").unwrap(), &JsonValue::Number(3.into())); + Ok(()) + } + + #[test] + fn raw_json_string_matches_expected() -> Result<()> { + let (mut acc1, mut acc2) = ObjectAggAccumulatorBuilder::string().build_two()?; + + // feed a couple of key→value pairs in two partitions + acc1.update_batch(&[data(["a", "b"]), data(["1", "2"])])?; + acc2.update_batch(&[data(["c"]), data(["3"])])?; + + // merge them + acc1 = merge(acc1, acc2)?; + + // evaluate to get the raw JSON + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected an Utf8 JSON string"); + }; + + assert_eq!(json, r#"{"a":1,"b":2,"c":3}"#); + + Ok(()) + } + + #[test] + fn object_agg_nested_json_values() -> Result<()> { + let builder = { + let mut b = ObjectAggAccumulatorBuilder::string(); + // override schema so we know the second field is Utf8 + b.schema = Schema::new(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ]); + b + }; + let (mut acc1, mut acc2) = builder.build_two()?; + + acc1.update_batch(&[data(["a", "b"]), data([r#"{"x":1}"#, r#"["foo","bar"]"#])])?; + + acc2.update_batch(&[data(["c", "d"]), data(["null", "true"])])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected Utf8 JSON"); + }; + + let map: JsonMap = serde_json::from_str(&json).unwrap(); + assert_eq!(map.len(), 4); + + assert_eq!(map.get("a").unwrap(), &json!({"x":1})); + assert_eq!(map.get("b").unwrap(), &json!(["foo", "bar"])); + assert_eq!(map.get("c").unwrap(), &JsonValue::Null); + assert_eq!(map.get("d").unwrap(), &JsonValue::Bool(true)); + + Ok(()) + } + + #[test] + fn object_agg_json_number_values() -> Result<()> { + let builder = { + let mut b = ObjectAggAccumulatorBuilder::string(); + b.schema = Schema::new(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ]); + b + }; + let (mut acc1, mut acc2) = builder.build_two()?; + + acc1.update_batch(&[data(["x", "y"]), data(["42", "3.14"])])?; + + // partition 2: z→0 + acc2.update_batch(&[data(["z"]), data(["0"])])?; + + acc1 = merge(acc1, acc2)?; + let ScalarValue::Utf8(Some(json)) = acc1.evaluate()? else { + return internal_err!("expected Utf8 JSON"); + }; + + assert_eq!(json, r#"{"x":42,"y":3.14,"z":0}"#); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/lib.rs b/crates/df-builtins/src/lib.rs index feb53fa3..70307d38 100644 --- a/crates/df-builtins/src/lib.rs +++ b/crates/df-builtins/src/lib.rs @@ -9,8 +9,6 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::DataType; use datafusion::{common::Result, execution::FunctionRegistry, logical_expr::ScalarUDF}; use datafusion_common::DataFusionError; -#[doc(hidden)] -pub use std::iter as __std_iter; use std::sync::Arc; pub(crate) mod aggregate; @@ -26,10 +24,8 @@ mod boolor; mod boolxor; mod equal_null; mod iff; -mod insert; mod is_array; mod is_object; -mod json; mod nullifzero; mod parse_json; pub mod table; @@ -37,6 +33,7 @@ mod time_from_parts; mod timestamp_from_parts; mod to_boolean; mod to_time; +mod variant; pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ @@ -57,7 +54,6 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { is_array::get_udf(), array_flatten::get_udf(), array_to_string::get_udf(), - insert::get_udf(), Arc::new(ScalarUDF::from(ToBooleanFunc::new(false))), Arc::new(ScalarUDF::from(ToBooleanFunc::new(true))), Arc::new(ScalarUDF::from(ToTimeFunc::new(false))), @@ -72,48 +68,6 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { } mod macros { - // Adopted from itertools: https://docs.rs/itertools/latest/src/itertools/lib.rs.html#321-360 - macro_rules! izip { - // @closure creates a tuple-flattening closure for .map() call. usage: - // @closure partial_pattern => partial_tuple , rest , of , iterators - // eg. izip!( @closure ((a, b), c) => (a, b, c) , dd , ee ) - ( @closure $p:pat => $tup:expr ) => { - |$p| $tup - }; - - // The "b" identifier is a different identifier on each recursion level thanks to hygiene. - ( @closure $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )* ) => { - $crate::macros::izip!(@closure ($p, b) => ( $($tup)*, b ) $( , $tail )*) - }; - - // unary - ($first:expr $(,)*) => { - $crate::__std_iter::IntoIterator::into_iter($first) - }; - - // binary - ($first:expr, $second:expr $(,)*) => { - $crate::__std_iter::Iterator::zip( - $crate::__std_iter::IntoIterator::into_iter($first), - $second, - ) - }; - - // n-ary where n > 2 - ( $first:expr $( , $rest:expr )* $(,)* ) => { - { - let iter = $crate::__std_iter::IntoIterator::into_iter($first); - $( - let iter = $crate::__std_iter::Iterator::zip(iter, $rest); - )* - $crate::__std_iter::Iterator::map( - iter, - $crate::macros::izip!(@closure a => (a) $( , $rest )*) - ) - } - }; - } - macro_rules! make_udf_function { ($udf_type:ty) => { paste::paste! { @@ -133,7 +87,6 @@ mod macros { } } - pub(crate) use izip; pub(crate) use make_udf_function; } diff --git a/crates/df-builtins/src/table/mod.rs b/crates/df-builtins/src/table/mod.rs index e6a62af1..43bf0e51 100644 --- a/crates/df-builtins/src/table/mod.rs +++ b/crates/df-builtins/src/table/mod.rs @@ -4,6 +4,6 @@ use std::sync::Arc; pub mod flatten; -pub fn register_table_funcs(ctx: &SessionContext) { +pub fn register_udtfs(ctx: &SessionContext) { ctx.register_udtf("flatten", Arc::new(FlattenTableFunc::new())); } diff --git a/crates/df-builtins/src/variant/array_append.rs b/crates/df-builtins/src/variant/array_append.rs new file mode 100644 index 00000000..9f757fae --- /dev/null +++ b/crates/df-builtins/src/variant/array_append.rs @@ -0,0 +1,252 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{Value, to_string}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayAppendUDF { + signature: Signature, +} + +impl ArrayAppendUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn append_element(array_str: impl AsRef, element: &ScalarValue) -> DFResult { + let array_str = array_str.as_ref(); + + // Parse the input array + let mut array_value: Value = serde_json::from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + let scalar_value = super::json::encode_array(element.to_array_of_size(1)?)?; + + let scalar_value = if let Value::Array(array) = scalar_value { + match array.first() { + Some(value) => value.clone(), + None => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )); + } + } + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )); + }; + // Ensure the first argument is an array + if let Value::Array(ref mut array) = array_value { + array.push(scalar_value); + + // Convert back to JSON string + to_string(&array_value).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayAppendUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayAppendUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_append" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let element = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected element argument".to_string(), + ))?; + + match (array_str, element) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(element_value)) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_value = string_array.value(i); + results.push(Some(Self::append_element(array_value, element_value)?)); + } + } + + Ok(ColumnarValue::Array(Arc::new(datafusion::arrow::array::StringArray::from(results)))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(element_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string() + )) + }; + + let result = Self::append_element(array_str, element_value)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second argument must be a scalar value".to_string() + )) + } + } +} + +make_udf_function!(ArrayAppendUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use crate::variant::array_construct::ArrayConstructUDF; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_array_append() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayAppendUDF::new())); + + // Test appending to numeric array + let sql = "SELECT array_append(array_construct(1, 2, 3), 4) as appended"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| appended |", + "+-----------+", + "| [1,2,3,4] |", + "+-----------+", + ], + &result + ); + + // Test appending to empty array + let sql = "SELECT array_append(array_construct(), 1) as empty_append"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| empty_append |", + "+--------------+", + "| [1] |", + "+--------------+", + ], + &result + ); + + // Test appending string to numeric array + let sql = "SELECT array_append(array_construct(1, 2), 'hello') as mixed_append"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| mixed_append |", + "+---------------+", + "| [1,2,\"hello\"] |", + "+---------------+", + ], + &result + ); + + // Test appending boolean + let sql = "SELECT array_append(array_construct(1, 2), true) as bool_append"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| bool_append |", + "+-------------+", + "| [1,2,true] |", + "+-------------+", + ], + &result + ); + + // Test appending float + let sql = "SELECT array_append(array_construct(1, 2), 3.14) as float_append"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| float_append |", + "+--------------+", + "| [1,2,3.14] |", + "+--------------+", + ], + &result + ); + + // Test appending null + let sql = "SELECT array_append(array_construct(1, 2), NULL) as null_append"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| null_append |", + "+-------------+", + "| [1,2,null] |", + "+-------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_cat.rs b/crates/df-builtins/src/variant/array_cat.rs new file mode 100644 index 00000000..2a3260f6 --- /dev/null +++ b/crates/df-builtins/src/variant/array_cat.rs @@ -0,0 +1,248 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayCatUDF { + signature: Signature, +} + +impl ArrayCatUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn concatenate_arrays(arrays: &[&str]) -> DFResult { + let mut result_array = Vec::new(); + + for array_str in arrays { + // Parse each input array + let array_value: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + // Ensure each argument is an array + if let Value::Array(array) = array_value { + result_array.extend(array); + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "All arguments must be JSON arrays".to_string(), + )); + } + } + + // Convert back to JSON string + to_string(&Value::Array(result_array)).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + } +} + +impl Default for ArrayCatUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayCatUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_cat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + + // Check for exactly two arguments + if args.len() != 2 { + return Err(datafusion_common::error::DataFusionError::Internal( + "array_cat expects exactly two arguments".to_string(), + )); + } + + match (&args[0], &args[1]) { + // Both scalar case + ( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s1))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s2))), + ) => { + let result = Self::concatenate_arrays(&[s1.as_str(), s2.as_str()])?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + + // Scalar + Array case + (ColumnarValue::Scalar(ScalarValue::Utf8(Some(s1))), ColumnarValue::Array(array2)) => { + let string_array2 = array2.as_string::(); + let len = string_array2.len(); + + let mut results = Vec::with_capacity(len); + for i in 0..len { + if string_array2.is_null(i) { + return Err(datafusion_common::error::DataFusionError::Internal( + "Cannot concatenate arrays with null values".to_string(), + )); + } + let result = Self::concatenate_arrays(&[ + s1.as_str(), + string_array2.value(i).to_string().as_str(), + ])?; + results.push(Some(result)); + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + + // Array + Scalar case + (ColumnarValue::Array(array1), ColumnarValue::Scalar(ScalarValue::Utf8(Some(s2)))) => { + let string_array1 = array1.as_string::(); + let len = string_array1.len(); + + let mut results = Vec::with_capacity(len); + for i in 0..len { + if string_array1.is_null(i) { + return Err(datafusion_common::error::DataFusionError::Internal( + "Cannot concatenate arrays with null values".to_string(), + )); + } + let result = Self::concatenate_arrays(&[ + string_array1.value(i).to_string().as_str(), + s2.as_str(), + ])?; + results.push(Some(result)); + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + + // Both array case + (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => { + let string_array1 = array1.as_string::(); + let string_array2 = array2.as_string::(); + let len = string_array1.len(); + + let mut results = Vec::with_capacity(len); + for i in 0..len { + if string_array1.is_null(i) || string_array2.is_null(i) { + return Err(datafusion_common::error::DataFusionError::Internal( + "Cannot concatenate arrays with null values".to_string(), + )); + } + let result = Self::concatenate_arrays(&[ + string_array1.value(i).to_string().as_str(), + string_array2.value(i).to_string().as_str(), + ])?; + results.push(Some(result)); + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Arguments must both be either scalar UTF8 strings or arrays".to_string(), + )), + } + } +} + +make_udf_function!(ArrayCatUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_cat() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayCatUDF::new())); + + + // Test concatenating two arrays + let sql = "SELECT array_cat(array_construct(1, 2), array_construct(3, 4)) as concatenated"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| concatenated |", + "+--------------+", + "| [1,2,3,4] |", + "+--------------+", + ], + &result + ); + + // Test concatenating empty arrays + let sql = "SELECT array_cat(array_construct(), array_construct(1, 2)) as empty_cat"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| empty_cat |", + "+-----------+", + "| [1,2] |", + "+-----------+", + ], + &result + ); + + // Test concatenating arrays with different types + let sql = "SELECT array_cat(array_construct(1, 2), array_construct('a', 'b')) as mixed_cat"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| mixed_cat |", + "+---------------+", + "| [1,2,\"a\",\"b\"] |", + "+---------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_compact.rs b/crates/df-builtins/src/variant/array_compact.rs new file mode 100644 index 00000000..7b75e49b --- /dev/null +++ b/crates/df-builtins/src/variant/array_compact.rs @@ -0,0 +1,190 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayCompactUDF { + signature: Signature, +} + +impl ArrayCompactUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(1), + volatility: Volatility::Immutable, + }, + } + } + + fn compact_array(array_str: impl AsRef) -> DFResult { + let array_str = array_str.as_ref(); + + // Parse the input array + let array_value: Value = serde_json::from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + // Ensure the input is an array + if let Value::Array(array) = array_value { + // Filter out null and undefined values + let compacted = array.iter().filter(|&v| !v.is_null() && v != &Value::Null); + + // Create a new array with the filtered values + let compacted_array = Value::Array(compacted.cloned().collect()); + + // Convert back to JSON string + to_string(&compacted_array).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Input must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayCompactUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayCompactUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_compact" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + + match array_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_value = string_array.value(i); + results.push(Some(Self::compact_array(array_value)?)); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )); + }; + + let result = Self::compact_array(array_str)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + } + } +} + +make_udf_function!(ArrayCompactUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_compact() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayCompactUDF::new())); + + // Test compacting array with null values + let sql = "SELECT array_compact(array_construct(1, null, 3, null, 5)) as compacted"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| compacted |", + "+-----------+", + "| [1,3,5] |", + "+-----------+", + ], + &result + ); + + // Test compacting empty array + let sql = "SELECT array_compact(array_construct()) as empty_compact"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| empty_compact |", + "+---------------+", + "| [] |", + "+---------------+", + ], + &result + ); + + // Test compacting array with mixed types + let sql = + "SELECT array_compact(array_construct(1, 'hello', null, 3.14, null)) as mixed_compact"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------------+", + "| mixed_compact |", + "+------------------+", + "| [1,\"hello\",3.14] |", + "+------------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_construct.rs b/crates/df-builtins/src/variant/array_construct.rs new file mode 100644 index 00000000..206e0584 --- /dev/null +++ b/crates/df-builtins/src/variant/array_construct.rs @@ -0,0 +1,207 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::{array::AsArray, datatypes::DataType}; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::Value; + +#[derive(Debug, Clone)] +pub struct ArrayConstructUDF { + signature: Signature, + aliases: Vec, +} + +impl ArrayConstructUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::VariadicAny, + TypeSignature::Nullary, + ]), + volatility: Volatility::Immutable, + }, + aliases: vec!["make_array".to_string(), "make_list".to_string()], + } + } +} + +impl Default for ArrayConstructUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayConstructUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_construct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + let mut results = Vec::new(); + + for arg in args { + let arg_array = arg.into_array(number_rows)?; + for i in 0..arg_array.len() { + if arg_array.is_null(i) { + results.push(Value::Null); + } else if let Some(str_array) = arg_array.as_string_opt::() { + for istr in str_array { + match istr { + Some(istr) => { + if let Ok(json_obj) = serde_json::from_str(istr) { + results.push(json_obj); + } else { + results.push(Value::String(istr.to_string())); + } + } + None => { + results.push(Value::Null); + } + } + } + } else { + let object = super::json::encode_array(arg_array.clone())?; + results.push(object); + } + } + } + + for result in &mut results { + if let Value::Array(arr) = result { + if arr.len() == 1 { + *result = arr[0].clone(); + } + } + } + + let arr = serde_json::Value::Array(results); + let json_str = serde_json::to_string(&arr).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize JSON: {e}", + )) + })?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(json_str)))) + } +} + +make_udf_function!(ArrayConstructUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_cat::ArrayCatUDF; + + #[tokio::test] + async fn test_array_construct() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + + // Test basic array construction + let sql = "SELECT array_construct(1, 2, 3) as arr1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| arr1 |", + "+---------+", + "| [1,2,3] |", + "+---------+" + ], + &result + ); + + // Test mixed types + let sql = "SELECT array_construct(1, 'hello', 2.5) as arr2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------------+", + "| arr2 |", + "+-----------------+", + "| [1,\"hello\",2.5] |", + "+-----------------+", + ], + &result + ); + + // Test empty array + let sql = "SELECT array_construct() as arr4"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + ["+------+", "| arr4 |", "+------+", "| [] |", "+------+"], + &result + ); + + // Test with null values + let sql = "SELECT array_construct(1, NULL, 3) as arr5"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| arr5 |", + "+------------+", + "| [1,null,3] |", + "+------------+", + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_array_construct_nested() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + + // Test basic array construction + let sql = + "SELECT array_construct(array_construct(1, 2, 3), array_construct(4, 5, 6)) as arr1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------------+", + "| arr1 |", + "+-------------------+", + "| [[1,2,3],[4,5,6]] |", + "+-------------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_contains.rs b/crates/df-builtins/src/variant/array_contains.rs new file mode 100644 index 00000000..7319aa9d --- /dev/null +++ b/crates/df-builtins/src/variant/array_contains.rs @@ -0,0 +1,262 @@ +use super::super::macros::make_udf_function; +use super::json::{encode_array, encode_scalar}; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayContainsUDF { + signature: Signature, +} + +impl ArrayContainsUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn contains_value(search_value: &Value, array_str: Option<&str>) -> DFResult> { + if let Some(array_str) = array_str { + // Parse the array + let array_value: Value = from_slice(array_str.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array: {e}", + )) + })?; + + if let Value::Array(array) = array_value { + // If search value is null, check if array contains null + if search_value.is_null() { + if array.iter().any(|v| v.is_null()) { + return Ok(Some(true)); + } + return Ok(None); + } + + // For non-null values, compare each array element + Ok(Some(array.contains(search_value))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } +} + +impl Default for ArrayContainsUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayContainsUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let value = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected a value argument".to_string(), + ))?; + let array = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected an array argument".to_string(), + ))?; + + match (value, array) { + (ColumnarValue::Array(value_array), ColumnarValue::Array(array_array)) => { + let array_strings = array_array.as_string::(); + let value_array = encode_array(value_array.clone())?; + let mut results = Vec::new(); + for (search_val, col_val) in value_array + .as_array() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected an array argument".to_string(), + ))? + .iter() + .zip(array_strings) + { + results.push(Self::contains_value(search_val, col_val)?); + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::BooleanArray::from(results), + ))) + } + (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(array_scalar)) => { + let value_scalar = encode_scalar(value_scalar)?; + let array_str = match array_scalar { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Null | ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )); + } + }; + + let result = Self::contains_value(&value_scalar, Some(array_str.as_str()))?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Mismatched argument types".to_string(), + )), + } + } +} + +make_udf_function!(ArrayContainsUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_contains() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayContainsUDF::new())); + + // Test value exists in array + let sql = "SELECT array_contains('hello', array_construct('hello', 'hi')) as contains"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| contains |", + "+----------+", + "| true |", + "+----------+", + ], + &result + ); + + // Test value not in array + let sql = "SELECT array_contains('hello', array_construct('hola', 'bonjour')) as contains"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| contains |", + "+----------+", + "| false |", + "+----------+", + ], + &result + ); + + // Test null value + let sql = "SELECT array_contains(NULL, array_construct('hola', 'bonjour')) as contains"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| contains |", + "+----------+", + "| |", + "+----------+", + ], + &result + ); + + // Test null in array + let sql = "SELECT array_contains(NULL, array_construct('hola', NULL)) as contains"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| contains |", + "+----------+", + "| true |", + "+----------+", + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_array_contains_with_table() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayContainsUDF::new())); + + // Create a table with a value column and an array column + let sql = "CREATE TABLE test_array_contains AS + SELECT + CAST(value AS VARCHAR) as value, + array_construct('apple', 'banana', 'orange') as fruits + FROM (VALUES + ('apple'), + ('grape'), + ('banana'), + (NULL) + ) as t(value)"; + + ctx.sql(sql).await?.collect().await?; + + // Test array_contains with table columns + let sql = "SELECT value, fruits, array_contains(value, fruits) as contains + FROM test_array_contains"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+-----------------------------+----------+", + "| value | fruits | contains |", + "+--------+-----------------------------+----------+", + "| apple | [\"apple\",\"banana\",\"orange\"] | true |", + "| grape | [\"apple\",\"banana\",\"orange\"] | false |", + "| banana | [\"apple\",\"banana\",\"orange\"] | true |", + "| | [\"apple\",\"banana\",\"orange\"] | |", + "+--------+-----------------------------+----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_distinct.rs b/crates/df-builtins/src/variant/array_distinct.rs new file mode 100644 index 00000000..3feb4177 --- /dev/null +++ b/crates/df-builtins/src/variant/array_distinct.rs @@ -0,0 +1,194 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::types::{logical_binary, logical_string}; +use datafusion_common::{types::NativeType, Result as DFResult, ScalarValue}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use serde_json::{from_slice, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayDistinctUDF { + signature: Signature, +} + +impl ArrayDistinctUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + )]), + volatility: Volatility::Immutable, + }, + } + } + + fn distinct_array(string: impl AsRef) -> DFResult> { + let string = string.as_ref(); + let array_value: Value = from_slice(string.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Couldn't parse the JSON string: {e}", + )) + })?; + + if let Value::Array(array) = array_value { + if array.is_empty() { + return Ok(Some("[]".to_string())); + } + + let mut seen = std::collections::HashSet::new(); + let mut distinct_values = Vec::new(); + + for value in array { + if seen.insert(value.clone()) { + distinct_values.push(value); + } + } + + Ok(Some(to_string(&distinct_values).map_err(|e| { + datafusion_common::DataFusionError::Internal(e.to_string()) + })?)) + } else { + Ok(None) + } + } +} + +impl Default for ArrayDistinctUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayDistinctUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_distinct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected a variant argument".to_string(), + ))?; + match array_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let str_value = string_array.value(i); + results.push(Self::distinct_array(str_value)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string".to_string(), + )); + }; + + let result = Self::distinct_array(array_str)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + } + } +} + +make_udf_function!(ArrayDistinctUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_distinct() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayDistinctUDF::new())); + + // Test with duplicates + let sql = + "SELECT array_distinct(array_construct('A', 'A', 'B', NULL, NULL)) as distinct_array"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------------+", + "| distinct_array |", + "+----------------+", + "| [\"A\",\"B\",null] |", + "+----------------+", + ], + &result + ); + + // Test with numbers + let sql = "SELECT array_distinct(array_construct(1, 2, 1, 3, 2)) as distinct_nums"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| distinct_nums |", + "+---------------+", + "| [1,2,3] |", + "+---------------+", + ], + &result + ); + + // Test empty array + let sql = "SELECT array_distinct(array_construct()) as empty_array"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| empty_array |", + "+-------------+", + "| [] |", + "+-------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_except.rs b/crates/df-builtins/src/variant/array_except.rs new file mode 100644 index 00000000..997c99b2 --- /dev/null +++ b/crates/df-builtins/src/variant/array_except.rs @@ -0,0 +1,269 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayExceptUDF { + signature: Signature, +} + +impl ArrayExceptUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn array_except(array1_str: Option<&str>, array2_str: Option<&str>) -> DFResult> { + if let (Some(arr1), Some(arr2)) = (array1_str, array2_str) { + // Parse both arrays + let array1_value: Value = from_slice(arr1.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse first array: {e}", + )) + })?; + + let array2_value: Value = from_slice(arr2.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse second array: {e}", + )) + })?; + + if let (Value::Array(arr1), Value::Array(arr2)) = (array1_value, array2_value) { + // Create a new array with elements from arr1 that are not in arr2 + let result: Vec = arr1 + .into_iter() + .filter(|item| !arr2.contains(item)) + .collect(); + + Ok(Some(Value::Array(result))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } +} + +impl Default for ArrayExceptUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayExceptUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_except" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array1 = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected first array argument".to_string(), + ))?; + let array2 = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected second array argument".to_string(), + ))?; + + match (array1, array2) { + (ColumnarValue::Array(array1_array), ColumnarValue::Array(array2_array)) => { + let array1_strings = array1_array.as_string::(); + let array2_strings = array2_array.as_string::(); + let mut results = Vec::new(); + + for (arr1, arr2) in array1_strings.iter().zip(array2_strings) { + let result = Self::array_except(arr1, arr2)?; + results.push( + result + .map(|v| { + serde_json::to_string(&v).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + }) + .transpose(), + ); + } + let results: DFResult>> = results.into_iter().collect(); + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results?), + ))) + } + (ColumnarValue::Scalar(array1_scalar), ColumnarValue::Scalar(array2_scalar)) => { + let array1_str = match array1_scalar { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Null | ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for first array".to_string(), + )); + } + }; + + let array2_str = match array2_scalar { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Null | ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for second array".to_string(), + )); + } + }; + + let result = Self::array_except(Some(array1_str), Some(array2_str))?; + let result = result + .map(|v| { + serde_json::to_string(&v).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + }) + .transpose()?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Mismatched argument types".to_string(), + )), + } + } +} + +make_udf_function!(ArrayExceptUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_except() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayExceptUDF::new())); + + // Test basic array difference + let sql = + "SELECT array_except(array_construct('A', 'B'), array_construct('B', 'C')) as result"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+", + "| result |", + "+--------+", + "| [\"A\"] |", + "+--------+", + ], + &result + ); + + // Test empty result + let sql = + "SELECT array_except(array_construct('A', 'B'), array_construct('A', 'B')) as result"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+", + "| result |", + "+--------+", + "| [] |", + "+--------+", + ], + &result + ); + + // Test with null values + let sql = "SELECT array_except(array_construct('A', NULL), array_construct('A')) as result"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+", + "| result |", + "+--------+", + "| [null] |", + "+--------+", + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_array_except_with_table() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayExceptUDF::new())); + + // Create a table with two array columns + let sql = "CREATE TABLE test_array_except AS + SELECT + array_construct('apple', 'banana', 'orange') as fruits1, + array_construct('banana', 'grape', 'apple') as fruits2 + FROM (VALUES (1)) as t(dummy)"; + + ctx.sql(sql).await?.collect().await?; + + // Test array_except with table columns + let sql = "SELECT fruits1, fruits2, array_except(fruits1, fruits2) as result + FROM test_array_except"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------------------------+----------------------------+------------+", + "| fruits1 | fruits2 | result |", + "+-----------------------------+----------------------------+------------+", + "| [\"apple\",\"banana\",\"orange\"] | [\"banana\",\"grape\",\"apple\"] | [\"orange\"] |", + "+-----------------------------+----------------------------+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/array_flatten.rs b/crates/df-builtins/src/variant/array_flatten.rs similarity index 99% rename from crates/df-builtins/src/array_flatten.rs rename to crates/df-builtins/src/variant/array_flatten.rs index 8405238d..6da74286 100644 --- a/crates/df-builtins/src/array_flatten.rs +++ b/crates/df-builtins/src/variant/array_flatten.rs @@ -8,6 +8,7 @@ use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use serde_json::{Map, Value}; use std::any::Any; use std::sync::Arc; +use crate::macros::make_udf_function; // array_flatten SQL function // Transforms a nested ARRAY (an ARRAY of ARRAYs) into a single, flat ARRAY by combining all inner ARRAYs into one continuous sequence. @@ -146,7 +147,7 @@ fn flatten(v: &str) -> DFResult> { })?)) } -super::macros::make_udf_function!(ArrayFlattenFunc); +make_udf_function!(ArrayFlattenFunc); #[cfg(test)] mod tests { diff --git a/crates/df-builtins/src/variant/array_generate_range.rs b/crates/df-builtins/src/variant/array_generate_range.rs new file mode 100644 index 00000000..7e728233 --- /dev/null +++ b/crates/df-builtins/src/variant/array_generate_range.rs @@ -0,0 +1,210 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{Value, to_string}; + +#[derive(Debug, Clone)] +pub struct ArrayGenerateRangeUDF { + signature: Signature, + aliases: Vec, +} + +impl ArrayGenerateRangeUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64, DataType::UInt64]), + ]), + volatility: Volatility::Immutable, + }, + aliases: vec!["generate_range".to_string()], + } + } +} + +impl Default for ArrayGenerateRangeUDF { + fn default() -> Self { + Self::new() + } +} + +#[allow(clippy::cast_possible_truncation, clippy::as_conversions)] +impl ScalarUDFImpl for ArrayGenerateRangeUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_generate_range" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + + if args.len() < 2 || args.len() > 3 { + return Err(datafusion_common::error::DataFusionError::Internal( + "array_generate_range requires 2 or 3 arguments".to_string(), + )); + } + + let mut args = args; + let step = if args.len() == 3 { + args.pop() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected step argument".to_string(), + ))? + .into_array(number_rows)? + } else { + // Default step is 1 + let default_step = ScalarValue::Int64(Some(1)); + default_step.to_array_of_size(number_rows)? + }; + let stop = args + .pop() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected stop argument".to_string(), + ))? + .into_array(number_rows)?; + let start = args + .pop() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected start argument".to_string(), + ))? + .into_array(number_rows)?; + + let mut results = Vec::new(); + + for i in 0..number_rows { + let start_val = if start.is_null(i) { + continue; + } else { + start + .as_any() + .downcast_ref::() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected start argument to be an Int64Array".to_string(), + ))? + .value(i) + }; + + let stop_val = if stop.is_null(i) { + continue; + } else { + stop.as_any() + .downcast_ref::() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected stop argument to be an Int64Array".to_string(), + ))? + .value(i) + }; + + let step_val = if step.is_null(i) { + continue; + } else { + step.as_any() + .downcast_ref::() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected step argument to be an Int64Array".to_string(), + ))? + .value(i) + }; + + for i in (start_val..stop_val).step_by(step_val as usize) { + results.push(Value::Number(i.into())); + } + } + + let json_str = to_string(&Value::Array(results)).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize JSON: {e}", + )) + })?; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(json_str)))) + } +} + +make_udf_function!(ArrayGenerateRangeUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_array_generate_range() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + ctx.register_udf(ScalarUDF::from(ArrayGenerateRangeUDF::new())); + + // Test basic range + let sql = "SELECT array_generate_range(2, 5) as range1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| range1 |", + "+---------+", + "| [2,3,4] |", + "+---------+", + ], + &result + ); + + // Test with step + let sql = "SELECT array_generate_range(5, 25, 10) as range2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+", + "| range2 |", + "+--------+", + "| [5,15] |", + "+--------+", + ], + &result + ); + + // Test empty range + let sql = "SELECT array_generate_range(5, 2) as range3"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------+", + "| range3 |", + "+--------+", + "| [] |", + "+--------+" + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_insert.rs b/crates/df-builtins/src/variant/array_insert.rs new file mode 100644 index 00000000..daf749b2 --- /dev/null +++ b/crates/df-builtins/src/variant/array_insert.rs @@ -0,0 +1,283 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayInsertUDF { + signature: Signature, +} + +#[allow( + clippy::cast_possible_truncation, + clippy::as_conversions, + clippy::cast_sign_loss, + clippy::cast_possible_wrap +)] +impl ArrayInsertUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(3), + volatility: Volatility::Immutable, + }, + } + } + + fn insert_element( + array_str: impl AsRef, + pos: i64, + element: &ScalarValue, + ) -> DFResult { + let array_str = array_str.as_ref(); + + // Parse the input array + let mut array_value: Value = serde_json::from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + let scalar_value = super::json::encode_scalar(element)?; + + // Ensure the first argument is an array + if let Value::Array(ref mut array) = array_value { + // Convert position to usize, handling negative indices + let pos = if pos < 0 { + (array.len() as i64 + pos).max(0) as usize + } else { + pos as usize + }; + + // Ensure position is within bounds + if pos > array.len() { + return Err(datafusion_common::error::DataFusionError::Internal( + format!( + "Position {pos} is out of bounds for array of length {}", + array.len() + ), + )); + } + + array.insert(pos, scalar_value); + + // Convert back to JSON string + to_string(&array_value).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayInsertUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayInsertUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_insert" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let pos = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected position argument".to_string(), + ))?; + let element = args + .get(2) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected element argument".to_string(), + ))?; + + match (array_str, pos, element) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(pos_value), ColumnarValue::Scalar(element_value)) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + // Get position as i64 + let ScalarValue::Int64(Some(pos)) = pos_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Position must be an integer".to_string() + )) + }; + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_value = string_array.value(i); + results.push(Some(Self::insert_element(array_value, *pos, element_value)?)); + } + } + + Ok(ColumnarValue::Array(Arc::new(datafusion::arrow::array::StringArray::from(results)))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(pos_value), ColumnarValue::Scalar(element_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string() + )) + }; + + let ScalarValue::Int64(Some(pos)) = pos_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Position must be an integer".to_string() + )) + }; + + let result = Self::insert_element(array_str, *pos, element_value)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second argument must be an integer, third argument must be a scalar value".to_string() + )) + } + } +} + +make_udf_function!(ArrayInsertUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_insert() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayInsertUDF::new())); + + // Test inserting into numeric array + let sql = "SELECT array_insert(array_construct(0,1,2,3), 2, 'hello') as inserted"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------------+", + "| inserted |", + "+-------------------+", + "| [0,1,\"hello\",2,3] |", + "+-------------------+", + ], + &result + ); + + // Test inserting at the beginning + let sql = "SELECT array_insert(array_construct(1,2,3), 0, 'start') as start_insert"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------------+", + "| start_insert |", + "+-----------------+", + "| [\"start\",1,2,3] |", + "+-----------------+", + ], + &result + ); + + // Test inserting at the end + let sql = "SELECT array_insert(array_construct(1,2,3), 3, 'end') as end_insert"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| end_insert |", + "+---------------+", + "| [1,2,3,\"end\"] |", + "+---------------+", + ], + &result + ); + + // Test inserting number + let sql = "SELECT array_insert(array_construct(1,2,3), 1, 42) as num_insert"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| num_insert |", + "+------------+", + "| [1,42,2,3] |", + "+------------+", + ], + &result + ); + + // Test inserting boolean + let sql = "SELECT array_insert(array_construct(1,2,3), 1, true) as bool_insert"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| bool_insert |", + "+--------------+", + "| [1,true,2,3] |", + "+--------------+", + ], + &result + ); + + // Test inserting null + let sql = "SELECT array_insert(array_construct(1,2,3), 1, NULL) as null_insert"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| null_insert |", + "+--------------+", + "| [1,null,2,3] |", + "+--------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_intersection.rs b/crates/df-builtins/src/variant/array_intersection.rs new file mode 100644 index 00000000..62b6cd49 --- /dev/null +++ b/crates/df-builtins/src/variant/array_intersection.rs @@ -0,0 +1,246 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayIntersectionUDF { + signature: Signature, +} + +impl ArrayIntersectionUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn array_intersection( + array1_str: Option<&str>, + array2_str: Option<&str>, + ) -> DFResult> { + if let (Some(arr1), Some(arr2)) = (array1_str, array2_str) { + // Parse both arrays + let array1_value: Value = from_slice(arr1.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse first array: {e}", + )) + })?; + + let array2_value: Value = from_slice(arr2.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse second array: {e}", + )) + })?; + + if let (Value::Array(arr1), Value::Array(arr2)) = (array1_value, array2_value) { + // Create a new array with elements that exist in both arr1 and arr2 + let result: Vec = arr2 + .into_iter() + .filter(|item| arr1.contains(item)) + .collect(); + + Ok(Some(Value::Array(result))) + } else { + Ok(None) + } + } else { + Ok(None) + } + } +} + +impl Default for ArrayIntersectionUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayIntersectionUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_intersection" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array1 = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected first array argument".to_string(), + ))?; + let array2 = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected second array argument".to_string(), + ))?; + + match (array1, array2) { + (ColumnarValue::Array(array1_array), ColumnarValue::Array(array2_array)) => { + let array1_strings = array1_array.as_string::(); + let array2_strings = array2_array.as_string::(); + let mut results = Vec::new(); + + for (arr1, arr2) in array1_strings.iter().zip(array2_strings) { + let result = Self::array_intersection(arr1, arr2)?; + results.push(result.map(|v| serde_json::to_string(&v)).transpose()); + } + + let results: Result>, serde_json::Error> = + results.into_iter().collect(); + let results = results.map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?; + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + (ColumnarValue::Scalar(array1_scalar), ColumnarValue::Scalar(array2_scalar)) => { + let array1_str = match array1_scalar { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Null | ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for first array".to_string(), + )); + } + }; + + let array2_str = match array2_scalar { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Null | ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for second array".to_string(), + )); + } + }; + + let result = Self::array_intersection(Some(array1_str), Some(array2_str))?; + let result = result + .map(|v| serde_json::to_string(&v)) + .transpose() + .map_err(|_e| { + datafusion_common::error::DataFusionError::Internal( + "Failed to serialize result".to_string(), + ) + })?; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Mismatched argument types".to_string(), + )), + } + } +} + +make_udf_function!(ArrayIntersectionUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_intersection() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayIntersectionUDF::new())); + + // Test basic array intersection + let sql = "SELECT array_intersection(array_construct('A', 'B'), array_construct('B', 'C')) as result1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result1 |", + "+---------+", + "| [\"B\"] |", + "+---------+", + ], + &result + ); + + // Test empty intersection + let sql = "SELECT array_intersection(array_construct('A', 'B'), array_construct('C', 'D')) as result2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result2 |", + "+---------+", + "| [] |", + "+---------+", + ], + &result + ); + + // Test with null values + let sql = "SELECT array_intersection(array_construct('A', NULL), array_construct('A', NULL)) as result3"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| result3 |", + "+------------+", + "| [\"A\",null] |", + "+------------+", + ], + &result + ); + + // Test with duplicate values + let sql = "SELECT array_intersection(array_construct('A', 'B', 'B', 'B', 'C'), array_construct('B', 'B')) as result4"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| result4 |", + "+-----------+", + "| [\"B\",\"B\"] |", + "+-----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_max.rs b/crates/df-builtins/src/variant/array_max.rs new file mode 100644 index 00000000..b9fa96dc --- /dev/null +++ b/crates/df-builtins/src/variant/array_max.rs @@ -0,0 +1,249 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::types::{logical_binary, logical_string, NativeType}; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayMaxUDF { + signature: Signature, +} + +impl ArrayMaxUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + )]), + volatility: Volatility::Immutable, + }, + } + } + + fn find_max(string: impl AsRef) -> DFResult> { + let string = string.as_ref(); + let array_value: Value = from_slice(string.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse the JSON string: {e}", + )) + })?; + + if let Value::Array(array) = array_value { + if array.is_empty() { + return Ok(None); + } + + // Try to find the maximum value, handling different types + let mut max_value: Option = None; + let mut max_type: Option<&str> = None; + + for value in array { + match value { + Value::Number(n) if n.is_i64() => { + let num = n.as_i64().ok_or( + datafusion_common::error::DataFusionError::Internal( + "Failed to parse number".to_string(), + ), + )?; + let should_update = match max_value.as_ref() { + None => true, + Some(current) => { + if let Ok(current_num) = current.parse::() { + max_type == Some("i64") && num > current_num + } else { + false + } + } + }; + if should_update { + max_value = Some(num.to_string()); + max_type = Some("i64"); + } + } + Value::Number(n) if n.is_f64() => { + let num = n.as_f64().ok_or( + datafusion_common::error::DataFusionError::Internal( + "Failed to parse number".to_string(), + ), + )?; + let should_update = match max_value.as_ref() { + None => true, + Some(current) => { + if let Ok(current_num) = current.parse::() { + max_type == Some("f64") && num > current_num + } else { + false + } + } + }; + if should_update { + max_value = Some(num.to_string()); + max_type = Some("f64"); + } + } + _ => {} + } + } + + Ok(max_value) + } else { + Ok(None) + } + } +} + +impl Default for ArrayMaxUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayMaxUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected a variant argument".to_string(), + ))?; + match array_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let str_value = string_array.value(i); + results.push(Self::find_max(str_value)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string".to_string(), + )); + }; + + let result = Self::find_max(array_str)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + } + } +} + +make_udf_function!(ArrayMaxUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_max() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayMaxUDF::new())); + + // Test numeric array + let sql = "SELECT array_max(array_construct(1, 5, 3, 9, 2)) as max_num"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| max_num |", + "+---------+", + "| 9 |", + "+---------+", + ], + &result + ); + + // Test mixed types + let sql = "SELECT array_max(array_construct(1, 'hello', 2.5, 10)) as max_mixed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| max_mixed |", + "+-----------+", + "| 10 |", + "+-----------+", + ], + &result + ); + + // Test array of nulls + let sql = "SELECT array_max(array_construct(NULL, NULL, NULL)) as null_max"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| null_max |", + "+----------+", + "| |", + "+----------+" + ], + &result + ); + + // Test empty array + let sql = "SELECT array_max(array_construct()) as empty_max"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| empty_max |", + "+-----------+", + "| |", + "+-----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_min.rs b/crates/df-builtins/src/variant/array_min.rs new file mode 100644 index 00000000..bc9fb2bf --- /dev/null +++ b/crates/df-builtins/src/variant/array_min.rs @@ -0,0 +1,249 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::types::{logical_binary, logical_string, NativeType}; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayMinUDF { + signature: Signature, +} + +impl ArrayMinUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + )]), + volatility: Volatility::Immutable, + }, + } + } + + fn find_min(string: impl AsRef) -> DFResult> { + let string = string.as_ref(); + let array_value: Value = from_slice(string.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse the JSON string: {e}", + )) + })?; + + if let Value::Array(array) = array_value { + if array.is_empty() { + return Ok(None); + } + + // Try to find the minimum value, handling different types + let mut min_value: Option = None; + let mut min_type: Option<&str> = None; + + for value in array { + match value { + Value::Number(n) if n.is_i64() => { + let num = n.as_i64().ok_or( + datafusion_common::error::DataFusionError::Internal( + "Failed to parse number".to_string(), + ), + )?; + let should_update = match min_value.as_ref() { + None => true, + Some(current) => { + if let Ok(current_num) = current.parse::() { + min_type == Some("i64") && num < current_num + } else { + false + } + } + }; + if should_update { + min_value = Some(num.to_string()); + min_type = Some("i64"); + } + } + Value::Number(n) if n.is_f64() => { + let num = n.as_f64().ok_or( + datafusion_common::error::DataFusionError::Internal( + "Failed to parse number".to_string(), + ), + )?; + let should_update = match min_value.as_ref() { + None => true, + Some(current) => { + if let Ok(current_num) = current.parse::() { + min_type == Some("f64") && num < current_num + } else { + false + } + } + }; + if should_update { + min_value = Some(num.to_string()); + min_type = Some("f64"); + } + } + _ => {} + } + } + + Ok(min_value) + } else { + Ok(None) + } + } +} + +impl Default for ArrayMinUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayMinUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected a variant argument".to_string(), + ))?; + match array_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let str_value = string_array.value(i); + results.push(Self::find_min(str_value)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string".to_string(), + )); + }; + + let result = Self::find_min(array_str)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + } + } +} + +make_udf_function!(ArrayMinUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_min() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayMinUDF::new())); + + // Test numeric array + let sql = "SELECT array_min(array_construct(1, 5, 3, 9, 2)) as min_num"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| min_num |", + "+---------+", + "| 1 |", + "+---------+", + ], + &result + ); + + // Test mixed types + let sql = "SELECT array_min(array_construct(1, 'hello', 2.5, 10)) as min_mixed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| min_mixed |", + "+-----------+", + "| 1 |", + "+-----------+", + ], + &result + ); + + // Test array of nulls + let sql = "SELECT array_min(array_construct(NULL, NULL, NULL)) as null_min"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| null_min |", + "+----------+", + "| |", + "+----------+" + ], + &result + ); + + // Test empty array + let sql = "SELECT array_min(array_construct()) as empty_min"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| empty_min |", + "+-----------+", + "| |", + "+-----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_position.rs b/crates/df-builtins/src/variant/array_position.rs new file mode 100644 index 00000000..d2e6b824 --- /dev/null +++ b/crates/df-builtins/src/variant/array_position.rs @@ -0,0 +1,216 @@ +use super::super::macros::make_udf_function; +use super::json::{encode_array, encode_scalar}; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_slice, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayPositionUDF { + signature: Signature, +} + +#[allow(clippy::cast_possible_wrap, clippy::as_conversions)] +impl ArrayPositionUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn array_position(element: &Value, array: &Value) -> Option { + if let Value::Array(arr) = array { + // Find the position of the element in the array + for (index, item) in arr.iter().enumerate() { + if item == element { + return Some(index as i64); + } + } + } + None + } +} + +impl Default for ArrayPositionUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayPositionUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let element = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected element argument".to_string(), + ))?; + let array = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + + match (element, array) { + (ColumnarValue::Array(element_array), ColumnarValue::Array(array_array)) => { + let mut results = Vec::new(); + + // Convert element array to JSON + let element_array = encode_array(element_array.clone())?; + + // Get array_array as string array + let string_array = array_array.as_string::(); + + if let Value::Array(element_array) = element_array { + #[allow(clippy::needless_range_loop)] + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_value: Value = from_slice(string_array.value(i).as_bytes()) + .map_err(|e| { + datafusion_common::error::DataFusionError::Internal(e.to_string()) + })?; + + let result = Self::array_position(&element_array[i], &array_value); + results.push(result); + } + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::Int64Array::from(results), + ))) + } + (ColumnarValue::Scalar(element_scalar), ColumnarValue::Scalar(array_scalar)) => { + let element_scalar = encode_scalar(element_scalar)?; + let array_scalar = match array_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => from_slice(s.as_bytes()).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(e.to_string()) + })?, + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Array argument must be a string type".to_string(), + )); + } + }; + + let result = Self::array_position(&element_scalar, &array_scalar); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Mismatched argument types".to_string(), + )), + } + } +} + +make_udf_function!(ArrayPositionUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_position() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayPositionUDF::new())); + + // Test basic array position + let sql = "SELECT array_position('hello', array_construct('hello', 'hi')) as result1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result1 |", + "+---------+", + "| 0 |", + "+---------+", + ], + &result + ); + + // Test element not found + let sql = "SELECT array_position('world', array_construct('hello', 'hi')) as result2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result2 |", + "+---------+", + "| |", + "+---------+", + ], + &result + ); + + // Test with null values + let sql = "SELECT array_position(NULL, array_construct('hello', 'hi')) as result3"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result3 |", + "+---------+", + "| |", + "+---------+", + ], + &result + ); + + // Test searching for NULL in array containing NULL + let sql = "SELECT array_position(NULL, array_construct('hello', NULL, 'hi')) as result4"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| result4 |", + "+---------+", + "| 1 |", + "+---------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_prepend.rs b/crates/df-builtins/src/variant/array_prepend.rs new file mode 100644 index 00000000..907d1aff --- /dev/null +++ b/crates/df-builtins/src/variant/array_prepend.rs @@ -0,0 +1,207 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_slice, to_string, Value}; +use std::sync::Arc; + +use super::json::encode_scalar; + +#[derive(Debug, Clone)] +pub struct ArrayPrependUDF { + signature: Signature, +} + +impl ArrayPrependUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn prepend_element(array_value: &Value, element_value: &Value) -> DFResult { + // Ensure the first argument is an array + if let Value::Array(array) = array_value { + // Create new array with element value prepended + let mut new_array = Vec::with_capacity(array.len() + 1); + new_array.push(element_value.clone()); + new_array.extend(array.iter().cloned()); + + // Convert back to JSON string + to_string(&Value::Array(new_array)).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + }) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayPrependUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayPrependUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_prepend" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let element = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected element argument".to_string(), + ))?; + + match (array_str, element) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(element_value)) => { + let string_array = array.as_string::(); + let element_value = encode_scalar(element_value)?; + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_value = from_slice(string_array.value(i).as_bytes()) + .map_err(|e| datafusion_common::error::DataFusionError::Internal(e.to_string()))?; + results.push(Some(Self::prepend_element(&array_value, &element_value)?)); + } + } + + Ok(ColumnarValue::Array(Arc::new(datafusion::arrow::array::StringArray::from(results)))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(element_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string() + )) + }; + let array_value = from_slice(array_str.as_bytes()) + .map_err(|e| datafusion_common::error::DataFusionError::Internal(e.to_string()))?; + let element_value = encode_scalar(element_value)?; + + let result = Self::prepend_element(&array_value, &element_value)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second argument must be a scalar value".to_string() + )) + } + } +} + +make_udf_function!(ArrayPrependUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_prepend() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayPrependUDF::new())); + + // Test prepending string to numeric array + let sql = "SELECT array_prepend(array_construct(0,1,2,3), 'hello') as prepended"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------------+", + "| prepended |", + "+-------------------+", + "| [\"hello\",0,1,2,3] |", + "+-------------------+", + ], + &result + ); + + // Test prepending number + let sql = "SELECT array_prepend(array_construct(1,2,3), 42) as num_prepend"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| num_prepend |", + "+-------------+", + "| [42,1,2,3] |", + "+-------------+", + ], + &result + ); + + // Test prepending boolean + let sql = "SELECT array_prepend(array_construct(1,2,3), true) as bool_prepend"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| bool_prepend |", + "+--------------+", + "| [true,1,2,3] |", + "+--------------+", + ], + &result + ); + + // Test prepending null + let sql = "SELECT array_prepend(array_construct(1,2,3), NULL) as null_prepend"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| null_prepend |", + "+--------------+", + "| [null,1,2,3] |", + "+--------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_remove.rs b/crates/df-builtins/src/variant/array_remove.rs new file mode 100644 index 00000000..8bac5400 --- /dev/null +++ b/crates/df-builtins/src/variant/array_remove.rs @@ -0,0 +1,275 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayRemoveUDF { + signature: Signature, +} + +impl ArrayRemoveUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn remove_element( + array_value: Value, + element_value: Option, + ) -> DFResult> { + // If element is null, return null + if element_value.is_none() { + return Ok(None); + } + let element_value = + element_value.ok_or(datafusion_common::error::DataFusionError::Internal( + "Element value is null".to_string(), + ))?; + + // Ensure the first argument is an array + if let Value::Array(array) = array_value { + // Filter out elements equal to the specified value + let filtered: Vec = array.into_iter().filter(|x| x != &element_value).collect(); + + // Convert back to JSON string + Ok(Some(to_string(&filtered).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayRemoveUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayRemoveUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_remove" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let element = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected element argument".to_string(), + ))?; + + match (array_str, element) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(element_value)) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + // Convert element_value to JSON Value once if not null + let element_json = if element_value.is_null() { + None + } else { + let element_json = super::json::encode_array(element_value.to_array_of_size(1)?)?; + if let Value::Array(array) = element_json { + match array.first() { + Some(value) => Some(value.clone()), + None => return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string() + )) + } + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string() + )) + } + }; + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str) + .map_err(|e| datafusion_common::error::DataFusionError::Internal( + format!("Failed to parse array JSON: {e}") + ))?; + results.push(Self::remove_element(array_json, element_json.clone())?); + } + } + + Ok(ColumnarValue::Array(Arc::new(datafusion::arrow::array::StringArray::from(results)))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(element_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string() + )); + }; + + // Parse array string to JSON Value + let array_json: Value = from_str(array_str) + .map_err(|e| datafusion_common::error::DataFusionError::Internal( + format!("Failed to parse array JSON: {e}") + ))?; + + // Convert element to JSON Value if not null + let element_json = if element_value.is_null() { + None + } else { + let element_json = super::json::encode_array(element_value.to_array_of_size(1)?)?; + if let Value::Array(array) = element_json { + match array.first() { + Some(value) => Some(value.clone()), + None => return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string() + )) + } + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string() + )) + } + }; + + let result = Self::remove_element(array_json, element_json)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second argument must be a scalar value".to_string() + )) + } + } +} + +make_udf_function!(ArrayRemoveUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_remove() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayRemoveUDF::new())); + + // Test removing from numeric array + let sql = "SELECT array_remove(array_construct(2, 5, 7, 5, 1), 5) as removed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| removed |", + "+---------+", + "| [2,7,1] |", + "+---------+", + ], + &result + ); + + // Test removing string + let sql = + "SELECT array_remove(array_construct('a', 'b', 'c', 'b', 'd'), 'b') as str_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| str_remove |", + "+---------------+", + "| [\"a\",\"c\",\"d\"] |", + "+---------------+", + ], + &result + ); + + // Test removing boolean + let sql = + "SELECT array_remove(array_construct(true, false, true, false), true) as bool_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| bool_remove |", + "+---------------+", + "| [false,false] |", + "+---------------+", + ], + &result + ); + + // Test removing non-existent element + let sql = "SELECT array_remove(array_construct(1, 2, 3), 4) as no_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| no_remove |", + "+-----------+", + "| [1,2,3] |", + "+-----------+", + ], + &result + ); + + // Test removing NULL element + let sql = "SELECT array_remove(array_construct(1, 2, 3), NULL) as null_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| null_remove |", + "+-------------+", + "| |", + "+-------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_remove_at.rs b/crates/df-builtins/src/variant/array_remove_at.rs new file mode 100644 index 00000000..77caaa96 --- /dev/null +++ b/crates/df-builtins/src/variant/array_remove_at.rs @@ -0,0 +1,292 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayRemoveAtUDF { + signature: Signature, +} + +#[allow( + clippy::cast_possible_wrap, + clippy::as_conversions, + clippy::cast_possible_truncation, + clippy::cast_sign_loss +)] +impl ArrayRemoveAtUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn remove_at_position(array_value: Value, position: i64) -> DFResult> { + // Ensure the first argument is an array + if let Value::Array(mut array) = array_value { + let array_len = array.len() as i64; + + // Convert negative index to positive (e.g., -1 means last element) + let actual_pos = if position < 0 { + position + array_len + } else { + position + }; + + // Check if position is valid + if actual_pos < 0 || actual_pos >= array_len { + return Ok(None); + } + + // Remove element at position + array.remove(actual_pos as usize); + + // Convert back to JSON string + Ok(Some(to_string(&array).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayRemoveAtUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayRemoveAtUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_remove_at" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let position = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected position argument".to_string(), + ))?; + + match (array_str, position) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(position_value)) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + // Get position value + let position = match position_value { + ScalarValue::Int64(Some(pos)) => *pos, + ScalarValue::Int64(None) | ScalarValue::Null => { + // If position is NULL, return NULL + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Position must be an integer".to_string(), + )); + } + }; + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + results.push(Self::remove_at_position(array_json, position)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(position_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )); + }; + + // If either array or position is NULL, return NULL + if array_value.is_null() || position_value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + let position = match position_value { + ScalarValue::Int64(Some(pos)) => *pos, + _ => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Position must be an integer".to_string(), + )); + } + }; + + // Parse array string to JSON Value + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}", + )) + })?; + + let result = Self::remove_at_position(array_json, position)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second argument must be an integer" + .to_string(), + )), + } + } +} + +make_udf_function!(ArrayRemoveAtUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_remove_at() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayRemoveAtUDF::new())); + + // Test removing at position 0 + let sql = "SELECT array_remove_at(array_construct(2, 5, 7), 0) as removed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| removed |", + "+---------+", + "| [5,7] |", + "+---------+", + ], + &result + ); + + // Test removing at last position + let sql = "SELECT array_remove_at(array_construct('a', 'b', 'c'), 2) as last_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| last_remove |", + "+-------------+", + "| [\"a\",\"b\"] |", + "+-------------+", + ], + &result + ); + + // Test removing at middle position + let sql = "SELECT array_remove_at(array_construct(true, false, true), 1) as middle_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| middle_remove |", + "+---------------+", + "| [true,true] |", + "+---------------+", + ], + &result + ); + + // Test removing with negative index + let sql = "SELECT array_remove_at(array_construct(1, 2, 3), -1) as neg_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| neg_remove |", + "+------------+", + "| [1,2] |", + "+------------+", + ], + &result + ); + + // Test removing with out of bounds index + let sql = "SELECT array_remove_at(array_construct(1, 2, 3), 5) as invalid_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------------+", + "| invalid_remove |", + "+----------------+", + "| |", + "+----------------+", + ], + &result + ); + + // Test removing with NULL position + let sql = "SELECT array_remove_at(array_construct(1, 2, 3), NULL) as null_pos"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| null_pos |", + "+----------+", + "| |", + "+----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_reverse.rs b/crates/df-builtins/src/variant/array_reverse.rs new file mode 100644 index 00000000..2e95767c --- /dev/null +++ b/crates/df-builtins/src/variant/array_reverse.rs @@ -0,0 +1,225 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArrayReverseUDF { + signature: Signature, +} + +impl ArrayReverseUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(1), + volatility: Volatility::Immutable, + }, + } + } + + fn reverse_array(array_value: Value) -> DFResult> { + // Ensure the argument is an array + if let Value::Array(mut array) = array_value { + // Reverse the array + array.reverse(); + + // Convert back to JSON string + Ok(Some(to_string(&array).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}" + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArrayReverseUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrayReverseUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_arg = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + + match array_arg { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + results.push(Self::reverse_array(array_json)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => { + // If array is NULL, return NULL + if array_value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )); + }; + + // Parse array string to JSON Value + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + let result = Self::reverse_array(array_json)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + } + } +} + +make_udf_function!(ArrayReverseUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_reverse() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayReverseUDF::new())); + + // Test basic array reverse + let sql = "SELECT array_reverse(array_construct(1, 2, 3, 4)) as reversed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| reversed |", + "+-----------+", + "| [4,3,2,1] |", + "+-----------+", + ], + &result + ); + + // Test with strings + let sql = "SELECT array_reverse(array_construct('a', 'b', 'c')) as str_reverse"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| str_reverse |", + "+---------------+", + "| [\"c\",\"b\",\"a\"] |", + "+---------------+", + ], + &result + ); + + // Test with booleans + let sql = "SELECT array_reverse(array_construct(true, false, true)) as bool_reverse"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------------+", + "| bool_reverse |", + "+-------------------+", + "| [true,false,true] |", + "+-------------------+", + ], + &result + ); + + // Test with empty array + let sql = "SELECT array_reverse(array_construct()) as empty_reverse"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| empty_reverse |", + "+---------------+", + "| [] |", + "+---------------+", + ], + &result + ); + + // Test with NULL + let sql = "SELECT array_reverse(NULL) as null_reverse"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| null_reverse |", + "+--------------+", + "| |", + "+--------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_size.rs b/crates/df-builtins/src/variant/array_size.rs new file mode 100644 index 00000000..a76df889 --- /dev/null +++ b/crates/df-builtins/src/variant/array_size.rs @@ -0,0 +1,189 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraySizeUDF { + signature: Signature, +} + +#[allow(clippy::as_conversions, clippy::cast_possible_wrap)] +impl ArraySizeUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(1), + volatility: Volatility::Immutable, + }, + } + } + + fn get_array_size(value: Value) -> Option { + match value { + Value::Array(array) => Some(array.len() as i64), + _ => None, // Return NULL for non-array values + } + } +} + +impl Default for ArraySizeUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraySizeUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_arg = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + + match array_arg { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + results.push(Self::get_array_size(array_json)); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::Int64Array::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => match array_value { + ScalarValue::Utf8(Some(s)) => { + let array_json: Value = from_str(s).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + let size = Self::get_array_size(array_json); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(size))) + } + ScalarValue::Utf8(None) | ScalarValue::Null => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )), + }, + } + } +} + +make_udf_function!(ArraySizeUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_size() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraySizeUDF::new())); + + // Test empty array + let sql = "SELECT array_size(array_construct()) as empty_size"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| empty_size |", + "+------------+", + "| 0 |", + "+------------+", + ], + &result + ); + + // Test array with elements + let sql = "SELECT array_size(array_construct(1, 2, 3, 4)) as size"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + ["+------+", "| size |", "+------+", "| 4 |", "+------+",], + &result + ); + + // Test with non-array input + let sql = "SELECT array_size('\"not an array\"') as invalid_size"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------+", + "| invalid_size |", + "+--------------+", + "| |", + "+--------------+", + ], + &result + ); + + // Test with NULL input + let sql = "SELECT array_size(NULL) as null_size"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| null_size |", + "+-----------+", + "| |", + "+-----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_slice.rs b/crates/df-builtins/src/variant/array_slice.rs new file mode 100644 index 00000000..b79adf80 --- /dev/null +++ b/crates/df-builtins/src/variant/array_slice.rs @@ -0,0 +1,275 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraySliceUDF { + signature: Signature, +} + +#[allow( + clippy::as_conversions, + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_sign_loss +)] +impl ArraySliceUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(3), // array, from, to + volatility: Volatility::Immutable, + }, + } + } + + fn slice_array(array_value: Value, from: i64, to: i64) -> DFResult> { + // Ensure the first argument is an array + if let Value::Array(array) = array_value { + let array_len = array.len() as i64; + + // Convert negative indices to positive (e.g., -1 means last element) + let actual_from = if from < 0 { from + array_len } else { from }; + + let actual_to = if to < 0 { to + array_len } else { to }; + + // Check if indices are valid + if actual_from < 0 + || actual_from >= array_len + || actual_to < actual_from + || actual_to > array_len + { + return Ok(None); + } + + // Extract slice + let slice = array[actual_from as usize..actual_to as usize].to_vec(); + + // Convert back to JSON string + Ok(Some(to_string(&slice).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}" + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArraySliceUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraySliceUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_slice" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array_str = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + let from = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected from argument".to_string(), + ))?; + let to = args + .get(2) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected to argument".to_string(), + ))?; + + match (array_str, from, to) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(from_value), ColumnarValue::Scalar(to_value)) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + // Get from and to values + let from = match from_value { + ScalarValue::Int64(Some(pos)) => *pos, + ScalarValue::Int64(None) | ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => return Err(datafusion_common::error::DataFusionError::Internal( + "From index must be an integer".to_string() + )) + }; + + let to = match to_value { + ScalarValue::Int64(Some(pos)) => *pos, + ScalarValue::Int64(None) | ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => return Err(datafusion_common::error::DataFusionError::Internal( + "To index must be an integer".to_string() + )) + }; + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str) + .map_err(|e| datafusion_common::error::DataFusionError::Internal( + format!("Failed to parse array JSON: {e}" + )))?; + results.push(Self::slice_array(array_json, from, to)?); + } + } + + Ok(ColumnarValue::Array(Arc::new(datafusion::arrow::array::StringArray::from(results)))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(from_value), ColumnarValue::Scalar(to_value)) => { + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string() + )) + }; + + // If any argument is NULL, return NULL + if array_value.is_null() || from_value.is_null() || to_value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + let from = match from_value { + ScalarValue::Int64(Some(pos)) => *pos, + _ => return Err(datafusion_common::error::DataFusionError::Internal( + "From index must be an integer".to_string() + )) + }; + + let to = match to_value { + ScalarValue::Int64(Some(pos)) => *pos, + _ => return Err(datafusion_common::error::DataFusionError::Internal( + "To index must be an integer".to_string() + )) + }; + + // Parse array string to JSON Value + let array_json: Value = from_str(array_str) + .map_err(|e| datafusion_common::error::DataFusionError::Internal( + format!("Failed to parse array JSON: {e}" + )))?; + + let result = Self::slice_array(array_json, from, to)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string, second and third arguments must be integers".to_string() + )) + } + } +} + +make_udf_function!(ArraySliceUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_slice() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraySliceUDF::new())); + + // Test basic slice + let sql = "SELECT array_slice(array_construct(0, 1, 2, 3, 4, 5, 6), 0, 2) as slice"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------+", + "| slice |", + "+-------+", + "| [0,1] |", + "+-------+", + ], + &result + ); + + // Test slice with negative indices + let sql = "SELECT array_slice(array_construct('a', 'b', 'c', 'd'), -2, -1) as neg_slice"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------+", + "| neg_slice |", + "+-----------+", + "| [\"c\"] |", + "+-----------+", + ], + &result + ); + + // Test slice with out of bounds indices + let sql = "SELECT array_slice(array_construct(1, 2, 3), 5, 7) as invalid_slice"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| invalid_slice |", + "+---------------+", + "| |", + "+---------------+", + ], + &result + ); + + // Test slice with NULL indices + let sql = "SELECT array_slice(array_construct(1, 2, 3), NULL, 2) as null_slice"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_slice |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/array_sort.rs b/crates/df-builtins/src/variant/array_sort.rs new file mode 100644 index 00000000..0e9f3b90 --- /dev/null +++ b/crates/df-builtins/src/variant/array_sort.rs @@ -0,0 +1,291 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraySortUDF { + signature: Signature, +} + +impl ArraySortUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::VariadicAny, + volatility: Volatility::Immutable, + }, + } + } + + fn compare_json_values(a: &Value, b: &Value) -> std::cmp::Ordering { + match (a, b) { + (Value::Null, Value::Null) => std::cmp::Ordering::Equal, + (Value::Bool(a), Value::Bool(b)) => a.cmp(b), + (Value::Number(a), Value::Number(b)) => { + if let (Some(a_f), Some(b_f)) = (a.as_f64(), b.as_f64()) { + a_f.partial_cmp(&b_f).unwrap_or(std::cmp::Ordering::Equal) + } else { + std::cmp::Ordering::Equal + } + } + (Value::String(a), Value::String(b)) => a.cmp(b), + // Type-based ordering for different types, with nulls always last + _ => { + let type_order = |v: &Value| { + match v { + Value::Null => 5, // Move nulls to end + Value::Bool(_) => 0, + Value::Number(_) => 1, + Value::String(_) => 2, + Value::Array(_) => 3, + Value::Object(_) => 4, + } + }; + type_order(a).cmp(&type_order(b)) + } + } + } + + fn sort_array( + array_value: Value, + sort_ascending: bool, + _nulls_first: bool, + ) -> DFResult> { + if let Value::Array(array) = array_value { + // Convert array elements to a format that can be sorted + let mut elements: Vec> = array + .into_iter() + .map(|v| match v { + Value::Null => None, + v => Some(v), + }) + .collect(); + + // Sort the array, putting nulls last for ascending and first for descending + elements.sort_by(|a, b| { + match (a, b) { + (None, None) => std::cmp::Ordering::Equal, + (None, Some(_)) => { + if sort_ascending { + std::cmp::Ordering::Greater // Nulls last for ascending + } else { + std::cmp::Ordering::Less // Nulls first for descending + } + } + (Some(_), None) => { + if sort_ascending { + std::cmp::Ordering::Less // Non-nulls before nulls for ascending + } else { + std::cmp::Ordering::Greater // Non-nulls after nulls for descending + } + } + (Some(a_val), Some(b_val)) => { + let cmp = Self::compare_json_values(a_val, b_val); + if sort_ascending { + cmp + } else { + cmp.reverse() + } + } + } + }); + + // Convert back to JSON array + let sorted_array: Vec = elements + .into_iter() + .map(|opt| opt.unwrap_or(Value::Null)) + .collect(); + + Ok(Some(to_string(&sorted_array).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}" + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array".to_string(), + )) + } + } +} + +impl Default for ArraySortUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraySortUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "array_sort" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + + // Get array argument + let array_arg = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected array argument".to_string(), + ))?; + + // Get optional sort_ascending argument (default: true) + let sort_ascending = args.get(1).is_none_or(|v| match v { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => *b, + _ => true, + }); + + // Get optional nulls_first argument (default: true) + let nulls_first = args.get(2).is_none_or(|v| match v { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => *b, + _ => true, + }); + + match array_arg { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let array_str = string_array.value(i); + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + results.push(Self::sort_array(array_json, sort_ascending, nulls_first)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(array_value) => match array_value { + ScalarValue::Utf8(Some(array_str)) => { + let array_json: Value = from_str(array_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + + let result = Self::sort_array(array_json, sort_ascending, nulls_first)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::Utf8(None) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + _ => Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON array string".to_string(), + )), + }, + } + } +} + +make_udf_function!(ArraySortUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_sort() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraySortUDF::new())); + + // Test basic sorting + let sql = "SELECT array_sort(array_construct(20, NULL, 0, NULL, 10)) as sorted"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------+", + "| sorted |", + "+---------------------+", + "| [0,10,20,null,null] |", + "+---------------------+", + ], + &result + ); + + // Test descending order + let sql = "SELECT array_sort(array_construct(20, NULL, 0, NULL, 10), false) as desc_sort"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------+", + "| desc_sort |", + "+---------------------+", + "| [null,null,20,10,0] |", + "+---------------------+", + ], + &result + ); + + // Test nulls last + let sql = + "SELECT array_sort(array_construct(20, NULL, 0, NULL, 10), true, false) as nulls_last"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------+", + "| nulls_last |", + "+---------------------+", + "| [0,10,20,null,null] |", + "+---------------------+", + ], + &result + ); + + // Test with mixed types + let sql = "SELECT array_sort(array_construct('a', 'c', 'b')) as str_sort"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| str_sort |", + "+---------------+", + "| [\"a\",\"b\",\"c\"] |", + "+---------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/array_to_string.rs b/crates/df-builtins/src/variant/array_to_string.rs similarity index 99% rename from crates/df-builtins/src/array_to_string.rs rename to crates/df-builtins/src/variant/array_to_string.rs index 58c755f0..62cc9c20 100644 --- a/crates/df-builtins/src/array_to_string.rs +++ b/crates/df-builtins/src/variant/array_to_string.rs @@ -8,6 +8,7 @@ use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use serde_json::Value; use std::any::Any; use std::sync::Arc; +use crate::macros::make_udf_function; // array_to_string SQL function // Converts the input array to a string by first casting each element to a string, @@ -137,7 +138,7 @@ fn to_string(v: &Value, sep: &str) -> DFResult { Ok(res.join(sep)) } -super::macros::make_udf_function!(ArrayToStringFunc); +make_udf_function!(ArrayToStringFunc); #[cfg(test)] mod tests { diff --git a/crates/df-builtins/src/variant/arrays_overlap.rs b/crates/df-builtins/src/variant/arrays_overlap.rs new file mode 100644 index 00000000..51068a35 --- /dev/null +++ b/crates/df-builtins/src/variant/arrays_overlap.rs @@ -0,0 +1,226 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraysOverlapUDF { + signature: Signature, +} + +impl ArraysOverlapUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn arrays_have_overlap(array1: Value, array2: Value) -> DFResult> { + // Ensure both arguments are arrays + if let (Value::Array(arr1), Value::Array(arr2)) = (array1, array2) { + // Convert arrays to HashSet for efficient comparison + let set1: std::collections::HashSet = + arr1.iter().map(|v| v.to_string()).collect(); + + // Check if any element from arr2 exists in set1 + for val in arr2 { + if set1.contains(&val.to_string()) { + return Ok(Some(true)); + } + } + + Ok(Some(false)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Both arguments must be JSON arrays".to_string(), + )) + } + } +} + +impl Default for ArraysOverlapUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraysOverlapUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "arrays_overlap" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let array1_arg = + args.first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected first array argument".to_string(), + ))?; + let array2_arg = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected second array argument".to_string(), + ))?; + + match (array1_arg, array2_arg) { + (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => { + let string_array1 = array1.as_string::(); + let string_array2 = array2.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array1.len() { + if string_array1.is_null(i) || string_array2.is_null(i) { + results.push(None); + } else { + let array1_str = string_array1.value(i); + let array2_str = string_array2.value(i); + + let array1_json: Value = from_str(array1_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse first array JSON: {e}" + )) + })?; + + let array2_json: Value = from_str(array2_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse second array JSON: {e}" + )) + })?; + + results.push(Self::arrays_have_overlap(array1_json, array2_json)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::BooleanArray::from(results), + ))) + } + (ColumnarValue::Scalar(array1_value), ColumnarValue::Scalar(array2_value)) => { + // If either array is NULL, return NULL + if array1_value.is_null() || array2_value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let ScalarValue::Utf8(Some(array1_str)) = array1_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for first array".to_string(), + )); + }; + let ScalarValue::Utf8(Some(array2_str)) = array2_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for first array".to_string(), + )); + }; + + // Parse array strings to JSON Values + let array1_json: Value = from_str(array1_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse first array JSON: {e}" + )) + })?; + + let array2_json: Value = from_str(array2_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse second array JSON: {e}", + )) + })?; + + let result = Self::arrays_have_overlap(array1_json, array2_json)?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Both arguments must be JSON array strings".to_string(), + )), + } + } +} + +make_udf_function!(ArraysOverlapUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_arrays_overlap() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraysOverlapUDF::new())); + + // Test with string arrays that overlap + let sql = "SELECT arrays_overlap(array_construct('hello', 'aloha'), array_construct('hello', 'hi', 'hey')) as overlap"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| overlap |", + "+---------+", + "| true |", + "+---------+", + ], + &result + ); + + // Test with string arrays that don't overlap + let sql = "SELECT arrays_overlap(array_construct('hello', 'aloha'), array_construct('hola', 'bonjour', 'ciao')) as overlap"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| overlap |", + "+---------+", + "| false |", + "+---------+", + ], + &result + ); + + // Test with NULL values + let sql = "SELECT arrays_overlap(NULL, array_construct(1, 2, 3)) as overlap"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| overlap |", + "+---------+", + "| |", + "+---------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/arrays_to_object.rs b/crates/df-builtins/src/variant/arrays_to_object.rs new file mode 100644 index 00000000..c18442ce --- /dev/null +++ b/crates/df-builtins/src/variant/arrays_to_object.rs @@ -0,0 +1,244 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraysToObjectUDF { + signature: Signature, +} + +impl ArraysToObjectUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::Any(2), + volatility: Volatility::Immutable, + }, + } + } + + fn create_object(keys: &[Option], values: &[Value]) -> Option { + if keys.len() != values.len() { + return None; + } + + let mut obj = serde_json::Map::new(); + + for (key_opt, value) in keys.iter().zip(values.iter()) { + if let Some(key) = key_opt { + obj.insert(key.clone(), value.clone()); + } + } + + Some(json!(obj).to_string()) + } +} + +impl Default for ArraysToObjectUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraysToObjectUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "arrays_to_object" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let keys_arg = args + .first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected keys array argument".to_string(), + ))?; + let values_arg = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected values array argument".to_string(), + ))?; + + match (keys_arg, values_arg) { + (ColumnarValue::Array(keys_array), ColumnarValue::Array(values_array)) => { + let keys_string_array = keys_array.as_string::(); + let values_string_array = values_array.as_string::(); + let mut results = Vec::new(); + + for i in 0..keys_string_array.len() { + if keys_string_array.is_null(i) && values_string_array.is_null(i) { + results.push(None); + continue; + } + + let keys: Vec> = (0..keys_string_array.len()) + .map(|j| { + if keys_string_array.is_null(j) { + None + } else { + Some(keys_string_array.value(j).to_string()) + } + }) + .collect(); + + let values: Vec = (0..values_string_array.len()) + .map(|j| { + if values_string_array.is_null(j) { + Value::Null + } else { + // Try to parse as JSON, fallback to string if not valid JSON + match serde_json::from_str(values_string_array.value(j)) { + Ok(val) => val, + Err(_) => { + Value::String(values_string_array.value(j).to_string()) + } + } + } + }) + .collect(); + + results.push(Self::create_object(&keys, &values)); + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + (ColumnarValue::Scalar(keys_value), ColumnarValue::Scalar(values_value)) => { + // Handle NULL inputs + if keys_value.is_null() || values_value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + let ScalarValue::Utf8(Some(keys_str)) = keys_value else { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + }; + + let ScalarValue::Utf8(Some(values_str)) = values_value else { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + }; + + // Parse arrays + let keys: Value = serde_json::from_str(keys_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse keys JSON array: {e}" + )) + })?; + + let values: Value = serde_json::from_str(values_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse values JSON array: {e}" + )) + })?; + + if let (Value::Array(key_array), Value::Array(value_array)) = (keys, values) { + let keys: Vec> = key_array + .into_iter() + .map(|v| match v { + Value::String(s) => Some(s), + Value::Null => None, + _ => Some(v.to_string()), + }) + .collect(); + + let result = Self::create_object(&keys, &value_array); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Both arguments must be JSON arrays".to_string(), + )) + } + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Arguments must be arrays".to_string(), + )), + } + } +} + +make_udf_function!(ArraysToObjectUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_arrays_to_object() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraysToObjectUDF::new())); + + // Test basic key-value mapping + let sql = "SELECT arrays_to_object(array_construct('key1', 'key2', 'key3'), array_construct(1, 2, 3)) as obj"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------------------------+", + "| obj |", + "+------------------------------+", + "| {\"key1\":1,\"key2\":2,\"key3\":3} |", + "+------------------------------+", + ], + &result + ); + + // Test with NULL key + let sql = "SELECT arrays_to_object(array_construct('key1', NULL, 'key3'), array_construct(1, 2, 3)) as obj_null_key"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------+", + "| obj_null_key |", + "+---------------------+", + "| {\"key1\":1,\"key3\":3} |", + "+---------------------+", + ], + &result + ); + + // Test with NULL value + let sql = "SELECT arrays_to_object(array_construct('key1', 'key2', 'key3'), array_construct(1, NULL, 3)) as obj_null_value"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------------------+", + "| obj_null_value |", + "+---------------------------------+", + "| {\"key1\":1,\"key2\":null,\"key3\":3} |", + "+---------------------------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/arrays_zip.rs b/crates/df-builtins/src/variant/arrays_zip.rs new file mode 100644 index 00000000..97f4cc2e --- /dev/null +++ b/crates/df-builtins/src/variant/arrays_zip.rs @@ -0,0 +1,253 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::Value; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ArraysZipUDF { + signature: Signature, +} + +impl ArraysZipUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::VariadicAny, + volatility: Volatility::Immutable, + }, + } + } + + fn zip_arrays(arrays: Vec) -> DFResult> { + // If any array is null, return null + if arrays.iter().any(|arr| arr.is_null()) { + return Ok(None); + } + + // Ensure all inputs are arrays + let arrays: Vec> = arrays + .into_iter() + .map(|val| match val { + Value::Array(arr) => Ok(arr), + _ => Err(datafusion_common::error::DataFusionError::Internal( + "All arguments must be arrays".to_string(), + )), + }) + .collect::>()?; + + // Find the maximum length among all arrays + let max_len = arrays.iter().map(|arr| arr.len()).max().unwrap_or(0); + + // Create the zipped array + let mut result = Vec::with_capacity(max_len); + for i in 0..max_len { + let mut obj = serde_json::Map::new(); + for (array_idx, array) in arrays.iter().enumerate() { + let key = format!("${}", array_idx + 1); + let value = array.get(i).cloned().unwrap_or(Value::Null); + obj.insert(key, value); + } + result.push(Value::Object(obj)); + } + + Ok(Some(serde_json::to_string(&result).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}" + )) + })?)) + } +} + +impl Default for ArraysZipUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArraysZipUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "arrays_zip" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + + match args.first() { + Some(ColumnarValue::Array(first_array)) => { + let string_array = first_array.as_string::(); + let mut results = Vec::new(); + + for row in 0..string_array.len() { + let mut row_arrays = Vec::new(); + let mut has_null = false; + + // Collect all array values for this row + for arg in &args { + match arg { + ColumnarValue::Array(arr) => { + let arr = arr.as_string::(); + if arr.is_null(row) { + has_null = true; + break; + } + let array_json: Value = serde_json::from_str(arr.value(row)) + .map_err(|e| { + datafusion_common::error::DataFusionError::Internal( + format!("Failed to parse array JSON: {e}"), + ) + })?; + row_arrays.push(array_json); + } + ColumnarValue::Scalar(_) => { + return Err(datafusion_common::error::DataFusionError::Internal( + "All arguments must be arrays".to_string(), + )); + } + } + } + + if has_null { + results.push(None); + } else { + results.push(Self::zip_arrays(row_arrays)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + Some(ColumnarValue::Scalar(_first_value)) => { + let mut scalar_arrays = Vec::new(); + + // If any scalar is NULL, return NULL + for arg in &args { + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + if let ScalarValue::Utf8(Some(s)) = scalar { + let array_json: Value = serde_json::from_str(s).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse array JSON: {e}" + )) + })?; + scalar_arrays.push(array_json); + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected UTF8 string for array".to_string(), + )); + } + } + ColumnarValue::Array(_) => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Mixed scalar and array arguments are not supported".to_string(), + )); + } + } + } + + let result = Self::zip_arrays(scalar_arrays)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + None => Err(datafusion_common::error::DataFusionError::Internal( + "ARRAYS_ZIP requires at least one array argument".to_string(), + )), + } + } +} + +make_udf_function!(ArraysZipUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_arrays_zip() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArraysZipUDF::new())); + + // Test basic zipping of three arrays + let sql = "SELECT arrays_zip( + array_construct(1, 2, 3), + array_construct('first', 'second', 'third'), + array_construct('i', 'ii', 'iii') + ) as zipped_arrays"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------------------------------------------------------------------------------------------------+", + "| zipped_arrays |", + "+----------------------------------------------------------------------------------------------------+", + "| [{\"$1\":1,\"$2\":\"first\",\"$3\":\"i\"},{\"$1\":2,\"$2\":\"second\",\"$3\":\"ii\"},{\"$1\":3,\"$2\":\"third\",\"$3\":\"iii\"}] |", + "+----------------------------------------------------------------------------------------------------+", + ], + &result + ); + + // Test arrays of different lengths + let sql = "SELECT arrays_zip( + array_construct(1, 2, 3), + array_construct('a', 'b') + ) as diff_lengths"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------------------------------------------------------+", + "| diff_lengths |", + "+----------------------------------------------------------+", + "| [{\"$1\":1,\"$2\":\"a\"},{\"$1\":2,\"$2\":\"b\"},{\"$1\":3,\"$2\":null}] |", + "+----------------------------------------------------------+", + ], + &result + ); + + // Test with NULL array + let sql = "SELECT arrays_zip(NULL, array_construct(1, 2, 3)) as null_array"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_array |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/json.rs b/crates/df-builtins/src/variant/json.rs new file mode 100644 index 00000000..408f28ac --- /dev/null +++ b/crates/df-builtins/src/variant/json.rs @@ -0,0 +1,1281 @@ +#![allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)] + +use datafusion::arrow::array::AsArray; +use datafusion::arrow::array::{Array, ArrayRef, BooleanArray, NullArray, PrimitiveArray, StringArray}; +use datafusion::arrow::datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion::arrow::error::ArrowError; +use base64::engine::Engine; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use serde_json::{Map, Number, Value as JsonValue}; +use std::sync::Arc; + +/// Encodes a Boolean Arrow array into a JSON array +pub fn encode_boolean_array(array: ArrayRef) -> Result { + let array = array + .as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::InvalidArgumentError("Expected boolean array".into()))?; + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Bool(array.value(i)) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a signed Primitive Arrow array into a JSON array +pub fn encode_signed_array(array: ArrayRef) -> Result +where + T: datafusion::arrow::datatypes::ArrowPrimitiveType, + T::Native: Into, +{ + let array = array + .as_ref() + .as_any() + .downcast_ref::>() + .ok_or_else(|| ArrowError::InvalidArgumentError("Expected primitive array".into()))?; + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i).into())) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes an unsigned Primitive Arrow array into a JSON array +pub fn encode_unsigned_array(array: ArrayRef) -> Result +where + T: datafusion::arrow::datatypes::ArrowPrimitiveType, + T::Native: Into, +{ + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i).into())) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Float16 Arrow array into a JSON array +pub fn encode_float16_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number( + Number::from_f64(array.value(i).to_f64()).ok_or_else(|| { + ArrowError::InvalidArgumentError("Invalid float value".into()) + })?, + ) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Float32 Arrow array into a JSON array +pub fn encode_float32_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number( + Number::from_f64(f64::from(array.value(i))).ok_or_else(|| { + ArrowError::InvalidArgumentError("Invalid float value".into()) + })?, + ) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Float64 Arrow array into a JSON array +pub fn encode_float64_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number( + Number::from_f64(array.value(i)).ok_or_else(|| { + ArrowError::InvalidArgumentError("Invalid float value".into()) + })?, + ) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Utf8 Arrow array into a JSON array +pub fn encode_utf8_array(array: ArrayRef) -> Result { + let array = array.as_string::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(array.value(i).to_string()) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a LargeUtf8 Arrow array into a JSON array +pub fn encode_large_utf8_array(array: ArrayRef) -> Result { + let array = array.as_string::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(array.value(i).to_string()) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Binary Arrow array into a JSON array +pub fn encode_binary_array(array: ArrayRef) -> Result { + let array = array.as_binary::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(array.value(i))) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a LargeBinary Arrow array into a JSON array +pub fn encode_large_binary_array(array: ArrayRef) -> Result { + let array = array.as_binary::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(array.value(i))) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a Dictionary Arrow array into a JSON array +pub fn encode_dictionary_array(array: ArrayRef) -> Result +where + K: datafusion::arrow::datatypes::ArrowDictionaryKeyType, +{ + let array = array.as_dictionary::(); + let keys = array.keys().clone(); + let values = array.values().clone(); + + // Encode dictionary keys and values + let key_values = encode_primitive_array(Arc::new(keys))?; + let value_values = encode_array(Arc::new(values))?; + + // Build dictionary structure + let mut result = Map::new(); + result.insert("keys".to_string(), key_values); + result.insert("values".to_string(), value_values); + Ok(JsonValue::Object(result)) +} + +/// Encodes a Timestamp Arrow array into a JSON array +pub fn encode_timestamp_array(array: ArrayRef) -> Result { + let data_type = array.data_type(); + let DataType::Timestamp(unit, _) = data_type else { + return Err(ArrowError::InvalidArgumentError( + "Expected timestamp array".into(), + )); + }; + + let mut values = Vec::with_capacity(array.len()); + + match unit { + TimeUnit::Second => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Millisecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Microsecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Nanosecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Date32 Arrow array into a JSON array +pub fn encode_date32_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(i64::from(array.value(i)))) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Date64 Arrow array into a JSON array +pub fn encode_date64_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Time32 Arrow array into a JSON array +pub fn encode_time32_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + match array.data_type() { + DataType::Time32(unit) => match unit { + TimeUnit::Second => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(i64::from(array.value(i)))) + }); + } + } + TimeUnit::Millisecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(i64::from(array.value(i)))) + }); + } + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Time32 arrays only support Second and Millisecond units, got {unit:?}", + ))); + } + }, + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected Time32 array, got {:?}", + array.data_type() + ))); + } + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Time64 Arrow array into a JSON array +pub fn encode_time64_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + match array.data_type() { + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Nanosecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Time64 arrays only support Microsecond and Nanosecond units, got {unit:?}", + ))); + } + }, + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected Time64 array, got {:?}", + array.data_type() + ))); + } + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Duration Arrow array into a JSON array +pub fn encode_duration_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + match array.data_type() { + DataType::Duration(unit) => match unit { + TimeUnit::Second => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Millisecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Microsecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + TimeUnit::Nanosecond => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::Number(Number::from(array.value(i))) + }); + } + } + }, + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected Duration array, got {:?}", + array.data_type() + ))); + } + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a primitive Arrow array into a JSON array by dispatching to the appropriate encoder function +pub fn encode_primitive_array(array: ArrayRef) -> Result { + match array.data_type() { + DataType::Int8 => encode_signed_array::(array), + DataType::Int16 => encode_signed_array::(array), + DataType::Int32 => encode_signed_array::(array), + DataType::Int64 => encode_signed_array::(array), + DataType::UInt8 => encode_unsigned_array::(array), + DataType::UInt16 => encode_unsigned_array::(array), + DataType::UInt32 => encode_unsigned_array::(array), + DataType::UInt64 => encode_unsigned_array::(array), + DataType::Float16 => encode_float16_array(array), + DataType::Float32 => encode_float32_array(array), + DataType::Float64 => encode_float64_array(array), + DataType::Timestamp(_, _) => encode_timestamp_array(array), + DataType::Date32 => encode_date32_array(array), + DataType::Date64 => encode_date64_array(array), + DataType::Time32(_) => encode_time32_array(array), + DataType::Time64(_) => encode_time64_array(array), + DataType::Duration(_) => encode_duration_array(array), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported primitive type: {:?}", + array.data_type() + ))), + } +} + +/// Encodes an Interval Arrow array into a JSON array +pub fn encode_interval_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + match array.data_type() { + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let value = array.value(i); + let mut obj = Map::new(); + obj.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj.insert( + "unit".to_string(), + JsonValue::String("YearMonth".to_string()), + ); + obj.insert("value".to_string(), JsonValue::String(value.to_string())); + JsonValue::Object(obj) + }); + } + } + IntervalUnit::DayTime => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let value = array.value(i); + let mut obj = Map::new(); + obj.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj.insert("unit".to_string(), JsonValue::String("DayTime".to_string())); + obj.insert( + "days".to_string(), + JsonValue::String(value.days.to_string()), + ); + obj.insert( + "milliseconds".to_string(), + JsonValue::String(value.milliseconds.to_string()), + ); + JsonValue::Object(obj) + }); + } + } + IntervalUnit::MonthDayNano => { + let array = array.as_primitive::(); + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let value = array.value(i); + let mut obj = Map::new(); + obj.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj.insert( + "unit".to_string(), + JsonValue::String("MonthDayNano".to_string()), + ); + obj.insert( + "months".to_string(), + JsonValue::String(value.months.to_string()), + ); + obj.insert( + "days".to_string(), + JsonValue::String(value.days.to_string()), + ); + obj.insert( + "nanoseconds".to_string(), + JsonValue::String(value.nanoseconds.to_string()), + ); + JsonValue::Object(obj) + }); + } + } + }, + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected Interval array, got {:?}", + array.data_type() + ))); + } + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a List Arrow array into a JSON array +pub fn encode_list_array(array: ArrayRef) -> Result { + let array = array.as_list::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let list = array.value(i); + encode_array(Arc::new(list))? + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a ListView Arrow array into a JSON array +pub fn encode_list_view_array(array: ArrayRef) -> Result { + let array = array.as_list_view::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let list = array.value(i); + encode_array(Arc::new(list))? + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a FixedSizeList Arrow array into a JSON array +pub fn encode_fixed_size_list_array(array: ArrayRef) -> Result { + let array = array.as_fixed_size_list(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let list = array.value(i); + encode_array(Arc::new(list))? + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a LargeList Arrow array into a JSON array +pub fn encode_large_list_array(array: ArrayRef) -> Result { + let array = array.as_list::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let list = array.value(i); + encode_array(Arc::new(list))? + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a Struct Arrow array into a JSON array +pub fn encode_struct_array(array: ArrayRef) -> Result { + let array = array.as_struct(); + let mut values = Vec::with_capacity(array.len()); + + let encoded_arrays = array + .columns() + .iter() + .map(|column| encode_array(column.clone())) + .collect::, _>>()? + .into_iter() + .zip(array.fields()) + .collect::>(); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let mut struct_value = Map::new(); + + #[allow(clippy::unwrap_used)] + for (encoded_array, field) in &encoded_arrays { + let encoded_value = encoded_array.get(i).unwrap(); + struct_value.insert(field.name().to_string(), encoded_value.clone()); + } + JsonValue::Object(struct_value) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Decimal128 Arrow array into a JSON array +pub fn encode_decimal128_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(array.value(i).to_string()) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Decimal256 Arrow array into a JSON array +pub fn encode_decimal256_array(array: ArrayRef) -> Result { + let array = array.as_primitive::(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let v = array.value(i); + let (low, high) = v.to_parts(); + JsonValue::Array(vec![ + JsonValue::String(low.to_string()), + JsonValue::String(high.to_string()), + ]) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a FixedSizeBinary Arrow array into a JSON array +pub fn encode_fixed_size_binary_array(array: ArrayRef) -> Result { + let array = array.as_fixed_size_binary(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(array.value(i))) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a BinaryView Arrow array into a JSON array +pub fn encode_binary_view_array(array: ArrayRef) -> Result { + let array = array.as_binary_view(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(array.value(i))) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a Utf8View Arrow array into a JSON array +pub fn encode_utf8_view_array(array: ArrayRef) -> Result { + let array = array + .as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::InvalidArgumentError("Expected utf8 view array".into()))?; + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + JsonValue::String(array.value(i).to_string()) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a RunEndEncoded Arrow array into a JSON array +pub fn encode_run_end_encoded_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let mut encoded_value = Map::new(); + encoded_value.insert( + "value".to_string(), + encode_array(Arc::new(array.slice(i, 1)))?, + ); + JsonValue::Object(encoded_value) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a Map Arrow array into a JSON array +pub fn encode_map_array(array: ArrayRef) -> Result { + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let map = array.slice(i, 1); + let mut map_value = Vec::new(); + let mut entry = Map::new(); + entry.insert("entries".to_string(), encode_array(Arc::new(map))?); + map_value.push(JsonValue::Object(entry)); + JsonValue::Array(map_value) + }); + } + Ok(JsonValue::Array(values)) +} + +// Encodes a Union Arrow array into a JSON array +pub fn encode_union_array(array: ArrayRef) -> Result { + let array = array.as_union(); + let mut values = Vec::with_capacity(array.len()); + + for i in 0..array.len() { + values.push(if array.is_null(i) { + JsonValue::Null + } else { + let mut union_value = Map::new(); + let type_id = array.type_id(i); + union_value.insert( + "type_id".to_string(), + JsonValue::Number(Number::from(type_id)), + ); + union_value.insert( + "value".to_string(), + encode_array(Arc::new(array.child(type_id).slice(i, 1)))?, + ); + JsonValue::Object(union_value) + }); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes a Null Arrow array into a JSON array +pub fn encode_null_array(array: ArrayRef) -> Result { + let array = array + .as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::InvalidArgumentError("Expected null array".into()))?; + let mut values = Vec::with_capacity(array.len()); + + for _ in 0..array.len() { + values.push(JsonValue::Null); + } + Ok(JsonValue::Array(values)) +} + +/// Encodes any Arrow array into a JSON array based on its data type +pub fn encode_array(array: ArrayRef) -> Result { + match array.as_ref().data_type() { + DataType::Null => encode_null_array(array), + DataType::Boolean => encode_boolean_array(array), + DataType::Int8 => encode_signed_array::(array), + DataType::Int16 => encode_signed_array::(array), + DataType::Int32 => encode_signed_array::(array), + DataType::Int64 => encode_signed_array::(array), + DataType::UInt8 => encode_unsigned_array::(array), + DataType::UInt16 => encode_unsigned_array::(array), + DataType::UInt32 => encode_unsigned_array::(array), + DataType::UInt64 => encode_unsigned_array::(array), + DataType::Float16 => encode_float16_array(array), + DataType::Float32 => encode_float32_array(array), + DataType::Float64 => encode_float64_array(array), + DataType::Utf8 => encode_utf8_array(array), + DataType::LargeUtf8 => encode_large_utf8_array(array), + DataType::Binary => encode_binary_array(array), + DataType::LargeBinary => encode_large_binary_array(array), + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => encode_dictionary_array::(array), + DataType::Int16 => encode_dictionary_array::(array), + DataType::Int32 => encode_dictionary_array::(array), + DataType::Int64 => encode_dictionary_array::(array), + DataType::UInt8 => encode_dictionary_array::(array), + DataType::UInt16 => encode_dictionary_array::(array), + DataType::UInt32 => encode_dictionary_array::(array), + DataType::UInt64 => encode_dictionary_array::(array), + _ => Err(ArrowError::InvalidArgumentError( + "Unsupported dictionary key type".into(), + )), + }, + DataType::Timestamp(_, _) => encode_timestamp_array(array), + DataType::Date32 => encode_date32_array(array), + DataType::Date64 => encode_date64_array(array), + DataType::Time32(_) => encode_time32_array(array), + DataType::Time64(_) => encode_time64_array(array), + DataType::Duration(_) => encode_duration_array(array), + DataType::Interval(_) => encode_interval_array(array), + DataType::List(_) => encode_list_array(array), + DataType::ListView(_) => encode_list_view_array(array), + DataType::FixedSizeList(_, _) => encode_fixed_size_list_array(array), + DataType::LargeList(_) | DataType::LargeListView(_) => encode_large_list_array(array), + DataType::Struct(_) => encode_struct_array(array), + DataType::Decimal128(_, _) => encode_decimal128_array(array), + DataType::Decimal256(_, _) => encode_decimal256_array(array), + DataType::FixedSizeBinary(_) => encode_fixed_size_binary_array(array), + DataType::BinaryView => encode_binary_view_array(array), + DataType::Utf8View => encode_utf8_view_array(array), + DataType::RunEndEncoded(_, _) => encode_run_end_encoded_array(array), + DataType::Map(_, _) => encode_map_array(array), + DataType::Union(_, _) => encode_union_array(array), + } +} + +pub fn encode_scalar(scalar: &ScalarValue) -> Result { + let scalar_array = scalar.to_array_of_size(1)?; + let encoded_array = encode_array(Arc::new(scalar_array))?; + if let JsonValue::Array(mut array) = encoded_array { + Ok(array.swap_remove(0)) + } else { + Err(ArrowError::InvalidArgumentError("Expected array".into())) + } +} + +// Encodes a ColumnarValue into a JSON value +#[allow(dead_code)] +pub fn encode_columnar_value(value: &ColumnarValue) -> Result { + match value { + ColumnarValue::Array(array) => encode_array(array.clone()), + ColumnarValue::Scalar(scalar) => encode_scalar(scalar), + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::arrow::array::*; + use datafusion::arrow::datatypes::{i256, Field}; + use datafusion::arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano}; + + #[test] + fn test_boolean_array() { + let array = BooleanArray::from(vec![Some(true), None, Some(false)]); + let json = encode_array(Arc::new(array)).unwrap(); + + let expected = JsonValue::Array(vec![ + JsonValue::Bool(true), + JsonValue::Null, + JsonValue::Bool(false), + ]); + assert_eq!(json, expected); + } + + #[test] + fn test_primitive_signed_arrays() { + // Int8 + let array = Int8Array::from(vec![Some(1), None, Some(-2)]); + let json = encode_array(Arc::new(array)).unwrap(); + let expected = JsonValue::Array(vec![ + JsonValue::Number(Number::from(1)), + JsonValue::Null, + JsonValue::Number(Number::from(-2)), + ]); + assert_eq!(json, expected); + + // Int16 + let array = Int16Array::from(vec![Some(1), None, Some(-2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // Int32 + let array = Int32Array::from(vec![Some(1), None, Some(-2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // Int64 + let array = Int64Array::from(vec![Some(1), None, Some(-2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + } + + #[test] + fn test_primitive_unsigned_arrays() { + let expected = JsonValue::Array(vec![ + JsonValue::Number(Number::from(1u64)), + JsonValue::Null, + JsonValue::Number(Number::from(2u64)), + ]); + + // UInt8 + let array = UInt8Array::from(vec![Some(1), None, Some(2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // UInt16 + let array = UInt16Array::from(vec![Some(1), None, Some(2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // UInt32 + let array = UInt32Array::from(vec![Some(1), None, Some(2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // UInt64 + let array = UInt64Array::from(vec![Some(1), None, Some(2)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + } + + #[test] + fn test_float_arrays() { + let expected = JsonValue::Array(vec![ + JsonValue::Number(Number::from_f64(1.5).unwrap()), + JsonValue::Null, + JsonValue::Number(Number::from_f64(-2.5).unwrap()), + ]); + + // Float32 + let array = Float32Array::from(vec![Some(1.5), None, Some(-2.5)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // Float64 + let array = Float64Array::from(vec![Some(1.5), None, Some(-2.5)]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + } + + #[test] + fn test_string_arrays() { + let expected = JsonValue::Array(vec![ + JsonValue::String("hello".to_string()), + JsonValue::Null, + JsonValue::String("world".to_string()), + ]); + + // Utf8 + let array = StringArray::from(vec![Some("hello"), None, Some("world")]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // LargeUtf8 + let array = LargeStringArray::from(vec![Some("hello"), None, Some("world")]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + } + + #[test] + fn test_binary_arrays() { + let data = vec![1u8, 2u8, 3u8]; + let expected = JsonValue::Array(vec![ + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(&data)), + JsonValue::Null, + JsonValue::String(base64::engine::general_purpose::STANDARD.encode(&data)), + ]); + + // Binary + let array = BinaryArray::from(vec![Some(data.as_slice()), None, Some(data.as_slice())]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + + // LargeBinary + let array = + LargeBinaryArray::from(vec![Some(data.as_slice()), None, Some(data.as_slice())]); + let json = encode_array(Arc::new(array)).unwrap(); + assert_eq!(json, expected); + } + + #[test] + fn test_list_arrays() { + let mut builder = ListBuilder::new(Int32Builder::new()); + + // First list: [1,2,3] + builder.values().append_value(1); + builder.values().append_value(2); + builder.values().append_value(3); + builder.append(true); + + // Second list: null + builder.append(false); + + // Third list: [4,5,6] + builder.values().append_value(4); + builder.values().append_value(5); + builder.values().append_value(6); + builder.append(true); + + let list_array = builder.finish(); + + let json = encode_array(Arc::new(list_array)).unwrap(); + + let expected = JsonValue::Array(vec![ + JsonValue::Array(vec![ + JsonValue::Number(Number::from(1)), + JsonValue::Number(Number::from(2)), + JsonValue::Number(Number::from(3)), + ]), + JsonValue::Null, + JsonValue::Array(vec![ + JsonValue::Number(Number::from(4)), + JsonValue::Number(Number::from(5)), + JsonValue::Number(Number::from(6)), + ]), + ]); + assert_eq!(json, expected); + } + + #[test] + #[allow(clippy::as_conversions)] + fn test_struct_array() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int as ArrayRef, + ), + ]); + + let json = encode_array(Arc::new(struct_array)).unwrap(); + + let mut obj1 = Map::new(); + obj1.insert("b".to_string(), JsonValue::Bool(false)); + obj1.insert("c".to_string(), JsonValue::Number(Number::from(42))); + + let mut obj2 = Map::new(); + obj2.insert("b".to_string(), JsonValue::Bool(false)); + obj2.insert("c".to_string(), JsonValue::Number(Number::from(28))); + + let mut obj3 = Map::new(); + obj3.insert("b".to_string(), JsonValue::Bool(true)); + obj3.insert("c".to_string(), JsonValue::Number(Number::from(19))); + + let mut obj4 = Map::new(); + obj4.insert("b".to_string(), JsonValue::Bool(true)); + obj4.insert("c".to_string(), JsonValue::Number(Number::from(31))); + + let expected = JsonValue::Array(vec![ + JsonValue::Object(obj1), + JsonValue::Object(obj2), + JsonValue::Object(obj3), + JsonValue::Object(obj4), + ]); + assert_eq!(json, expected); + } + + #[test] + fn test_null_array() { + let array = NullArray::new(3); + let json = encode_array(Arc::new(array)).unwrap(); + + let expected = JsonValue::Array(vec![JsonValue::Null, JsonValue::Null, JsonValue::Null]); + assert_eq!(json, expected); + } + + #[test] + #[allow(clippy::too_many_lines)] + fn test_interval_arrays() { + // Test YearMonth interval + let array = IntervalYearMonthArray::from(vec![Some(12), None, Some(-24)]); + let json = encode_array(Arc::new(array)).unwrap(); + + let mut obj1 = Map::new(); + obj1.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj1.insert( + "unit".to_string(), + JsonValue::String("YearMonth".to_string()), + ); + obj1.insert("value".to_string(), JsonValue::String("12".to_string())); + + let mut obj2 = Map::new(); + obj2.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj2.insert( + "unit".to_string(), + JsonValue::String("YearMonth".to_string()), + ); + obj2.insert("value".to_string(), JsonValue::String("0".to_string())); + + let mut obj3 = Map::new(); + obj3.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj3.insert( + "unit".to_string(), + JsonValue::String("YearMonth".to_string()), + ); + obj3.insert("value".to_string(), JsonValue::String("-24".to_string())); + + let expected = JsonValue::Array(vec![ + JsonValue::Object(obj1), + JsonValue::Null, + JsonValue::Object(obj3), + ]); + assert_eq!(json, expected); + + // Test DayTime interval + let array = IntervalDayTimeArray::from(vec![ + Some(IntervalDayTime::new(1, 0)), + None, + Some(IntervalDayTime::new(-2, 0)), + ]); + let json = encode_array(Arc::new(array)).unwrap(); + + let mut obj1 = Map::new(); + obj1.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj1.insert("unit".to_string(), JsonValue::String("DayTime".to_string())); + obj1.insert("days".to_string(), JsonValue::String("1".to_string())); + obj1.insert( + "milliseconds".to_string(), + JsonValue::String("0".to_string()), + ); + + let mut obj3 = Map::new(); + obj3.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj3.insert("unit".to_string(), JsonValue::String("DayTime".to_string())); + obj3.insert("days".to_string(), JsonValue::String("-2".to_string())); + obj3.insert( + "milliseconds".to_string(), + JsonValue::String("0".to_string()), + ); + + let expected = JsonValue::Array(vec![ + JsonValue::Object(obj1), + JsonValue::Null, + JsonValue::Object(obj3), + ]); + assert_eq!(json, expected); + + // Test MonthDayNano interval + let array = IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano::new(1, 0, 0)), + None, + Some(IntervalMonthDayNano::new(-2, 0, 0)), + ]); + let json = encode_array(Arc::new(array)).unwrap(); + + let mut obj1 = Map::new(); + obj1.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj1.insert( + "unit".to_string(), + JsonValue::String("MonthDayNano".to_string()), + ); + obj1.insert("months".to_string(), JsonValue::String("1".to_string())); + obj1.insert("days".to_string(), JsonValue::String("0".to_string())); + obj1.insert( + "nanoseconds".to_string(), + JsonValue::String("0".to_string()), + ); + + let mut obj3 = Map::new(); + obj3.insert( + "type".to_string(), + JsonValue::String("interval".to_string()), + ); + obj3.insert( + "unit".to_string(), + JsonValue::String("MonthDayNano".to_string()), + ); + obj3.insert("months".to_string(), JsonValue::String("-2".to_string())); + obj3.insert("days".to_string(), JsonValue::String("0".to_string())); + obj3.insert( + "nanoseconds".to_string(), + JsonValue::String("0".to_string()), + ); + + let expected = JsonValue::Array(vec![ + JsonValue::Object(obj1), + JsonValue::Null, + JsonValue::Object(obj3), + ]); + assert_eq!(json, expected); + } + + #[test] + fn test_decimal256_array() { + let array = Decimal256Array::from(vec![Some(i256::from(1)), None, Some(i256::from(-2))]); + let json = encode_array(Arc::new(array)).unwrap(); + + let expected = JsonValue::Array(vec![ + JsonValue::Array(vec![ + JsonValue::String("1".to_string()), + JsonValue::String("0".to_string()), + ]), + JsonValue::Null, + JsonValue::Array(vec![ + JsonValue::String("340282366920938463463374607431768211454".to_string()), + JsonValue::String("-1".to_string()), + ]), + ]); + assert_eq!(json, expected); + } +} diff --git a/crates/df-builtins/src/variant/mod.rs b/crates/df-builtins/src/variant/mod.rs new file mode 100644 index 00000000..e705b8ae --- /dev/null +++ b/crates/df-builtins/src/variant/mod.rs @@ -0,0 +1,75 @@ +pub mod array_append; +pub mod array_cat; +pub mod array_compact; +pub mod array_construct; +pub mod array_contains; +pub mod array_distinct; +pub mod array_except; +pub mod array_flatten; +pub mod array_generate_range; +pub mod array_insert; +pub mod array_intersection; +pub mod array_max; +pub mod array_min; +pub mod array_position; +pub mod array_prepend; +pub mod array_remove; +pub mod array_remove_at; +pub mod array_reverse; +pub mod array_size; +pub mod array_slice; +pub mod array_sort; +pub mod arrays_overlap; +pub mod arrays_to_object; +pub mod arrays_zip; +pub mod json; +pub mod object_construct; +pub mod object_delete; +pub mod object_insert; +pub mod object_pick; +pub mod variant_element; +pub mod array_to_string; + +use std::sync::Arc; +use datafusion::common::Result; +use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::ScalarUDF; + +pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = vec![ + array_append::get_udf(), + array_cat::get_udf(), + array_compact::get_udf(), + array_construct::get_udf(), + array_contains::get_udf(), + array_distinct::get_udf(), + array_except::get_udf(), + array_generate_range::get_udf(), + array_insert::get_udf(), + array_intersection::get_udf(), + array_max::get_udf(), + array_min::get_udf(), + array_position::get_udf(), + array_prepend::get_udf(), + array_remove::get_udf(), + array_remove_at::get_udf(), + array_reverse::get_udf(), + array_size::get_udf(), + array_slice::get_udf(), + array_sort::get_udf(), + arrays_overlap::get_udf(), + arrays_to_object::get_udf(), + arrays_zip::get_udf(), + variant_element::get_udf(), + object_delete::get_udf(), + object_insert::get_udf(), + object_pick::get_udf(), + object_construct::get_udf() + ]; + + for func in functions { + registry.register_udf(func)?; + } + + Ok(()) +} diff --git a/crates/df-builtins/src/variant/object_construct.rs b/crates/df-builtins/src/variant/object_construct.rs new file mode 100644 index 00000000..1a77346a --- /dev/null +++ b/crates/df-builtins/src/variant/object_construct.rs @@ -0,0 +1,292 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::{array::AsArray, datatypes::DataType}; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct ObjectConstructUDF { + signature: Signature, + aliases: Vec, +} + +impl ObjectConstructUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::VariadicAny, + TypeSignature::Nullary, + ]), + volatility: Volatility::Immutable, + }, + aliases: vec!["make_object".to_string()], + } + } +} + +impl Default for ObjectConstructUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ObjectConstructUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "object_construct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + + if args.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "{}".to_string(), + )))); + } + + if args.len() % 2 != 0 { + return Err(datafusion_common::error::DataFusionError::Execution( + "object_construct requires an even number of arguments (key-value pairs)" + .to_string(), + )); + } + + let mut object = HashMap::new(); + + for chunk in args.chunks(2) { + let key_array = chunk[0].clone().into_array(number_rows)?; + let value_array = chunk[1].clone().into_array(number_rows)?; + + for i in 0..key_array.len() { + if key_array.is_null(i) { + return Err(datafusion_common::error::DataFusionError::Execution( + "object_construct key cannot be null".to_string(), + )); + } + + let key = if let Some(str_array) = key_array.as_string_opt::() { + str_array.value(i).to_string() + } else { + return Err(datafusion_common::error::DataFusionError::Execution( + "object_construct key must be a string".to_string(), + )); + }; + + let value = if value_array.is_null(i) { + Value::Null + } else if let Some(str_array) = value_array.as_string_opt::() { + let istr = str_array.value(i); + if let Ok(json_obj) = serde_json::from_str(istr) { + json_obj + } else { + Value::String(istr.to_string()) + } + } else { + super::json::encode_array(value_array.clone())? + }; + + object.insert(key, value); + } + } + + let json_str = serde_json::to_string(&Value::Object(serde_json::Map::from_iter(object))) + .map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize JSON: {e}", + )) + })?; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(json_str)))) + } +} + +// This is no longer possible to use due to the deprecation of the Wildcard expr +// This will require fixing the SQLparser crate to support the new wildcard syntax inside a function call +// #[derive(Debug)] +// pub struct ObjectConstructAnalyzerRule; + +// impl AnalyzerRule for ObjectConstructAnalyzerRule { +// fn name(&self) -> &'static str { +// "object_construct" +// } + +// fn analyze( +// &self, +// plan: LogicalPlan, +// _config: &ConfigOptions, +// ) -> DFResult { +// plan.transform_up(analyze_object_construct_wildcard) +// .map(|transformed|transformed.data) +// } +// } + +// fn analyze_object_construct_wildcard(plan: LogicalPlan) -> DFResult> { +// let transformed_plan = plan.map_subqueries(|plan| plan.transform_up(analyze_object_construct_wildcard))?; + +// let transformed_plan = transformed_plan.transform_data(|plan| { + +// match &plan { +// LogicalPlan::Projection(projection) => { +// let has_construct_schema = projection.expr.iter().any(|expr| { +// match expr { +// Expr::ScalarFunction (func ) => { +// func.name() == "object_construct" && func.args.iter().any(|arg| matches!(arg, Expr::Wildcard { .. })) +// } +// _ => false, +// } +// }); +// //dbg!(&projection); +// Ok(Transformed::yes(plan.clone())) +// } +// _ => Ok(Transformed::no(plan)), +// } +// })?; + +// Ok(transformed_plan) +// } + +make_udf_function!(ObjectConstructUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_object_construct() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + ctx.register_udf(ScalarUDF::from(ObjectConstructUDF::new())); + + // Test basic object construction + let sql = "SELECT object_construct('a', 1, 'b', 'hello', 'c', 2.5) as obj1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------------------+", + "| obj1 |", + "+------------------------+", + "| {\"a\":1,\"b\":\"hello\",\"c\":2.5} |", + "+------------------------+", + ], + &result + ); + + // Test empty object + let sql = "SELECT object_construct() as obj2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + ["+------+", "| obj2 |", "+------+", "| {} |", "+------+"], + &result + ); + + // Test with null values + let sql = "SELECT object_construct('a', 1, 'b', NULL, 'c', 3) as obj3"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------------+", + "| obj3 |", + "+-------------------+", + "| {\"a\":1,\"b\":null,\"c\":3} |", + "+-------------------+", + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_object_construct_nested() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + ctx.register_udf(ScalarUDF::from(ObjectConstructUDF::new())); + + // Test nested object construction + let sql = "SELECT object_construct('a', object_construct('x', 1, 'y', 2), 'b', array_construct(1, 2, 3)) as obj1"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+--------------------------------+", + "| obj1 |", + "+--------------------------------+", + "| {\"a\":{\"x\":1,\"y\":2},\"b\":[1,2,3]} |", + "+--------------------------------+", + ], + &result + ); + + Ok(()) + } + + // #[tokio::test] + // async fn test_object_construct_wildcard() -> DFResult<()> { + // let mut ctx = SessionContext::new(); + // register_udf(&mut ctx); + + // // Create and populate the test table + // let sql = "CREATE TABLE demo_table_1 (province VARCHAR, created_date DATE);"; + // ctx.sql(sql).await?; + + // let sql = "INSERT INTO demo_table_1 (province, created_date) VALUES + // ('Manitoba', '2024-01-18'::DATE), + // ('Alberta', '2024-01-19'::DATE); + // "; + // ctx.sql(sql).await?; + + // // Test object_construct with wildcard + // let sql = "SELECT object_construct(*) AS oc FROM demo_table_1 ORDER BY oc['PROVINCE']"; + // let result = ctx.sql(sql).await?.collect().await?; + + // assert_batches_eq!( + // [ + // "+---------------------------------+", + // "| oc |", + // "+---------------------------------+", + // "| { |", + // "| \"CREATED_DATE\": \"2024-01-19\", |", + // "| \"PROVINCE\": \"Alberta\" |", + // "| } |", + // "| { |", + // "| \"CREATED_DATE\": \"2024-01-18\", |", + // "| \"PROVINCE\": \"Manitoba\" |", + // "| } |", + // "+---------------------------------+", + // ], + // &result + // ); + + // Ok(()) + // } +} diff --git a/crates/df-builtins/src/variant/object_delete.rs b/crates/df-builtins/src/variant/object_delete.rs new file mode 100644 index 00000000..5e33615b --- /dev/null +++ b/crates/df-builtins/src/variant/object_delete.rs @@ -0,0 +1,235 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ObjectDeleteUDF { + signature: Signature, +} + +impl ObjectDeleteUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::VariadicAny, + volatility: Volatility::Immutable, + }, + } + } + + fn delete_keys(object_value: Value, keys: &[Value]) -> DFResult> { + // Ensure the first argument is an object + if let Value::Object(mut obj) = object_value { + // Remove each key from the object + for key in keys { + if let Value::String(key_str) = key { + obj.remove(key_str); + } + } + + // Convert back to JSON string + Ok(Some(to_string(&obj).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON object".to_string(), + )) + } + } +} + +impl Default for ObjectDeleteUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ObjectDeleteUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "object_delete" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let object_str = + args.first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected object argument".to_string(), + ))?; + + // Collect all keys to delete + let keys: Vec = args[1..] + .iter() + .map(|arg| { + if let ColumnarValue::Scalar(value) = arg { + if value.is_null() { + Ok(Value::Null) + } else { + let key_json = super::json::encode_array(value.to_array_of_size(1)?)?; + if let Value::Array(array) = key_json { + match array.first() { + Some(value) => Ok(value.clone()), + None => Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )), + } + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )) + } + } + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Key arguments must be scalar values".to_string(), + )) + } + }) + .collect::>>()?; + + match object_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let object_str = string_array.value(i); + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + results.push(Self::delete_keys(object_json, &keys)?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(object_value) => { + match object_value { + ScalarValue::Utf8(Some(object_str)) => { + // Parse object string to JSON Value + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + + let result = Self::delete_keys(object_json, &keys)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + } + } + } + } +} + +make_udf_function!(ObjectDeleteUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_object_delete() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDF + ctx.register_udf(ScalarUDF::from(ObjectDeleteUDF::new())); + + // Test removing single key + let sql = "SELECT object_delete('{\"a\": 1, \"b\": 2, \"c\": 3}', 'b') as removed"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| removed |", + "+---------------+", + "| {\"a\":1,\"c\":3} |", + "+---------------+", + ], + &result + ); + + // Test removing multiple keys + let sql = "SELECT object_delete('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}', 'b', 'd') as removed2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| removed2 |", + "+---------------+", + "| {\"a\":1,\"c\":3} |", + "+---------------+", + ], + &result + ); + + // Test removing non-existent key + let sql = "SELECT object_delete('{\"a\": 1, \"b\": 2}', 'c') as no_remove"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| no_remove |", + "+---------------+", + "| {\"a\":1,\"b\":2} |", + "+---------------+", + ], + &result + ); + + // Test with NULL input + let sql = "SELECT object_delete(NULL, 'a') as null_input"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_input |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/object_insert.rs b/crates/df-builtins/src/variant/object_insert.rs new file mode 100644 index 00000000..fbce549e --- /dev/null +++ b/crates/df-builtins/src/variant/object_insert.rs @@ -0,0 +1,333 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ObjectInsertUDF { + signature: Signature, +} + +impl ObjectInsertUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::VariadicAny, + volatility: Volatility::Immutable, + }, + } + } + + fn insert_key_value( + object_value: Value, + key: &Value, + value: &Value, + update_flag: bool, + ) -> DFResult> { + // Ensure the first argument is an object + if let Value::Object(mut obj) = object_value { + // Get the key string + let Value::String(key_str) = key else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Key must be a string".to_string(), + )); + }; + + // Check if key exists and handle according to update_flag + if obj.contains_key(key_str) && !update_flag { + return Err(datafusion_common::error::DataFusionError::Internal( + format!("Key '{key_str}' already exists and update_flag is false",), + )); + } + + // Insert or update the key-value pair + obj.insert(key_str.clone(), value.clone()); + + // Convert back to JSON string + Ok(Some(to_string(&obj).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON object".to_string(), + )) + } + } +} + +impl Default for ObjectInsertUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ObjectInsertUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "object_insert" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + #[allow(clippy::too_many_lines)] + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let object_str = + args.first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected object argument".to_string(), + ))?; + + // Get key argument + let key_arg = args + .get(1) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected key argument".to_string(), + ))?; + + // Get value argument + let value_arg = args + .get(2) + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected value argument".to_string(), + ))?; + + // Get update flag (optional) + let update_flag = args.get(3).map_or(Ok(false), |arg| { + if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = arg { + Ok(*b) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "Update flag must be a boolean".to_string(), + )) + } + })?; + + // Convert key and value to JSON Values + let key_json = match key_arg { + ColumnarValue::Scalar(value) => { + if value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + let key_json = super::json::encode_array(value.to_array_of_size(1)?)?; + if let Value::Array(array) = key_json { + match array.first() { + Some(value) => value.clone(), + None => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )) + } + } + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )); + } + } + ColumnarValue::Array(_) => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Key argument must be a scalar value".to_string(), + )) + } + }; + + let value_json = match value_arg { + ColumnarValue::Scalar(value) => { + if value.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + let value_json = super::json::encode_array(value.to_array_of_size(1)?)?; + if let Value::Array(array) = value_json { + match array.first() { + Some(value) => value.clone(), + None => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )) + } + } + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected array for scalar value".to_string(), + )); + } + } + ColumnarValue::Array(_) => { + return Err(datafusion_common::error::DataFusionError::Internal( + "Value argument must be a scalar value".to_string(), + )) + } + }; + + match object_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let object_str = string_array.value(i); + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + results.push(Self::insert_key_value( + object_json, + &key_json, + &value_json, + update_flag, + )?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(object_value) => { + match object_value { + ScalarValue::Utf8(Some(object_str)) => { + // Parse object string to JSON Value + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + + let result = Self::insert_key_value( + object_json, + &key_json, + &value_json, + update_flag, + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + } + } + } + } +} + +make_udf_function!(ObjectInsertUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_object_insert() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDF + ctx.register_udf(ScalarUDF::from(ObjectInsertUDF::new())); + + // Test inserting new key-value pair + let sql = "SELECT object_insert('{\"a\": 1, \"b\": 2}', 'c', 3) as inserted"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------------+", + "| inserted |", + "+---------------------+", + "| {\"a\":1,\"b\":2,\"c\":3} |", + "+---------------------+", + ], + &result + ); + + // Test updating existing key with update_flag=true + let sql = "SELECT object_insert('{\"a\": 1, \"b\": 2}', 'b', 3, true) as updated"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| updated |", + "+---------------+", + "| {\"a\":1,\"b\":3} |", + "+---------------+", + ], + &result + ); + + // Test error when updating existing key without update_flag + let sql = "SELECT object_insert('{\"a\": 1, \"b\": 2}', 'b', 3) as error"; + let result = ctx.sql(sql).await?.collect().await; + assert!(result.is_err()); + + // Test with NULL input + let sql = "SELECT object_insert(NULL, 'a', 1) as null_input"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_input |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + // Test with NULL key + let sql = "SELECT object_insert('{\"a\": 1}', NULL, 2) as null_key"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| null_key |", + "+----------+", + "| |", + "+----------+", + ], + &result + ); + + // Test with NULL value + let sql = "SELECT object_insert('{\"a\": 1}', 'b', NULL) as null_value"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_value |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/object_pick.rs b/crates/df-builtins/src/variant/object_pick.rs new file mode 100644 index 00000000..9bbf04d7 --- /dev/null +++ b/crates/df-builtins/src/variant/object_pick.rs @@ -0,0 +1,234 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::array::Array; +use datafusion_common::{Result as DFResult, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use serde_json::{from_str, to_string, Value}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct ObjectPickUDF { + signature: Signature, +} + +impl ObjectPickUDF { + #[must_use] + pub const fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::VariadicAny, + volatility: Volatility::Immutable, + }, + } + } + + fn pick_keys(object_value: Value, keys: Vec) -> DFResult> { + // Ensure the first argument is an object + if let Value::Object(obj) = object_value { + let mut new_obj = serde_json::Map::new(); + + // Only include specified keys that exist in the original object + for key in keys { + if let Some(value) = obj.get(&key) { + new_obj.insert(key, value.clone()); + } + } + + // Convert back to JSON string + Ok(Some(to_string(&Value::Object(new_obj)).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to serialize result: {e}", + )) + })?)) + } else { + Err(datafusion_common::error::DataFusionError::Internal( + "First argument must be a JSON object".to_string(), + )) + } + } +} + +impl Default for ObjectPickUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ObjectPickUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "object_pick" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { args, .. } = args; + let object_str = + args.first() + .ok_or(datafusion_common::error::DataFusionError::Internal( + "Expected object argument".to_string(), + ))?; + + // Get all keys from remaining arguments + let mut keys = Vec::new(); + + // Check if second argument is an array + if args.len() == 2 { + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(array_str))) = &args[1] { + // Try to parse as JSON array first + if let Ok(Value::Array(json_array)) = from_str::(array_str) { + for value in json_array { + if let Value::String(key) = value { + keys.push(key); + } + } + } else { + // If not a JSON array, treat as a single key + keys.push(array_str.clone()); + } + } + } else { + // Handle individual key arguments + for arg in args.iter().skip(1) { + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) = arg { + keys.push(key.clone()); + } + } + } + + match object_str { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let mut results = Vec::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + results.push(None); + } else { + let object_str = string_array.value(i); + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + results.push(Self::pick_keys(object_json, keys.clone())?); + } + } + + Ok(ColumnarValue::Array(Arc::new( + datafusion::arrow::array::StringArray::from(results), + ))) + } + ColumnarValue::Scalar(object_value) => { + match object_value { + ScalarValue::Utf8(Some(object_str)) => { + // Parse object string to JSON Value + let object_json: Value = from_str(object_str).map_err(|e| { + datafusion_common::error::DataFusionError::Internal(format!( + "Failed to parse object JSON: {e}" + )) + })?; + + let result = Self::pick_keys(object_json, keys)?; + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + } + } + } + } +} + +make_udf_function!(ObjectPickUDF); + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_object_pick() -> DFResult<()> { + let mut ctx = SessionContext::new(); + + // Register UDF + ctx.register_udf(ScalarUDF::from(ObjectPickUDF::new())); + + // Test picking specific keys + let sql = "SELECT object_pick('{\"a\": 1, \"b\": 2, \"c\": 3}', 'a', 'b') as picked"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| picked |", + "+---------------+", + "| {\"a\":1,\"b\":2} |", + "+---------------+", + ], + &result + ); + + // Test picking with array argument + let sql = "SELECT object_pick('{\"a\": 1, \"b\": 2, \"c\": 3}', array_construct('a', 'b')) as picked2"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| picked2 |", + "+---------------+", + "| {\"a\":1,\"b\":2} |", + "+---------------+", + ], + &result + ); + + // Test with non-existent keys + let sql = "SELECT object_pick('{\"a\": 1, \"b\": 2}', 'c') as picked3"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| picked3 |", + "+---------+", + "| {} |", + "+---------+", + ], + &result + ); + + // Test with NULL input + let sql = "SELECT object_pick(NULL, 'a') as null_input"; + let result = ctx.sql(sql).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| null_input |", + "+------------+", + "| |", + "+------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/variant/variant_element.rs b/crates/df-builtins/src/variant/variant_element.rs new file mode 100644 index 00000000..33573fe6 --- /dev/null +++ b/crates/df-builtins/src/variant/variant_element.rs @@ -0,0 +1,377 @@ +use super::super::macros::make_udf_function; +use datafusion::arrow::array::Array; +use datafusion::arrow::array::builder::StringBuilder; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; +use datafusion_common::{ + Result as DFResult, ScalarValue, + types::{ + NativeType, logical_binary, logical_boolean, logical_int8, logical_int16, logical_int32, + logical_int64, logical_string, logical_uint8, logical_uint16, logical_uint32, + }, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use serde_json::Value; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct VariantArrayElementUDF { + signature: Signature, + aliases: Vec, +} + +impl VariantArrayElementUDF { + #[must_use] + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_boolean())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::String, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![ + TypeSignatureClass::Native(logical_binary()), + TypeSignatureClass::Native(logical_int8()), + TypeSignatureClass::Native(logical_int16()), + TypeSignatureClass::Native(logical_int32()), + TypeSignatureClass::Native(logical_int64()), + TypeSignatureClass::Native(logical_uint8()), + TypeSignatureClass::Native(logical_uint16()), + TypeSignatureClass::Native(logical_uint32()), + ], + NativeType::String, + ), + ]), + ]), + volatility: Volatility::Immutable, + }, + aliases: vec!["array_element".to_string()], + } + } +} + +impl Default for VariantArrayElementUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for VariantArrayElementUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "variant_element" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + #[allow(clippy::too_many_lines, clippy::unwrap_used)] + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let ScalarFunctionArgs { mut args, .. } = args; + let (flatten, index, array_str) = if args.len() == 3 { + (args.pop(), args.pop().unwrap(), args.pop().unwrap()) + } else if args.len() == 2 { + (None, args.pop().unwrap(), args.pop().unwrap()) + } else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Invalid number of arguments".to_string(), + )); + }; + match (array_str, index) { + (ColumnarValue::Array(array), ColumnarValue::Scalar(index_value)) => { + let ScalarValue::Utf8(Some(index)) = index_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected JSONPath value for index".to_string(), + )); + }; + + let flatten = + if let Some(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b)))) = flatten { + b + } else { + false + }; + + let string_array = array.as_string::(); + let mut builder = StringBuilder::new(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + builder.append_null(); + } else { + let value: Option> = + jsonpath_lib::select_as(string_array.value(i), &index).ok(); + match value { + Some(s) => { + if s.is_empty() { + builder.append_null(); + } else if flatten { + if s.len() == 1 { + builder.append_value(s[0].to_string()); + } else { + builder.append_value(Value::Array(s).to_string()); + } + } else { + builder.append_value(Value::Array(s).to_string()); + } + } + None => builder.append_null(), + } + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + (ColumnarValue::Scalar(array_value), ColumnarValue::Scalar(index_value)) => { + let ScalarValue::Utf8(Some(index)) = index_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected JSONPath value for index".to_string(), + )); + }; + + let ScalarValue::Utf8(Some(array_str)) = array_value else { + return Err(datafusion_common::error::DataFusionError::Internal( + "Expected string array".to_string(), + )); + }; + + let flatten = + if let Some(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b)))) = flatten { + b + } else { + false + }; + let value: Option> = jsonpath_lib::select_as(&array_str, &index).ok(); + match value { + Some(s) => { + if s.is_empty() { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } else if flatten { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + s[0].to_string(), + )))) + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + Value::Array(s).to_string(), + )))) + } + } + None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + } + } + _ => Err(datafusion_common::error::DataFusionError::Internal( + "Invalid argument types".to_string(), + )), + } + } +} + +make_udf_function!(VariantArrayElementUDF); + +#[cfg(test)] +mod tests { + use super::super::array_construct::ArrayConstructUDF; + use super::*; + use crate::visitors::variant::variant_element; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion::sql::parser::Statement; + use datafusion_expr::ScalarUDF; + #[tokio::test] + async fn test_array_indexing() -> DFResult<()> { + let ctx = SessionContext::new(); + + // Register both UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(VariantArrayElementUDF::new())); + + // Create a table with ID and arrvar columns + let sql = "CREATE TABLE test_table (id INT, arrvar STRING)"; + ctx.sql(sql).await?.collect().await?; + + // Insert some test data + let sql = "INSERT INTO test_table VALUES (1, array_construct(1, 2, 3)), (2, array_construct('a', 'b', 'c'))"; + ctx.sql(sql).await?.collect().await?; + + // Test basic array indexing + let sql = "SELECT arrvar[0] as first, \ + arrvar[1] as second, \ + arrvar[2] as third \ + FROM test_table WHERE id = 1"; + + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------+--------+-------+", + "| first | second | third |", + "+-------+--------+-------+", + "| 1 | 2 | 3 |", + "+-------+--------+-------+" + ], + &result + ); + + // Test out of bounds indexing + let sql = "SELECT arrvar[5] as out_of_bounds FROM test_table WHERE id = 1"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| out_of_bounds |", + "+---------------+", + "| |", + "+---------------+", + ], + &result + ); + + // Test mixed type array + let sql = "SELECT arrvar[1] as str_element FROM test_table WHERE id = 2"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| str_element |", + "+-------------+", + "| \"b\" |", + "+-------------+" + ], + &result + ); + + // Test empty array + let sql = "SELECT array_construct()[0] as empty_array"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| empty_array |", + "+-------------+", + "| |", + "+-------------+" + ], + &result + ); + + Ok(()) + } + + #[tokio::test] + async fn test_variant_object_path() -> DFResult<()> { + let ctx = SessionContext::new(); + + // Register UDFs + ctx.register_udf(ScalarUDF::from(VariantArrayElementUDF::new())); + + // Create a table with JSON data + let sql = "CREATE TABLE json_table (id INT, json_col STRING)"; + ctx.sql(sql).await?.collect().await?; + + // Insert test JSON data + let sql = "INSERT INTO json_table VALUES + (1, '{\"a\": {\"b\": [1,2,3]}}'), + (2, '{\"a\": {\"b\": [\"x\",\"y\",\"z\"]}}')"; + ctx.sql(sql).await?.collect().await?; + + // Test JSON path access + let sql = "SELECT json_col:a.b[0] as first_elem FROM json_table"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+------------+", + "| first_elem |", + "+------------+", + "| 1 |", + "| \"x\" |", + "+------------+" + ], + &result + ); + + // Test nested JSON path access with array flattening + let sql = "SELECT json_col:a.b as array_elem FROM json_table"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + if let Statement::Statement(ref mut stmt) = statement { + variant_element::visit(stmt); + } + let plan = ctx.state().statement_to_plan(statement).await?; + let result = ctx.execute_logical_plan(plan).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| array_elem |", + "+---------------+", + "| [1,2,3] |", + "| [\"x\",\"y\",\"z\"] |", + "+---------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/visitors/mod.rs b/crates/df-builtins/src/visitors/mod.rs new file mode 100644 index 00000000..cb1eb8cd --- /dev/null +++ b/crates/df-builtins/src/visitors/mod.rs @@ -0,0 +1 @@ +pub mod variant; \ No newline at end of file diff --git a/crates/df-builtins/src/visitors/variant/array_agg.rs b/crates/df-builtins/src/visitors/variant/array_agg.rs new file mode 100644 index 00000000..8e50a004 --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/array_agg.rs @@ -0,0 +1,122 @@ +use datafusion::logical_expr::sqlparser::tokenizer::{Location, Span}; +use datafusion_expr::sqlparser::ast::VisitMut; +use datafusion_expr::sqlparser::ast::{ + Expr, Function, FunctionArg, FunctionArgumentList, FunctionArguments, Ident, JsonPath, + JsonPathElem, ObjectName, ObjectNamePart, Statement, Value, ValueWithSpan, VisitorMut, +}; + +#[derive(Debug)] +pub struct VariantArrayAggRewriter; + +impl VisitorMut for VariantArrayAggRewriter { + type Break = bool; + + fn post_visit_expr( + &mut self, + expr: &mut datafusion_expr::sqlparser::ast::Expr, + ) -> std::ops::ControlFlow { + if let datafusion_expr::sqlparser::ast::Expr::Function(Function { name, .. }) = expr { + if let Some(part) = name.0.last() { + if part.as_ident().is_some_and(|i| i.value == "array_agg") { + let wrapped_function = Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( + "array_construct".to_string(), + ))]), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + args: FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed( + datafusion_expr::sqlparser::ast::FunctionArgExpr::Expr( + expr.clone(), + ), + )], + duplicate_treatment: None, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: Vec::default(), + }; + let fn_call_expr = + datafusion_expr::sqlparser::ast::Expr::Function(wrapped_function); + + let index_expr = Expr::JsonAccess { + value: Box::new(fn_call_expr), + path: JsonPath { + path: vec![JsonPathElem::Bracket { + key: Expr::Value(ValueWithSpan { + value: Value::Number("0".to_string(), false), + span: Span::new(Location::new(0, 0), Location::new(0, 0)), + }), + }], + }, + }; + + *expr = index_expr; + } + } + } + std::ops::ControlFlow::Continue(()) + } +} + +pub fn visit(stmt: &mut Statement) { + stmt.visit(&mut VariantArrayAggRewriter {}); +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion::sql::parser::Statement as DFStatement; + use datafusion_common::Result as DFResult; + use datafusion_expr::ScalarUDF; + use crate::variant::array_compact::ArrayCompactUDF; + use crate::variant::array_construct::ArrayConstructUDF; + use crate::visitors::variant::variant_element; + + #[tokio::test] + async fn test_array_agg_rewrite() -> DFResult<()> { + let mut ctx = SessionContext::new(); + // Register array_construct UDF + + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayCompactUDF::new())); + + // Create table and insert data + let create_sql = "CREATE TABLE test_table (id INT, val INT)"; + let mut create_stmt = ctx.state().sql_to_statement(create_sql, "snowflake")?; + if let DFStatement::Statement(ref mut stmt) = create_stmt { + visit(stmt); + } + ctx.sql(&create_stmt.to_string()).await?.collect().await?; + + let insert_sql = "INSERT INTO test_table VALUES (1, 1), (1, 2), (1, 3)"; + ctx.sql(insert_sql).await?.collect().await?; + + // Test array_agg rewrite by validating JSON output + let sql = "SELECT array_agg(val) as json_arr FROM test_table GROUP BY id"; + let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; + if let DFStatement::Statement(ref mut s) = stmt { + visit(s); + variant_element::visit(s); + } + let result = ctx.sql(&stmt.to_string()).await?.collect().await?; + + assert_batches_eq!( + [ + "+----------+", + "| json_arr |", + "+----------+", + "| [1,2,3] |", + "+----------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/visitors/variant/array_construct.rs b/crates/df-builtins/src/visitors/variant/array_construct.rs new file mode 100644 index 00000000..fc5722cb --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/array_construct.rs @@ -0,0 +1,131 @@ +use datafusion_expr::sqlparser::ast::VisitMut; +use datafusion_expr::sqlparser::ast::{ + Expr as ASTExpr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, + FunctionArguments, Ident, ObjectName, ObjectNamePart, Statement, VisitorMut, +}; + +#[derive(Debug, Default)] +pub struct ArrayConstructVisitor; + +impl ArrayConstructVisitor { + #[must_use] + pub const fn new() -> Self { + Self + } +} + +impl VisitorMut for ArrayConstructVisitor { + fn post_visit_expr(&mut self, expr: &mut ASTExpr) -> std::ops::ControlFlow { + if let ASTExpr::Array(elements) = expr { + let args = elements + .elem + .iter() + .map(|e| FunctionArg::Unnamed(FunctionArgExpr::Expr(e.clone()))) + .collect(); + + let new_expr = ASTExpr::Function(Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( + "array_construct", + ))]), + args: FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment: None, + clauses: vec![], + }), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }); + *expr = new_expr; + } + std::ops::ControlFlow::Continue(()) + } + type Break = bool; +} + +pub fn visit(stmt: &mut Statement) { + stmt.visit(&mut ArrayConstructVisitor::new()); +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion::sql::parser::Statement as DFStatement; + use datafusion_common::Result as DFResult; + use datafusion_expr::ScalarUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_construct_rewrite() -> DFResult<()> { + let mut ctx = SessionContext::new(); + // Register array_construct UDF + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + + + // Test simple array construction + let sql = "SELECT [1, 2, 3] as arr"; + let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; + if let DFStatement::Statement(ref mut s) = stmt { + visit(s); + } + let result = ctx.sql(&stmt.to_string()).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------+", + "| arr |", + "+---------+", + "| [1,2,3] |", + "+---------+", + ], + &result + ); + + // Test array with mixed types + let sql = "SELECT [1, 'test', null] as mixed_arr"; + let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; + if let DFStatement::Statement(ref mut s) = stmt { + visit(s); + } + let result = ctx.sql(&stmt.to_string()).await?.collect().await?; + + assert_batches_eq!( + [ + "+-----------------+", + "| mixed_arr |", + "+-----------------+", + "| [1,\"test\",null] |", + "+-----------------+", + ], + &result + ); + + // Test nested arrays + let sql = "SELECT [[1, 2], [3, 4]] as nested_arr"; + let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; + if let DFStatement::Statement(ref mut s) = stmt { + visit(s); + } + + let result = ctx.sql(&stmt.to_string()).await?.collect().await?; + + assert_batches_eq!( + [ + "+---------------+", + "| nested_arr |", + "+---------------+", + "| [[1,2],[3,4]] |", + "+---------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/visitors/variant/array_construct_compact.rs b/crates/df-builtins/src/visitors/variant/array_construct_compact.rs new file mode 100644 index 00000000..89cb72a7 --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/array_construct_compact.rs @@ -0,0 +1,111 @@ +use datafusion_expr::sqlparser::ast::VisitMut; +use datafusion_expr::sqlparser::ast::{ + Expr, Function, FunctionArg, FunctionArgumentList, FunctionArguments, Ident, ObjectName, + ObjectNamePart, Statement, VisitorMut, +}; + +#[derive(Debug)] +pub struct ArrayConstructCompactRewriter; + +impl VisitorMut for ArrayConstructCompactRewriter { + type Break = bool; + + fn post_visit_expr( + &mut self, + expr: &mut datafusion_expr::sqlparser::ast::Expr, + ) -> std::ops::ControlFlow { + if let datafusion_expr::sqlparser::ast::Expr::Function(Function { name, args, .. }) = expr { + if let Some(part) = name.0.last() { + if part + .as_ident() + .is_some_and(|i| i.value == "array_construct_compact") + { + // Create the inner array_construct function + let array_construct = Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( + "array_construct".to_string(), + ))]), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + args: args.clone(), + filter: None, + null_treatment: None, + over: None, + within_group: Vec::default(), + }; + + // Create the outer array_compact function that wraps array_construct + let array_compact = Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( + "array_compact".to_string(), + ))]), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + args: FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed( + datafusion_expr::sqlparser::ast::FunctionArgExpr::Expr( + Expr::Function(array_construct), + ), + )], + duplicate_treatment: None, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: Vec::default(), + }; + + *expr = Expr::Function(array_compact); + } + } + } + std::ops::ControlFlow::Continue(()) + } +} + +pub fn visit(stmt: &mut Statement) { + stmt.visit(&mut ArrayConstructCompactRewriter {}); +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use datafusion::assert_batches_eq; + use datafusion::prelude::SessionContext; + use datafusion::sql::parser::Statement as DFStatement; + use datafusion_common::Result as DFResult; + use datafusion_expr::ScalarUDF; + use crate::variant::array_compact::ArrayCompactUDF; + use crate::variant::array_construct::ArrayConstructUDF; + + #[tokio::test] + async fn test_array_construct_compact_rewrite() -> DFResult<()> { + let mut ctx = SessionContext::new(); + // Register array_construct and array_compact UDFs + ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); + ctx.register_udf(ScalarUDF::from(ArrayCompactUDF::new())); + + // Test array_construct_compact rewrite + let sql = "SELECT array_construct_compact(1, NULL, 2, NULL, 3) as compact_arr"; + let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; + if let DFStatement::Statement(ref mut s) = stmt { + visit(s); + } + let result = ctx.sql(&stmt.to_string()).await?.collect().await?; + + assert_batches_eq!( + [ + "+-------------+", + "| compact_arr |", + "+-------------+", + "| [1,2,3] |", + "+-------------+", + ], + &result + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/visitors/variant/mod.rs b/crates/df-builtins/src/visitors/variant/mod.rs new file mode 100644 index 00000000..7454f36d --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/mod.rs @@ -0,0 +1,15 @@ +pub mod array_agg; +pub mod array_construct; +pub mod array_construct_compact; +pub mod type_rewrite; +pub mod variant_element; + +use datafusion_expr::sqlparser::ast::Statement; + +pub fn visit_all(stmt: &mut Statement) { + array_agg::visit(stmt); + array_construct_compact::visit(stmt); + array_construct::visit(stmt); + type_rewrite::visit(stmt); + variant_element::visit(stmt); +} diff --git a/crates/df-builtins/src/visitors/variant/type_rewrite.rs b/crates/df-builtins/src/visitors/variant/type_rewrite.rs new file mode 100644 index 00000000..6f41055d --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/type_rewrite.rs @@ -0,0 +1,108 @@ +use datafusion_expr::sqlparser::ast::VisitMut; +use datafusion_expr::sqlparser::ast::{DataType, Statement, VisitorMut}; +use std::ops::ControlFlow; + +#[derive(Debug, Default)] +pub struct ArrayObjectToBinaryVisitor; + +impl ArrayObjectToBinaryVisitor { + #[must_use] + pub const fn new() -> Self { + Self + } +} + +impl VisitorMut for ArrayObjectToBinaryVisitor { + type Break = (); + + fn post_visit_statement(&mut self, stmt: &mut Statement) -> ControlFlow { + if let Statement::CreateTable(create_table) = stmt { + for column in &mut create_table.columns { + if let DataType::Array(_) = &column.data_type { + column.data_type = DataType::String(None); + } + if let DataType::Custom(_, _) = &column.data_type { + if column.data_type.to_string() == "OBJECT" { + column.data_type = DataType::String(None); + } + if column.data_type.to_string() == "VARIANT" { + column.data_type = DataType::String(None); + } + } + } + } + ControlFlow::Continue(()) + } +} + +pub fn visit(stmt: &mut Statement) { + stmt.visit(&mut ArrayObjectToBinaryVisitor::new()); +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::prelude::SessionContext; + use datafusion::sql::parser::Statement as DFStatement; + use datafusion_common::Result as DFResult; + + #[tokio::test] + async fn test_array_to_binary_rewrite() -> DFResult<()> { + let ctx = SessionContext::new(); + + // Test table creation with Array type + let sql = "CREATE TABLE test_table (id INT, arr ARRAY)"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + + if let DFStatement::Statement(ref mut stmt) = statement { + visit(stmt); + } + + assert_eq!( + statement.to_string(), + "CREATE TABLE test_table (id INT, arr STRING)" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_object_to_binary_rewrite() -> DFResult<()> { + let ctx = SessionContext::new(); + + // Test table creation with Array type + let sql = "CREATE TABLE test_table (id INT, obj OBJECT)"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + + if let DFStatement::Statement(ref mut stmt) = statement { + visit(stmt); + } + + assert_eq!( + statement.to_string(), + "CREATE TABLE test_table (id INT, obj STRING)" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_variant_to_binary_rewrite() -> DFResult<()> { + let ctx = SessionContext::new(); + + // Test table creation with Array type + let sql = "CREATE TABLE test_table (id INT, variant VARIANT)"; + let mut statement = ctx.state().sql_to_statement(sql, "snowflake")?; + + if let DFStatement::Statement(ref mut stmt) = statement { + visit(stmt); + } + + assert_eq!( + statement.to_string(), + "CREATE TABLE test_table (id INT, variant STRING)" + ); + + Ok(()) + } +} diff --git a/crates/df-builtins/src/visitors/variant/variant_element.rs b/crates/df-builtins/src/visitors/variant/variant_element.rs new file mode 100644 index 00000000..f7e8d2dc --- /dev/null +++ b/crates/df-builtins/src/visitors/variant/variant_element.rs @@ -0,0 +1,69 @@ +use datafusion_expr::sqlparser::ast::VisitMut; +use datafusion_expr::sqlparser::ast::{ + Expr as ASTExpr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, + FunctionArguments, Ident, ObjectName, ObjectNamePart, Statement, Value, ValueWithSpan, + VisitorMut, +}; +use datafusion_expr::sqlparser::tokenizer::Location; +use datafusion_expr::sqlparser::tokenizer::Span; + +#[derive(Debug, Default)] +pub struct VariantElementVisitor {} + +impl VariantElementVisitor { + #[must_use] + pub const fn new() -> Self { + Self {} + } +} + +impl VisitorMut for VariantElementVisitor { + fn post_visit_expr(&mut self, expr: &mut ASTExpr) -> std::ops::ControlFlow { + if let ASTExpr::JsonAccess { value, path } = expr { + let mut path = path.to_string(); + if path.starts_with(':') { + path = format!(".{}", path.split_at(1).1); + } + let path = format!("${path}"); + let new_expr = ASTExpr::Function(Function { + name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new( + "variant_element".to_string(), + ))]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(*value.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Expr(ASTExpr::Value( + ValueWithSpan { + value: Value::SingleQuotedString(path), + span: Span::new(Location::new(0, 0), Location::new(0, 0)), + }, + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(ASTExpr::Value( + ValueWithSpan { + value: Value::Boolean(true), + span: Span::new(Location::new(0, 0), Location::new(0, 0)), + }, + ))), + ], + duplicate_treatment: None, + clauses: vec![], + }), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }); + *expr = new_expr; + } + std::ops::ControlFlow::Continue(()) + } + type Break = bool; +} + +pub fn visit(stmt: &mut Statement) { + stmt.visit(&mut VariantElementVisitor::new()); +} + +// For unit tests see udfs/variant_element.rs From 96d50e7d27b60a75f04c7b1c93e2f42aaf315b8d Mon Sep 17 00:00:00 2001 From: Maxim Bogdanov Date: Tue, 20 May 2025 16:24:12 +0200 Subject: [PATCH 2/2] visitors as a separate crate --- Cargo.toml | 1 + crates/core-executor/Cargo.toml | 2 +- crates/core-executor/src/datafusion/mod.rs | 1 - crates/core-executor/src/query.rs | 2 +- crates/core-executor/src/session.rs | 2 +- crates/df-builtins/Cargo.toml | 1 + crates/df-builtins/src/aggregate/object_agg.rs | 14 +++++++------- crates/df-builtins/src/variant/array_cat.rs | 9 ++++----- crates/df-builtins/src/variant/array_compact.rs | 8 ++++---- crates/df-builtins/src/variant/array_construct.rs | 2 +- crates/df-builtins/src/variant/array_contains.rs | 6 +++--- crates/df-builtins/src/variant/array_distinct.rs | 10 +++++----- crates/df-builtins/src/variant/array_except.rs | 6 +++--- crates/df-builtins/src/variant/array_flatten.rs | 2 +- crates/df-builtins/src/variant/array_insert.rs | 8 ++++---- .../df-builtins/src/variant/array_intersection.rs | 6 +++--- crates/df-builtins/src/variant/array_max.rs | 10 +++++----- crates/df-builtins/src/variant/array_min.rs | 10 +++++----- crates/df-builtins/src/variant/array_position.rs | 8 ++++---- crates/df-builtins/src/variant/array_prepend.rs | 8 ++++---- crates/df-builtins/src/variant/array_remove.rs | 8 ++++---- crates/df-builtins/src/variant/array_remove_at.rs | 8 ++++---- crates/df-builtins/src/variant/array_reverse.rs | 8 ++++---- crates/df-builtins/src/variant/array_size.rs | 8 ++++---- crates/df-builtins/src/variant/array_slice.rs | 8 ++++---- crates/df-builtins/src/variant/array_sort.rs | 14 +++++--------- crates/df-builtins/src/variant/array_to_string.rs | 2 +- crates/df-builtins/src/variant/arrays_overlap.rs | 8 ++++---- .../df-builtins/src/variant/arrays_to_object.rs | 8 ++++---- crates/df-builtins/src/variant/arrays_zip.rs | 6 +++--- crates/df-builtins/src/variant/json.rs | 12 +++++++----- crates/df-builtins/src/variant/mod.rs | 10 +++++----- crates/df-builtins/src/variant/object_delete.rs | 6 +++--- crates/df-builtins/src/variant/object_insert.rs | 14 +++++++------- crates/df-builtins/src/variant/object_pick.rs | 6 +++--- crates/df-builtins/src/variant/variant_element.rs | 2 +- crates/df-builtins/src/visitors/mod.rs | 1 - crates/visitors/Cargo.toml | 15 +++++++++++++++ .../src}/copy_into_identifiers.rs | 0 .../src}/functions_rewriter.rs | 0 .../visitors => visitors/src}/json_element.rs | 0 .../visitors/mod.rs => visitors/src/lib.rs} | 3 +-- .../src}/variant/array_agg.rs | 8 ++++---- .../src}/variant/array_construct.rs | 3 +-- .../src}/variant/array_construct_compact.rs | 4 ++-- .../src/visitors => visitors/src}/variant/mod.rs | 0 .../src}/variant/type_rewrite.rs | 0 .../src}/variant/variant_element.rs | 0 48 files changed, 144 insertions(+), 134 deletions(-) delete mode 100644 crates/df-builtins/src/visitors/mod.rs create mode 100644 crates/visitors/Cargo.toml rename crates/{core-executor/src/datafusion/visitors => visitors/src}/copy_into_identifiers.rs (100%) rename crates/{core-executor/src/datafusion/visitors => visitors/src}/functions_rewriter.rs (100%) rename crates/{core-executor/src/datafusion/visitors => visitors/src}/json_element.rs (100%) rename crates/{core-executor/src/datafusion/visitors/mod.rs => visitors/src/lib.rs} (68%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/array_agg.rs (99%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/array_construct.rs (99%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/array_construct_compact.rs (100%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/mod.rs (100%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/type_rewrite.rs (100%) rename crates/{df-builtins/src/visitors => visitors/src}/variant/variant_element.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index aba9ac39..eb34f87a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "crates/core-metastore", "crates/core-utils", "crates/api-sessions", + "crates/visitors", ] resolver = "2" package.license-file = "LICENSE" diff --git a/crates/core-executor/Cargo.toml b/crates/core-executor/Cargo.toml index b723aec5..c8844c5e 100644 --- a/crates/core-executor/Cargo.toml +++ b/crates/core-executor/Cargo.toml @@ -9,7 +9,7 @@ core-utils = { path = "../core-utils" } core-metastore = { path = "../core-metastore" } df-builtins = { path = "../df-builtins" } df-catalog = { path = "../df-catalog" } - +visitors = { path = "../visitors" } async-trait = { workspace = true } aws-config = { workspace = true } aws-credential-types = { workspace = true } diff --git a/crates/core-executor/src/datafusion/mod.rs b/crates/core-executor/src/datafusion/mod.rs index 95513f5b..e4482f19 100644 --- a/crates/core-executor/src/datafusion/mod.rs +++ b/crates/core-executor/src/datafusion/mod.rs @@ -5,6 +5,5 @@ pub mod error; pub mod physical_optimizer; pub mod planner; pub mod type_planner; -pub mod visitors; pub use df_builtins as functions; diff --git a/crates/core-executor/src/query.rs b/crates/core-executor/src/query.rs index cc1c2a52..6dbdeb08 100644 --- a/crates/core-executor/src/query.rs +++ b/crates/core-executor/src/query.rs @@ -53,12 +53,12 @@ use super::datafusion::planner::ExtendedSqlToRel; use super::error::{self as ex_error, ExecutionError, ExecutionResult, RefreshCatalogListSnafu}; use super::session::UserSession; use super::utils::{NormalizedIdent, is_logical_plan_effectively_empty}; -use crate::datafusion::visitors::{copy_into_identifiers, functions_rewriter, json_element}; use df_catalog::catalog::CachingCatalog; use df_catalog::catalogs::slatedb::schema::{ SLATEDB_CATALOG, SLATEDB_SCHEMA, SlateDBViewSchemaProvider, }; use tracing_attributes::instrument; +use visitors::{copy_into_identifiers, functions_rewriter, json_element}; #[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct QueryContext { diff --git a/crates/core-executor/src/session.rs b/crates/core-executor/src/session.rs index 29e9bdda..809bed60 100644 --- a/crates/core-executor/src/session.rs +++ b/crates/core-executor/src/session.rs @@ -29,6 +29,7 @@ use df_catalog::catalog_list::{DEFAULT_CATALOG, EmbucketCatalogList}; // TODO: We need to fix this after geodatafusion is updated to datafusion 47 //use geodatafusion::udf::native::register_native as register_geo_native; use crate::datafusion::physical_optimizer::physical_optimizer_rules; +use df_builtins::table::register_udtfs; use iceberg_rust::object_store::ObjectStoreBuilder; use iceberg_s3tables_catalog::S3TablesCatalog; use snafu::ResultExt; @@ -36,7 +37,6 @@ use std::any::Any; use std::collections::HashMap; use std::env; use std::sync::Arc; -use df_builtins::table::register_udtfs; pub struct UserSession { pub metastore: Arc, diff --git a/crates/df-builtins/Cargo.toml b/crates/df-builtins/Cargo.toml index c4dc7b65..9de8f669 100644 --- a/crates/df-builtins/Cargo.toml +++ b/crates/df-builtins/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" license-file.workspace = true [dependencies] +visitors = { path = "../visitors" } chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } diff --git a/crates/df-builtins/src/aggregate/object_agg.rs b/crates/df-builtins/src/aggregate/object_agg.rs index 2458c163..82d95dcf 100644 --- a/crates/df-builtins/src/aggregate/object_agg.rs +++ b/crates/df-builtins/src/aggregate/object_agg.rs @@ -4,18 +4,18 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -use datafusion::arrow::array::as_list_array; -use datafusion::arrow::datatypes::{DataType, Field, Fields}; use datafusion::arrow::array::StringArray; -use datafusion::arrow::array::{new_empty_array, Array}; +use datafusion::arrow::array::as_list_array; +use datafusion::arrow::array::{Array, new_empty_array}; use datafusion::arrow::array::{ArrayRef, StructArray}; +use datafusion::arrow::datatypes::{DataType, Field, Fields}; use datafusion::common::ScalarValue; use datafusion_common::utils::SingleRowListArrayBuilder; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, exec_err, internal_err}; +use datafusion_expr::Volatility; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::Volatility; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature}; #[derive(Debug, Clone)] @@ -252,9 +252,9 @@ mod tests { use super::*; use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::physical_expr::LexOrdering; - use datafusion_common::{internal_err, Result}; - use datafusion_physical_plan::expressions::Column; + use datafusion_common::{Result, internal_err}; use datafusion_physical_plan::Accumulator; + use datafusion_physical_plan::expressions::Column; use serde_json::json; use serde_json::{Map as JsonMap, Value as JsonValue}; use std::sync::Arc; diff --git a/crates/df-builtins/src/variant/array_cat.rs b/crates/df-builtins/src/variant/array_cat.rs index 2a3260f6..b023a1fb 100644 --- a/crates/df-builtins/src/variant/array_cat.rs +++ b/crates/df-builtins/src/variant/array_cat.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -184,10 +184,10 @@ make_udf_function!(ArrayCatUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_cat() -> DFResult<()> { @@ -197,7 +197,6 @@ mod tests { ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); ctx.register_udf(ScalarUDF::from(ArrayCatUDF::new())); - // Test concatenating two arrays let sql = "SELECT array_cat(array_construct(1, 2), array_construct(3, 4)) as concatenated"; let result = ctx.sql(sql).await?.collect().await?; diff --git a/crates/df-builtins/src/variant/array_compact.rs b/crates/df-builtins/src/variant/array_compact.rs index 7b75e49b..33508f4b 100644 --- a/crates/df-builtins/src/variant/array_compact.rs +++ b/crates/df-builtins/src/variant/array_compact.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{to_string, Value}; +use serde_json::{Value, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -126,10 +126,10 @@ make_udf_function!(ArrayCompactUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_compact() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_construct.rs b/crates/df-builtins/src/variant/array_construct.rs index 206e0584..1fbea39b 100644 --- a/crates/df-builtins/src/variant/array_construct.rs +++ b/crates/df-builtins/src/variant/array_construct.rs @@ -112,10 +112,10 @@ make_udf_function!(ArrayConstructUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_cat::ArrayCatUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_cat::ArrayCatUDF; #[tokio::test] async fn test_array_construct() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_contains.rs b/crates/df-builtins/src/variant/array_contains.rs index 7319aa9d..c3a1e9af 100644 --- a/crates/df-builtins/src/variant/array_contains.rs +++ b/crates/df-builtins/src/variant/array_contains.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; use super::json::{encode_array, encode_scalar}; -use datafusion::arrow::datatypes::DataType; use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -140,10 +140,10 @@ make_udf_function!(ArrayContainsUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_contains() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_distinct.rs b/crates/df-builtins/src/variant/array_distinct.rs index 3feb4177..fe568d14 100644 --- a/crates/df-builtins/src/variant/array_distinct.rs +++ b/crates/df-builtins/src/variant/array_distinct.rs @@ -1,14 +1,14 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::types::{logical_binary, logical_string}; -use datafusion_common::{types::NativeType, Result as DFResult, ScalarValue}; +use datafusion_common::{Result as DFResult, ScalarValue, types::NativeType}; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; -use serde_json::{from_slice, to_string, Value}; +use serde_json::{Value, from_slice, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -130,10 +130,10 @@ make_udf_function!(ArrayDistinctUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_distinct() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_except.rs b/crates/df-builtins/src/variant/array_except.rs index 997c99b2..08db7e16 100644 --- a/crates/df-builtins/src/variant/array_except.rs +++ b/crates/df-builtins/src/variant/array_except.rs @@ -1,11 +1,11 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -168,10 +168,10 @@ make_udf_function!(ArrayExceptUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_except() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_flatten.rs b/crates/df-builtins/src/variant/array_flatten.rs index 6da74286..55069122 100644 --- a/crates/df-builtins/src/variant/array_flatten.rs +++ b/crates/df-builtins/src/variant/array_flatten.rs @@ -1,3 +1,4 @@ +use crate::macros::make_udf_function; use datafusion::arrow::array::as_string_array; use datafusion::arrow::datatypes::DataType; use datafusion::error::Result as DFResult; @@ -8,7 +9,6 @@ use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use serde_json::{Map, Value}; use std::any::Any; use std::sync::Arc; -use crate::macros::make_udf_function; // array_flatten SQL function // Transforms a nested ARRAY (an ARRAY of ARRAYs) into a single, flat ARRAY by combining all inner ARRAYs into one continuous sequence. diff --git a/crates/df-builtins/src/variant/array_insert.rs b/crates/df-builtins/src/variant/array_insert.rs index daf749b2..96800c23 100644 --- a/crates/df-builtins/src/variant/array_insert.rs +++ b/crates/df-builtins/src/variant/array_insert.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{to_string, Value}; +use serde_json::{Value, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -175,10 +175,10 @@ make_udf_function!(ArrayInsertUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_insert() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_intersection.rs b/crates/df-builtins/src/variant/array_intersection.rs index 62b6cd49..9836d9fa 100644 --- a/crates/df-builtins/src/variant/array_intersection.rs +++ b/crates/df-builtins/src/variant/array_intersection.rs @@ -1,11 +1,11 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -168,10 +168,10 @@ make_udf_function!(ArrayIntersectionUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_intersection() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_max.rs b/crates/df-builtins/src/variant/array_max.rs index b9fa96dc..d72b8910 100644 --- a/crates/df-builtins/src/variant/array_max.rs +++ b/crates/df-builtins/src/variant/array_max.rs @@ -1,14 +1,14 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; -use datafusion_common::types::{logical_binary, logical_string, NativeType}; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; +use datafusion_common::types::{NativeType, logical_binary, logical_string}; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -171,10 +171,10 @@ make_udf_function!(ArrayMaxUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_max() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_min.rs b/crates/df-builtins/src/variant/array_min.rs index bc9fb2bf..a9939473 100644 --- a/crates/df-builtins/src/variant/array_min.rs +++ b/crates/df-builtins/src/variant/array_min.rs @@ -1,14 +1,14 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; -use datafusion_common::types::{logical_binary, logical_string, NativeType}; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; +use datafusion_common::types::{NativeType, logical_binary, logical_string}; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -171,10 +171,10 @@ make_udf_function!(ArrayMinUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_min() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_position.rs b/crates/df-builtins/src/variant/array_position.rs index d2e6b824..9131d849 100644 --- a/crates/df-builtins/src/variant/array_position.rs +++ b/crates/df-builtins/src/variant/array_position.rs @@ -1,13 +1,13 @@ use super::super::macros::make_udf_function; use super::json::{encode_array, encode_scalar}; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_slice, Value}; +use serde_json::{Value, from_slice}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -138,10 +138,10 @@ make_udf_function!(ArrayPositionUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_position() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_prepend.rs b/crates/df-builtins/src/variant/array_prepend.rs index 907d1aff..cef4b73c 100644 --- a/crates/df-builtins/src/variant/array_prepend.rs +++ b/crates/df-builtins/src/variant/array_prepend.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_slice, to_string, Value}; +use serde_json::{Value, from_slice, to_string}; use std::sync::Arc; use super::json::encode_scalar; @@ -129,10 +129,10 @@ make_udf_function!(ArrayPrependUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_prepend() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_remove.rs b/crates/df-builtins/src/variant/array_remove.rs index 8bac5400..6d3fa80a 100644 --- a/crates/df-builtins/src/variant/array_remove.rs +++ b/crates/df-builtins/src/variant/array_remove.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -180,10 +180,10 @@ make_udf_function!(ArrayRemoveUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_remove() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_remove_at.rs b/crates/df-builtins/src/variant/array_remove_at.rs index 77caaa96..c7e5d2e2 100644 --- a/crates/df-builtins/src/variant/array_remove_at.rs +++ b/crates/df-builtins/src/variant/array_remove_at.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -184,10 +184,10 @@ make_udf_function!(ArrayRemoveAtUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_remove_at() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_reverse.rs b/crates/df-builtins/src/variant/array_reverse.rs index 2e95767c..b2a7e48c 100644 --- a/crates/df-builtins/src/variant/array_reverse.rs +++ b/crates/df-builtins/src/variant/array_reverse.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -132,10 +132,10 @@ make_udf_function!(ArrayReverseUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_reverse() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_size.rs b/crates/df-builtins/src/variant/array_size.rs index a76df889..c825fb07 100644 --- a/crates/df-builtins/src/variant/array_size.rs +++ b/crates/df-builtins/src/variant/array_size.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, Value}; +use serde_json::{Value, from_str}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -117,10 +117,10 @@ make_udf_function!(ArraySizeUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_size() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_slice.rs b/crates/df-builtins/src/variant/array_slice.rs index b79adf80..7c320076 100644 --- a/crates/df-builtins/src/variant/array_slice.rs +++ b/crates/df-builtins/src/variant/array_slice.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -197,10 +197,10 @@ make_udf_function!(ArraySliceUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_slice() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_sort.rs b/crates/df-builtins/src/variant/array_sort.rs index 0e9f3b90..3d58dc33 100644 --- a/crates/df-builtins/src/variant/array_sort.rs +++ b/crates/df-builtins/src/variant/array_sort.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -89,11 +89,7 @@ impl ArraySortUDF { } (Some(a_val), Some(b_val)) => { let cmp = Self::compare_json_values(a_val, b_val); - if sort_ascending { - cmp - } else { - cmp.reverse() - } + if sort_ascending { cmp } else { cmp.reverse() } } } }); @@ -212,10 +208,10 @@ make_udf_function!(ArraySortUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_sort() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/array_to_string.rs b/crates/df-builtins/src/variant/array_to_string.rs index 62cc9c20..92207309 100644 --- a/crates/df-builtins/src/variant/array_to_string.rs +++ b/crates/df-builtins/src/variant/array_to_string.rs @@ -1,3 +1,4 @@ +use crate::macros::make_udf_function; use datafusion::arrow::array::as_string_array; use datafusion::arrow::datatypes::DataType; use datafusion::error::Result as DFResult; @@ -8,7 +9,6 @@ use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use serde_json::Value; use std::any::Any; use std::sync::Arc; -use crate::macros::make_udf_function; // array_to_string SQL function // Converts the input array to a string by first casting each element to a string, diff --git a/crates/df-builtins/src/variant/arrays_overlap.rs b/crates/df-builtins/src/variant/arrays_overlap.rs index 51068a35..604dfe4f 100644 --- a/crates/df-builtins/src/variant/arrays_overlap.rs +++ b/crates/df-builtins/src/variant/arrays_overlap.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, Value}; +use serde_json::{Value, from_str}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -163,10 +163,10 @@ make_udf_function!(ArraysOverlapUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_arrays_overlap() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/arrays_to_object.rs b/crates/df-builtins/src/variant/arrays_to_object.rs index c18442ce..880328a0 100644 --- a/crates/df-builtins/src/variant/arrays_to_object.rs +++ b/crates/df-builtins/src/variant/arrays_to_object.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -181,10 +181,10 @@ make_udf_function!(ArraysToObjectUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_arrays_to_object() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/arrays_zip.rs b/crates/df-builtins/src/variant/arrays_zip.rs index 97f4cc2e..618607b1 100644 --- a/crates/df-builtins/src/variant/arrays_zip.rs +++ b/crates/df-builtins/src/variant/arrays_zip.rs @@ -1,7 +1,7 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -183,10 +183,10 @@ make_udf_function!(ArraysZipUDF); #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_arrays_zip() -> DFResult<()> { diff --git a/crates/df-builtins/src/variant/json.rs b/crates/df-builtins/src/variant/json.rs index 408f28ac..77f0c620 100644 --- a/crates/df-builtins/src/variant/json.rs +++ b/crates/df-builtins/src/variant/json.rs @@ -1,18 +1,20 @@ #![allow(clippy::needless_pass_by_value, clippy::unnecessary_wraps)] +use base64::engine::Engine; use datafusion::arrow::array::AsArray; -use datafusion::arrow::array::{Array, ArrayRef, BooleanArray, NullArray, PrimitiveArray, StringArray}; +use datafusion::arrow::array::{ + Array, ArrayRef, BooleanArray, NullArray, PrimitiveArray, StringArray, +}; use datafusion::arrow::datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use datafusion::arrow::error::ArrowError; -use base64::engine::Engine; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; use serde_json::{Map, Number, Value as JsonValue}; @@ -902,7 +904,7 @@ pub fn encode_columnar_value(value: &ColumnarValue) -> Result Result<()> { let functions: Vec> = vec![ @@ -64,12 +64,12 @@ pub fn register_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { object_delete::get_udf(), object_insert::get_udf(), object_pick::get_udf(), - object_construct::get_udf() + object_construct::get_udf(), ]; for func in functions { registry.register_udf(func)?; } - + Ok(()) } diff --git a/crates/df-builtins/src/variant/object_delete.rs b/crates/df-builtins/src/variant/object_delete.rs index 5e33615b..20af0c39 100644 --- a/crates/df-builtins/src/variant/object_delete.rs +++ b/crates/df-builtins/src/variant/object_delete.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] diff --git a/crates/df-builtins/src/variant/object_insert.rs b/crates/df-builtins/src/variant/object_insert.rs index fbce549e..26ff4dc0 100644 --- a/crates/df-builtins/src/variant/object_insert.rs +++ b/crates/df-builtins/src/variant/object_insert.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -134,7 +134,7 @@ impl ScalarUDFImpl for ObjectInsertUDF { None => { return Err(datafusion_common::error::DataFusionError::Internal( "Expected array for scalar value".to_string(), - )) + )); } } } else { @@ -146,7 +146,7 @@ impl ScalarUDFImpl for ObjectInsertUDF { ColumnarValue::Array(_) => { return Err(datafusion_common::error::DataFusionError::Internal( "Key argument must be a scalar value".to_string(), - )) + )); } }; @@ -162,7 +162,7 @@ impl ScalarUDFImpl for ObjectInsertUDF { None => { return Err(datafusion_common::error::DataFusionError::Internal( "Expected array for scalar value".to_string(), - )) + )); } } } else { @@ -174,7 +174,7 @@ impl ScalarUDFImpl for ObjectInsertUDF { ColumnarValue::Array(_) => { return Err(datafusion_common::error::DataFusionError::Internal( "Value argument must be a scalar value".to_string(), - )) + )); } }; diff --git a/crates/df-builtins/src/variant/object_pick.rs b/crates/df-builtins/src/variant/object_pick.rs index 9bbf04d7..f08ac769 100644 --- a/crates/df-builtins/src/variant/object_pick.rs +++ b/crates/df-builtins/src/variant/object_pick.rs @@ -1,12 +1,12 @@ use super::super::macros::make_udf_function; -use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::array::cast::AsArray; use datafusion::arrow::array::Array; +use datafusion::arrow::array::cast::AsArray; +use datafusion::arrow::datatypes::DataType; use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{Value, from_str, to_string}; use std::sync::Arc; #[derive(Debug, Clone)] diff --git a/crates/df-builtins/src/variant/variant_element.rs b/crates/df-builtins/src/variant/variant_element.rs index 33573fe6..14c335fb 100644 --- a/crates/df-builtins/src/variant/variant_element.rs +++ b/crates/df-builtins/src/variant/variant_element.rs @@ -205,11 +205,11 @@ make_udf_function!(VariantArrayElementUDF); mod tests { use super::super::array_construct::ArrayConstructUDF; use super::*; - use crate::visitors::variant::variant_element; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion::sql::parser::Statement; use datafusion_expr::ScalarUDF; + use visitors::variant::variant_element; #[tokio::test] async fn test_array_indexing() -> DFResult<()> { let ctx = SessionContext::new(); diff --git a/crates/df-builtins/src/visitors/mod.rs b/crates/df-builtins/src/visitors/mod.rs deleted file mode 100644 index cb1eb8cd..00000000 --- a/crates/df-builtins/src/visitors/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod variant; \ No newline at end of file diff --git a/crates/visitors/Cargo.toml b/crates/visitors/Cargo.toml new file mode 100644 index 00000000..31842018 --- /dev/null +++ b/crates/visitors/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "visitors" +version = "0.1.0" +edition = "2024" +license-file.workspace = true + +[dependencies] +datafusion = { workspace = true } +datafusion-expr = { workspace = true } +sqlparser = { git = "https://github.com/Embucket/datafusion-sqlparser-rs.git", rev = "ed416548dcfe4a73a3240bbf625fb9010a4925c8", features = [ + "visitor", +] } + +[lints] +workspace = true diff --git a/crates/core-executor/src/datafusion/visitors/copy_into_identifiers.rs b/crates/visitors/src/copy_into_identifiers.rs similarity index 100% rename from crates/core-executor/src/datafusion/visitors/copy_into_identifiers.rs rename to crates/visitors/src/copy_into_identifiers.rs diff --git a/crates/core-executor/src/datafusion/visitors/functions_rewriter.rs b/crates/visitors/src/functions_rewriter.rs similarity index 100% rename from crates/core-executor/src/datafusion/visitors/functions_rewriter.rs rename to crates/visitors/src/functions_rewriter.rs diff --git a/crates/core-executor/src/datafusion/visitors/json_element.rs b/crates/visitors/src/json_element.rs similarity index 100% rename from crates/core-executor/src/datafusion/visitors/json_element.rs rename to crates/visitors/src/json_element.rs diff --git a/crates/core-executor/src/datafusion/visitors/mod.rs b/crates/visitors/src/lib.rs similarity index 68% rename from crates/core-executor/src/datafusion/visitors/mod.rs rename to crates/visitors/src/lib.rs index d737a680..e2879460 100644 --- a/crates/core-executor/src/datafusion/visitors/mod.rs +++ b/crates/visitors/src/lib.rs @@ -1,5 +1,4 @@ -//pub mod analyzer; -//pub mod error; pub mod copy_into_identifiers; pub mod functions_rewriter; pub mod json_element; +pub mod variant; diff --git a/crates/df-builtins/src/visitors/variant/array_agg.rs b/crates/visitors/src/variant/array_agg.rs similarity index 99% rename from crates/df-builtins/src/visitors/variant/array_agg.rs rename to crates/visitors/src/variant/array_agg.rs index 8e50a004..b862071e 100644 --- a/crates/df-builtins/src/visitors/variant/array_agg.rs +++ b/crates/visitors/src/variant/array_agg.rs @@ -69,14 +69,14 @@ pub fn visit(stmt: &mut Statement) { #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_compact::ArrayCompactUDF; + use crate::variant::array_construct::ArrayConstructUDF; + use crate::visitors::variant::variant_element; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion::sql::parser::Statement as DFStatement; use datafusion_common::Result as DFResult; use datafusion_expr::ScalarUDF; - use crate::variant::array_compact::ArrayCompactUDF; - use crate::variant::array_construct::ArrayConstructUDF; - use crate::visitors::variant::variant_element; #[tokio::test] async fn test_array_agg_rewrite() -> DFResult<()> { @@ -85,7 +85,7 @@ mod tests { ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); ctx.register_udf(ScalarUDF::from(ArrayCompactUDF::new())); - + // Create table and insert data let create_sql = "CREATE TABLE test_table (id INT, val INT)"; let mut create_stmt = ctx.state().sql_to_statement(create_sql, "snowflake")?; diff --git a/crates/df-builtins/src/visitors/variant/array_construct.rs b/crates/visitors/src/variant/array_construct.rs similarity index 99% rename from crates/df-builtins/src/visitors/variant/array_construct.rs rename to crates/visitors/src/variant/array_construct.rs index fc5722cb..bb7a6421 100644 --- a/crates/df-builtins/src/visitors/variant/array_construct.rs +++ b/crates/visitors/src/variant/array_construct.rs @@ -54,12 +54,12 @@ pub fn visit(stmt: &mut Statement) { #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion::sql::parser::Statement as DFStatement; use datafusion_common::Result as DFResult; use datafusion_expr::ScalarUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_construct_rewrite() -> DFResult<()> { @@ -67,7 +67,6 @@ mod tests { // Register array_construct UDF ctx.register_udf(ScalarUDF::from(ArrayConstructUDF::new())); - // Test simple array construction let sql = "SELECT [1, 2, 3] as arr"; let mut stmt = ctx.state().sql_to_statement(sql, "snowflake")?; diff --git a/crates/df-builtins/src/visitors/variant/array_construct_compact.rs b/crates/visitors/src/variant/array_construct_compact.rs similarity index 100% rename from crates/df-builtins/src/visitors/variant/array_construct_compact.rs rename to crates/visitors/src/variant/array_construct_compact.rs index 89cb72a7..537b7acf 100644 --- a/crates/df-builtins/src/visitors/variant/array_construct_compact.rs +++ b/crates/visitors/src/variant/array_construct_compact.rs @@ -72,13 +72,13 @@ pub fn visit(stmt: &mut Statement) { #[allow(clippy::unwrap_used)] mod tests { use super::*; + use crate::variant::array_compact::ArrayCompactUDF; + use crate::variant::array_construct::ArrayConstructUDF; use datafusion::assert_batches_eq; use datafusion::prelude::SessionContext; use datafusion::sql::parser::Statement as DFStatement; use datafusion_common::Result as DFResult; use datafusion_expr::ScalarUDF; - use crate::variant::array_compact::ArrayCompactUDF; - use crate::variant::array_construct::ArrayConstructUDF; #[tokio::test] async fn test_array_construct_compact_rewrite() -> DFResult<()> { diff --git a/crates/df-builtins/src/visitors/variant/mod.rs b/crates/visitors/src/variant/mod.rs similarity index 100% rename from crates/df-builtins/src/visitors/variant/mod.rs rename to crates/visitors/src/variant/mod.rs diff --git a/crates/df-builtins/src/visitors/variant/type_rewrite.rs b/crates/visitors/src/variant/type_rewrite.rs similarity index 100% rename from crates/df-builtins/src/visitors/variant/type_rewrite.rs rename to crates/visitors/src/variant/type_rewrite.rs diff --git a/crates/df-builtins/src/visitors/variant/variant_element.rs b/crates/visitors/src/variant/variant_element.rs similarity index 100% rename from crates/df-builtins/src/visitors/variant/variant_element.rs rename to crates/visitors/src/variant/variant_element.rs