diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index f8383bbe3d2f..22991c4f2876 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -19,15 +19,14 @@ use std::sync::Arc; +use arrow_array::builder::BufferBuilder; +use arrow_array::types::*; use arrow_array::*; -use arrow_array::{builder::PrimitiveRunBuilder, types::*}; use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; -use arrow_array::cast::{ - as_generic_binary_array, as_largestring_array, as_primitive_array, as_string_array, -}; +use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array}; use num::{ToPrimitive, Zero}; /// Take elements by index from [Array], creating a new [Array] from those indexes. @@ -816,22 +815,14 @@ where Ok(DictionaryArray::::from(data)) } -macro_rules! primitive_run_take { - ($t:ty, $o:ty, $indices:ident, $value:ident) => { - take_primitive_run_values::<$o, $t>( - $indices, - as_primitive_array::<$t>($value.values()), - ) - }; -} - /// `take` implementation for run arrays /// /// Finds physical indices for the given logical indices and builds output run array -/// by taking values in the input run array at the physical indices. -/// for e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `indices=[2,7]` -/// would be converted to `physical_indices=[1,3]` which will be used to build -/// output `RunArray{ run_ends=[2], values=[2] }` +/// by taking values in the input run_array.values at the physical indices. +/// The output run array will be run encoded on the physical indices and not on output values. +/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]` +/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build +/// output `RunArray{ run_ends=[2,4], values=[2,2] }`. fn take_run( run_array: &RunArray, logical_indices: &PrimitiveArray, @@ -842,43 +833,60 @@ where I: ArrowPrimitiveType, I::Native: ToPrimitive, { - match run_array.data_type() { - DataType::RunEndEncoded(_, fl) => { - let physical_indices = - run_array.get_physical_indices(logical_indices.values())?; - - downcast_primitive! { - fl.data_type() => (primitive_run_take, T, physical_indices, run_array), - dt => Err(ArrowError::NotYetImplemented(format!("take_run is not implemented for {dt:?}"))) - } + // get physical indices for the input logical indices + let physical_indices = run_array.get_physical_indices(logical_indices.values())?; + + // Run encode the physical indices into new_run_ends_builder + // Keep track of the physical indices to take in take_value_indices + // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`. + let mut new_run_ends_builder = BufferBuilder::::new(1); + let mut take_value_indices = BufferBuilder::::new(1); + let mut new_physical_len = 1; + for ix in 1..physical_indices.len() { + if physical_indices[ix] != physical_indices[ix - 1] { + take_value_indices + .append(I::Native::from_usize(physical_indices[ix - 1]).unwrap()); + new_run_ends_builder.append(T::Native::from_usize(ix).unwrap()); + new_physical_len += 1; } - dt => Err(ArrowError::InvalidArgumentError(format!( - "Expected DataType::RunEndEncoded found {dt:?}" - ))), } -} + take_value_indices.append( + I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap(), + ); + new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap()); + let new_run_ends = unsafe { + // Safety: + // The function builds a valid run_ends array and hence need not be validated. + ArrayDataBuilder::new(T::DATA_TYPE) + .len(new_physical_len) + .null_count(0) + .add_buffer(new_run_ends_builder.finish()) + .build_unchecked() + }; -// Builds a `RunArray` by taking values from given array for the given indices. -fn take_primitive_run_values( - physical_indices: Vec, - values: &PrimitiveArray, -) -> Result, ArrowError> -where - R: RunEndIndexType, - V: ArrowPrimitiveType, -{ - let mut builder = PrimitiveRunBuilder::::new(); - let values_len = values.len(); - for ix in physical_indices { - if ix >= values_len { - return Err(ArrowError::InvalidArgumentError("The requested index {ix} is out of bounds for values array with length {values_len}".to_string())); - } else if values.is_null(ix) { - builder.append_null() - } else { - builder.append_value(values.value(ix)) - } - } - Ok(builder.finish()) + let take_value_indices: PrimitiveArray = unsafe { + // Safety: + // The function builds a valid take_value_indices array and hence need not be validated. + ArrayDataBuilder::new(I::DATA_TYPE) + .len(new_physical_len) + .null_count(0) + .add_buffer(take_value_indices.finish()) + .build_unchecked() + .into() + }; + + let new_values = take(run_array.values(), &take_value_indices, None)?; + + let builder = ArrayDataBuilder::new(run_array.data_type().clone()) + .len(physical_indices.len()) + .add_child_data(new_run_ends) + .add_child_data(new_values.into_data()); + let array_data = unsafe { + // Safety: + // This function builds a valid run array and hence can skip validation. + builder.build_unchecked() + }; + Ok(array_data.into()) } /// Takes/filters a list array's inner data using the offsets of the list array. @@ -983,7 +991,7 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::*; + use arrow_array::{builder::*, cast::as_primitive_array}; use arrow_schema::TimeUnit; fn test_take_decimal_arrays( @@ -2159,24 +2167,24 @@ mod tests { #[test] fn test_take_runs() { - let logical_array: Vec = vec![1_i32, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2]; + let logical_array: Vec = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2]; let mut builder = PrimitiveRunBuilder::::new(); builder.extend(logical_array.into_iter().map(Some)); let run_array = builder.finish(); let take_indices: PrimitiveArray = - vec![2, 7, 10].into_iter().collect(); + vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect(); let take_out = take_run(&run_array, &take_indices).unwrap(); - assert_eq!(take_out.len(), 3); + assert_eq!(take_out.len(), 7); - assert_eq!(take_out.run_ends().len(), 1); - assert_eq!(take_out.run_ends().value(0), 3); + assert_eq!(take_out.run_ends().len(), 5); + assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]); let take_out_values = as_primitive_array::(take_out.values()); - assert_eq!(take_out_values.value(0), 2); + assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]); } #[test]