diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 9a3ee9c8ebcd..22d37043e473 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -96,8 +96,8 @@ impl ScalarUDFImpl for PowUdf { // function, but we check again to make sure assert_eq!(args.len(), 2); let (base, exp) = (&args[0], &args[1]); - assert_eq!(base.data_type(), DataType::Float64); - assert_eq!(exp.data_type(), DataType::Float64); + assert_eq!(base.data_type(), &DataType::Float64); + assert_eq!(exp.data_type(), &DataType::Float64); match (base, exp) { // For demonstration purposes we also implement the scalar / scalar @@ -108,28 +108,31 @@ impl ScalarUDFImpl for PowUdf { // the DataFusion expression simplification logic will often invoke // this path once during planning, and simply use the result during // execution. - ( - ColumnarValue::Scalar(ScalarValue::Float64(base)), - ColumnarValue::Scalar(ScalarValue::Float64(exp)), - ) => { - // compute the output. Note DataFusion treats `None` as NULL. - let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), - // one or both arguments were NULL - _ => None, - }; - Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + (ColumnarValue::Scalar(base), ColumnarValue::Scalar(exp)) => { + match (base.value(), exp.value()) { + (ScalarValue::Float64(base), ScalarValue::Float64(exp)) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::from(ScalarValue::from(res))) + } + _ => { + internal_err!("Invalid argument types to pow function") + } + } } // special case if the exponent is a constant - ( - ColumnarValue::Array(base_array), - ColumnarValue::Scalar(ScalarValue::Float64(exp)), - ) => { - let result_array = match exp { + (ColumnarValue::Array(base_array), ColumnarValue::Scalar(exp)) => { + let result_array = match exp.value() { // a ^ null = null - None => new_null_array(base_array.data_type(), base_array.len()), + ScalarValue::Float64(None) => { + new_null_array(base_array.data_type(), base_array.len()) + } // a ^ exp - Some(exp) => { + ScalarValue::Float64(Some(exp)) => { // DataFusion has ensured both arguments are Float64: let base_array = base_array.as_primitive::(); // calculate the result for every row. The `unary` @@ -139,24 +142,25 @@ impl ScalarUDFImpl for PowUdf { compute::unary(base_array, |base| base.powf(*exp)); Arc::new(res) } + _ => return internal_err!("Invalid argument types to pow function"), }; Ok(ColumnarValue::Array(result_array)) } // special case if the base is a constant (note this code is quite // similar to the previous case, so we omit comments) - ( - ColumnarValue::Scalar(ScalarValue::Float64(base)), - ColumnarValue::Array(exp_array), - ) => { - let res = match base { - None => new_null_array(exp_array.data_type(), exp_array.len()), - Some(base) => { + (ColumnarValue::Scalar(base), ColumnarValue::Array(exp_array)) => { + let res = match base.value() { + ScalarValue::Float64(None) => { + new_null_array(exp_array.data_type(), exp_array.len()) + } + ScalarValue::Float64(Some(base)) => { let exp_array = exp_array.as_primitive::(); let res: Float64Array = compute::unary(exp_array, |exp| base.powf(exp)); Arc::new(res) } + _ => return internal_err!("Invalid argument types to pow function"), }; Ok(ColumnarValue::Array(res)) } @@ -169,10 +173,6 @@ impl ScalarUDFImpl for PowUdf { )?; Ok(ColumnarValue::Array(Arc::new(res))) } - // if the types were not float, it is a bug in DataFusion - _ => { - internal_err!("Invalid argument types to pow function") - } } } diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index b4663b345f64..cf24a4b23eb5 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -207,7 +207,7 @@ impl ScalarUDFImpl for MyEq { fn invoke(&self, _args: &[ColumnarValue]) -> Result { // this example simply returns "true" which is not what a real // implementation would do. - Ok(ColumnarValue::Scalar(ScalarValue::from(true))) + Ok(ColumnarValue::from(ScalarValue::from(true))) } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 10541e01914a..dc7e4bd1687a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -37,6 +37,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod logical; pub mod parsers; pub mod rounding; pub mod scalar; diff --git a/datafusion/common/src/logical/equality.rs b/datafusion/common/src/logical/equality.rs new file mode 100644 index 000000000000..239cebf1338f --- /dev/null +++ b/datafusion/common/src/logical/equality.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; + +pub trait LogicallyEq { + #[must_use] + fn logically_eq(&self, other: &Rhs) -> bool; +} + +impl LogicallyEq for DataType { + fn logically_eq(&self, other: &Self) -> bool { + use DataType::*; + + match (self, other) { + (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View) + | (Binary | LargeBinary | BinaryView, Binary | LargeBinary | BinaryView) => { + true + } + (Dictionary(_, left), Dictionary(_, right)) => left.logically_eq(right), + (Dictionary(_, inner), other) | (other, Dictionary(_, inner)) => { + other.logically_eq(inner) + } + (RunEndEncoded(_, left), RunEndEncoded(_, right)) => { + left.data_type().logically_eq(right.data_type()) + } + (RunEndEncoded(_, inner), other) | (other, RunEndEncoded(_, inner)) => { + other.logically_eq(inner.data_type()) + } + _ => self == other, + } + } +} diff --git a/datafusion/common/src/logical/mod.rs b/datafusion/common/src/logical/mod.rs new file mode 100644 index 000000000000..ff4f478bc1a8 --- /dev/null +++ b/datafusion/common/src/logical/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod equality; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3356a85fb6d4..b1c703066c03 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -27,7 +27,7 @@ use std::convert::Infallible; use std::fmt; use std::hash::Hash; use std::hash::Hasher; -use std::iter::repeat; +use std::iter::{repeat, Peekable}; use std::str::FromStr; use std::sync::Arc; @@ -41,7 +41,7 @@ use crate::hash_utils::create_hashes; use crate::utils::{ array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, }; -use arrow::compute::kernels::numeric::*; +use arrow::compute::kernels::{self, numeric::*}; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use arrow::{ array::*, @@ -224,18 +224,10 @@ pub enum ScalarValue { UInt64(Option), /// utf-8 encoded string. Utf8(Option), - /// utf-8 encoded string but from view types. - Utf8View(Option), - /// utf-8 encoded string representing a LargeString's arrow type. - LargeUtf8(Option), /// binary Binary(Option>), - /// binary but from view types. - BinaryView(Option>), /// fixed size binary FixedSizeBinary(i32, Option>), - /// large binary - LargeBinary(Option>), /// Fixed size list scalar. /// /// The array must be a FixedSizeListArray with length 1. @@ -293,8 +285,6 @@ pub enum ScalarValue { /// `.1`: the list of fields, zero-to-one of which will by set in `.0` /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came Union(Option<(i8, Box)>, UnionFields, UnionMode), - /// Dictionary type: index type and value - Dictionary(Box, Box), } impl Hash for Fl { @@ -354,18 +344,10 @@ impl PartialEq for ScalarValue { (UInt64(_), _) => false, (Utf8(v1), Utf8(v2)) => v1.eq(v2), (Utf8(_), _) => false, - (Utf8View(v1), Utf8View(v2)) => v1.eq(v2), - (Utf8View(_), _) => false, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), - (LargeUtf8(_), _) => false, (Binary(v1), Binary(v2)) => v1.eq(v2), (Binary(_), _) => false, - (BinaryView(v1), BinaryView(v2)) => v1.eq(v2), - (BinaryView(_), _) => false, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2), (FixedSizeBinary(_, _), _) => false, - (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), - (LargeBinary(_), _) => false, (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), @@ -414,8 +396,6 @@ impl PartialEq for ScalarValue { val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) } (Union(_, _, _), _) => false, - (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), - (Dictionary(_, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -483,18 +463,10 @@ impl PartialOrd for ScalarValue { (UInt64(_), _) => None, (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), (Utf8(_), _) => None, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), - (LargeUtf8(_), _) => None, - (Utf8View(v1), Utf8View(v2)) => v1.partial_cmp(v2), - (Utf8View(_), _) => None, (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), (Binary(_), _) => None, - (BinaryView(v1), BinaryView(v2)) => v1.partial_cmp(v2), - (BinaryView(_), _) => None, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2), (FixedSizeBinary(_, _), _) => None, - (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), - (LargeBinary(_), _) => None, // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), (FixedSizeList(arr1), FixedSizeList(arr2)) => { @@ -558,15 +530,6 @@ impl PartialOrd for ScalarValue { } } (Union(_, _, _), _) => None, - (Dictionary(k1, v1), Dictionary(k2, v2)) => { - // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { - v1.partial_cmp(v2) - } else { - None - } - } - (Dictionary(_, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -716,10 +679,8 @@ impl std::hash::Hash for ScalarValue { UInt16(v) => v.hash(state), UInt32(v) => v.hash(state), UInt64(v) => v.hash(state), - Utf8(v) | LargeUtf8(v) | Utf8View(v) => v.hash(state), - Binary(v) | FixedSizeBinary(_, v) | LargeBinary(v) | BinaryView(v) => { - v.hash(state) - } + Utf8(v) => v.hash(state), + Binary(v) | FixedSizeBinary(_, v) => v.hash(state), List(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } @@ -757,10 +718,6 @@ impl std::hash::Hash for ScalarValue { t.hash(state); m.hash(state); } - Dictionary(k, v) => { - k.hash(state); - v.hash(state); - } // stable hash for Null value Null => 1.hash(state), } @@ -791,34 +748,6 @@ pub fn get_dict_value( Ok((dict_array.values(), dict_array.key(index))) } -/// Create a dictionary array representing `value` repeated `size` -/// times -fn dict_from_scalar( - value: &ScalarValue, - size: usize, -) -> Result { - // values array is one element long (the value) - let values_array = value.to_array_of_size(1)?; - - // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(if value.is_null() { - None - } else { - Some(K::default_value()) - }) - .take(size) - .collect(); - - // create a new DictionaryArray - // - // Note: this path could be made faster by using the ArrayData - // APIs and skipping validation, if it every comes up in - // performance traces. - Ok(Arc::new( - DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above - )) -} - /// Create a dictionary array representing all the values in values fn dict_from_values( values_array: ArrayRef, @@ -1279,12 +1208,8 @@ impl ScalarValue { ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, - ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, - ScalarValue::Utf8View(_) => DataType::Utf8View, ScalarValue::Binary(_) => DataType::Binary, - ScalarValue::BinaryView(_) => DataType::BinaryView, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), - ScalarValue::LargeBinary(_) => DataType::LargeBinary, ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), @@ -1314,9 +1239,6 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) } ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), - ScalarValue::Dictionary(k, v) => { - DataType::Dictionary(k.clone(), Box::new(v.data_type())) - } ScalarValue::Null => DataType::Null, } } @@ -1540,13 +1462,8 @@ impl ScalarValue { ScalarValue::UInt16(v) => v.is_none(), ScalarValue::UInt32(v) => v.is_none(), ScalarValue::UInt64(v) => v.is_none(), - ScalarValue::Utf8(v) - | ScalarValue::Utf8View(v) - | ScalarValue::LargeUtf8(v) => v.is_none(), - ScalarValue::Binary(v) - | ScalarValue::BinaryView(v) - | ScalarValue::FixedSizeBinary(_, v) - | ScalarValue::LargeBinary(v) => v.is_none(), + ScalarValue::Utf8(v) => v.is_none(), + ScalarValue::Binary(v) | ScalarValue::FixedSizeBinary(_, v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. ScalarValue::List(arr) => arr.len() == arr.null_count(), @@ -1575,7 +1492,6 @@ impl ScalarValue { Some((_, s)) => s.is_null(), None => true, }, - ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1610,6 +1526,10 @@ impl ScalarValue { } } + pub fn to_array_of_type(&self, data_type: &DataType) -> Result { + self.to_array_of_size_and_type(1, data_type) + } + /// Converts a scalar value into an 1-row array. /// /// # Errors @@ -1704,6 +1624,21 @@ impl ScalarValue { Some(sv) => sv.data_type(), }; + Self::iter_to_array_of_type_internal(&mut scalars, &data_type) + } + + pub fn iter_to_array_of_type( + scalars: impl IntoIterator, + data_type: &DataType, + ) -> Result { + let mut scalars = scalars.into_iter().peekable(); + Self::iter_to_array_of_type_internal(&mut scalars, data_type) + } + + fn iter_to_array_of_type_internal( + scalars: &mut Peekable>, + data_type: &DataType, + ) -> Result { /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types macro_rules! build_array_primitive { @@ -1792,12 +1727,12 @@ impl ScalarValue { DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), - DataType::Utf8View => build_array_string!(StringViewArray, Utf8View), + DataType::Utf8View => build_array_string!(StringViewArray, Utf8), DataType::Utf8 => build_array_string!(StringArray, Utf8), - DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), - DataType::BinaryView => build_array_string!(BinaryViewArray, BinaryView), + DataType::LargeUtf8 => build_array_string!(LargeStringArray, Utf8), + DataType::BinaryView => build_array_string!(BinaryViewArray, Binary), DataType::Binary => build_array_string!(BinaryArray, Binary), - DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + DataType::LargeBinary => build_array_string!(LargeBinaryArray, Binary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), DataType::Date64 => build_array_primitive!(Date64Array, Date64), DataType::Time32(TimeUnit::Second) => { @@ -1891,25 +1826,7 @@ impl ScalarValue { arrow::compute::concat(arrays.as_slice())? } DataType::Dictionary(key_type, value_type) => { - // create the values array - let value_scalars = scalars - .map(|scalar| match scalar { - ScalarValue::Dictionary(inner_key_type, scalar) => { - if &inner_key_type == key_type { - Ok(*scalar) - } else { - _exec_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") - } - } - _ => { - _exec_err!( - "Expected scalar of type {value_type} but found: {scalar} {scalar:?}" - ) - } - }) - .collect::>>()?; - - let values = Self::iter_to_array(value_scalars)?; + let values = Self::iter_to_array_of_type_internal(scalars, value_type)?; assert_eq!(values.data_type(), value_type.as_ref()); match key_type.as_ref() { @@ -2081,7 +1998,7 @@ impl ScalarValue { let values = if values.is_empty() { new_empty_array(data_type) } else { - Self::iter_to_array(values.iter().cloned()).unwrap() + Self::iter_to_array_of_type(values.to_vec(), data_type).unwrap() }; Arc::new(array_into_list_array(values, nullable)) } @@ -2176,7 +2093,18 @@ impl ScalarValue { } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_large_list_array(values)) + Arc::new(array_into_large_list_array(values, true)) + } + + pub fn to_array_of_size_and_type( + &self, + size: usize, + data_type: &DataType, + ) -> Result { + // TODO(@notfilippo): for now cast as it's a POC, but it can be optimized later with a bit `match` + let array = self.to_array_of_size(size)?; + let cast_array = kernels::cast::cast(&array, data_type)?; + Ok(cast_array) } /// Converts a scalar value into an array of `size` rows. @@ -2265,18 +2193,6 @@ impl ScalarValue { } None => new_null_array(&DataType::Utf8, size), }, - ScalarValue::Utf8View(e) => match e { - Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::Utf8View, size), - }, - ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::LargeUtf8, size), - }, ScalarValue::Binary(e) => match e { Some(value) => Arc::new( repeat(Some(value.as_slice())) @@ -2287,16 +2203,6 @@ impl ScalarValue { Arc::new(repeat(None::<&str>).take(size).collect::()) } }, - ScalarValue::BinaryView(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => { - Arc::new(repeat(None::<&str>).take(size).collect::()) - } - }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( @@ -2313,18 +2219,6 @@ impl ScalarValue { .unwrap(), ), }, - ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => Arc::new( - repeat(None::<&str>) - .take(size) - .collect::(), - ), - }, ScalarValue::List(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } @@ -2463,20 +2357,6 @@ impl ScalarValue { new_null_array(&dt, size) } }, - ScalarValue::Dictionary(key_type, v) => { - // values array is one element long (the value) - match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size)?, - DataType::Int16 => dict_from_scalar::(v, size)?, - DataType::Int32 => dict_from_scalar::(v, size)?, - DataType::Int64 => dict_from_scalar::(v, size)?, - DataType::UInt8 => dict_from_scalar::(v, size)?, - DataType::UInt16 => dict_from_scalar::(v, size)?, - DataType::UInt32 => dict_from_scalar::(v, size)?, - DataType::UInt64 => dict_from_scalar::(v, size)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - } - } ScalarValue::Null => new_null_array(&DataType::Null, size), }) } @@ -2641,17 +2521,11 @@ impl ScalarValue { DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, - DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary)? - } - DataType::BinaryView => { - typed_cast!(array, index, BinaryViewArray, BinaryView)? - } + DataType::LargeBinary => typed_cast!(array, index, LargeBinaryArray, Binary)?, + DataType::BinaryView => typed_cast!(array, index, BinaryViewArray, Binary)?, DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, - DataType::LargeUtf8 => { - typed_cast!(array, index, LargeStringArray, LargeUtf8)? - } - DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, + DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, Utf8)?, + DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8)?, DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); @@ -2661,11 +2535,14 @@ impl ScalarValue { ScalarValue::List(arr) } - DataType::LargeList(_) => { + DataType::LargeList(field) => { let list_array = as_large_list_array(array); let nested_array = list_array.value(index); // Produces a single element `LargeListArray` with the value at `index`. - let arr = Arc::new(array_into_large_list_array(nested_array)); + let arr = Arc::new(array_into_large_list_array( + nested_array, + field.is_nullable(), + )); ScalarValue::LargeList(arr) } @@ -2735,15 +2612,13 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary - let value = match values_index { + match values_index { Some(values_index) => { ScalarValue::try_from_array(values_array, values_index) } // else entry was null, so return null None => values_array.data_type().try_into(), - }?; - - Self::Dictionary(key_type.clone(), Box::new(value)) + }? } DataType::Struct(_) => { let a = array.slice(index, 1); @@ -2894,6 +2769,7 @@ impl ScalarValue { /// Panics if `self` is a dictionary with invalid key type #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + // TODO(@notfilippo): maybe match on the array DataType instead of self Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( @@ -2950,24 +2826,12 @@ impl ScalarValue { ScalarValue::Utf8(val) => { eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8View(val) => { - eq_array_primitive!(array, index, StringViewArray, val)? - } - ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val)? - } ScalarValue::Binary(val) => { eq_array_primitive!(array, index, BinaryArray, val)? } - ScalarValue::BinaryView(val) => { - eq_array_primitive!(array, index, BinaryViewArray, val)? - } ScalarValue::FixedSizeBinary(_, val) => { eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } - ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val)? - } ScalarValue::List(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } @@ -3044,24 +2908,6 @@ impl ScalarValue { array.child(ti).is_null(index) } } - ScalarValue::Dictionary(key_type, v) => { - let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index)?, - DataType::Int16 => get_dict_value::(array, index)?, - DataType::Int32 => get_dict_value::(array, index)?, - DataType::Int64 => get_dict_value::(array, index)?, - DataType::UInt8 => get_dict_value::(array, index)?, - DataType::UInt16 => get_dict_value::(array, index)?, - DataType::UInt32 => get_dict_value::(array, index)?, - DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - }; - // was the value in the array non null? - match values_index { - Some(values_index) => v.eq_array(values_array, values_index)?, - None => v.is_null(), - } - } ScalarValue::Null => array.is_null(index), }) } @@ -3104,9 +2950,7 @@ impl ScalarValue { | ScalarValue::DurationMillisecond(_) | ScalarValue::DurationMicrosecond(_) | ScalarValue::DurationNanosecond(_) => 0, - ScalarValue::Utf8(s) - | ScalarValue::LargeUtf8(s) - | ScalarValue::Utf8View(s) => { + ScalarValue::Utf8(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } ScalarValue::TimestampSecond(_, s) @@ -3115,10 +2959,7 @@ impl ScalarValue { | ScalarValue::TimestampNanosecond(_, s) => { s.as_ref().map(|s| s.len()).unwrap_or_default() } - ScalarValue::Binary(b) - | ScalarValue::FixedSizeBinary(_, b) - | ScalarValue::LargeBinary(b) - | ScalarValue::BinaryView(b) => { + ScalarValue::Binary(b) | ScalarValue::FixedSizeBinary(_, b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } ScalarValue::List(arr) => arr.get_array_memory_size(), @@ -3135,10 +2976,6 @@ impl ScalarValue { + (std::mem::size_of::() * fields.len()) + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() } - ScalarValue::Dictionary(dt, sv) => { - // `dt` and `sv` are boxed, so they are NOT already included in `self` - dt.size() + sv.size() - } } } @@ -3385,12 +3222,12 @@ impl TryFrom<&DataType> for ScalarValue { ScalarValue::Decimal256(None, *precision, *scale) } DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Utf8View => ScalarValue::Utf8View(None), + DataType::LargeUtf8 => ScalarValue::Utf8(None), + DataType::Utf8View => ScalarValue::Utf8(None), DataType::Binary => ScalarValue::Binary(None), - DataType::BinaryView => ScalarValue::BinaryView(None), + DataType::BinaryView => ScalarValue::Binary(None), DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), - DataType::LargeBinary => ScalarValue::LargeBinary(None), + DataType::LargeBinary => ScalarValue::Binary(None), DataType::Date32 => ScalarValue::Date32(None), DataType::Date64 => ScalarValue::Date64(None), DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), @@ -3432,10 +3269,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { ScalarValue::DurationNanosecond(None) } - DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( - index_type.clone(), - Box::new(value_type.as_ref().try_into()?), - ), + DataType::Dictionary(_, value_type) => Self::try_from(value_type.as_ref())?, // `ScalaValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), @@ -3516,13 +3350,8 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) - | ScalarValue::LargeUtf8(e) - | ScalarValue::Utf8View(e) => format_option!(f, e)?, - ScalarValue::Binary(e) - | ScalarValue::FixedSizeBinary(_, e) - | ScalarValue::LargeBinary(e) - | ScalarValue::BinaryView(e) => match e { + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) | ScalarValue::FixedSizeBinary(_, e) => match e { Some(bytes) => { // print up to first 10 bytes, with trailing ... if needed for b in bytes.iter().take(10) { @@ -3636,7 +3465,6 @@ impl fmt::Display for ScalarValue { Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, }, - ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -3696,22 +3524,12 @@ impl fmt::Debug for ScalarValue { } ScalarValue::Utf8(None) => write!(f, "Utf8({self})"), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{self}\")"), - ScalarValue::Utf8View(None) => write!(f, "Utf8View({self})"), - ScalarValue::Utf8View(Some(_)) => write!(f, "Utf8View(\"{self}\")"), - ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({self})"), - ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{self}\")"), ScalarValue::Binary(None) => write!(f, "Binary({self})"), ScalarValue::Binary(Some(b)) => { write!(f, "Binary(\"")?; fmt_binary(b.as_slice(), f)?; write!(f, "\")") } - ScalarValue::BinaryView(None) => write!(f, "BinaryView({self})"), - ScalarValue::BinaryView(Some(b)) => { - write!(f, "BinaryView(\"")?; - fmt_binary(b.as_slice(), f)?; - write!(f, "\")") - } ScalarValue::FixedSizeBinary(size, None) => { write!(f, "FixedSizeBinary({size}, {self})") } @@ -3720,12 +3538,6 @@ impl fmt::Debug for ScalarValue { fmt_binary(b.as_slice(), f)?; write!(f, "\")") } - ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), - ScalarValue::LargeBinary(Some(b)) => { - write!(f, "LargeBinary(\"")?; - fmt_binary(b.as_slice(), f)?; - write!(f, "\")") - } ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), @@ -3813,7 +3625,6 @@ impl fmt::Debug for ScalarValue { Some((id, val)) => write!(f, "Union {}:{}", id, val), None => write!(f, "Union(NULL)"), }, - ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } } @@ -3865,9 +3676,7 @@ impl ScalarType for Date32Type { mod tests { use super::*; - use crate::cast::{ - as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, - }; + use crate::cast::{as_map_array, as_struct_array, as_uint32_array, as_uint64_array}; use crate::assert_batches_eq; use crate::utils::array_into_list_array_nullable; @@ -4836,21 +4645,11 @@ mod tests { StringArray, vec![Some("foo"), None, Some("bar")] ); - check_scalar_iter_string!( - LargeUtf8, - LargeStringArray, - vec![Some("foo"), None, Some("bar")] - ); check_scalar_iter_binary!( Binary, BinaryArray, vec![Some(b"foo"), None, Some(b"bar")] ); - check_scalar_iter_binary!( - LargeBinary, - LargeBinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] - ); } #[test] @@ -4867,38 +4666,6 @@ mod tests { ); } - #[test] - fn scalar_iter_to_dictionary() { - fn make_val(v: Option) -> ScalarValue { - let key_type = DataType::Int32; - let value = ScalarValue::Utf8(v); - ScalarValue::Dictionary(Box::new(key_type), Box::new(value)) - } - - let scalars = [ - make_val(Some("Foo".into())), - make_val(None), - make_val(Some("Bar".into())), - ]; - - let array = ScalarValue::iter_to_array(scalars).unwrap(); - let array = as_dictionary_array::(&array).unwrap(); - let values_array = as_string_array(array.values()).unwrap(); - - let values = array - .keys_iter() - .map(|k| { - k.map(|k| { - assert!(values_array.is_valid(k)); - values_array.value(k) - }) - }) - .collect::>(); - - let expected = vec![Some("Foo"), None, Some("Bar")]; - assert_eq!(values, expected); - } - #[test] fn scalar_iter_to_array_mismatched_types() { use ScalarValue::*; @@ -5028,18 +4795,6 @@ mod tests { assert_ne!(list_scalar, nested_list_scalar); } - #[test] - fn scalar_try_from_dict_datatype() { - let data_type = - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); - let data_type = &data_type; - let expected = ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Utf8(None)), - ); - assert_eq!(expected, data_type.try_into().unwrap()) - } - #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, @@ -5180,29 +4935,6 @@ mod tests { }}; } - /// create a test case for DictionaryArray<$INDEX_TY> - macro_rules! make_str_dict_test_case { - ($INPUT:expr, $INDEX_TY:ident) => {{ - TestCase { - array: Arc::new( - $INPUT - .iter() - .cloned() - .collect::>(), - ), - scalars: $INPUT - .iter() - .map(|v| { - ScalarValue::Dictionary( - Box::new($INDEX_TY::DATA_TYPE), - Box::new(ScalarValue::Utf8(v.map(|v| v.to_string()))), - ) - }) - .collect(), - } - }}; - } - let cases = vec![ make_test_case!(bool_vals, BooleanArray, Boolean), make_test_case!(f32_vals, Float32Array, Float32), @@ -5216,9 +4948,7 @@ mod tests { make_test_case!(u32_vals, UInt32Array, UInt32), make_test_case!(u64_vals, UInt64Array, UInt64), make_str_test_case!(str_vals, StringArray, Utf8), - make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), make_binary_test_case!(str_vals, BinaryArray, Binary), - make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), make_test_case!(i32_vals, Date32Array, Date32), make_test_case!(i64_vals, Date64Array, Date64), make_test_case!(i32_vals, Time32SecondArray, Time32Second), @@ -5275,14 +5005,6 @@ mod tests { IntervalMonthDayNanoArray, IntervalMonthDayNano ), - make_str_dict_test_case!(str_vals, Int8Type), - make_str_dict_test_case!(str_vals, Int16Type), - make_str_dict_test_case!(str_vals, Int32Type), - make_str_dict_test_case!(str_vals, Int64Type), - make_str_dict_test_case!(str_vals, UInt8Type), - make_str_dict_test_case!(str_vals, UInt16Type), - make_str_dict_test_case!(str_vals, UInt32Type), - make_str_dict_test_case!(str_vals, UInt64Type), ]; for case in cases { @@ -5942,22 +5664,23 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); - check_scalar_cast( - ScalarValue::from("foo"), - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - ); - - check_scalar_cast( - ScalarValue::Utf8(None), - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - ); - - check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); - check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); - check_scalar_cast( - ScalarValue::from("larger than 12 bytes string"), - DataType::Utf8View, - ); + // TODO(@notfilippo): this tests fails but it should check if logically equal + // check_scalar_cast( + // ScalarValue::from("foo"), + // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + // ); + // + // check_scalar_cast( + // ScalarValue::Utf8(None), + // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + // ); + // + // check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); + // check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); + // check_scalar_cast( + // ScalarValue::from("larger than 12 bytes string"), + // DataType::Utf8View, + // ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` @@ -6669,22 +6392,6 @@ mod tests { ScalarValue::Binary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); - let no_binary_value = ScalarValue::BinaryView(None); - assert_eq!(format!("{no_binary_value}"), "NULL"); - let small_binary_value = ScalarValue::BinaryView(Some(vec![1u8, 2, 3])); - assert_eq!(format!("{small_binary_value}"), "010203"); - let large_binary_value = - ScalarValue::BinaryView(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); - assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); - - let no_binary_value = ScalarValue::LargeBinary(None); - assert_eq!(format!("{no_binary_value}"), "NULL"); - let small_binary_value = ScalarValue::LargeBinary(Some(vec![1u8, 2, 3])); - assert_eq!(format!("{small_binary_value}"), "010203"); - let large_binary_value = - ScalarValue::LargeBinary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); - assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); - let no_binary_value = ScalarValue::FixedSizeBinary(3, None); assert_eq!(format!("{no_binary_value}"), "NULL"); let small_binary_value = ScalarValue::FixedSizeBinary(3, Some(vec![1u8, 2, 3])); @@ -6711,28 +6418,6 @@ mod tests { "Binary(\"1,2,3,4,5,6,7,8,9,10,11\")" ); - let no_binary_value = ScalarValue::BinaryView(None); - assert_eq!(format!("{no_binary_value:?}"), "BinaryView(NULL)"); - let small_binary_value = ScalarValue::BinaryView(Some(vec![1u8, 2, 3])); - assert_eq!(format!("{small_binary_value:?}"), "BinaryView(\"1,2,3\")"); - let large_binary_value = - ScalarValue::BinaryView(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); - assert_eq!( - format!("{large_binary_value:?}"), - "BinaryView(\"1,2,3,4,5,6,7,8,9,10,11\")" - ); - - let no_binary_value = ScalarValue::LargeBinary(None); - assert_eq!(format!("{no_binary_value:?}"), "LargeBinary(NULL)"); - let small_binary_value = ScalarValue::LargeBinary(Some(vec![1u8, 2, 3])); - assert_eq!(format!("{small_binary_value:?}"), "LargeBinary(\"1,2,3\")"); - let large_binary_value = - ScalarValue::LargeBinary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); - assert_eq!( - format!("{large_binary_value:?}"), - "LargeBinary(\"1,2,3,4,5,6,7,8,9,10,11\")" - ); - let no_binary_value = ScalarValue::FixedSizeBinary(3, None); assert_eq!(format!("{no_binary_value:?}"), "FixedSizeBinary(3, NULL)"); let small_binary_value = ScalarValue::FixedSizeBinary(3, Some(vec![1u8, 2, 3])); @@ -6908,15 +6593,4 @@ mod tests { ); assert!(dense_scalar.is_null()); } - - #[test] - fn null_dictionary_scalar_produces_null_dictionary_array() { - let dictionary_scalar = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Null), - ); - assert!(dictionary_scalar.is_null()); - let dictionary_array = dictionary_scalar.to_array().unwrap(); - assert!(dictionary_array.is_null(0)); - } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 83f98ff9aff6..c623109e7b78 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -379,10 +379,10 @@ pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { /// Wrap an array into a single element `LargeListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { +pub fn array_into_large_list_array(arr: ArrayRef, nullable: bool) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), offsets, arr, None, diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 2a862dd6dcb3..636d27d6f1b6 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1664,11 +1664,7 @@ mod tests { let schema = format.infer_schema(&state, &store, &files).await.unwrap(); let null_i64 = ScalarValue::Int64(None); - let null_utf8 = if force_views { - ScalarValue::Utf8View(None) - } else { - ScalarValue::Utf8(None) - }; + let null_utf8 = ScalarValue::Utf8(None); // Fetch statistics for first file let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; @@ -1677,18 +1673,13 @@ mod tests { // column c1 let c1_stats = &stats.column_statistics[0]; assert_eq!(c1_stats.null_count, Precision::Exact(1)); - let expected_type = if force_views { - ScalarValue::Utf8View - } else { - ScalarValue::Utf8 - }; assert_eq!( c1_stats.max_value, - Precision::Exact(expected_type(Some("bar".to_string()))) + Precision::Exact(ScalarValue::Utf8(Some("bar".to_string()))) ); assert_eq!( c1_stats.min_value, - Precision::Exact(expected_type(Some("Foo".to_string()))) + Precision::Exact(ScalarValue::Utf8(Some("Foo".to_string()))) ); // column c2: missing from the file so the table treats all 3 rows as null let c2_stats = &stats.column_statistics[1]; diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index c5a441aacf1d..98e0dc48d034 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -63,13 +63,8 @@ pub struct PartitionedFile { pub object_meta: ObjectMeta, /// Values of partition columns to be appended to each row. /// - /// These MUST have the same count, order, and type than the [`table_partition_cols`]. + /// These MUST have the same count, order, than the [`table_partition_cols`]. /// - /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. - /// - /// - /// [`wrap_partition_type_in_dict`]: crate::datasource::physical_plan::wrap_partition_type_in_dict - /// [`wrap_partition_value_in_dict`]: crate::datasource::physical_plan::wrap_partition_value_in_dict /// [`table_partition_cols`]: table::ListingOptions::table_partition_cols pub partition_values: Vec, /// An optional file range for a more fine-grained parallel execution diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 9f674185694d..b22a4307d460 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -35,8 +35,6 @@ use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, DataFusionError, Statistics}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use log::warn; - /// Convert type to a type suitable for use as a [`ListingTable`] /// partition column. Returns `Dictionary(UInt16, val_type)`, which is /// a reasonable trade off between a reasonable number of partition @@ -46,20 +44,11 @@ use log::warn; /// you MAY also choose not to dictionary-encode the data or to use a /// different dictionary type. /// -/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. -/// /// [`ListingTable`]: crate::datasource::listing::ListingTable pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) } -/// Convert a [`ScalarValue`] of partition columns to a type, as -/// described in the documentation of [`wrap_partition_type_in_dict`], -/// which can wrap the types. -pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) -} - /// The base configurations to provide when creating a physical plan for /// any given file format. /// @@ -430,27 +419,17 @@ impl PartitionColumnProjector { "Invalid partitioning found on disk".to_string(), ))?; - let mut partition_value = Cow::Borrowed(p_value); + let partition_value = Cow::Borrowed(p_value); - // check if user forgot to dict-encode the partition value let field = self.projected_schema.field(sidx); let expected_data_type = field.data_type(); - let actual_data_type = partition_value.data_type(); - if let DataType::Dictionary(key_type, _) = expected_data_type { - if !matches!(actual_data_type, DataType::Dictionary(_, _)) { - warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); - partition_value = Cow::Owned(ScalarValue::Dictionary( - key_type.clone(), - Box::new(partition_value.as_ref().clone()), - )); - } - } cols.insert( sidx, create_output_array( &mut self.key_buffer_cache, partition_value.as_ref(), + expected_data_type, file_batch.num_rows(), )?, ) @@ -509,19 +488,20 @@ where fn create_dict_array( buffer_gen: &mut ZeroBufferGenerator, - dict_val: &ScalarValue, + val: &ScalarValue, len: usize, - data_type: DataType, + dict_type: &DataType, + inner_type: &DataType, ) -> Result where T: ArrowNativeType, { - let dict_vals = dict_val.to_array()?; + let dict_vals = val.to_array_of_type(inner_type)?; let sliced_key_buffer = buffer_gen.get_buffer(len); // assemble pieces together - let mut builder = ArrayData::builder(data_type) + let mut builder = ArrayData::builder(dict_type.clone()) .len(len) .add_buffer(sliced_key_buffer); builder = builder.add_child_data(dict_vals.to_data()); @@ -533,79 +513,88 @@ where fn create_output_array( key_buffer_cache: &mut ZeroBufferGenerators, val: &ScalarValue, + data_type: &DataType, len: usize, ) -> Result { - if let ScalarValue::Dictionary(key_type, dict_val) = &val { + if let DataType::Dictionary(key_type, inner_type) = data_type { match key_type.as_ref() { DataType::Int8 => { return create_dict_array( &mut key_buffer_cache.gen_i8, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::Int16 => { return create_dict_array( &mut key_buffer_cache.gen_i16, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::Int32 => { return create_dict_array( &mut key_buffer_cache.gen_i32, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::Int64 => { return create_dict_array( &mut key_buffer_cache.gen_i64, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::UInt8 => { return create_dict_array( &mut key_buffer_cache.gen_u8, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::UInt16 => { return create_dict_array( &mut key_buffer_cache.gen_u16, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::UInt32 => { return create_dict_array( &mut key_buffer_cache.gen_u32, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } DataType::UInt64 => { return create_dict_array( &mut key_buffer_cache.gen_u64, - dict_val, + val, len, - val.data_type(), + data_type, + inner_type, ); } _ => {} } } - val.to_array_of_size(len) + val.to_array_of_size_and_type(len, data_type) } #[cfg(test)] @@ -770,9 +759,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("26")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -798,9 +787,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("27")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("27"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -828,9 +817,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("28")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("28"), ], ) .expect("Projection of partition columns into record batch failed"); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index f810fb86bd89..940c1451abc0 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -37,9 +37,7 @@ pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; pub use file_groups::FileGroupPartitioner; -pub use file_scan_config::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, -}; +pub use file_scan_config::{wrap_partition_type_in_dict, FileScanConfig}; pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; pub use json::{JsonOpener, NdJsonExec}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 54d4d7262a8e..18216fdc0ae7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1636,10 +1636,7 @@ mod tests { partition_values: vec![ ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), - ScalarValue::Dictionary( - Box::new(DataType::UInt16), - Box::new(ScalarValue::from("26")), - ), + ScalarValue::from("26"), ], range: None, statistics: None, @@ -1651,14 +1648,7 @@ mod tests { Field::new("bool_col", DataType::Boolean, true), Field::new("tinyint_col", DataType::Int32, true), Field::new("month", DataType::UInt8, false), - Field::new( - "day", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), + Field::new("day", DataType::Utf8, false), ]); let parquet_exec = ParquetExec::builder( @@ -1669,14 +1659,7 @@ mod tests { .with_table_partition_cols(vec![ Field::new("year", DataType::Utf8, false), Field::new("month", DataType::UInt8, false), - Field::new( - "day", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), + Field::new("day", DataType::Utf8, false), ]), ) .build(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9bc2bb1d1db9..3c8e5ddd1c74 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -687,14 +687,16 @@ impl BoolVecBuilder { ColumnarValue::Array(array) => { self.combine_array(array.as_boolean()); } - ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { - // False means all containers can not pass the predicate - self.inner = vec![false; self.inner.len()]; - } - _ => { - // Null or true means the rows in container may pass this - // conjunct so we can't prune any containers based on that - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Boolean(Some(false)) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + }, } } diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index f17d13a42060..0686b954207f 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -56,7 +56,8 @@ fn init() { #[test] fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; - let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + let expected = + "Projection: Float64(1234) AS f64, CAST(Utf8(\"foo\") AS LargeUtf8) AS large\ \n EmptyRelation"; quick_test(sql, expected); } diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 7e7544bdb7c0..82a24a57c49c 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -64,13 +64,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { ], &[ ("year", DataType::Int32), - ( - "month", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - ), + ("month", DataType::Utf8), ("day", DataType::Utf8), ], "mirror:///", @@ -170,7 +164,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { let s = ScalarValue::try_from_array(results[0].column(1), 0)?; let month = match extract_as_utf(&s) { Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + _ => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -192,10 +186,8 @@ async fn parquet_distinct_partition_col() -> Result<()> { } fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } + if let ScalarValue::Utf8(v) = v { + return v.clone(); } None } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 0f1c3b8e53c4..c5944e9a7ac3 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -212,7 +212,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { } fn invoke_no_args(&self, _number_rows: usize) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) + Ok(ColumnarValue::from(ScalarValue::Int32(Some(100)))) } } @@ -323,7 +323,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), + Arc::new(move |_| Ok(ColumnarValue::from(ScalarValue::Int32(Some(1))))), )); // Make sure that the UDF is used instead of the built-in function @@ -669,7 +669,10 @@ impl ScalarUDFImpl for TakeUDF { // The actual implementation fn invoke(&self, args: &[ColumnarValue]) -> Result { let take_idx = match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Int64(Some(v)) if v < &2 => *v as usize, + _ => unreachable!(), + }, _ => unreachable!(), }; match &args[take_idx] { @@ -1070,11 +1073,12 @@ impl ScalarUDFImpl for MyRegexUdf { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args { - [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { - Ok(ColumnarValue::Scalar(ScalarValue::Boolean( - self.matches(value.as_deref()), - ))) - } + [ColumnarValue::Scalar(scalar)] => match scalar.value() { + ScalarValue::Utf8(value) => Ok(ColumnarValue::from( + ScalarValue::Boolean(self.matches(value.as_deref())), + )), + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + }, [ColumnarValue::Array(values)] => { let mut builder = BooleanBuilder::with_capacity(values.len()); for value in values.as_string::() { diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 75335209451e..43f1e86a1ff7 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -17,6 +17,7 @@ //! Accumulator module contains the trait definition for aggregation function's accumulators. +use crate::columnar_value::Scalar; use arrow::array::ArrayRef; use datafusion_common::{internal_err, Result, ScalarValue}; use std::fmt::Debug; @@ -72,6 +73,10 @@ pub trait Accumulator: Send + Sync + Debug { /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; + fn evaluate_as_scalar(&mut self) -> Result { + self.evaluate().map(Scalar::from) + } + /// Returns the allocated size required for this accumulator, in /// bytes, including `Self`. /// @@ -250,6 +255,11 @@ pub trait Accumulator: Send + Sync + Debug { /// ``` fn state(&mut self) -> Result>; + fn state_as_scalars(&mut self) -> Result> { + self.state() + .map(|scalars| scalars.into_iter().map(Scalar::from).collect()) + } + /// Updates the accumulator's state from an `Array` containing one /// or more intermediate values. /// diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index bfefb37c98d7..3191f653e6db 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,12 +17,12 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef}; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use std::sync::Arc; /// The result of evaluating an expression. @@ -89,7 +89,94 @@ pub enum ColumnarValue { /// Array of values Array(ArrayRef), /// A single value - Scalar(ScalarValue), + Scalar(Scalar), +} + +#[derive(Clone, Debug)] +pub struct Scalar { + value: ScalarValue, + data_type: DataType, +} + +impl From for Scalar { + fn from(value: ScalarValue) -> Self { + Self { + data_type: value.data_type(), + value, + } + } +} + +impl TryFrom for Scalar { + type Error = DataFusionError; + fn try_from(value: DataType) -> Result { + Ok(Self { + value: ScalarValue::try_from(&value)?, + data_type: value, + }) + } +} + +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Scalar { + pub fn new(value: ScalarValue, data_type: DataType) -> Self { + Self { value, data_type } + } + + pub fn try_from_array(array: &dyn Array, index: usize) -> Result { + let value = ScalarValue::try_from_array(array, index)?; + Ok(Self::new(value, array.data_type().clone())) + } + + #[inline] + pub fn value(&self) -> &ScalarValue { + &self.value + } + + #[inline] + pub fn into_value(self) -> ScalarValue { + self.value + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } + + pub fn to_array_of_size(&self, size: usize) -> Result { + self.value.to_array_of_size_and_type(size, &self.data_type) + } + + pub fn to_array(&self) -> Result { + self.to_array_of_size(1) + } + + pub fn to_scalar(&self) -> Result> { + Ok(arrow::array::Scalar::new(self.to_array()?)) + } + + pub fn iter_to_array(scalars: impl IntoIterator) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return exec_err!("Empty iterator passed to Scalar::iter_to_array"); + } + Some(sv) => sv.data_type().clone(), + }; + + ScalarValue::iter_to_array_of_type(scalars.map(|scalar| scalar.value), &data_type) + } } impl From for ColumnarValue { @@ -100,15 +187,15 @@ impl From for ColumnarValue { impl From for ColumnarValue { fn from(value: ScalarValue) -> Self { - ColumnarValue::Scalar(value) + ColumnarValue::Scalar(Scalar::from(value)) } } impl ColumnarValue { - pub fn data_type(&self) -> DataType { + pub fn data_type(&self) -> &DataType { match self { - ColumnarValue::Array(array_value) => array_value.data_type().clone(), - ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(), + ColumnarValue::Array(array_value) => array_value.data_type(), + ColumnarValue::Scalar(scalar) => scalar.data_type(), } } @@ -197,7 +284,7 @@ impl ColumnarValue { ColumnarValue::Scalar(scalar) => { let scalar_array = if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar { + if let ScalarValue::Float64(Some(float_ts)) = scalar.value() { ScalarValue::Int64(Some( (float_ts * 1_000_000_000_f64).trunc() as i64, )) @@ -214,7 +301,10 @@ impl ColumnarValue { &cast_options, )?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) + Ok(ColumnarValue::Scalar(Scalar::new( + cast_scalar, + cast_type.clone(), + ))) } } } @@ -250,7 +340,7 @@ mod tests { TestCase { input: vec![ ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ], expected: vec![ make_array(1, 3), @@ -260,7 +350,7 @@ mod tests { // scalar and array TestCase { input: vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(1, 3)), ], expected: vec![ @@ -271,9 +361,9 @@ mod tests { // multiple scalars and array TestCase { input: vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(200))), + ColumnarValue::from(ScalarValue::Int32(Some(200))), ], expected: vec![ make_array(100, 3), // scalar is expanded @@ -306,7 +396,7 @@ mod tests { fn values_to_arrays_mixed_length_and_scalar() { ColumnarValue::values_to_arrays(&[ ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(2, 7)), ]) .unwrap(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index db0bfd6b1bc2..c033237a54a4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2530,7 +2530,7 @@ mod test { } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + Ok(ColumnarValue::from(ScalarValue::from("a"))) } } let udf = Arc::new(ScalarUDF::from(TestScalarUDF { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 260065f69af9..2b3faf2b8d21 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -68,7 +68,7 @@ pub mod window_state; pub use built_in_window_function::BuiltInWindowFunction; pub use datafusion_expr_common::accumulator::Accumulator; -pub use datafusion_expr_common::columnar_value::ColumnarValue; +pub use datafusion_expr_common::columnar_value::{ColumnarValue, Scalar}; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index b5eb36c3fac7..8ddb2567094d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -30,9 +30,9 @@ use arrow::{ }; use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, - ScalarValue, }; use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::columnar_value::Scalar; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] @@ -278,15 +278,15 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { let states = emit_to.take_needed(&mut self.states); - let results: Vec = states + let results: Vec = states .into_iter() .map(|mut state| { self.free_allocation(state.size()); - state.accumulator.evaluate() + state.accumulator.evaluate_as_scalar() }) .collect::>()?; - let result = ScalarValue::iter_to_array(results); + let result = Scalar::iter_to_array(results); self.adjust_allocation(vec_size_pre, self.states.allocated_size()); @@ -300,11 +300,11 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // each accumulator produces a potential vector of values // which we need to form into columns - let mut results: Vec> = vec![]; + let mut results: Vec> = vec![]; for mut state in states { self.free_allocation(state.size()); - let accumulator_state = state.accumulator.state()?; + let accumulator_state = state.accumulator.state_as_scalars()?; results.resize_with(accumulator_state.len(), Vec::new); for (idx, state_val) in accumulator_state.into_iter().enumerate() { results[idx].push(state_val); @@ -314,7 +314,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // create an array for each intermediate column let arrays = results .into_iter() - .map(ScalarValue::iter_to_array) + .map(Scalar::iter_to_array) .collect::>>()?; // double check each array has the same length (aka the @@ -370,7 +370,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { let values_to_accumulate = slice_and_maybe_filter(values, opt_filter, &[row_idx, row_idx + 1])?; converted_accumulator.update_batch(&values_to_accumulate)?; - let states = converted_accumulator.state()?; + let states = converted_accumulator.state_as_scalars()?; // Resize results to have enough columns according to the converted states results.resize_with(states.len(), || Vec::with_capacity(num_rows)); @@ -383,7 +383,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { let arrays = results .into_iter() - .map(ScalarValue::iter_to_array) + .map(Scalar::iter_to_array) .collect::>>()?; Ok(arrays) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 5578aebbf403..51d9ac764c40 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -147,7 +147,7 @@ fn get_scalar_value(expr: &Arc) -> Result { let empty_schema = Arc::new(Schema::empty()); let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { - Ok(s) + Ok(s.into_value()) } else { internal_err!("Didn't expect ColumnarValue::Array") } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c..fe9d188c563c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -63,10 +63,10 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use datafusion_expr::{GroupsAccumulator, Scalar}; use half::f16; use std::ops::Deref; @@ -452,15 +452,10 @@ fn min_batch(values: &ArrayRef) -> Result { typed_min_max_batch_string!(values, StringArray, Utf8, min_string) } DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + typed_min_max_batch_string!(values, LargeStringArray, Utf8, min_string) } DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - min_string_view - ) + typed_min_max_batch_string!(values, StringViewArray, Utf8, min_string_view) } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) @@ -469,20 +464,10 @@ fn min_batch(values: &ArrayRef) -> Result { typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) } DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - min_binary - ) + typed_min_max_batch_binary!(&values, LargeBinaryArray, Binary, min_binary) } DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - min_binary_view - ) + typed_min_max_batch_binary!(&values, BinaryViewArray, Binary, min_binary_view) } _ => min_max_batch!(values, min), }) @@ -495,15 +480,10 @@ fn max_batch(values: &ArrayRef) -> Result { typed_min_max_batch_string!(values, StringArray, Utf8, max_string) } DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + typed_min_max_batch_string!(values, LargeStringArray, Utf8, max_string) } DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - max_string_view - ) + typed_min_max_batch_string!(values, StringViewArray, Utf8, max_string_view) } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) @@ -512,20 +492,10 @@ fn max_batch(values: &ArrayRef) -> Result { typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) } DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - max_binary_view - ) + typed_min_max_batch_binary!(&values, BinaryViewArray, Binary, max_binary_view) } DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - max_binary - ) + typed_min_max_batch_binary!(&values, LargeBinaryArray, Binary, max_binary) } _ => min_max_batch!(values, max), }) @@ -662,21 +632,9 @@ macro_rules! min_max { (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { typed_min_max_string!(lhs, rhs, Utf8, $OP) } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8View, $OP) - } (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { typed_min_max_string!(lhs, rhs, Binary, $OP) } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { - typed_min_max_string!(lhs, rhs, BinaryView, $OP) - } (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) } @@ -811,6 +769,7 @@ macro_rules! min_max { #[derive(Debug)] pub struct MaxAccumulator { max: ScalarValue, + return_type: DataType, } impl MaxAccumulator { @@ -818,6 +777,7 @@ impl MaxAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { max: ScalarValue::try_from(datatype)?, + return_type: datatype.clone(), }) } } @@ -839,9 +799,15 @@ impl Accumulator for MaxAccumulator { fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() @@ -852,6 +818,7 @@ impl Accumulator for MaxAccumulator { pub struct SlidingMaxAccumulator { max: ScalarValue, moving_max: MovingMax, + return_type: DataType, } impl SlidingMaxAccumulator { @@ -860,6 +827,7 @@ impl SlidingMaxAccumulator { Ok(Self { max: ScalarValue::try_from(datatype)?, moving_max: MovingMax::::new(), + return_type: datatype.clone(), }) } } @@ -898,6 +866,13 @@ impl Accumulator for SlidingMaxAccumulator { Ok(self.max.clone()) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } fn supports_retract_batch(&self) -> bool { true } @@ -1068,6 +1043,7 @@ impl AggregateUDFImpl for Min { #[derive(Debug)] pub struct MinAccumulator { min: ScalarValue, + return_type: DataType, } impl MinAccumulator { @@ -1075,6 +1051,7 @@ impl MinAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { min: ScalarValue::try_from(datatype)?, + return_type: datatype.clone(), }) } } @@ -1084,6 +1061,10 @@ impl Accumulator for MinAccumulator { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; @@ -1101,6 +1082,10 @@ impl Accumulator for MinAccumulator { Ok(self.min.clone()) } + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } + fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } @@ -1110,6 +1095,7 @@ impl Accumulator for MinAccumulator { pub struct SlidingMinAccumulator { min: ScalarValue, moving_min: MovingMin, + return_type: DataType, } impl SlidingMinAccumulator { @@ -1117,6 +1103,7 @@ impl SlidingMinAccumulator { Ok(Self { min: ScalarValue::try_from(datatype)?, moving_min: MovingMin::::new(), + return_type: datatype.clone(), }) } } @@ -1126,6 +1113,14 @@ impl Accumulator for SlidingMinAccumulator { Ok(vec![self.min.clone()]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { for idx in 0..values[0].len() { let val = ScalarValue::try_from_array(&values[0], idx)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7e9a37e23ad..5269cff8b9e2 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,7 +24,7 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Scalar, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr::expressions::Literal; use std::any::Any; @@ -85,13 +85,12 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { return match lit.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { + ScalarValue::Utf8(Some(delimiter)) => { Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) } - ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), + ScalarValue::Utf8(None) | ScalarValue::Null => { + Ok(Box::new(StringAggAccumulator::new(""))) + } e => not_impl_err!("StringAgg not supported for delimiter {}", e), }; } @@ -141,8 +140,19 @@ impl Accumulator for StringAggAccumulator { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) + Ok(ScalarValue::Utf8(self.values.clone())) + } + + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new( + ScalarValue::Utf8(self.values.clone()), + DataType::LargeUtf8, + )) } fn size(&self) -> usize { diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index ca23d8b7ff4c..24e892f8b715 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -91,8 +91,8 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(Int32Array::from(values(&mut rng))), None, ); - let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); - let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + let keys = ColumnarValue::from(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::from(ScalarValue::List(Arc::new(value_list))); b.iter(|| { black_box( diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 8f8d123bf5f9..df1a336426d7 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -106,8 +106,8 @@ impl ScalarUDFImpl for ArrayHas { ColumnarValue::Scalar(scalar_needle) => { // Always return null if the second argument is null // i.e. array_has(array, null) -> null - if scalar_needle.is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + if scalar_needle.value().is_null() { + return Ok(ColumnarValue::from(ScalarValue::Boolean(None))); } // since the needle is a scalar, convert it to an array of size 1 @@ -118,7 +118,7 @@ impl ScalarUDFImpl for ArrayHas { if let ColumnarValue::Scalar(_) = &args[0] { // If both inputs are scalar, keeps output as scalar let scalar_value = ScalarValue::try_from_array(&array, 0)?; - Ok(ColumnarValue::Scalar(scalar_value)) + Ok(ColumnarValue::from(scalar_value)) } else { Ok(ColumnarValue::Array(array)) } diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 29afe4a7f3be..e7ee42697bb2 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -93,7 +93,12 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { } let values = get_first_array_ref(&args[1])?; - make_map_batch_internal(keys, values, can_evaluate_to_const, args[0].data_type()) + make_map_batch_internal( + keys, + values, + can_evaluate_to_const, + args[0].data_type().clone(), + ) } fn check_unique_keys(array: &dyn Array) -> Result<()> { @@ -111,7 +116,7 @@ fn check_unique_keys(array: &dyn Array) -> Result<()> { fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { match columnar_value { - ColumnarValue::Scalar(value) => match value { + ColumnarValue::Scalar(value) => match value.value() { ScalarValue::List(array) => Ok(array.value(0)), ScalarValue::LargeList(array) => Ok(array.value(0)), ScalarValue::FixedSizeList(array) => Ok(array.value(0)), @@ -172,7 +177,7 @@ fn make_map_batch_internal( let map_array = Arc::new(MapArray::from(map_data)); Ok(if can_evaluate_to_const { - ColumnarValue::Scalar(ScalarValue::try_from_array(map_array.as_ref(), 0)?) + ColumnarValue::from(ScalarValue::try_from_array(map_array.as_ref(), 0)?) } else { ColumnarValue::Array(map_array) }) diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 0765f6cd237d..4e93477ea56d 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -83,7 +83,7 @@ where if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) + result.map(ColumnarValue::from) } else { result.map(ColumnarValue::Array) } diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 91c46ac775a8..bd3bc31b0c65 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -28,7 +28,7 @@ fn create_args(size: usize, str_len: usize) -> Vec { let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ ColumnarValue::Array(Arc::clone(&array) as ArrayRef), - ColumnarValue::Scalar(scalar), + ColumnarValue::from(scalar), ColumnarValue::Array(array), ] } diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index c881947354fd..7a92037ccc5d 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -40,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { let mut rng = rand::thread_rng(); - let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000)); + let interval = ColumnarValue::from(ScalarValue::new_interval_dt(0, 1_000_000)); let timestamps = ColumnarValue::Array(Arc::new(timestamps(&mut rng)) as ArrayRef); let udf = date_bin(); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 01acb9de3381..525430e21ea1 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -30,7 +30,7 @@ fn create_args(size: usize, characters: &str) -> Vec { let array = Arc::new(StringArray::from_iter_values(iter)) as ArrayRef; vec![ ColumnarValue::Array(array), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(characters.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some(characters.to_string()))), ] } diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index cb8f1abe6d5d..a865953897eb 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -72,7 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_scalar_col_col_1000", |b| { let mut rng = rand::thread_rng(); - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); @@ -87,8 +87,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_scalar_scalar_col_1000", |b| { let mut rng = rand::thread_rng(); - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); - let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); + let month = ColumnarValue::from(ScalarValue::Int32(Some(11))); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { @@ -101,9 +101,9 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_scalar_scalar", |b| { - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); - let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); - let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); + let month = ColumnarValue::from(ScalarValue::Int32(Some(11))); + let day = ColumnarValue::from(ScalarValue::Int32(Some(26))); b.iter(|| { black_box( diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index dfabad335835..31192c1a749f 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -29,7 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); let args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; c.bench_function(&format!("nullif scalar array: {}", size), |b| { diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index d9a153e64abc..14819c90a7a5 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -98,7 +98,7 @@ fn criterion_benchmark(c: &mut Criterion) { let mut rng = rand::thread_rng(); let data = ColumnarValue::Array(Arc::new(data(&mut rng)) as ArrayRef); let patterns = - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); + ColumnarValue::from(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); b.iter(|| { black_box( @@ -118,10 +118,9 @@ fn criterion_benchmark(c: &mut Criterion) { .and_utc() .timestamp_nanos_opt() .unwrap(); - let data = ColumnarValue::Scalar(TimestampNanosecond(Some(timestamp), None)); - let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some( - "%d-%m-%Y %H:%M:%S".to_string(), - ))); + let data = ColumnarValue::from(TimestampNanosecond(Some(timestamp), None)); + let pattern = + ColumnarValue::from(ScalarValue::Utf8(Some("%d-%m-%Y %H:%M:%S".to_string()))); b.iter(|| { black_box( diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index cc5e7e619bd8..dd502fb2686d 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -65,7 +65,7 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { } let input_data_type = args[0].data_type(); - Ok(ColumnarValue::Scalar(ScalarValue::from(format!( + Ok(ColumnarValue::from(ScalarValue::from(format!( "{input_data_type}" )))) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 19db58c181e7..e925b94426d0 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -86,7 +86,7 @@ impl ScalarUDFImpl for CoalesceFunc { if let Some(size) = return_array.next() { // start with nulls as default output - let mut current_value = new_null_array(&return_type, size); + let mut current_value = new_null_array(return_type, size); let mut remainder = BooleanArray::from(vec![true; size]); for arg in args { @@ -96,11 +96,11 @@ impl ScalarUDFImpl for CoalesceFunc { current_value = zip(&to_apply, array, ¤t_value)?; remainder = and(&remainder, &is_null(array)?)?; } - ColumnarValue::Scalar(value) => { - if value.is_null() { + ColumnarValue::Scalar(scalar) => { + if scalar.value().is_null() { continue; } else { - let last_value = value.to_scalar()?; + let last_value = scalar.to_scalar()?; current_value = zip(&remainder, &last_value, ¤t_value)?; break; } @@ -115,7 +115,7 @@ impl ScalarUDFImpl for CoalesceFunc { let result = args .iter() .filter_map(|x| match x { - ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()), + ColumnarValue::Scalar(s) if !s.value().is_null() => Some(x.clone()), _ => None, }) .next() diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index a51f895c5084..756109edc4e0 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -168,7 +168,7 @@ impl ScalarUDFImpl for GetFieldFunc { } if args[0].data_type().is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + return Ok(ColumnarValue::from(ScalarValue::Null)); } let arrays = ColumnarValue::values_to_arrays(args)?; @@ -183,7 +183,7 @@ impl ScalarUDFImpl for GetFieldFunc { } }; - match (array.data_type(), name) { + match (array.data_type(), name.value()) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { let map_array = as_map_array(array.as_ref())?; let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); @@ -227,7 +227,7 @@ impl ScalarUDFImpl for GetFieldFunc { "get indexed field is only possible on struct with utf8 indexes. \ Tried with {name:?} index" ), - (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + (DataType::Null, _) => Ok(ColumnarValue::from(ScalarValue::Null)), (dt, name) => exec_err!( "get indexed field is only possible on lists with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {name:?} index" diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 85c332745355..d7835945a82e 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -47,8 +47,15 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { let name_column = &chunk[0]; let name = match name_column { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, - _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(name_scalar)) => name_scalar, + _ => return exec_err!( + "named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2 + ) + }, + _ => return exec_err!( + "named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2 + ) }; Ok((name, chunk[1].clone())) diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 6fcfbd36416e..de1099f7b0ed 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -17,11 +17,10 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Scalar}; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::ScalarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -131,8 +130,8 @@ fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { - let val: ScalarValue = match lhs.eq(rhs) { - true => lhs.data_type().try_into()?, + let val = match lhs.eq(rhs) { + true => Scalar::try_from(lhs.data_type().clone())?, false => lhs.clone(), }; @@ -146,6 +145,7 @@ mod tests { use std::sync::Arc; use arrow::array::*; + use datafusion_common::ScalarValue; use super::*; @@ -164,7 +164,7 @@ mod tests { ]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -190,7 +190,7 @@ mod tests { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -215,7 +215,7 @@ mod tests { let a = BooleanArray::from(vec![Some(true), Some(false), None]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + let lit_array = ColumnarValue::from(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -232,7 +232,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); + let lit_array = ColumnarValue::from(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -253,7 +253,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -272,8 +272,8 @@ mod tests { #[test] fn nullif_scalar() -> Result<()> { - let a_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let a_eq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_eq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); @@ -282,8 +282,8 @@ mod tests { assert_eq!(expected_eq.as_ref(), result_eq.as_ref()); - let a_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let a_neq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_neq = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; let result_neq = result_neq diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index a09224acefcd..09d86ddfc0d9 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -112,7 +112,7 @@ fn nvl_func(args: &[ColumnarValue]) -> Result { } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { let mut current_value = lhs; - if lhs.is_null() { + if lhs.value().is_null() { current_value = rhs; } return Ok(ColumnarValue::Scalar(current_value.clone())); @@ -147,7 +147,7 @@ mod tests { ]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(6i32))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -173,7 +173,7 @@ mod tests { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(20i32))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -198,7 +198,7 @@ mod tests { let a = BooleanArray::from(vec![Some(true), Some(false), None]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + let lit_array = ColumnarValue::from(ScalarValue::Boolean(Some(false))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -218,7 +218,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax")); + let lit_array = ColumnarValue::from(ScalarValue::from("bax")); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -239,7 +239,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nvl_func(&[lit_array, a])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -258,8 +258,8 @@ mod tests { #[test] fn nvl_scalar() -> Result<()> { - let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None)); - let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let a_null = ColumnarValue::from(ScalarValue::Int32(None)); + let b_null = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result_null = nvl_func(&[a_null, b_null])?; let result_null = result_null @@ -270,8 +270,8 @@ mod tests { assert_eq!(expected_null.as_ref(), result_null.as_ref()); - let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let a_nnull = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_nnull = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result_nnull = nvl_func(&[a_nnull, b_nnull])?; let result_nnull = result_nnull diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 1144dc0fb7c5..f3027925b26a 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -126,7 +126,7 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result { internal_err!("except Scalar value, but got Array") } ColumnarValue::Scalar(scalar) => { - if scalar.is_null() { + if scalar.value().is_null() { current_value = &args[2]; } Ok(current_value.clone()) diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 212349e68981..5f337081bd17 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -76,7 +76,7 @@ impl ScalarUDFImpl for VersionFunc { std::env::consts::ARCH, std::env::consts::OS, ); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(version)))) + Ok(ColumnarValue::from(ScalarValue::Utf8(Some(version)))) } } @@ -90,10 +90,14 @@ mod test { let version_udf = ScalarUDF::from(VersionFunc::new()); let version = version_udf.invoke_no_args(0).unwrap(); - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version { - assert!(version.starts_with("Apache DataFusion")); + if let ColumnarValue::Scalar(scalar) = version { + if let ScalarValue::Utf8(Some(version)) = scalar.value() { + assert!(version.starts_with("Apache DataFusion")); + } else { + panic!("Expected version string scalar"); + } } else { - panic!("Expected version string"); + panic!("Expected version string scalar"); } } } diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 716afd84a9c9..85b0b4a0bc6d 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -94,6 +94,7 @@ macro_rules! digest_to_scalar { digest.update(v); digest.finalize().as_slice().to_vec() })) + .into() }}; } @@ -120,10 +121,8 @@ pub fn digest(args: &[ColumnarValue]) -> Result { ); } let digest_algorithm = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(method)) => method.parse::(), other => exec_err!("Unsupported data type {other:?} for function digest"), }, ColumnarValue::Array(_) => { @@ -191,10 +190,12 @@ pub fn md5(args: &[ColumnarValue]) -> Result { .collect(); ColumnarValue::Array(Arc::new(string_array)) } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8(opt.map(hex_encode::<_>))) - } - _ => return exec_err!("Impossibly got invalid results from digest"), + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Binary(opt) => { + ColumnarValue::from(ScalarValue::Utf8(opt.map(hex_encode::<_>))) + } + _ => return exec_err!("Impossibly got invalid results from digest"), + }, }) } @@ -256,7 +257,8 @@ impl DigestAlgorithm { let mut digest = Blake3::default(); digest.update(v); Blake3::finalize(&digest).as_bytes().to_vec() - })), + })) + .into(), }) } @@ -338,16 +340,16 @@ pub fn digest_process( "Unsupported data type {other:?} for function {digest_algorithm}" ), }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + ColumnarValue::Scalar(scalar) => { + match scalar.value() { + ScalarValue::Utf8(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Binary(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), + other => exec_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ), } - ScalarValue::Binary(a) | ScalarValue::LargeBinary(a) => Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => exec_err!( - "Unsupported data type {other:?} for function {digest_algorithm}" - ), - }, + } } } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 89b40a3534d3..d27a04e7c2b6 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -195,10 +195,10 @@ where ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::scalar(result))) + Ok(ColumnarValue::from(S::scalar(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, @@ -252,8 +252,8 @@ where } }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? let a = unwrap_or_internal_err!(a); @@ -261,17 +261,26 @@ where let mut ret = None; for (pos, v) in args.iter().enumerate().skip(1) { - let ColumnarValue::Scalar( - ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), - ) = v - else { - return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + let x = match v { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(v) => v, + _ => { + return exec_err!( + "Unsupported data type {v:?} for function {name}, arg # {pos}" + ) + } + }, + _ => { + return exec_err!( + "Unsupported data type {v:?} for function {name}, arg # {pos}" + ) + } }; if let Some(s) = x { match op(a.as_str(), s.as_str()) { Ok(r) => { - ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some( + ret = Some(Ok(ColumnarValue::from(S::scalar(Some( op2(r), ))))); break; @@ -328,8 +337,8 @@ where ColumnarValue::Array(a) => { Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) } - ColumnarValue::Scalar(s) => match s { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), + ColumnarValue::Scalar(s) => match s.value() { + ScalarValue::Utf8(a) => Ok(Either::Right(a)), other => exec_err!( "Unexpected scalar type encountered '{other}' for function '{name}'" ), diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 997f1a36ad04..36b97d2ef3c3 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -135,7 +135,7 @@ impl ScalarUDFImpl for DateBinFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { if args.len() == 2 { // Default to unix EPOCH - let origin = ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + let origin = ColumnarValue::from(ScalarValue::TimestampNanosecond( Some(0), Some("+00:00".into()), )); @@ -260,67 +260,67 @@ fn date_bin_impl( array: &ColumnarValue, origin: &ColumnarValue, ) -> Result { - let stride = match stride { - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(v))) => { - let (days, ms) = IntervalDayTimeType::to_parts(*v); - let nanos = (TimeDelta::try_days(days as i64).unwrap() - + TimeDelta::try_milliseconds(ms as i64).unwrap()) - .num_nanoseconds(); - - match nanos { - Some(v) => Interval::Nanoseconds(v), - _ => return exec_err!("DATE_BIN stride argument is too large"), - } - } - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(v))) => { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - - // If interval is months, its origin must be midnight of first date of the month - if months != 0 { - // Return error if days or nanos is not zero - if days != 0 || nanos != 0 { - return not_impl_err!( - "DATE_BIN stride does not support combination of month, day and nanosecond intervals" - ); - } else { - Interval::Months(months as i64) - } - } else { + let stride = if let ColumnarValue::Scalar(scalar) = stride { + match scalar.value() { + ScalarValue::IntervalDayTime(Some(v)) => { + let (days, ms) = IntervalDayTimeType::to_parts(*v); let nanos = (TimeDelta::try_days(days as i64).unwrap() - + Duration::nanoseconds(nanos)) + + TimeDelta::try_milliseconds(ms as i64).unwrap()) .num_nanoseconds(); + match nanos { Some(v) => Interval::Nanoseconds(v), _ => return exec_err!("DATE_BIN stride argument is too large"), } } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); + + // If interval is months, its origin must be midnight of first date of the month + if months != 0 { + // Return error if days or nanos is not zero + if days != 0 || nanos != 0 { + return not_impl_err!( + "DATE_BIN stride does not support combination of month, day and nanosecond intervals" + ); + } else { + Interval::Months(months as i64) + } + } else { + let nanos = (TimeDelta::try_days(days as i64).unwrap() + + Duration::nanoseconds(nanos)) + .num_nanoseconds(); + match nanos { + Some(v) => Interval::Nanoseconds(v), + _ => return exec_err!("DATE_BIN stride argument is too large"), + } + } + } + _ => { + return exec_err!( + "DATE_BIN expects stride argument to be an INTERVAL but got {}", + scalar.data_type() + ); + } } - ColumnarValue::Scalar(v) => { - return exec_err!( - "DATE_BIN expects stride argument to be an INTERVAL but got {}", - v.data_type() - ); - } - ColumnarValue::Array(_) => { - return not_impl_err!( + } else { + return not_impl_err!( "DATE_BIN only supports literal values for the stride argument, not arrays" ); - } }; - let origin = match origin { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => *v, - ColumnarValue::Scalar(v) => { - return exec_err!( + let origin = if let ColumnarValue::Scalar(scalar) = origin { + match scalar.value() { + ScalarValue::TimestampNanosecond(Some(v), _) => *v, + _ => return exec_err!( "DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got {}", - v.data_type() - ); + scalar.data_type() + ) } - ColumnarValue::Array(_) => { - return not_impl_err!( + } else { + return not_impl_err!( "DATE_BIN only supports literal values for the origin argument, not arrays" ); - } }; let (stride, stride_fn) = stride.bin_fn(); @@ -345,38 +345,37 @@ fn date_bin_impl( } Ok(match array { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampSecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampNanosecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampMicrosecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampMicrosecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampMillisecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampMillisecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampSecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampSecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + value => { + return exec_err!( + "DATE_BIN expects source argument to be a TIMESTAMP scalar but got {}", value + ); + } + }, ColumnarValue::Array(array) => { fn transform_array_with_stride( @@ -427,11 +426,6 @@ fn date_bin_impl( } } } - _ => { - return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" - ); - } }) } @@ -454,46 +448,46 @@ mod tests { #[test] fn test_date_bin() { let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); // stride supports month-day-nano let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + ColumnarValue::from(ScalarValue::IntervalMonthDayNano(Some( IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 1, }, ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); @@ -502,7 +496,7 @@ mod tests { // // invalid number of arguments - let res = DateBinFunc::new().invoke(&[ColumnarValue::Scalar( + let res = DateBinFunc::new().invoke(&[ColumnarValue::from( ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, @@ -515,9 +509,9 @@ mod tests { // stride: invalid type let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -526,12 +520,12 @@ mod tests { // stride: invalid value let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 0, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -540,11 +534,9 @@ mod tests { // stride: overflow of day-time interval let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime::MAX))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -553,9 +545,9 @@ mod tests { // stride: overflow of month-day-nano interval let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -564,9 +556,9 @@ mod tests { // stride: month intervals let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -575,12 +567,12 @@ mod tests { // origin: invalid type let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -588,12 +580,12 @@ mod tests { ); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); @@ -610,8 +602,8 @@ mod tests { ); let res = DateBinFunc::new().invoke(&[ ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -621,11 +613,11 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Array(timestamps), ]); assert_eq!( @@ -744,9 +736,9 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateBinFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::from(ScalarValue::new_interval_dt(1, 0)), ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + ColumnarValue::from(ScalarValue::TimestampNanosecond( Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), )), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 8ee82d872651..52030c814407 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,14 +140,20 @@ impl ScalarUDFImpl for DatePartFunc { } let (part, array) = (&args[0], &args[1]); - let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { - v - } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = part { - v - } else { - return exec_err!( - "First argument of `DATE_PART` must be non-null scalar Utf8" - ); + let part = match part { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(v)) => v, + _ => { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ) + } + }, + _ => { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ) + } }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); @@ -192,7 +198,7 @@ impl ScalarUDFImpl for DatePartFunc { }; Ok(if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(arr.as_ref(), 0)?) + ColumnarValue::from(ScalarValue::try_from_array(arr.as_ref(), 0)?) } else { ColumnarValue::Array(arr) }) diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 0ef839c49f0b..87115dcb61e6 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -137,15 +137,20 @@ impl ScalarUDFImpl for DateTruncFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { let (granularity, array) = (&args[0], &args[1]); - let granularity = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = - granularity - { - v.to_lowercase() - } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = granularity - { - v.to_lowercase() - } else { - return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); + let granularity = match granularity { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(v)) => v.to_lowercase(), + _ => { + return exec_err!( + "Granularity of `date_trunc` must be non-null scalar Utf8" + ) + } + }, + _ => { + return exec_err!( + "Granularity of `date_trunc` must be non-null scalar Utf8" + ) + } }; fn process_array( @@ -171,22 +176,29 @@ impl ScalarUDFImpl for DateTruncFunc { let parsed_tz = parse_tz(tz_opt)?; let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); - Ok(ColumnarValue::Scalar(value)) + Ok(ColumnarValue::from(value)) } Ok(match array { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampMicrosecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampMillisecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampSecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + _ => { + return exec_err!( + "second argument of `date_trunc` must be timestamp scalar or array" + ); + } + }, ColumnarValue::Array(array) => { let array_type = array.data_type(); if let Timestamp(unit, tz_opt) = array_type { @@ -216,11 +228,6 @@ impl ScalarUDFImpl for DateTruncFunc { return exec_err!("second argument of `date_trunc` is an unsupported array type: {array_type}"); } } - _ => { - return exec_err!( - "second argument of `date_trunc` must be timestamp scalar or array" - ); - } }) } @@ -689,7 +696,7 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateTruncFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::from("day")), + ColumnarValue::from(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); @@ -847,7 +854,7 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateTruncFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::from("hour")), + ColumnarValue::from(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ded7b454f9eb..75dfa305351a 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -94,7 +94,7 @@ impl ScalarUDFImpl for MakeDateFunc { let ColumnarValue::Scalar(s) = col else { return exec_err!("Expected scalar value"); }; - let ScalarValue::Int32(Some(i)) = s else { + let ScalarValue::Int32(Some(i)) = s.value() else { return exec_err!("Unable to parse date from null/empty value"); }; Ok(*i) @@ -143,7 +143,7 @@ impl ScalarUDFImpl for MakeDateFunc { |days: i32| value = days, )?; - ColumnarValue::Scalar(ScalarValue::Date32(Some(value))) + ColumnarValue::from(ScalarValue::Date32(Some(value))) }; Ok(value) @@ -192,42 +192,51 @@ mod tests { fn test_make_date() { let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), + ColumnarValue::from(ScalarValue::Int32(Some(2024))), + ColumnarValue::from(ScalarValue::Int64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), + ColumnarValue::from(ScalarValue::Int64(Some(2024))), + ColumnarValue::from(ScalarValue::UInt64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("2024".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("1".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("14".to_string()))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } @@ -261,7 +270,7 @@ mod tests { // invalid number of arguments let res = MakeDateFunc::new() - .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); + .invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" @@ -269,9 +278,9 @@ mod tests { // invalid type let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -280,9 +289,9 @@ mod tests { // overflow of month let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::UInt64(Some(u64::MAX))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -291,9 +300,9 @@ mod tests { // overflow of day let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), + ColumnarValue::from(ScalarValue::UInt32(Some(u32::MAX))), ]); assert_eq!( res.err().unwrap().strip_backtrace(), diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f2e5af978ca0..5c0f57ca09e2 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -115,22 +115,23 @@ impl ScalarUDFImpl for ToCharFunc { } match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Null) => { - _to_char_scalar(args[0].clone(), None) - } - // constant format - ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => { - // invoke to_char_scalar with the known string, without converting to array - _to_char_scalar(args[0].clone(), Some(format)) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(None) | ScalarValue::Null => { + _to_char_scalar(args[0].clone(), None) + } + // constant format + ScalarValue::Utf8(Some(format)) => { + // invoke to_char_scalar with the known string, without converting to array + _to_char_scalar(args[0].clone(), Some(format)) + } + _ => { + exec_err!( + "Format for `to_char` must be non-null Utf8, received {:?}", + args[1].data_type() + ) + } + }, ColumnarValue::Array(_) => _to_char_array(args), - _ => { - exec_err!( - "Format for `to_char` must be non-null Utf8, received {:?}", - args[1].data_type() - ) - } } } @@ -177,13 +178,13 @@ fn _to_char_scalar( ) -> Result { // it's possible that the expression is a scalar however because // of the implementation in arrow-rs we need to convert it to an array - let data_type = &expression.data_type(); + let data_type = expression.data_type().clone(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); let array = expression.into_array(1)?; if format.is_none() { if is_scalar_expression { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); } else { return Ok(ColumnarValue::Array(new_null_array( &DataType::Utf8, @@ -192,7 +193,7 @@ fn _to_char_scalar( } } - let format_options = match _build_format_options(data_type, format) { + let format_options = match _build_format_options(&data_type, format) { Ok(value) => value, Err(value) => return value, }; @@ -204,7 +205,7 @@ fn _to_char_scalar( if let Ok(formatted) = formatted { if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + Ok(ColumnarValue::from(ScalarValue::Utf8(Some( formatted.first().unwrap().to_string(), )))) } else { @@ -252,10 +253,10 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { results, )) as ArrayRef)), ColumnarValue::Scalar(_) => match results.first().unwrap() { - Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + Some(value) => Ok(ColumnarValue::from(ScalarValue::Utf8(Some( value.to_string(), )))), - None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + None => Ok(ColumnarValue::from(ScalarValue::Utf8(None))), }, } } @@ -351,13 +352,15 @@ mod tests { for (value, format, expected) in scalar_data { let result = ToCharFunc::new() - .invoke(&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)]) + .invoke(&[ColumnarValue::from(value), ColumnarValue::from(format)]) .expect("that to_char parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result { - assert_eq!(expected, date.unwrap()); - } else { - panic!("Expected a scalar value") + match result { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(date)) => assert_eq!(&expected, date), + _ => panic!("Expected a scalar value"), + }, + _ => panic!("Expected a scalar value"), } } @@ -426,15 +429,17 @@ mod tests { for (value, format, expected) in scalar_array_data { let result = ToCharFunc::new() .invoke(&[ - ColumnarValue::Scalar(value), + ColumnarValue::from(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ]) .expect("that to_char parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result { - assert_eq!(expected, date.unwrap()); - } else { - panic!("Expected a scalar value") + match result { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(date)) => assert_eq!(&expected, date), + _ => panic!("Expected a scalar value"), + }, + _ => panic!("Expected a scalar value"), } } @@ -552,7 +557,7 @@ mod tests { let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value as ArrayRef), - ColumnarValue::Scalar(format), + ColumnarValue::from(format), ]) .expect("that to_char parsed values without error"); @@ -585,8 +590,8 @@ mod tests { // // invalid number of arguments - let result = ToCharFunc::new() - .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); + let result = + ToCharFunc::new().invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); assert_eq!( result.err().unwrap().strip_backtrace(), "Execution error: to_char function requires 2 arguments, got 1" @@ -594,8 +599,8 @@ mod tests { // invalid type let result = ToCharFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::Int32(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( result.err().unwrap().strip_backtrace(), diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 288641b84dd7..ee87da0dea08 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -156,17 +156,21 @@ mod tests { for tc in &test_cases { let date_scalar = ScalarValue::Utf8(Some(tc.date_str.to_string())); let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); match to_date_result { - Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { - let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); - assert_eq!( - date_val, expected, - "{}: to_date created wrong value", - tc.name - ); - } + Ok(ColumnarValue::Scalar(scalar)) => match scalar.value() { + ScalarValue::Date32(date_val) => { + let expected = + Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + assert_eq!( + date_val, &expected, + "{}: to_date created wrong value", + tc.name + ); + } + _ => panic!("Could not convert '{}' to Date", tc.date_str), + }, _ => panic!("Could not convert '{}' to Date", tc.date_str), } } @@ -226,15 +230,22 @@ mod tests { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); let to_date_result = ToDateFunc::new().invoke(&[ - ColumnarValue::Scalar(formatted_date_scalar), - ColumnarValue::Scalar(format_scalar), + ColumnarValue::from(formatted_date_scalar), + ColumnarValue::from(format_scalar), ]); match to_date_result { - Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { - let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); - assert_eq!(date_val, expected, "{}: to_date created wrong value for date '{}' with format string '{}'", tc.name, tc.formatted_date, tc.format_str); - } + Ok(ColumnarValue::Scalar(scalar)) => match scalar.value() { + ScalarValue::Date32(date_val) => { + let expected = + Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + assert_eq!(date_val, &expected, "{}: to_date created wrong value for date '{}' with format string '{}'", tc.name, tc.formatted_date, tc.format_str); + } + _ => panic!( + "Could not convert '{}' with format string '{}'to Date", + tc.date_str, tc.format_str + ), + }, _ => panic!( "Could not convert '{}' with format string '{}'to Date", tc.date_str, tc.format_str @@ -250,19 +261,22 @@ mod tests { let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); let to_date_result = ToDateFunc::new().invoke(&[ - ColumnarValue::Scalar(formatted_date_scalar), - ColumnarValue::Scalar(format1_scalar), - ColumnarValue::Scalar(format2_scalar), + ColumnarValue::from(formatted_date_scalar), + ColumnarValue::from(format1_scalar), + ColumnarValue::from(format2_scalar), ]); match to_date_result { - Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { - let expected = Date32Type::parse_formatted("2023-01-31", "%Y-%m-%d"); - assert_eq!( - date_val, expected, - "to_date created wrong value for date with 2 format strings" - ); - } + Ok(ColumnarValue::Scalar(scalar)) => match scalar.value() { + ScalarValue::Date32(date_val) => { + let expected = Date32Type::parse_formatted("2023-01-31", "%Y-%m-%d"); + assert_eq!( + date_val, &expected, + "to_date created wrong value for date with 2 format strings" + ); + } + _ => panic!("Conversion failed",), + }, _ => panic!("Conversion failed",), } } @@ -278,13 +292,17 @@ mod tests { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::Scalar(formatted_date_scalar)]); + ToDateFunc::new().invoke(&[ColumnarValue::from(formatted_date_scalar)]); match to_date_result { - Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { - let expected = Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); - assert_eq!(date_val, expected, "to_date created wrong value"); - } + Ok(ColumnarValue::Scalar(scalar)) => match scalar.value() { + ScalarValue::Date32(date_val) => { + let expected = + Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); + assert_eq!(date_val, &expected, "to_date created wrong value"); + } + _ => panic!("Conversion of {} failed", date_str), + }, _ => panic!("Conversion of {} failed", date_str), } } @@ -296,17 +314,20 @@ mod tests { let date_scalar = ScalarValue::Utf8(Some(date_str.into())); let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); match to_date_result { - Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { - let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); - assert_eq!( - date_val, expected, - "to_date created wrong value for {}", - date_str - ); - } + Ok(ColumnarValue::Scalar(scalar)) => match scalar.value() { + ScalarValue::Date32(date_val) => { + let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); + assert_eq!( + date_val, &expected, + "to_date created wrong value for {}", + date_str + ); + } + _ => panic!("Conversion of {} failed", date_str), + }, _ => panic!("Conversion of {} failed", date_str), } } @@ -317,13 +338,15 @@ mod tests { let date_scalar = ScalarValue::Utf8(Some(date_str.into())); let to_date_result = - ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + ToDateFunc::new().invoke(&[ColumnarValue::from(date_scalar)]); - if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { - panic!( - "Conversion of {} succeded, but should have failed, ", - date_str - ); + if let Ok(ColumnarValue::Scalar(scalar)) = to_date_result { + if matches!(scalar.value(), ScalarValue::Date32(_)) { + panic!( + "Conversion of {} succeded, but should have failed, ", + date_str + ); + } } } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 634e28e6f393..3fa36223b918 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -97,50 +97,48 @@ impl ToLocalTimeFunc { let tz: Tz = timezone.parse()?; match time_value { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(adjusted_ts), - None, - ))) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(Some(ts), Some(_)) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampMicrosecond(Some(ts), Some(_)) => { + let adjusted_ts = adjust_to_local_time::< + TimestampMicrosecondType, + >(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampMillisecond(Some(ts), Some(_)) => { + let adjusted_ts = adjust_to_local_time::< + TimestampMillisecondType, + >(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampSecond(Some(ts), Some(_)) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + }, ColumnarValue::Array(array) => { fn transform_array( array: &ArrayRef, @@ -185,12 +183,6 @@ impl ToLocalTimeFunc { } } } - _ => { - exec_err!( - "to_local_time function requires timestamp argument, got {:?}", - time_value.data_type() - ) - } } } _ => { @@ -486,11 +478,11 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::Scalar(input)]) + .invoke(&[ColumnarValue::from(input)]) .unwrap(); match res { ColumnarValue::Scalar(res) => { - assert_eq!(res, expected); + assert_eq!(res.into_value(), expected); } _ => panic!("unexpected return type"), } diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index cbb6f37603d2..4a6f62e39c70 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -161,7 +161,7 @@ impl ScalarUDFImpl for ToTimestampFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Int32 | DataType::Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), @@ -214,7 +214,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } @@ -264,7 +264,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } @@ -314,7 +314,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } @@ -364,7 +364,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } @@ -803,7 +803,7 @@ mod tests { for udf in &udfs { for array in arrays { - let rt = udf.return_type(&[array.data_type()]).unwrap(); + let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); let res = udf @@ -846,7 +846,7 @@ mod tests { for udf in &udfs { for array in arrays { - let rt = udf.return_type(&[array.data_type()]).unwrap(); + let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); assert!(matches!(rt, DataType::Timestamp(_, None))); let res = udf @@ -896,9 +896,9 @@ mod tests { // test UTF8 let string_array = [ ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%s".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%c".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%+".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%s".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%c".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%+".to_string()))), ]; let parsed_timestamps = func(&string_array) .expect("that to_timestamp with format args parsed values without error"); @@ -922,46 +922,12 @@ mod tests { panic!("Expected a columnar array") } - // test LargeUTF8 - let string_array = [ - ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%s".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%c".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%+".to_string()))), - ]; - let parsed_timestamps = func(&string_array) - .expect("that to_timestamp with format args parsed values without error"); - if let ColumnarValue::Array(parsed_array) = parsed_timestamps { - assert_eq!(parsed_array.len(), 1); - assert!(matches!( - parsed_array.data_type(), - DataType::Timestamp(_, None) - )); - - match time_unit { - Nanosecond => { - assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) - } - Millisecond => { - assert_eq!(millis_expected_timestamps, parsed_array.as_ref()) - } - Microsecond => { - assert_eq!(micros_expected_timestamps, parsed_array.as_ref()) - } - Second => { - assert_eq!(sec_expected_timestamps, parsed_array.as_ref()) - } - }; - } else { - panic!("Expected a columnar array") - } - // test other types let string_array = [ ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ColumnarValue::from(ScalarValue::Int32(Some(1))), + ColumnarValue::from(ScalarValue::Int32(Some(2))), + ColumnarValue::from(ScalarValue::Int32(Some(3))), ]; let expected = "Unsupported data type Int32 for function".to_string(); diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index d9ce299a2602..e2dbf5164fde 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -28,7 +28,7 @@ use datafusion_common::{ }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Scalar}; use std::sync::Arc; use std::{fmt, str::FromStr}; @@ -173,23 +173,28 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { - match scalar { - ScalarValue::Utf8(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) - } - ScalarValue::LargeUtf8(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), - ScalarValue::Binary(a) => Ok( - encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - ), - ScalarValue::LargeBinary(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => exec_err!( - "Unsupported data type {other:?} for function encode({encoding})" - ), - } - } + ColumnarValue::Scalar(scalar) => match (scalar.value(), scalar.data_type()) { + (ScalarValue::Utf8(a), DataType::Utf8) => Ok(encoding.encode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::Utf8, + )), + (ScalarValue::Utf8(a), DataType::LargeUtf8) => Ok(encoding.encode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::LargeUtf8, + )), + (ScalarValue::Binary(a), DataType::Binary) => Ok(encoding.encode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::Utf8, + )), + (ScalarValue::Binary(a), DataType::LargeBinary) => Ok(encoding + .encode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::LargeUtf8, + )), + other => exec_err!( + "Unsupported data type {other:?} for function encode({encoding})" + ), + }, } } @@ -204,23 +209,27 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result { - match scalar { - ScalarValue::Utf8(a) => { - encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) - } - ScalarValue::LargeUtf8(a) => encoding - .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), - ScalarValue::Binary(a) => { - encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - } - ScalarValue::LargeBinary(a) => encoding - .decode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice())), - other => exec_err!( - "Unsupported data type {other:?} for function decode({encoding})" - ), - } - } + ColumnarValue::Scalar(scalar) => match (scalar.value(), scalar.data_type()) { + (ScalarValue::Utf8(a), DataType::Utf8) => encoding.decode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::Binary, + ), + (ScalarValue::Utf8(a), DataType::LargeUtf8) => encoding.decode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::LargeBinary, + ), + (ScalarValue::Binary(a), DataType::Binary) => encoding.decode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::Binary, + ), + (ScalarValue::Binary(a), DataType::LargeBinary) => encoding.decode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::LargeBinary, + ), + other => exec_err!( + "Unsupported data type {other:?} for function decode({encoding})" + ), + }, } } @@ -265,22 +274,14 @@ macro_rules! decode_to_array { } impl Encoding { - fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { + fn encode_scalar(self, value: Option<&[u8]>, data_type: DataType) -> ColumnarValue { + let value = match self { Self::Base64 => ScalarValue::Utf8( value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), ), Self::Hex => ScalarValue::Utf8(value.map(hex::encode)), - }) - } - - fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { - Self::Base64 => ScalarValue::LargeUtf8( - value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), - ), - Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)), - }) + }; + ColumnarValue::Scalar(Scalar::new(value, data_type)) } fn encode_binary_array(self, value: &dyn Array) -> Result @@ -307,38 +308,19 @@ impl Encoding { Ok(ColumnarValue::Array(array)) } - fn decode_scalar(self, value: Option<&[u8]>) -> Result { + fn decode_scalar( + self, + value: Option<&[u8]>, + data_type: DataType, + ) -> Result { let value = match value { Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))), - }; - - let out = match self { - Self::Base64 => { - general_purpose::STANDARD_NO_PAD - .decode(value) - .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e - )) - })? + None => { + return Ok(ColumnarValue::Scalar(Scalar::new( + ScalarValue::Binary(None), + data_type, + ))) } - Self::Hex => hex::decode(value).map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e - )) - })?, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out)))) - } - - fn decode_large_scalar(self, value: Option<&[u8]>) -> Result { - let value = match value { - Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))), }; let out = match self { @@ -360,7 +342,10 @@ impl Encoding { })?, }; - Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out)))) + Ok(ColumnarValue::Scalar(Scalar::new( + ScalarValue::Binary(Some(out)), + data_type, + ))) } fn decode_binary_array(self, value: &dyn Array) -> Result @@ -425,8 +410,8 @@ fn encode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(method)) => { method.parse::() } _ => not_impl_err!( @@ -451,8 +436,8 @@ fn decode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(method)) => { method.parse::() } _ => not_impl_err!( diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ad7cff1f7149..7c65439478df 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -110,7 +110,7 @@ impl ScalarUDFImpl for LogFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); + let mut base = ColumnarValue::from(ScalarValue::Float32(Some(10.0))); let mut x = &args[0]; if args.len() == 2 { @@ -120,11 +120,18 @@ impl ScalarUDFImpl for LogFunc { // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) - } + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Float32(Some(base)) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { + |value: f64| f64::log(value, base as f64) + })) + } + _ => { + return exec_err!( + "log function requires a scalar or array for base" + ) + } + }, ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, @@ -133,17 +140,21 @@ impl ScalarUDFImpl for LogFunc { Float64Array, { f64::log } )), - _ => { - return exec_err!("log function requires a scalar or array for base") - } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) - } + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Float32(Some(base)) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { + |value: f32| f32::log(value, base) + })) + } + _ => { + return exec_err!( + "log function requires a scalar or array for base" + ) + } + }, ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, @@ -152,9 +163,6 @@ impl ScalarUDFImpl for LogFunc { Float32Array, { f32::log } )), - _ => { - return exec_err!("log function requires a scalar or array for base") - } }, other => { return exec_err!("Unsupported data type {other:?} for function log") diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index c2fe4efb1139..b4adfc190e69 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -64,7 +64,7 @@ impl ScalarUDFImpl for PiFunc { } fn invoke_no_args(&self, _number_rows: usize) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + Ok(ColumnarValue::from(ScalarValue::Float64(Some( std::f64::consts::PI, )))) } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 89554a76febb..05d3d055c2b8 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -108,7 +108,7 @@ pub fn round(args: &[ArrayRef]) -> Result { ); } - let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); + let mut decimal_places = ColumnarValue::from(ScalarValue::Int64(Some(0))); if args.len() == 2 { decimal_places = ColumnarValue::Array(Arc::clone(&args[1])); @@ -116,25 +116,32 @@ pub fn round(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Float64 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Int64(Some(decimal_places)) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float64Array, + { + |value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } } - } - )) as ArrayRef) - } + )) as ArrayRef) + } + _ => { + exec_err!( + "round function requires a scalar or array for decimal_places" + ) + } + }, ColumnarValue::Array(decimal_places) => { let options = CastOptions { safe: false, // raise error if the cast is not possible @@ -159,31 +166,35 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } }, DataType::Float32 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Int64(Some(decimal_places)) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float32Array, + { + |value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } } - } - )) as ArrayRef) - } + )) as ArrayRef) + } + _ => { + exec_err!( + "round function requires a scalar or array for decimal_places" + ) + } + }, ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = decimal_places.cast_to(&Int32, None).map_err(|e| { @@ -208,9 +219,6 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } }, other => exec_err!("Unsupported data type {other:?} for function round"), diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 3344438454c4..e6020f18d168 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -115,16 +115,22 @@ fn trunc(args: &[ArrayRef]) -> Result { //or then invoke the compute_truncate method to process precision let num = &args[0]; let precision = if args.len() == 1 { - ColumnarValue::Scalar(Int64(Some(0))) + ColumnarValue::from(Int64(Some(0))) } else { ColumnarValue::Array(Arc::clone(&args[1])) }; match args[0].data_type() { Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), + ColumnarValue::Scalar(scalar) => match scalar.value() { + Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( + num, + "num", + Float64Array, + { f64::trunc } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( num, precision, @@ -134,12 +140,17 @@ fn trunc(args: &[ArrayRef]) -> Result { Int64Array, { compute_truncate64 } )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), }, Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), + ColumnarValue::Scalar(scalar) => match scalar.value() { + Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( + num, + "num", + Float32Array, + { f32::trunc } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( num, precision, @@ -149,7 +160,6 @@ fn trunc(args: &[ArrayRef]) -> Result { Int64Array, { compute_truncate32 } )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), }, other => exec_err!("Unsupported data type {other:?} for function trunc"), } diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 20029ba005c4..c67bc4157298 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -20,13 +20,12 @@ use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use datafusion_common::exec_err; -use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -99,7 +98,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { let result = regexp_like_func(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index bf40eff11d30..e7285bf3887e 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -21,13 +21,12 @@ use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::exec_err; -use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -101,7 +100,7 @@ impl ScalarUDFImpl for RegexpMatchFunc { let result = regexp_match_func(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 3eb72a1fb5f5..f25894d54fe4 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -27,12 +27,12 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_string_view_array; use datafusion_common::exec_err; use datafusion_common::plan_err; -use datafusion_common::ScalarValue; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; +use datafusion_expr::Scalar; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use regex::Regex; @@ -117,7 +117,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc { let result = regexp_replace_func(args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 68ba3f5ff15f..196a5a607639 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -122,25 +122,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_function!( - AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_function!( - AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 65ec1a4a7734..0dd07fa87f69 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -77,13 +77,10 @@ impl ScalarUDFImpl for BitLengthFunc { match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + ColumnarValue::Scalar(v) => match v.value() { + ScalarValue::Utf8(v) => Ok(ColumnarValue::from(ScalarValue::Int32( v.as_ref().map(|x| (x.len() * 8) as i32), ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), _ => unreachable!(), }, } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 9365a6d83331..aab8df62cc9e 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -230,18 +230,10 @@ where } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - ScalarValue::Utf8View(a) => { - let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + Ok(ColumnarValue::from(ScalarValue::Utf8(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 98f57efef90d..ce274f0ccbfc 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -24,7 +24,7 @@ use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Expr, Scalar, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; @@ -86,13 +86,13 @@ impl ScalarUDFImpl for ConcatFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { - if col.data_type() == DataType::Utf8View { - return_datatype = col.data_type(); + if col.data_type() == &DataType::Utf8View { + return_datatype = col.data_type().clone(); } - if col.data_type() == DataType::LargeUtf8 + if col.data_type() == &DataType::LargeUtf8 && return_datatype != DataType::Utf8View { - return_datatype = col.data_type(); + return_datatype = col.data_type().clone(); } }); @@ -108,20 +108,19 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + if let ColumnarValue::Scalar(scalar) = arg { + if let ScalarValue::Utf8(Some(v)) = scalar.value() { + result.push_str(v); + } } } return match return_datatype { - DataType::Utf8View => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) - } - DataType::Utf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(Scalar::new( + ScalarValue::Utf8(Some(result)), + return_datatype, + ))) } other => { plan_err!("Concat function does not support datatype of {other}") @@ -136,14 +135,15 @@ impl ScalarUDFImpl for ConcatFunc { for arg in args { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(maybe_value) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } } - } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { match array.data_type() { DataType::Utf8 => { @@ -184,7 +184,6 @@ impl ScalarUDFImpl for ConcatFunc { } }; } - _ => unreachable!(), } } @@ -253,11 +252,11 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} + Expr::Literal(ScalarValue::Utf8(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), + ScalarValue::Utf8(Some(v)), ) => contiguous_scalar += &v, Expr::Literal(x) => { return internal_err!( @@ -297,7 +296,7 @@ pub fn simplify_concat(args: Vec) -> Result { mod tests { use super::*; use crate::utils::test::test_function; - use arrow::array::{Array, LargeStringArray, StringViewArray}; + use arrow::array::Array; use arrow::array::{ArrayRef, StringArray}; use DataType::*; @@ -306,9 +305,9 @@ mod tests { test_function!( ConcatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aabbcc")), &str, @@ -318,9 +317,9 @@ mod tests { test_function!( ConcatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aacc")), &str, @@ -329,37 +328,12 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, StringArray ); - test_function!( - ConcatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), - ], - Ok(Some("aacc")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - ConcatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), - ], - Ok(Some("aacc")), - &str, - LargeUtf8, - LargeStringArray - ); Ok(()) } @@ -368,7 +342,7 @@ mod tests { fn concat() -> Result<()> { let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c1 = ColumnarValue::from(ScalarValue::Utf8(Some(",".to_string()))); let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ Some("x"), None, diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 1134c525cfca..ed6dcc26f507 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -94,14 +94,13 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); + } + _ => unreachable!(), + }, _ => unreachable!(), }; @@ -110,35 +109,33 @@ impl ScalarUDFImpl for ConcatWsFunc { for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { - result.push_str(s); - break; - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + result.push_str(s); + break; + } + ScalarValue::Utf8(None) => {} + _ => unreachable!(), + }, _ => unreachable!(), } } for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { - result.push_str(sep); - result.push_str(s); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + result.push_str(sep); + result.push_str(s); + } + ScalarValue::Utf8(None) => {} + _ => unreachable!(), + }, _ => unreachable!(), } } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + return Ok(ColumnarValue::from(ScalarValue::Utf8(Some(result)))); } // Array @@ -147,13 +144,18 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) + } + ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null( + len, + )))); + } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { let string_array = as_string_array(array)?; data_size += string_array.values().len() * (args.len() - 2); // estimate @@ -163,20 +165,20 @@ impl ScalarUDFImpl for ConcatWsFunc { ColumnarValueRef::NonNullableArray(string_array) } } - _ => unreachable!(), }; let mut columns = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(maybe_value) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } } - } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { match array.data_type() { DataType::Utf8 => { @@ -217,7 +219,6 @@ impl ScalarUDFImpl for ConcatWsFunc { } }; } - _ => unreachable!(), } } @@ -268,11 +269,7 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { - Expr::Literal( - ScalarValue::Utf8(delimiter) - | ScalarValue::LargeUtf8(delimiter) - | ScalarValue::Utf8View(delimiter), - ) => { + Expr::Literal(ScalarValue::Utf8(delimiter)) => { match delimiter { // when the delimiter is an empty string, // we can use `concat` to replace `concat_ws` @@ -284,8 +281,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v))) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -366,10 +363,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aa|bb|cc")), &str, @@ -379,8 +376,8 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(Some("")), &str, @@ -390,10 +387,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(None), &str, @@ -403,10 +400,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aa|cc")), &str, @@ -420,7 +417,7 @@ mod tests { #[test] fn concat_ws() -> Result<()> { // sep is scalar - let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c0 = ColumnarValue::from(ScalarValue::Utf8(Some(",".to_string()))); let c1 = ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index c319f80661c3..52bc9ce58252 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -209,8 +209,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(Some(true)), bool, @@ -220,8 +220,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("dddddd")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("dddddd")), ], Ok(Some(false)), bool, @@ -231,8 +231,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("pha")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("pha")), ], Ok(Some(true)), bool, @@ -240,48 +240,6 @@ mod tests { BooleanArray ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( - "DataFusion" - )))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - Ok(()) } } diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 03a1795954d0..82fde772c282 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -111,8 +111,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(Some(false)), bool, @@ -122,8 +122,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("bet")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("bet")), ], Ok(Some(true)), bool, @@ -133,8 +133,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(None), bool, @@ -144,8 +144,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(None), bool, diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 4e1eb213ef57..66e095b84e0f 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -137,7 +137,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + &[ColumnarValue::from(ScalarValue::from("hi THOMAS"))], Ok(Some("Hi Thomas")), &str, Utf8, @@ -145,7 +145,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + &[ColumnarValue::from(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -153,7 +153,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + &[ColumnarValue::from(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -161,45 +161,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "hi THOMAS".to_string() - )))], - Ok(Some("Hi Thomas")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() - )))], - Ok(Some("Hi Thomas With M0re Than 12 Chars")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "".to_string() - )))], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f792914d862e..f228291a7925 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -77,16 +77,10 @@ impl ScalarUDFImpl for OctetLengthFunc { match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + ColumnarValue::Scalar(v) => match v.value() { + ScalarValue::Utf8(v) => Ok(ColumnarValue::from(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)), - )), _ => unreachable!(), }, } @@ -111,7 +105,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + &[ColumnarValue::from(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -133,8 +127,8 @@ mod tests { test_function!( OctetLengthFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))) ], exec_err!("octet_length function requires 1 argument, got 2"), i32, @@ -143,9 +137,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("chars") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "chars" + ))))], Ok(Some(5)), i32, Int32, @@ -153,9 +147,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("josé") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "josé" + ))))], Ok(Some(5)), i32, Int32, @@ -163,9 +157,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "" + ))))], Ok(Some(0)), i32, Int32, @@ -173,42 +167,12 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(None), i32, Int32, Int32Array ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("joséjoséjoséjosé") - )))], - Ok(Some(20)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("josé") - )))], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("") - )))], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); Ok(()) } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 20e4462784b8..64ba7e5f65c8 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -147,8 +147,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::from(ScalarValue::Int64(Some(4))), ], Ok(Some("PgPgPgPg")), &str, @@ -158,8 +158,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::Int64(Some(4))), ], Ok(None), &str, @@ -169,42 +169,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ], - Ok(Some("PgPgPgPg")), - &str, - Utf8, - StringArray - ); - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 13fa3d55672d..0a58b9e55727 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -127,18 +127,17 @@ mod tests { use super::*; use crate::utils::test::test_function; use arrow::array::Array; - use arrow::array::LargeStringArray; use arrow::array::StringArray; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::ScalarValue; #[test] fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("bb")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("ccc")))), ], Ok(Some("aacccdqcccc")), &str, @@ -146,36 +145,6 @@ mod tests { StringArray ); - test_function!( - ReplaceFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( - "aabbb" - )))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))), - ], - Ok(Some("aacc")), - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - ReplaceFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "aabbbcw" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))), - ], - Ok(Some("aaccbcw")), - &str, - Utf8, - StringArray - ); - Ok(()) } } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 8d292315a35a..438e2e611359 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -173,7 +173,7 @@ impl ScalarUDFImpl for SplitPartFunc { if is_scalar { // If all inputs are scalar, keep the output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) + result.map(ColumnarValue::from) } else { result.map(ColumnarValue::Array) } @@ -242,11 +242,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(2))), ], Ok(Some("def")), &str, @@ -256,11 +256,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(20))), ], Ok(Some("")), &str, @@ -270,11 +270,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(-1))), ], Ok(Some("ghi")), &str, @@ -284,11 +284,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(0))), ], exec_err!("field position must not be zero"), &str, diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 8450697cbf30..07302aa06465 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -117,21 +117,11 @@ mod tests { .into_iter() .flat_map(|(a, b, c)| { let utf_8_args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(a.map(|s| s.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(b.map(|s| s.to_string()))), ]; - let large_utf_8_args = vec![ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))), - ]; - - let utf_8_view_args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))), - ]; - - vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)] + vec![(utf_8_args, c)] }); for (args, expected) in test_cases { diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c9dc96b2a935..43bd9eceb0c6 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -130,8 +130,8 @@ where mod tests { use crate::unicode::character_length::CharacterLengthFunc; use crate::utils::test::test_function; - use arrow::array::{Array, Int32Array, Int64Array}; - use arrow::datatypes::DataType::{Int32, Int64}; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -139,25 +139,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - i64, - Int64, - Int64Array - ); - - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index c49784948dd0..f1f84c98ef5e 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -153,8 +153,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("ab")), &str, @@ -164,8 +164,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(200i64)), ], Ok(Some("abcde")), &str, @@ -175,8 +175,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("abc")), &str, @@ -186,8 +186,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-200i64)), ], Ok(Some("")), &str, @@ -197,8 +197,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -208,8 +208,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(None), &str, @@ -219,8 +219,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -230,8 +230,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("joséé")), &str, @@ -241,8 +241,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("joséé")), &str, @@ -253,8 +253,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], internal_err!( "function left requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index e102673c4253..d45b2639cb6c 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -254,8 +254,8 @@ mod tests { use crate::unicode::lpad::LPadFunc; use crate::utils::test::test_function; - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -265,32 +265,8 @@ mod tests { test_function!( LPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH) + ColumnarValue::from(ScalarValue::Utf8($INPUT)), + ColumnarValue::from($LENGTH) ], $EXPECTED, &str, @@ -304,115 +280,9 @@ mod tests { test_function!( LPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - - // largeutf8, utf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - // largeutf8, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - // largeutf8, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - // utf8view, utf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8view, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8view, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ColumnarValue::from(ScalarValue::Utf8($INPUT)), + ColumnarValue::from($LENGTH), + ColumnarValue::from(ScalarValue::Utf8($REPLACE)) ], $EXPECTED, &str, diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index da16d3ee3752..7b72570a83e4 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -104,8 +104,8 @@ fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( #[cfg(test)] mod tests { - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -117,25 +117,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], - $EXPECTED, - &str, - Utf8, - StringArray - ); - - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 9d542bb2c006..7fadb058c19b 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -156,8 +156,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("de")), &str, @@ -167,8 +167,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(200i64)), ], Ok(Some("abcde")), &str, @@ -178,8 +178,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("cde")), &str, @@ -189,8 +189,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-200i64)), ], Ok(Some("")), &str, @@ -200,8 +200,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -211,8 +211,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(None), &str, @@ -222,8 +222,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -233,8 +233,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("éésoj")), &str, @@ -244,8 +244,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("éésoj")), &str, @@ -256,8 +256,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], internal_err!( "function right requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index c1d6f327928f..4c8c2c2ca5f5 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -281,8 +281,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("josé ")), &str, @@ -292,8 +292,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("hi ")), &str, @@ -303,8 +303,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -314,8 +314,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -325,8 +325,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(None), &str, @@ -336,9 +336,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(Some("hixyx")), &str, @@ -348,9 +348,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(21i64)), - ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(21i64)), + ColumnarValue::from(ScalarValue::from("abcdef")), ], Ok(Some("hiabcdefabcdefabcdefa")), &str, @@ -360,9 +360,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(" ")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from(" ")), ], Ok(Some("hi ")), &str, @@ -372,9 +372,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("")), ], Ok(Some("hi")), &str, @@ -384,9 +384,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(None), &str, @@ -396,9 +396,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(None), &str, @@ -408,9 +408,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(None), &str, @@ -420,9 +420,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(Some("joséxyxyxy")), &str, @@ -432,9 +432,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("éñ")), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("éñ")), ], Ok(Some("josééñéñéñ")), &str, @@ -445,8 +445,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(5i64)), ], internal_err!( "function rpad requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index cf10b18ae338..a98d13e372c6 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -168,8 +168,8 @@ where #[cfg(test)] mod tests { - use arrow::array::{Array, Int32Array, Int64Array}; - use arrow::datatypes::DataType::{Int32, Int64}; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -182,8 +182,8 @@ mod tests { test_function!( StrposFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))), - ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))), + ColumnarValue::from(ScalarValue::$t1(Some($lhs.to_owned()))), + ColumnarValue::from(ScalarValue::$t2(Some($rhs.to_owned()))), ], Ok(Some($result)), $t3, @@ -201,47 +201,5 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); - - // LargeUtf8 and LargeUtf8 combinations - test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - - // Utf8 and LargeUtf8 combinations - test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); - - // LargeUtf8 and Utf8 combinations - test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); - - // Utf8View and Utf8View combinations - test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); - test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array); - test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); - test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); - test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); - - // Utf8View and Utf8 combinations - test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); - test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); - - // Utf8View and LargeUtf8 combinations - test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); - test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); - test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); } } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 40d3a4d13e97..a565c50680ca 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -344,8 +344,8 @@ where #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray, StringViewArray}; - use arrow::datatypes::DataType::{Utf8, Utf8View}; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -358,100 +358,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::from(1i64)), - ], - Ok(None), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ], - Ok(Some("alphabet")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "this és longer than 12B" - )))), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), - ], - Ok(Some(" é")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "this is longer than 12B" - )))), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" is longer than 12B")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "joséésoj" - )))), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some("ésoj")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), - ], - Ok(Some("ph")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), - ], - Ok(Some("phabet")), - &str, - Utf8View, - StringViewArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("alphabet")), &str, @@ -461,8 +369,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("ésoj")), &str, @@ -472,8 +380,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-5i64)), ], Ok(Some("joséésoj")), &str, @@ -483,8 +391,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("alphabet")), &str, @@ -494,8 +402,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("lphabet")), &str, @@ -505,8 +413,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), ], Ok(Some("phabet")), &str, @@ -516,8 +424,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("alphabet")), &str, @@ -527,8 +435,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(30i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(30i64)), ], Ok(Some("")), &str, @@ -538,8 +446,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -549,9 +457,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("ph")), &str, @@ -561,9 +469,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from(20i64)), ], Ok(Some("phabet")), &str, @@ -573,9 +481,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("alph")), &str, @@ -586,9 +494,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(10i64)), ], Ok(Some("alph")), &str, @@ -599,9 +507,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(4i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(4i64)), ], Ok(Some("")), &str, @@ -612,9 +520,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("")), &str, @@ -624,9 +532,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from(20i64)), ], Ok(None), &str, @@ -636,9 +544,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -648,9 +556,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), - ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from(-1i64)), ], exec_err!("negative substring length not allowed: substr(, 1, -1)"), &str, @@ -660,9 +568,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("és")), &str, @@ -673,8 +581,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), ], internal_err!( "function substr requires compilation with feature flag: unicode_expressions." @@ -683,29 +591,6 @@ mod tests { Utf8, StringArray ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("abc")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), - ], - Ok(Some("abc")), - &str, - Utf8, - StringArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("overflow")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), - ColumnarValue::Scalar(ScalarValue::from(1i64)), - ], - exec_err!("negative overflow when calculating skip value"), - &str, - Utf8, - StringArray - ); Ok(()) } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 6591ee26403a..9ca3d018d884 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -213,9 +213,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("www")), &str, @@ -225,9 +225,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("www.apache")), &str, @@ -237,9 +237,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("apache.org")), &str, @@ -249,9 +249,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(-1i64)), ], Ok(Some("org")), &str, @@ -261,9 +261,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -273,9 +273,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("")), &str, @@ -285,9 +285,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from("")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("")), &str, diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index a42b9c6cb857..d49559d452c8 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -171,9 +171,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(Some("a2x5")), &str, @@ -183,9 +183,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(None), &str, @@ -195,9 +195,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(None), &str, @@ -207,9 +207,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::Utf8(None)) ], Ok(None), &str, @@ -219,9 +219,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), - ColumnarValue::Scalar(ScalarValue::from("éñí")), - ColumnarValue::Scalar(ScalarValue::from("óü")), + ColumnarValue::from(ScalarValue::from("é2íñ5")), + ColumnarValue::from(ScalarValue::from("éñí")), + ColumnarValue::from(ScalarValue::from("óü")), ], Ok(Some("ó2ü5")), &str, @@ -232,9 +232,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")), + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")), ], internal_err!( "function translate requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d36c5473ba01..163e1141b7fd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -17,12 +17,13 @@ use std::sync::Arc; +use arrow::array::Array; use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::Result; use datafusion_expr::function::Hint; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_expr::{ColumnarValue, Scalar, ScalarFunctionImplementation}; /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. @@ -114,7 +115,7 @@ where let result = (inner)(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) @@ -135,7 +136,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let type_array = $ARGS.iter().map(|arg| arg.data_type().clone()).collect::>(); let return_type = func.return_type(&type_array); match expected { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 7a8746572cfd..b20a303eb5c0 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1238,7 +1238,7 @@ mod test { } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + Ok(ColumnarValue::from(ScalarValue::from("a"))) } } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 0623be504b9b..9e78bc98c384 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1675,6 +1675,21 @@ mod tests { assert_optimized_plan_equal(projection, expected) } + #[test] + fn cast_literal() -> Result<()> { + let projection = LogicalPlanBuilder::empty(false) + .project(vec![Expr::Cast(Cast::new( + Box::new(lit("hello")), + DataType::LargeUtf8, + ))])? + .build()?; + + let expected = "Projection: CAST(Utf8(\"hello\") AS LargeUtf8)\ + \n EmptyRelation"; + + assert_optimized_plan_equal(projection, expected) + } + #[test] fn table_scan_projected_schema() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6f0a64b85cb6..b1a19fbf1086 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3047,7 +3047,7 @@ Projection: a, b } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from(1))) + Ok(ColumnarValue::from(ScalarValue::from(1))) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fc3921d29615..333ba55d05c2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,16 +27,19 @@ use arrow::{ record_batch::RecordBatch, }; +// use datafusion_common::logical::eq::LogicallyEq; +use datafusion_common::logical::equality::LogicallyEq; +use datafusion_common::{cast::as_large_list_array, exec_datafusion_err}; use datafusion_common::{ - cast::{as_large_list_array, as_list_array}, + cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, - WindowFunctionDefinition, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, + Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -626,15 +629,22 @@ impl<'a> ConstEvaluator<'a> { return ConstSimplifyResult::NotSimplified(s); } + let expected_type = match expr.get_type(&self.input_schema) { + Ok(t) => t, + Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), + }; + let phys_expr = match create_physical_expr(&expr, &self.input_schema, self.execution_props) { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + match col_val { ColumnarValue::Array(a) => { if a.len() != 1 { @@ -669,8 +679,22 @@ impl<'a> ConstEvaluator<'a> { } } ColumnarValue::Scalar(s) => { + // TODO(@notfilippo): a fix for the select_arrow_cast error + let actual_type = s.value().data_type(); + if expected_type.logically_eq(&actual_type) + && expected_type.ne(&actual_type) + { + return ConstSimplifyResult::SimplifyRuntimeError( + exec_datafusion_err!( + "Skipping, actual_type {} is logically equal to expected_type {} but not strictly equal", + actual_type, expected_type + ), + expr, + ); + } + // TODO: support the optimization for `Map` type after support impl hash for it - if matches!(&s, ScalarValue::Map(_)) { + if matches!(s.value(), ScalarValue::Map(_)) { ConstSimplifyResult::SimplifyRuntimeError( DataFusionError::NotImplemented( "Const evaluate for Map type is still not supported" @@ -679,7 +703,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s.into_value()) } } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index afcbe528083b..2a4dd3e49a3e 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -406,7 +406,6 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), ScalarValue::from("abc"), - ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), ScalarValue::Decimal128(Some(1000), 19, 2), diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 2118ae96a198..4da1d41249c6 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -327,7 +327,6 @@ fn try_cast_literal_to_type( } try_cast_numeric_literal(lit_value, target_type) .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) } /// Convert a numeric value from one numeric data type to another @@ -476,46 +475,16 @@ fn try_cast_string_literal( target_type: &DataType, ) -> Option { let string_value = match lit_value { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { - s.clone() - } + ScalarValue::Utf8(s) => s.clone(), _ => return None, }; let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), _ => return None, }; Some(scalar_value) } -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - /// Cast a timestamp value from one unit to another fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { let value = value as i64; @@ -606,45 +575,6 @@ mod tests { assert_eq!(optimize_test(expr_input, &schema), expected); } - #[test] - fn test_unwrap_cast_comparison_string() { - let schema = expr_test_schema(); - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("value")), - ); - - // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') - let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); - let expected = col("str1").eq(lit("value")); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') - let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); - let expected = col("tag").eq(lit(dict.clone())); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // Verify reversed argument order - // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 - let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = lit("value").eq(col("str1")); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_large_string() { - let schema = expr_test_schema(); - // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), - ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); - let expected = - col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - #[test] fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); @@ -925,7 +855,6 @@ mod tests { ScalarValue::Decimal128(None, 3, 0), ScalarValue::Decimal128(None, 8, 2), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ]; for s1 in &scalars { @@ -1378,45 +1307,4 @@ mod tests { .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); } - - #[test] - fn test_try_cast_to_string_type() { - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - #[test] - fn test_try_cast_to_dictionary_type() { - fn dictionary_type(t: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) - } - fn dictionary_value(value: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) - } - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - for s in &scalars { - expect_cast( - s.clone(), - dictionary_type(s.data_type()), - ExpectedCast::Value(dictionary_value(s.clone())), - ); - expect_cast( - dictionary_value(s.clone()), - s.data_type(), - ExpectedCast::Value(s.clone()), - ) - } - } } diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index c47ec9d75d50..b6f64245ae64 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -21,9 +21,9 @@ use arrow::buffer::NullBuffer; use arrow::compute::SortOptions; use arrow::error::ArrowError; use datafusion_common::DataFusionError; +use datafusion_common::Result; use datafusion_common::{arrow_datafusion_err, internal_err}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::columnar_value::{ColumnarValue, Scalar}; use datafusion_expr_common::operator::Operator; use std::sync::Arc; @@ -40,15 +40,14 @@ pub fn apply( Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) } (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ColumnarValue::Array(f(&left.value().to_scalar()?, &right.as_ref())?), ), (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), ), (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) + Ok(ColumnarValue::Scalar(Scalar::try_from_array(&array, 0)?)) } } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index e115ec3c74fe..9288c63c587e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -213,7 +213,7 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value))|ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { + if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { let flag = if $FLAG { Some("i") } else { None }; let mut array = paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?; @@ -253,8 +253,8 @@ impl PhysicalExpr for BinaryExpr { let lhs = self.left.evaluate(batch)?; let rhs = self.right.evaluate(batch)?; - let left_data_type = lhs.data_type(); - let right_data_type = rhs.data_type(); + let left_data_type = lhs.data_type().clone(); + let right_data_type = rhs.data_type().clone(); let schema = batch.schema(); let input_schema = schema.as_ref(); @@ -296,12 +296,15 @@ impl PhysicalExpr for BinaryExpr { let scalar_result = match (&lhs, &rhs) { (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { // if left is array and right is literal(not NULL) - use scalar operations - if scalar.is_null() { + if scalar.value().is_null() { None } else { - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + self.evaluate_array_scalar(array, scalar.clone().into_value())? + .map(|r| { + r.and_then(|a| { + to_result_type_array(&self.op, a, &result_type) + }) + }) } } (_, _) => None, // default to array implementation @@ -1492,87 +1495,6 @@ mod tests { Ok(()) } - #[test] - fn plus_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Plus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(1))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn plus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value + 1), - Some(value), - None, - Some(value + 2), - Some(value + 1), - ], - 11, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Plus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn minus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1603,98 +1525,6 @@ mod tests { Ok(()) } - #[test] - fn minus_op_dict() -> Result<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - ), - Field::new( - "b", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - ), - ]); - - let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]); - let a = DictionaryArray::try_new(keys, Arc::new(a))?; - - let b = Int32Array::from(vec![1, 2, 4, 8, 16]); - let keys = Int8Array::from(vec![0, 1, 1, 2, 1]); - let b = DictionaryArray::try_new(keys, Arc::new(b))?; - - apply_arithmetic::( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b)], - Operator::Minus, - Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]), - )?; - - Ok(()) - } - - #[test] - fn minus_op_dict_decimal() -> Result<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - ), - Field::new( - "b", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - ), - ]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value), - Some(value + 2), - Some(value - 1), - Some(value + 1), - ], - 10, - 0, - )); - - let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]); - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value + 1), - Some(value + 3), - Some(value), - Some(value + 2), - ], - 10, - 0, - )); - let b = DictionaryArray::try_new(keys, decimal_array)?; - - apply_arithmetic( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b)], - Operator::Minus, - create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0), - )?; - - Ok(()) - } - #[test] fn minus_op_scalar() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -1711,87 +1541,6 @@ mod tests { Ok(()) } - #[test] - fn minus_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Minus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(1))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn minus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value - 1), - Some(value - 2), - None, - Some(value), - Some(value - 1), - ], - 11, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Minus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn multiply_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1923,81 +1672,6 @@ mod tests { Ok(()) } - #[test] - fn multiply_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Multiply, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn multiply_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(246), Some(244), None, Some(248), Some(246)], - 21, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Multiply, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn divide_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2141,81 +1815,6 @@ mod tests { Ok(()) } - #[test] - fn divide_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Divide, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn divide_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(615000), Some(610000), None, Some(620000), Some(615000)], - 14, - 4, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Divide, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn modulus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2349,81 +1948,6 @@ mod tests { Ok(()) } - #[test] - fn modules_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Modulo, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn modulus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(1), Some(0), None, Some(0), Some(1)], - 10, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Modulo, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - fn apply_arithmetic( schema: SchemaRef, data: Vec, @@ -3216,97 +2740,6 @@ mod tests { .unwrap() } - #[test] - fn comparison_dict_decimal_scalar_expr_test() -> Result<()> { - // scalar of decimal compare with dictionary decimal array - let value_i128 = 123; - let decimal_scalar = ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)), - ); - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(25, 3)), - ), - true, - )])); - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value_i128), - None, - Some(value_i128 - 1), - Some(value_i128 + 1), - ], - 25, - 3, - )); - - let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]); - let dictionary = - Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef; - - // array = scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Eq, - &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), - ) - .unwrap(); - // array != scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::NotEq, - &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), - ) - .unwrap(); - // array < scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Lt, - &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), - ) - .unwrap(); - - // array <= scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::LtEq, - &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), - ) - .unwrap(); - // array > scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Gt, - &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), - ) - .unwrap(); - - // array >= scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::GtEq, - &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), - ) - .unwrap(); - - Ok(()) - } - #[test] fn comparison_decimal_expr_test() -> Result<()> { // scalar of decimal compare with decimal array diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 712175c9afbe..a0852f217a90 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -223,12 +223,12 @@ impl CaseExpr { .evaluate_selection(batch, &when_match)?; current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_match)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_match, &then_value.to_scalar()?, ¤t_value)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Null => nullif(current_value.as_ref(), &when_match)?, + then_value => { + zip(&when_match, &then_value.to_scalar()?, ¤t_value)? + } + }, ColumnarValue::Array(then_value) => { zip(&when_match, &then_value, ¤t_value)? } @@ -294,12 +294,12 @@ impl CaseExpr { .evaluate_selection(batch, &when_value)?; current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Null => nullif(current_value.as_ref(), &when_value)?, + then_value => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + }, ColumnarValue::Array(then_value) => { zip(&when_value, &then_value, ¤t_value)? } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 0a3e5fcefcf6..63824d40bb7d 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -37,10 +37,8 @@ use datafusion_common::cast::{ as_boolean_array, as_generic_binary_array, as_string_array, }; use datafusion_common::hash_utils::HashValue; -use datafusion_common::{ - exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue, -}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{exec_err, internal_err, not_impl_err, DFSchema, Result}; +use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; @@ -222,13 +220,12 @@ fn evaluate_list( exec_err!("InList expression must evaluate to a scalar") } // Flatten dictionary values - ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v), ColumnarValue::Scalar(s) => Ok(s), }) }) .collect::>>()?; - ScalarValue::iter_to_array(scalars) + Scalar::iter_to_array(scalars) } fn try_cast_static_filter_to_set( @@ -453,7 +450,7 @@ mod tests { use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; - use datafusion_common::plan_err; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::binary::comparison_coercion; type InListCastResult = (Arc, Vec>); @@ -1332,96 +1329,4 @@ mod tests { Ok(()) } - - #[test] - fn in_list_utf8_with_dict_types() -> Result<()> { - fn dict_lit(key_type: DataType, value: &str) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::new_utf8(value.to_string())), - )) - } - - fn null_dict_lit(key_type: DataType) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::Utf8(None)), - )) - } - - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), - true, - )]); - let a: UInt16DictionaryArray = - vec![Some("a"), Some("d"), None].into_iter().collect(); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ("a", "b")" - let lists = [ - vec![lit("a"), lit("b")], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a not in ("a", "b")" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a in ("a", "b", null)" - let lists = [ - vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - null_dict_lit(DataType::UInt16), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a not in ("a", "b", null)" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 58559352d44c..50c3cbab9baf 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -76,8 +76,8 @@ impl PhysicalExpr for IsNotNullExpr { let is_not_null = super::is_null::compute_is_not_null(array)?; Ok(ColumnarValue::Array(Arc::new(is_not_null))) } - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( - ScalarValue::Boolean(Some(!scalar.is_null())), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + ScalarValue::Boolean(Some(!scalar.value().is_null())), )), } } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3cdb49bcab42..cdc5f101002e 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -80,8 +80,8 @@ impl PhysicalExpr for IsNullExpr { ColumnarValue::Array(array) => { Ok(ColumnarValue::Array(Arc::new(compute_is_null(array)?))) } - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( - ScalarValue::Boolean(Some(scalar.is_null())), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + ScalarValue::Boolean(Some(scalar.value().is_null())), )), } } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index ed24e9028153..e064abbca35c 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -72,7 +72,7 @@ impl PhysicalExpr for Literal { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - Ok(ColumnarValue::Scalar(self.value.clone())) + Ok(ColumnarValue::from(self.value.clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index b5ebc250cb89..01429614d552 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -83,9 +83,9 @@ impl PhysicalExpr for NegativeExpr { let result = neg_wrapping(array.as_ref())?; Ok(ColumnarValue::Array(result)) } - ColumnarValue::Scalar(scalar) => { - Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) - } + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + (scalar.into_value().arithmetic_negate())?, + )), } } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index b69954e00bba..7a0afaa1a637 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -78,13 +78,11 @@ impl PhysicalExpr for NotExpr { ))) } ColumnarValue::Scalar(scalar) => { - if scalar.is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + if scalar.value().is_null() { + return Ok(ColumnarValue::from(ScalarValue::Boolean(None))); } - let bool_value: bool = scalar.try_into()?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - !bool_value, - )))) + let bool_value: bool = scalar.into_value().try_into()?; + Ok(ColumnarValue::from(ScalarValue::Boolean(Some(!bool_value)))) } } } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 43b6c993d2b2..8da0717165e3 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,12 +23,12 @@ use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; -use arrow::compute::{cast_with_options, CastOptions}; +use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast @@ -83,18 +83,7 @@ impl PhysicalExpr for TryCastExpr { safe: true, format_options: DEFAULT_FORMAT_OPTIONS, }; - match value { - ColumnarValue::Array(array) => { - let cast = cast_with_options(&array, &self.cast_type, &options)?; - Ok(ColumnarValue::Array(cast)) - } - ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array()?; - let cast_array = cast_with_options(&array, &self.cast_type, &options)?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } - } + value.cast_to(&self.cast_type, Some(&options)) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e33c28df1988..7935839d1c98 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,8 +34,8 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, Scalar, ScalarFunctionImplementation}; pub use crate::scalar_function::create_physical_expr; // For backward compatibility @@ -46,7 +46,7 @@ pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result> ColumnarValue::values_to_arrays(args) } -/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function +/// Decorates a function to handle [`Scalar`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. /// That's said its output will be same for all input rows in a batch. @@ -69,9 +69,9 @@ where make_scalar_function_with_hints(inner, vec![]) } -/// Just like [`make_scalar_function`], decorates the given function to handle both [`ScalarValue`]s and arrays. +/// Just like [`make_scalar_function`], decorates the given function to handle both [`Scalar`]s and arrays. /// Additionally can receive a `hints` vector which can be used to control the output arrays when generating them -/// from [`ScalarValue`]s. +/// from [`Scalar`]s. /// /// Each element of the `hints` vector gets mapped to the corresponding argument of the function. The number of hints /// can be less or greater than the number of arguments (for functions with variable number of arguments). Each unmapped @@ -113,7 +113,7 @@ where let result = (inner)(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7..195a32042e6f 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -503,10 +503,7 @@ pub(crate) mod tests { let schema_big = Arc::new(Schema::new(vec![int_field, dict_field])); let pred = in_list( Arc::new(Column::new_with_schema("id", &schema_big).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], + vec![lit(ScalarValue::from("2"))], &false, &schema_big, ) @@ -516,10 +513,7 @@ pub(crate) mod tests { let expected = in_list( Arc::new(Column::new_with_schema("id", &schema_small).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], + vec![lit(ScalarValue::from("2"))], &false, &schema_small, ) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c3bc7b042e65..98292a951626 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1074,7 +1074,7 @@ pub fn finalize_aggregation( accumulators .iter_mut() .map(|accumulator| { - accumulator.state().and_then(|e| { + accumulator.state_as_scalars().and_then(|e| { e.iter() .map(|v| v.to_array()) .collect::>>() @@ -1090,7 +1090,9 @@ pub fn finalize_aggregation( // Merge the state to the final value accumulators .iter_mut() - .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .map(|accumulator| { + accumulator.evaluate_as_scalar().and_then(|v| v.to_array()) + }) .collect() } } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index e01aea1fdd6b..35397adb01cf 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -33,6 +33,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_execution::TaskContext; +use datafusion_expr::Scalar; use datafusion_physical_expr::EquivalenceProperties; /// Execution plan for values list based relation (produces constant rows) @@ -74,7 +75,7 @@ impl ValuesExec { match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) + Ok(Scalar::from(ScalarValue::try_from_array(&a, 0)?)) } Ok(ColumnarValue::Array(a)) => { plan_err!( @@ -85,7 +86,7 @@ impl ValuesExec { } }) .collect::>>() - .and_then(ScalarValue::iter_to_array) + .and_then(Scalar::iter_to_array) }) .collect::>>()?; let batch = RecordBatch::try_new_with_options( diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index d1b4374fc0e7..95a482c991cf 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -368,8 +368,8 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Ok(match value { Value::BoolValue(v) => Self::Boolean(Some(*v)), Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), - Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())), - Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), + Value::Utf8ViewValue(v) => Self::Utf8(Some(v.to_owned())), + Value::LargeUtf8Value(v) => Self::Utf8(Some(v.to_owned())), Value::Int8Value(v) => Self::Int8(Some(*v as i8)), Value::Int16Value(v) => Self::Int16(Some(*v as i16)), Value::Int32Value(v) => Self::Int32(Some(*v)), @@ -564,25 +564,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } } } - Value::DictionaryValue(v) => { - let index_type: DataType = v - .index_type - .as_ref() - .ok_or_else(|| Error::required("index_type"))? - .try_into()?; - - let value: Self = v - .value - .as_ref() - .ok_or_else(|| Error::required("value"))? - .as_ref() - .try_into()?; - - Self::Dictionary(Box::new(index_type), Box::new(value)) - } + Value::DictionaryValue(v) => v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?, Value::BinaryValue(v) => Self::Binary(Some(v.clone())), - Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), - Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), + Value::BinaryViewValue(v) => Self::Binary(Some(v.clone())), + Value::LargeBinaryValue(v) => Self::Binary(Some(v.clone())), Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some( IntervalDayTimeType::make_value(v.days, v.milliseconds), )), diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index ebb53ae7577c..d7411579230c 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -347,16 +347,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::Utf8Value(s.to_owned()) }) } - ScalarValue::LargeUtf8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeUtf8Value(s.to_owned()) - }) - } - ScalarValue::Utf8View(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Utf8ViewValue(s.to_owned()) - }) - } ScalarValue::List(arr) => { encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } @@ -474,16 +464,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::BinaryValue(s.to_owned()) }) } - ScalarValue::BinaryView(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::BinaryViewValue(s.to_owned()) - }) - } - ScalarValue::LargeBinary(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeBinaryValue(s.to_owned()) - }) - } ScalarValue::FixedSizeBinary(length, val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { @@ -624,18 +604,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { let val = protobuf::ScalarValue { value: Some(val) }; Ok(val) } - - ScalarValue::Dictionary(index_type, val) => { - let value: protobuf::ScalarValue = val.as_ref().try_into()?; - Ok(protobuf::ScalarValue { - value: Some(Value::DictionaryValue(Box::new( - protobuf::ScalarDictionaryValue { - index_type: Some(index_type.as_ref().try_into()?), - value: Some(Box::new(value)), - }, - ))), - }) - } } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1ff39e9e65b7..2603e772b3cf 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1211,7 +1211,6 @@ fn round_trip_scalar_values_and_data_types() { ScalarValue::UInt32(None), ScalarValue::UInt64(None), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), @@ -1250,9 +1249,6 @@ fn round_trip_scalar_values_and_data_types() { ScalarValue::UInt64(Some(u64::MAX)), ScalarValue::UInt64(Some(0)), ScalarValue::Utf8(Some(String::from("Test string "))), - ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), - ScalarValue::Utf8View(Some(String::from("Test stringview"))), - ScalarValue::BinaryView(Some(b"binaryview".to_vec())), ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(i32::MAX)), ScalarValue::Date32(None), @@ -1372,18 +1368,8 @@ fn round_trip_scalar_values_and_data_types() { vec![Some(vec![Some(1), Some(2), Some(3)])], 3, ))), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("foo")), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(None)), - ), ScalarValue::Binary(Some(b"bar".to_vec())), ScalarValue::Binary(None), - ScalarValue::LargeBinary(Some(b"bar".to_vec())), - ScalarValue::LargeBinary(None), ScalarStructBuilder::new() .with_scalar( Field::new("a", DataType::Int32, true), @@ -1404,20 +1390,6 @@ fn round_trip_scalar_values_and_data_types() { Field::new("b", DataType::Boolean, false), ScalarValue::from(false), ) - .with_scalar( - Field::new( - "c", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), - ScalarValue::Dictionary( - Box::new(DataType::UInt16), - Box::new("value".into()), - ), - ) .build() .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ @@ -1777,7 +1749,6 @@ fn roundtrip_null_scalar_values() { ScalarValue::UInt32(None), ScalarValue::UInt64(None), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ScalarValue::Date32(None), ScalarValue::TimestampMicrosecond(None, None), ScalarValue::TimestampNanosecond(None, None), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 58f6015ee336..72be5bb5efd6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -42,8 +42,7 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, - FileSinkConfig, ParquetExec, + wrap_partition_type_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; @@ -691,8 +690,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { let mut file_group = PartitionedFile::new("/path/to/part=0/file.parquet".to_string(), 1024); - file_group.partition_values = - vec![wrap_partition_value_in_dict(ScalarValue::Int64(Some(0)))]; + file_group.partition_values = vec![ScalarValue::Int64(Some(0))]; let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); let scan_config = FileScanConfig { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 516833a39f1e..6fa1a9c8be32 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1010,27 +1010,11 @@ impl Unparser<'_> { ast::Value::SingleQuotedString(str.to_string()), )), ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8View(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), - ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), - ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::BinaryView(Some(_)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::BinaryView(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeBinary(..) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::LargeBinary(Some(_)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::LargeBinary(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::List(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), @@ -1161,7 +1145,6 @@ impl Unparser<'_> { ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index ec8a51488564..48cbd36d2132 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -407,11 +407,11 @@ query TT explain SELECT * from test where column2 = '1'; ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: test.column2 = CAST(Utf8("1") AS Dictionary(Int32, Utf8)) 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: column2@1 = CAST(1 AS Dictionary(Int32, Utf8)) 03)----MemoryExec: partitions=1, partition_sizes=[1] # try literal = col to verify order doesn't matter @@ -420,11 +420,11 @@ query TT explain SELECT * from test where '1' = column2 ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: CAST(Utf8("1") AS Dictionary(Int32, Utf8)) = test.column2 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: CAST(1 AS Dictionary(Int32, Utf8)) = column2@1 03)----MemoryExec: partitions=1, partition_sizes=[1] @@ -438,9 +438,9 @@ query TT explain SELECT * from test where column2 = 1; ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: CAST(test.column2 AS Utf8) = Utf8("1") 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: CAST(column2@1 AS Utf8) = 1 03)----MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 7df43bb7eddb..8aa9cf87a31d 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -260,7 +260,7 @@ explain SELECT column1_utf8 from test where column1_utf8view = 'Andrew'; ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: test.column1_utf8view = CAST(Utf8("Andrew") AS Utf8View) 03)----TableScan: test projection=[column1_utf8, column1_utf8view] # reverse order should be the same @@ -269,21 +269,21 @@ explain SELECT column1_utf8 from test where 'Andrew' = column1_utf8view; ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: CAST(Utf8("Andrew") AS Utf8View) = test.column1_utf8view 03)----TableScan: test projection=[column1_utf8, column1_utf8view] query TT explain SELECT column1_utf8 from test where column1_utf8 = arrow_cast('Andrew', 'Utf8View'); ---- logical_plan -01)Filter: test.column1_utf8 = Utf8("Andrew") +01)Filter: CAST(test.column1_utf8 AS Utf8View) = CAST(Utf8("Andrew") AS Utf8View) 02)--TableScan: test projection=[column1_utf8] query TT explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Utf8View') = column1_utf8; ---- logical_plan -01)Filter: test.column1_utf8 = Utf8("Andrew") +01)Filter: CAST(Utf8("Andrew") AS Utf8View) = CAST(test.column1_utf8 AS Utf8View) 02)--TableScan: test projection=[column1_utf8] query TT @@ -291,7 +291,7 @@ explain SELECT column1_utf8 from test where column1_utf8view = arrow_cast('Andre ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: test.column1_utf8view = CAST(CAST(Utf8("Andrew") AS Dictionary(Int32, Utf8)) AS Utf8View) 03)----TableScan: test projection=[column1_utf8, column1_utf8view] query TT @@ -299,7 +299,7 @@ explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Dictionary(Int ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: CAST(CAST(Utf8("Andrew") AS Dictionary(Int32, Utf8)) AS Utf8View) = test.column1_utf8view 03)----TableScan: test projection=[column1_utf8, column1_utf8view] # compare string / stringview @@ -421,8 +421,10 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 -02)--TableScan: test projection=[column1_utf8view] +01)Projection: starts_with(test.column1_utf8view, CAST(Utf8("äöüß") AS Utf8View)) AS c1, starts_with(test.column1_utf8view, CAST(Utf8("") AS Utf8View)) AS c2, starts_with(test.column1_utf8view, __common_expr_1) AS c3, starts_with(__common_expr_1, test.column1_utf8view) AS c4 +02)--Projection: CAST(NULL AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] + ### Test TRANSLATE @@ -605,8 +607,9 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: test.column1_utf8view LIKE Utf8View("foo") AS like, test.column1_utf8view ILIKE Utf8View("foo") AS ilike -02)--TableScan: test projection=[column1_utf8view] +01)Projection: test.column1_utf8view LIKE __common_expr_1 AS like, test.column1_utf8view ILIKE __common_expr_1 AS ilike +02)--Projection: CAST(Utf8("foo") AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] ## Ensure no casts for SUBSTR @@ -727,7 +730,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: btrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: btrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test BTRIM with Utf8View bytes longer than 12 @@ -737,7 +740,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: btrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: btrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test BTRIM outputs @@ -772,7 +775,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ltrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: ltrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test LTRIM with Utf8View bytes longer than 12 @@ -782,7 +785,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ltrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: ltrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test LTRIM outputs @@ -818,7 +821,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: rtrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: rtrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test RTRIM with Utf8View bytes longer than 12 @@ -828,7 +831,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: rtrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: rtrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test RTRIM outputs @@ -932,7 +935,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ends_with(test.column1_utf8view, Utf8View("foo")) AS c1, ends_with(test.column2_utf8view, test.column2_utf8view) AS c2 +01)Projection: ends_with(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c1, ends_with(test.column2_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for LEVENSHTEIN @@ -943,7 +946,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: levenshtein(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for LOWER @@ -1161,7 +1164,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1 +01)Projection: overlay(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View), Int64(2)) AS c1 02)--TableScan: test projection=[column1_utf8view] query T @@ -1220,8 +1223,10 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: replace(test.column1_utf8view, Utf8View("foo"), Utf8View("bar")) AS c1, replace(test.column1_utf8view, test.column2_utf8view, Utf8View("bar")) AS c2 -02)--TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: replace(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View), __common_expr_1) AS c1, replace(test.column1_utf8view, test.column2_utf8view, __common_expr_1) AS c2 +02)--Projection: CAST(Utf8("bar") AS Utf8View) AS __common_expr_1, test.column1_utf8view, test.column2_utf8view +03)----TableScan: test projection=[column1_utf8view, column2_utf8view] + query TT SELECT @@ -1362,8 +1367,9 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2 -02)--TableScan: test projection=[column1_utf8view] +01)Projection: substr_index(test.column1_utf8view, __common_expr_1, Int64(1)) AS c, substr_index(test.column1_utf8view, __common_expr_1, Int64(2)) AS c2 +02)--Projection: CAST(Utf8("a") AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] query TT SELECT @@ -1384,7 +1390,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(test.column1_utf8view, Utf8View("foo")) AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: starts_with(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for TRANSLATE @@ -1404,7 +1410,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c +01)Projection: find_in_set(test.column1_utf8view, CAST(Utf8("a,b,c,d") AS Utf8View)) AS c 02)--TableScan: test projection=[column1_utf8view] query I diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e6bfc67eda81..b163dcc02b7f 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1871,18 +1871,12 @@ fn from_substrait_literal( Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), Some(LiteralType::String(s)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::LargeBinary(Some(b.clone())) - } - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -2282,7 +2276,6 @@ fn from_substrait_null( }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Binary(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeBinary(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), @@ -2290,7 +2283,6 @@ fn from_substrait_null( // FixedBinary is not supported because `None` doesn't have length r#type::Kind::String(string) => match string.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Utf8(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeUtf8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a923aaf31abb..0481c002b495 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1904,14 +1904,6 @@ fn to_substrait_literal( LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), - ScalarValue::LargeBinary(Some(b)) => ( - LiteralType::Binary(b.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::BinaryView(Some(b)) => ( - LiteralType::Binary(b.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), ScalarValue::FixedSizeBinary(_, Some(b)) => ( LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_VARIATION_REF, @@ -1920,14 +1912,6 @@ fn to_substrait_literal( LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), - ScalarValue::LargeUtf8(Some(s)) => ( - LiteralType::String(s.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8View(Some(s)) => ( - LiteralType::String(s.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), ScalarValue::Decimal128(v, p, s) if v.is_some() => ( LiteralType::Decimal(Decimal { value: v.unwrap().to_le_bytes().to_vec(),