diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index c9cbaa8396fc..d0ae7ee6b19e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -17,3 +17,4 @@ pub mod count_distinct; pub mod groups_accumulator; +pub mod mode; diff --git a/datafusion/functions-aggregate-common/src/aggregate/mode.rs b/datafusion/functions-aggregate-common/src/aggregate/mode.rs new file mode 100644 index 000000000000..b5d7b5316135 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/mode.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesModeAccumulator; +pub use bytes::BytesViewModeAccumulator; +pub use native::FloatModeAccumulator; +pub use native::PrimitiveModeAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/mode/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/mode/bytes.rs new file mode 100644 index 000000000000..2f3b906360f3 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/mode/bytes.rs @@ -0,0 +1,412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::array::AsArray; +use arrow::array::OffsetSizeTrait; +use arrow::datatypes::DataType; +use datafusion_common::cast::as_list_array; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::utils::array_into_list_array_nullable; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::binary_map::ArrowBytesSet; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; + +#[derive(Debug)] +pub struct BytesModeAccumulator { + values: ArrowBytesSet, + value_counts: ArrowBytesMap, +} + +impl BytesModeAccumulator { + pub fn new(output_type: OutputType) -> Self { + Self { + values: ArrowBytesSet::new(output_type), + value_counts: ArrowBytesMap::new(output_type), + } + } +} + +impl Accumulator for BytesModeAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.values.insert(&values[0]); + + self.value_counts + .insert_or_update(&values[0], |_| 1i64, |count| *count += 1); + + Ok(()) + } + + fn state(&mut self) -> Result> { + let values = self.values.take().into_state(); + let payloads: Vec = self + .value_counts + .take() + .get_payloads(&values) + .into_iter() + .map(|count| match count { + Some(c) => ScalarValue::Int64(Some(c)), + None => ScalarValue::Int64(None), + }) + .collect(); + + let values_list = Arc::new(array_into_list_array_nullable(values)); + let payloads_list = ScalarValue::new_list_nullable(&payloads, &DataType::Int64); + + Ok(vec![ + ScalarValue::List(values_list), + ScalarValue::List(payloads_list), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + let arr = as_list_array(&states[0])?; + let counts = as_primitive_array::(&states[1])?; + + arr.iter() + .zip(counts.iter()) + .try_for_each(|(maybe_list, maybe_count)| { + if let (Some(list), Some(count)) = (maybe_list, maybe_count) { + // Insert or update the count for each value + self.value_counts.insert_or_update( + &list, + |_| count, + |existing_count| *existing_count += count, + ); + } + Ok(()) + }) + } + + fn evaluate(&mut self) -> Result { + let mut max_index: Option = None; + let mut max_count: i64 = 0; + + let values = self.values.take().into_state(); + let counts = self.value_counts.take().get_payloads(&values); + + for (i, count) in counts.into_iter().enumerate() { + if let Some(c) = count { + if c > max_count { + max_count = c; + max_index = Some(i); + } + } + } + + match max_index { + Some(index) => { + let array = values.as_string::(); + let mode_value = array.value(index); + if mode_value.is_empty() { + Ok(ScalarValue::Utf8(None)) + } else if O::IS_LARGE { + Ok(ScalarValue::LargeUtf8(Some(mode_value.to_string()))) + } else { + Ok(ScalarValue::Utf8(Some(mode_value.to_string()))) + } + } + None => { + if O::IS_LARGE { + Ok(ScalarValue::LargeUtf8(None)) + } else { + Ok(ScalarValue::Utf8(None)) + } + } + } + } + + fn size(&self) -> usize { + self.values.size() + self.value_counts.size() + } +} + +#[derive(Debug)] +pub struct BytesViewModeAccumulator { + values: ArrowBytesViewSet, + value_counts: ArrowBytesViewMap, +} + +impl BytesViewModeAccumulator { + pub fn new(output_type: OutputType) -> Self { + Self { + values: ArrowBytesViewSet::new(output_type), + value_counts: ArrowBytesViewMap::new(output_type), + } + } +} + +impl Accumulator for BytesViewModeAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.values.insert(&values[0]); + + self.value_counts + .insert_or_update(&values[0], |_| 1i64, |count| *count += 1); + + Ok(()) + } + + fn state(&mut self) -> Result> { + let values = self.values.take().into_state(); + let payloads: Vec = self + .value_counts + .take() + .get_payloads(&values) + .into_iter() + .map(|count| match count { + Some(c) => ScalarValue::Int64(Some(c)), + None => ScalarValue::Int64(None), + }) + .collect(); + + let values_list = Arc::new(array_into_list_array_nullable(values)); + let payloads_list = ScalarValue::new_list_nullable(&payloads, &DataType::Int64); + + Ok(vec![ + ScalarValue::List(values_list), + ScalarValue::List(payloads_list), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + let arr = as_list_array(&states[0])?; + let counts = as_primitive_array::(&states[1])?; + + arr.iter() + .zip(counts.iter()) + .try_for_each(|(maybe_list, maybe_count)| { + if let (Some(list), Some(count)) = (maybe_list, maybe_count) { + // Insert or update the count for each value + self.value_counts.insert_or_update( + &list, + |_| count, + |existing_count| *existing_count += count, + ); + } + Ok(()) + }) + } + + fn evaluate(&mut self) -> Result { + let mut max_index: Option = None; + let mut max_count: i64 = 0; + + let values = self.values.take().into_state(); + let counts = self.value_counts.take().get_payloads(&values); + + for (i, count) in counts.into_iter().enumerate() { + if let Some(c) = count { + if c > max_count { + max_count = c; + max_index = Some(i); + } + } + } + + match max_index { + Some(index) => { + let array = values.as_string_view(); + let mode_value = array.value(index); + if mode_value.is_empty() { + Ok(ScalarValue::Utf8View(None)) + } else { + Ok(ScalarValue::Utf8View(Some(mode_value.to_string()))) + } + } + None => Ok(ScalarValue::Utf8View(None)), + } + } + + fn size(&self) -> usize { + self.values.size() + self.value_counts.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, GenericByteViewArray, StringArray}; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_mode_accumulator_single_mode_utf8() -> Result<()> { + let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("apple"), + Some("banana"), + Some("apple"), + Some("orange"), + Some("banana"), + Some("apple"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string()))); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_utf8() -> Result<()> { + let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("apple"), + Some("banana"), + Some("apple"), + Some("orange"), + Some("banana"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string()))); + Ok(()) + } + + #[test] + fn test_mode_accumulator_all_nulls_utf8() -> Result<()> { + let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let values: ArrayRef = + Arc::new(StringArray::from(vec![None as Option<&str>, None, None])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8(None)); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_utf8() -> Result<()> { + let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("apple"), + None, + Some("banana"), + Some("apple"), + None, + None, + None, + Some("banana"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8(Some("apple".to_string()))); + Ok(()) + } + + #[test] + fn test_mode_accumulator_single_mode_utf8view() -> Result<()> { + let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ + Some("apple"), + Some("banana"), + Some("apple"), + Some("orange"), + Some("banana"), + Some("apple"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string()))); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_utf8view() -> Result<()> { + let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ + Some("apple"), + Some("banana"), + Some("apple"), + Some("orange"), + Some("banana"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string()))); + Ok(()) + } + + #[test] + fn test_mode_accumulator_all_nulls_utf8view() -> Result<()> { + let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ + None as Option<&str>, + None, + None, + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8View(None)); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_utf8view() -> Result<()> { + let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ + Some("apple"), + None, + Some("banana"), + Some("apple"), + None, + None, + None, + Some("banana"), + ])); + + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + + assert_eq!(result, ScalarValue::Utf8View(Some("apple".to_string()))); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/mode/native.rs b/datafusion/functions-aggregate-common/src/aggregate/mode/native.rs new file mode 100644 index 000000000000..678c12308fa8 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/mode/native.rs @@ -0,0 +1,578 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; + +use crate::utils::Hashable; +use arrow::{ + array::{ArrayRef, ArrowPrimitiveType}, + datatypes::DataType, +}; +use datafusion_common::{cast::as_primitive_array, Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; + +#[derive(Debug)] +pub struct PrimitiveModeAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + value_counts: HashMap, + data_type: DataType, +} + +impl PrimitiveModeAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash + Clone, +{ + pub fn new(data_type: &DataType) -> Self { + Self { + value_counts: HashMap::default(), + data_type: data_type.clone(), + } + } +} + +impl Accumulator for PrimitiveModeAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash + Clone + PartialOrd + Debug, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = as_primitive_array::(&values[0])?; + + for value in arr.iter().flatten() { + let counter = self.value_counts.entry(value).or_insert(0); + *counter += 1; + } + + Ok(()) + } + + fn state(&mut self) -> Result> { + let values: Vec = self + .value_counts + .keys() + .map(|key| ScalarValue::new_primitive::(Some(*key), &self.data_type)) + .collect::>>()?; + + let frequencies: Vec = self + .value_counts + .values() + .map(|count| ScalarValue::from(*count)) + .collect(); + + let values_scalar = + ScalarValue::new_list_nullable(&values, &self.data_type.clone()); + let frequencies_scalar = + ScalarValue::new_list_nullable(&frequencies, &DataType::Int64); + + Ok(vec![ + ScalarValue::List(values_scalar), + ScalarValue::List(frequencies_scalar), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + let values_array = as_primitive_array::(&states[0])?; + let counts_array = as_primitive_array::(&states[1])?; + + for i in 0..values_array.len() { + let value = values_array.value(i); + let count = counts_array.value(i); + let entry = self.value_counts.entry(value).or_insert(0); + *entry += count; + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut max_value: Option = None; + let mut max_count: i64 = 0; + + self.value_counts.iter().for_each(|(value, &count)| { + match count.cmp(&max_count) { + std::cmp::Ordering::Greater => { + max_value = Some(*value); + max_count = count; + } + std::cmp::Ordering::Equal => { + max_value = match max_value { + Some(ref current_max_value) if value < current_max_value => { + Some(*value) + } + Some(ref current_max_value) => Some(*current_max_value), + None => Some(*value), + }; + } + _ => {} // Do nothing if count is less than max_count + } + }); + + match max_value { + Some(val) => ScalarValue::new_primitive::(Some(val), &self.data_type), + None => ScalarValue::new_primitive::(None, &self.data_type), + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(&self.value_counts) + + self.value_counts.len() * std::mem::size_of::<(T::Native, i64)>() + } +} + +#[derive(Debug)] +pub struct FloatModeAccumulator +where + T: ArrowPrimitiveType, +{ + value_counts: HashMap, i64>, + data_type: DataType, +} + +impl FloatModeAccumulator +where + T: ArrowPrimitiveType, +{ + pub fn new(data_type: &DataType) -> Self { + Self { + value_counts: HashMap::default(), + data_type: data_type.clone(), + } + } +} + +impl Accumulator for FloatModeAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: PartialOrd + Debug + Clone, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + + for value in arr.iter().flatten() { + let counter = self.value_counts.entry(Hashable(value)).or_insert(0); + *counter += 1; + } + + Ok(()) + } + + fn state(&mut self) -> Result> { + let values: Vec = self + .value_counts + .keys() + .map(|key| ScalarValue::new_primitive::(Some(key.0), &self.data_type)) + .collect::>>()?; + + let frequencies: Vec = self + .value_counts + .values() + .map(|count| ScalarValue::from(*count)) + .collect(); + + let values_scalar = + ScalarValue::new_list_nullable(&values, &self.data_type.clone()); + let frequencies_scalar = + ScalarValue::new_list_nullable(&frequencies, &DataType::Int64); + + Ok(vec![ + ScalarValue::List(values_scalar), + ScalarValue::List(frequencies_scalar), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + let values_array = as_primitive_array::(&states[0])?; + let counts_array = as_primitive_array::(&states[1])?; + + for i in 0..values_array.len() { + let count = counts_array.value(i); + let entry = self + .value_counts + .entry(Hashable(values_array.value(i))) + .or_insert(0); + *entry += count; + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut max_value: Option = None; + let mut max_count: i64 = 0; + + self.value_counts.iter().for_each(|(value, &count)| { + match count.cmp(&max_count) { + std::cmp::Ordering::Greater => { + max_value = Some(value.0); + max_count = count; + } + std::cmp::Ordering::Equal => { + max_value = match max_value { + Some(current_max_value) if value.0 < current_max_value => { + Some(value.0) + } + Some(current_max_value) => Some(current_max_value), + None => Some(value.0), + }; + } + _ => {} // Do nothing if count is less than max_count + } + }); + + match max_value { + Some(val) => ScalarValue::new_primitive::(Some(val), &self.data_type), + None => ScalarValue::new_primitive::(None, &self.data_type), + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(&self.value_counts) + + self.value_counts.len() * std::mem::size_of::<(Hashable, i64)>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + ArrayRef, Date64Array, Float64Array, Int64Array, Time64MicrosecondArray, + }; + use arrow::datatypes::{ + DataType, Date64Type, Float64Type, Int64Type, Time64MicrosecondType, TimeUnit, + }; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_mode_accumulator_single_mode_int64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Int64); + let values: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 2, 3, 3, 3])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(3), &DataType::Int64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_int64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Int64); + let values: ArrayRef = Arc::new(Int64Array::from(vec![ + None, + Some(1), + Some(2), + Some(2), + Some(3), + Some(3), + Some(3), + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(3), &DataType::Int64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_case_int64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Int64); + let values: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 2, 3, 3])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(2), &DataType::Int64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_only_nulls_int64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Int64); + let values: ArrayRef = Arc::new(Int64Array::from(vec![None, None, None, None])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(None, &DataType::Int64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_single_mode_float64() -> Result<()> { + let mut acc = FloatModeAccumulator::::new(&DataType::Float64); + let values: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(3.0), &DataType::Float64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_float64() -> Result<()> { + let mut acc = FloatModeAccumulator::::new(&DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![ + None, + Some(1.0), + Some(2.0), + Some(2.0), + Some(3.0), + Some(3.0), + Some(3.0), + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(3.0), &DataType::Float64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_case_float64() -> Result<()> { + let mut acc = FloatModeAccumulator::::new(&DataType::Float64); + let values: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 2.0, 3.0, 3.0])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(Some(2.0), &DataType::Float64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_only_nulls_float64() -> Result<()> { + let mut acc = FloatModeAccumulator::::new(&DataType::Float64); + let values: ArrayRef = Arc::new(Float64Array::from(vec![None, None, None, None])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(None, &DataType::Float64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_single_mode_date64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Date64); + let values: ArrayRef = Arc::new(Date64Array::from(vec![ + 1609459200000, + 1609545600000, + 1609545600000, + 1609632000000, + 1609632000000, + 1609632000000, + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(1609632000000), + &DataType::Date64 + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_date64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Date64); + let values: ArrayRef = Arc::new(Date64Array::from(vec![ + None, + Some(1609459200000), + Some(1609545600000), + Some(1609545600000), + Some(1609632000000), + Some(1609632000000), + Some(1609632000000), + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(1609632000000), + &DataType::Date64 + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_case_date64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Date64); + let values: ArrayRef = Arc::new(Date64Array::from(vec![ + 1609459200000, + 1609545600000, + 1609545600000, + 1609632000000, + 1609632000000, + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(1609545600000), + &DataType::Date64 + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_only_nulls_date64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new(&DataType::Date64); + let values: ArrayRef = Arc::new(Date64Array::from(vec![None, None, None, None])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::(None, &DataType::Date64)? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_single_mode_time64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new( + &DataType::Time64(TimeUnit::Microsecond), + ); + let values: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![ + 3600000000, + 7200000000, + 7200000000, + 10800000000, + 10800000000, + 10800000000, + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(10800000000), + &DataType::Time64(TimeUnit::Microsecond) + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_with_nulls_time64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new( + &DataType::Time64(TimeUnit::Microsecond), + ); + let values: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![ + None, + Some(3600000000), + Some(7200000000), + Some(7200000000), + Some(10800000000), + Some(10800000000), + Some(10800000000), + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(10800000000), + &DataType::Time64(TimeUnit::Microsecond) + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_tie_case_time64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new( + &DataType::Time64(TimeUnit::Microsecond), + ); + let values: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![ + 3600000000, + 7200000000, + 7200000000, + 10800000000, + 10800000000, + ])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + Some(7200000000), + &DataType::Time64(TimeUnit::Microsecond) + )? + ); + Ok(()) + } + + #[test] + fn test_mode_accumulator_only_nulls_time64() -> Result<()> { + let mut acc = PrimitiveModeAccumulator::::new( + &DataType::Time64(TimeUnit::Microsecond), + ); + let values: ArrayRef = + Arc::new(Time64MicrosecondArray::from(vec![None, None, None, None])); + acc.update_batch(&[values])?; + let result = acc.evaluate()?; + assert_eq!( + result, + ScalarValue::new_primitive::( + None, + &DataType::Time64(TimeUnit::Microsecond) + )? + ); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 60e2602eb6ed..b27a390ca20a 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -79,6 +79,7 @@ pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; pub mod kurtosis_pop; +pub mod mode; pub mod nth_value; pub mod string_agg; @@ -172,6 +173,7 @@ pub fn all_default_aggregate_functions() -> Vec> { grouping::grouping_udaf(), nth_value::nth_value_udaf(), kurtosis_pop::kurtosis_pop_udaf(), + mode::mode_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/mode.rs b/datafusion/functions-aggregate/src/mode.rs new file mode 100644 index 000000000000..63a9dcfc53ac --- /dev/null +++ b/datafusion/functions-aggregate/src/mode.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{ + Date32Type, Date64Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use arrow_schema::{DataType, Field, TimeUnit}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, StateFieldsArgs, +}; +use datafusion_functions_aggregate_common::aggregate::mode::{ + BytesModeAccumulator, BytesViewModeAccumulator, FloatModeAccumulator, + PrimitiveModeAccumulator, +}; +use datafusion_physical_expr::binary_map::OutputType; +use std::any::Any; +use std::fmt::Debug; + +make_udaf_expr_and_func!( + ModeFunction, + mode, + x, + "Calculates the most frequent value.", + mode_udaf +); + +/// The `ModeFunction` calculates the mode (most frequent value) from a set of values. +/// +/// - Null values are ignored during the calculation. +/// - If multiple values have the same frequency, the first encountered value with the highest frequency is returned. +/// - In the case of `Utf8` or `Utf8View`, the first value encountered in the original order with the highest frequency is returned. +pub struct ModeFunction { + signature: Signature, +} + +impl Debug for ModeFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ModeFunction") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ModeFunction { + fn default() -> Self { + Self::new() + } +} + +impl ModeFunction { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ModeFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let value_type = args.input_types[0].clone(); + + Ok(vec![ + Field::new("values", value_type, true), + Field::new("frequencies", DataType::UInt64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; + + let accumulator: Box = match data_type { + DataType::Int8 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Int16 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Int32 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Int64 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::UInt8 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::UInt16 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::UInt32 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::UInt64 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + + DataType::Date32 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Date64 => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Second) => { + Box::new(PrimitiveModeAccumulator::::new(data_type)) + } + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveModeAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatModeAccumulator::::new(data_type)) + } + DataType::Float32 => { + Box::new(FloatModeAccumulator::::new(data_type)) + } + DataType::Float64 => { + Box::new(FloatModeAccumulator::::new(data_type)) + } + + DataType::Utf8 => { + Box::new(BytesModeAccumulator::::new(OutputType::Utf8)) + } + DataType::LargeUtf8 => { + Box::new(BytesModeAccumulator::::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewModeAccumulator::new(OutputType::Utf8View)) + } + _ => { + return not_impl_err!( + "Unsupported data type: {:?} for mode function", + data_type + ); + } + }; + + Ok(accumulator) + } +} diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index d21bdb3434c4..fcad87b7d666 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -475,6 +475,304 @@ where } } + /// Inserts each value from `values` into the map, invoking `make_payload_fn` for + /// each value if not already present, or `update_payload_fn` if the value already exists. + /// + /// This function handles both the insert and update cases. + /// + /// # Arguments: + /// + /// `values`: The array whose values are inserted or updated in the map. + /// + /// `make_payload_fn`: Invoked for each value that is not already present + /// to create the payload, in the order of the values in `values`. + /// + /// `update_payload_fn`: Invoked for each value that is already present, + /// allowing the payload to be updated in-place. + /// + /// # Safety: + /// + /// Note that `make_payload_fn` and `update_payload_fn` are only invoked + /// with valid values from `values`, not for the `NULL` value. + pub fn insert_or_update( + &mut self, + values: &ArrayRef, + make_payload_fn: MP, + update_payload_fn: UP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + UP: FnMut(&mut V), + { + // Check the output type and dispatch to the appropriate internal function + match self.output_type { + OutputType::Binary => { + assert!(matches!( + values.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.insert_or_update_inner::>( + values, + make_payload_fn, + update_payload_fn, + ) + } + OutputType::Utf8 => { + assert!(matches!( + values.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.insert_or_update_inner::>( + values, + make_payload_fn, + update_payload_fn, + ) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + /// Generic version of [`Self::insert_or_update`] that handles `ByteArrayType` + /// (both String and Binary). + /// + /// This is the only function that is generic on [`ByteArrayType`], which avoids having + /// to template the entire structure, simplifying the code and reducing code bloat due + /// to duplication. + /// + /// See comments on `insert_or_update` for more details. + fn insert_or_update_inner( + &mut self, + values: &ArrayRef, + mut make_payload_fn: MP, + mut update_payload_fn: UP, + ) where + MP: FnMut(Option<&[u8]>) -> V, // Function to create a new entry + UP: FnMut(&mut V), // Function to update an existing entry + B: ByteArrayType, + { + // Step 1: Compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes).unwrap(); // Compute the hashes for the values + + // Step 2: Insert or update each value + let values = values.as_bytes::(); + + assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // Handle null value + let Some(value) = value else { + let _payload = if let Some(&(payload, _)) = self.null.as_ref() { + payload + } else { + let payload = make_payload_fn(None); + let null_index = self.offsets.len() - 1; + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + self.null = Some((payload, null_index)); + payload + }; + continue; + }; + + let value: &[u8] = value.as_ref(); + let value_len = O::usize_as(value.len()); + + // Small value optimization + if value.len() <= SHORT_VALUE_LEN { + let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + + // Check if the value is already present in the set + let entry = self.map.get_mut(hash, |header| { + if header.len != value_len { + return false; + } + inline == header.offset_or_inline + }); + + if let Some(entry) = entry { + update_payload_fn(&mut entry.payload); + } else { + // Insert a new value if not found + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + let payload = make_payload_fn(Some(value)); + let new_entry = Entry { + hash, + len: value_len, + offset_or_inline: inline, + payload, + }; + self.map.insert_accounted( + new_entry, + |header| header.hash, + &mut self.map_size, + ); + } + } else { + // Handle larger values + let entry = self.map.get_mut(hash, |header| { + if header.len != value_len { + return false; + } + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value + }); + + if let Some(entry) = entry { + update_payload_fn(&mut entry.payload); + } else { + // Insert a new large value if not found + let offset = self.buffer.len(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + let payload = make_payload_fn(Some(value)); + let new_entry = Entry { + hash, + len: value_len, + offset_or_inline: offset, + payload, + }; + self.map.insert_accounted( + new_entry, + |header| header.hash, + &mut self.map_size, + ); + } + }; + } + + // Ensure no overflow in offsets + if O::from_usize(self.buffer.len()).is_none() { + panic!( + "Put {} bytes in buffer, more than can be represented by a {}", + self.buffer.len(), + type_name::() + ); + } + } + + /// Generic version of [`Self::get_payloads`] that handles `ByteArrayType` + /// (both String and Binary). + /// + /// This function computes the hashes for each value and retrieves the payloads + /// stored in the map, leveraging small value optimizations when possible. + /// + /// # Arguments: + /// + /// `values`: The array whose payloads are being retrieved. + /// + /// # Returns + /// + /// A vector of payloads for each value, or `None` if the value is not found. + /// + /// # Safety: + /// + /// This function ensures that small values are handled using inline optimization + /// and larger values are safely retrieved from the buffer. + fn get_payloads_inner(self, values: &ArrayRef) -> Vec> + where + B: ByteArrayType, + { + // Step 1: Compute hashes + let mut batch_hashes = vec![0u64; values.len()]; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, &mut batch_hashes).unwrap(); // Compute the hashes for the values + + // Step 2: Get payloads for each value + let values = values.as_bytes::(); + assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count + + let mut payloads = Vec::with_capacity(values.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // Handle null value + let Some(value) = value else { + if let Some(&(payload, _)) = self.null.as_ref() { + payloads.push(Some(payload)); + } else { + payloads.push(None); + } + continue; + }; + + let value: &[u8] = value.as_ref(); + let value_len = O::usize_as(value.len()); + + // Small value optimization + let payload = if value.len() <= SHORT_VALUE_LEN { + let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + + // Check if the value is already present in the set + let entry = self.map.get(hash, |header| { + if header.len != value_len { + return false; + } + inline == header.offset_or_inline + }); + + entry.map(|entry| entry.payload) + } else { + // Handle larger values + let entry = self.map.get(hash, |header| { + if header.len != value_len { + return false; + } + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value + }); + + entry.map(|entry| entry.payload) + }; + + payloads.push(payload); + } + + payloads + } + + /// Retrieves the payloads for each value from `values`, either by using + /// small value optimizations or larger value handling. + /// + /// This function will compute hashes for each value and attempt to retrieve + /// the corresponding payload from the map. If the value is not found, it will return `None`. + /// + /// # Arguments: + /// + /// `values`: The array whose payloads need to be retrieved. + /// + /// # Returns + /// + /// A vector of payloads for each value, or `None` if the value is not found. + /// + /// # Safety: + /// + /// This function handles both small and large values in a safe manner, though `unsafe` code is + /// used internally for performance optimization. + pub fn get_payloads(self, values: &ArrayRef) -> Vec> { + match self.output_type { + OutputType::Binary => { + assert!(matches!( + values.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.get_payloads_inner::>(values) + } + OutputType::Utf8 => { + assert!(matches!( + values.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.get_payloads_inner::>(values) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + /// Converts this set into a `StringArray`, `LargeStringArray`, /// `BinaryArray`, or `LargeBinaryArray` containing each distinct value /// that was inserted. This is done without copying the values. @@ -871,6 +1169,124 @@ mod tests { assert!(size_after_values2 > total_strings1_len + total_strings2_len); } + #[test] + fn test_insert_or_update_count_u8() { + let input = vec![ + Some("A"), + Some("bcdefghijklmnop"), + Some("X"), + Some("Y"), + None, + Some("qrstuvqxyzhjwya"), + Some("✨🔥"), + Some("🔥"), + Some("🔥🔥🔥🔥🔥🔥"), + Some("A"), // Duplicate to test the count increment + Some("Y"), // Another duplicate to test the count increment + ]; + + let mut map: ArrowBytesMap = ArrowBytesMap::new(OutputType::Utf8); + + let string_array = StringArray::from(input.clone()); + let arr: ArrayRef = Arc::new(string_array); + + map.insert_or_update( + &arr, + |_| 1u8, + |count| { + *count += 1; + }, + ); + + let expected_counts = [ + ("A", 2), + ("bcdefghijklmnop", 1), + ("X", 1), + ("Y", 2), + ("qrstuvqxyzhjwya", 1), + ("✨🔥", 1), + ("🔥", 1), + ("🔥🔥🔥🔥🔥🔥", 1), + ]; + + for &value in input.iter() { + if let Some(value) = value { + let string_array = StringArray::from(vec![Some(value)]); + let arr: ArrayRef = Arc::new(string_array); + + let mut result_payload: Option = None; + + map.insert_or_update( + &arr, + |_| { + panic!("Unexpected new entry during verification"); + }, + |count| { + result_payload = Some(*count); + }, + ); + + if let Some(expected_count) = + expected_counts.iter().find(|&&(s, _)| s == value) + { + assert_eq!(result_payload.unwrap(), expected_count.1); + } + } + } + } + + #[test] + fn test_get_payloads_u8() { + let input = vec![ + Some("A"), + Some("bcdefghijklmnop"), + Some("X"), + Some("Y"), + None, + Some("qrstuvqxyzhjwya"), + Some("✨🔥"), + Some("🔥"), + Some("🔥🔥🔥🔥🔥🔥"), + Some("A"), // Duplicate to test the count increment + Some("Y"), // Another duplicate to test the count increment + ]; + + let mut map: ArrowBytesMap = ArrowBytesMap::new(OutputType::Utf8); + + let string_array = StringArray::from(input.clone()); + let arr: ArrayRef = Arc::new(string_array); + + map.insert_or_update( + &arr, + |_| 1u8, + |count| { + *count += 1; + }, + ); + + let expected_payloads = [ + Some(2u8), + Some(1u8), + Some(1u8), + Some(2u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(2u8), + Some(2u8), + ]; + + let payloads = map.get_payloads(&arr); + + assert_eq!(payloads.len(), expected_payloads.len()); + + for (i, payload) in payloads.iter().enumerate() { + assert_eq!(*payload, expected_payloads[i]); + } + } + #[test] fn test_map() { let input = vec![ diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index bdcf7bbacc69..f1f87d4ca22f 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -306,6 +306,221 @@ where } } + /// Inserts each value from `values` into the map, invoking `make_payload_fn` for + /// each value if not already present, or `update_payload_fn` if the value already exists. + /// + /// This function handles both the insert and update cases. + /// + /// # Arguments: + /// + /// `values`: The array whose values are inserted or updated in the map. + /// + /// `make_payload_fn`: Invoked for each value that is not already present + /// to create the payload, in the order of the values in `values`. + /// + /// `update_payload_fn`: Invoked for each value that is already present, + /// allowing the payload to be updated in-place. + /// + /// # Safety: + /// + /// Note that `make_payload_fn` and `update_payload_fn` are only invoked + /// with valid values from `values`, not for the `NULL` value. + pub fn insert_or_update( + &mut self, + values: &ArrayRef, + make_payload_fn: MP, + update_payload_fn: UP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + UP: FnMut(&mut V), + { + // Check the output type and dispatch to the appropriate internal function + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.insert_or_update_inner::( + values, + make_payload_fn, + update_payload_fn, + ) + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.insert_or_update_inner::( + values, + make_payload_fn, + update_payload_fn, + ) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), + }; + } + + /// Generic version of [`Self::insert_or_update`] that handles `ByteViewType` + /// (both StringView and BinaryView). + /// + /// This is the only function that is generic on [`ByteViewType`], which avoids having + /// to template the entire structure, simplifying the code and reducing code bloat due + /// to duplication. + /// + /// See comments on `insert_or_update` for more details. + fn insert_or_update_inner( + &mut self, + values: &ArrayRef, + mut make_payload_fn: MP, + mut update_payload_fn: UP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + UP: FnMut(&mut V), + B: ByteViewType, + { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each value into the set, if not already present + let values = values.as_byte_view::(); + + // Ensure lengths are equivalent + assert_eq!(values.len(), batch_hashes.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // Handle null value + let Some(value) = value else { + let _payload = if let Some(&(payload, _)) = self.null.as_ref() { + payload + } else { + let payload = make_payload_fn(None); + let null_index = self.builder.len(); + self.builder.append_null(); + self.null = Some((payload, null_index)); + payload + }; + continue; + }; + + let value: &[u8] = value.as_ref(); + + let entry = self.map.get_mut(hash, |header| { + let v = self.builder.get_value(header.view_idx); + + if v.len() != value.len() { + return false; + } + + v == value + }); + + if let Some(entry) = entry { + update_payload_fn(&mut entry.payload); + } else { + // no existing value, make a new one. + let payload = make_payload_fn(Some(value)); + + let inner_view_idx = self.builder.len(); + let new_header = Entry { + view_idx: inner_view_idx, + hash, + payload, + }; + + self.builder.append_value(value); + + self.map + .insert_accounted(new_header, |h| h.hash, &mut self.map_size); + }; + } + } + + /// Generic version of [`Self::get_payloads`] that handles `ByteViewType` + /// (both StringView and BinaryView). + /// + /// This function computes the hashes for each value and retrieves the payloads + /// stored in the map, leveraging small value optimizations when possible. + /// + /// # Arguments: + /// + /// `values`: The array whose payloads are being retrieved. + /// + /// # Returns + /// + /// A vector of payloads for each value, or `None` if the value is not found. + /// + /// # Safety: + /// + /// This function ensures that small values are handled using inline optimization + /// and larger values are safely retrieved from the builder. + fn get_payloads_inner(self, values: &ArrayRef) -> Vec> + where + B: ByteViewType, + { + // Step 1: Compute hashes + let mut batch_hashes = vec![0u64; values.len()]; + create_hashes(&[values.clone()], &self.random_state, &mut batch_hashes).unwrap(); // Compute the hashes for the values + + // Step 2: Get payloads for each value + let values = values.as_byte_view::(); + assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count + + let mut payloads = Vec::with_capacity(values.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // Handle null value + let Some(value) = value else { + if let Some(&(payload, _)) = self.null.as_ref() { + payloads.push(Some(payload)); + } else { + payloads.push(None); + } + continue; + }; + + let value: &[u8] = value.as_ref(); + + let entry = self.map.get(hash, |header| { + let v = self.builder.get_value(header.view_idx); + v.len() == value.len() && v == value + }); + + let payload = entry.map(|e| e.payload); + payloads.push(payload); + } + + payloads + } + + /// Retrieves the payloads for each value from `values`, either by using + /// small value optimizations or larger value handling. + /// + /// This function will compute hashes for each value and attempt to retrieve + /// the corresponding payload from the map. If the value is not found, it will return `None`. + /// + /// # Arguments: + /// + /// `values`: The array whose payloads need to be retrieved. + /// + /// # Returns + /// + /// A vector of payloads for each value, or `None` if the value is not found. + pub fn get_payloads(self, values: &ArrayRef) -> Vec> { + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.get_payloads_inner::(values) + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.get_payloads_inner::(values) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), + } + } + /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, /// containing each distinct value /// that was inserted. This is done without copying the values. @@ -581,6 +796,107 @@ mod tests { assert_eq!(set.len(), 10); } + #[test] + fn test_insert_or_update_count_u8() { + let values = GenericByteViewArray::from(vec![ + Some("a"), + Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + Some("🔥"), + Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), + ]); + + let mut map: ArrowBytesViewMap = ArrowBytesViewMap::new(OutputType::Utf8View); + let arr: ArrayRef = Arc::new(values); + + map.insert_or_update( + &arr, + |_| 1u8, + |count| { + *count += 1; + }, + ); + + let expected_counts = [ + ("a", 1), + ("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥", 2), + ("🔥", 2), + ("✨✨✨", 1), + ("foobarbaz", 1), + ]; + + for value in expected_counts.iter() { + let string_array = GenericByteViewArray::from(vec![Some(value.0)]); + let arr: ArrayRef = Arc::new(string_array); + + let mut result_payload: Option = None; + + map.insert_or_update( + &arr, + |_| { + panic!("Unexpected new entry during verification"); + }, + |count| { + result_payload = Some(*count); + }, + ); + + assert_eq!(result_payload.unwrap(), value.1); + } + } + + #[test] + fn test_get_payloads_u8() { + let values = GenericByteViewArray::from(vec![ + Some("A"), + Some("bcdefghijklmnop"), + Some("X"), + Some("Y"), + None, + Some("qrstuvqxyzhjwya"), + Some("✨🔥"), + Some("🔥"), + Some("🔥🔥🔥🔥🔥🔥"), + Some("A"), // Duplicate to test the count increment + Some("Y"), // Another duplicate to test the count increment + ]); + + let mut map: ArrowBytesViewMap = ArrowBytesViewMap::new(OutputType::Utf8View); + let arr: ArrayRef = Arc::new(values); + + map.insert_or_update( + &arr, + |_| 1u8, + |count| { + *count += 1; + }, + ); + + let expected_payloads = [ + Some(2u8), + Some(1u8), + Some(1u8), + Some(2u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(1u8), + Some(2u8), + Some(2u8), + ]; + + let payloads = map.get_payloads(&arr); + + assert_eq!(payloads.len(), expected_payloads.len()); + + for (i, payload) in payloads.iter().enumerate() { + assert_eq!(*payload, expected_payloads[i]); + } + } + #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] struct TestPayload { // store the string value to check against input diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 576abe5c6f5a..89a742b3b1cb 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5921,5 +5921,80 @@ SELECT kurtosis_pop(c1) FROM t1; ---- 0.194323231917 +query I +SELECT mode(col) FROM VALUES (1), (2), (3), (2), (3), (3) as tab(col); +---- +3 + +query R +SELECT mode(col) FROM VALUES (1.0), (2.115), (3.225), (2.115), (3.225), (3.225) as tab(col); +---- +3.225 + +query D +SELECT mode(col) +FROM VALUES + (TO_DATE('2022-01-01')), + (TO_DATE('2022-01-02')), + (TO_DATE('2022-01-03')), + (TO_DATE('2022-01-02')), + (TO_DATE('2022-01-03')), + (TO_DATE('2022-01-03')) as tab(col); +---- +2022-01-03 + +query T +SELECT mode(col) +FROM VALUES + (NULL), + (NULL), + (NULL), + (NULL) as tab(col); +---- +NULL + +query T +SELECT mode(col) FROM VALUES + ('apple'), + ('banana'), + ('apple'), + ('orange'), + ('banana'), + ('apple') as tab(col); +---- +apple + +query T +SELECT mode(col) FROM VALUES + ('long_text_data_1'), + ('long_text_data_2'), + ('long_text_data_1'), + ('long_text_data_2'), + ('long_text_data_1') as tab(col); +---- +long_text_data_1 + +query T +SELECT mode(col) FROM VALUES + (arrow_cast('Andrew', 'Utf8View')), + (arrow_cast('Andrew', 'Utf8View')), + (arrow_cast('Xiangpeng', 'Utf8View')), + (arrow_cast('Andrew', 'Utf8View')), + (arrow_cast('Xiangpeng', 'Utf8View')), + (arrow_cast('Xiangpeng', 'Utf8View')), + (arrow_cast('Xiangpeng', 'Utf8View')) as tab(col); +---- +Xiangpeng + +query T +SELECT mode(col) FROM VALUES + (arrow_cast('apple', 'Utf8View')), + (arrow_cast('banana', 'Utf8View')), + (arrow_cast('apple', 'Utf8View')), + (NULL), + (arrow_cast('banana', 'Utf8View')) as tab(col); +---- +apple + statement ok DROP TABLE t1;