@@ -748,6 +748,40 @@ pub fn get_dict_value<K: ArrowDictionaryKeyType>(
748
748
Ok ( ( dict_array. values ( ) , dict_array. key ( index) ) )
749
749
}
750
750
751
+ /// Create a dictionary array representing all the values in values
752
+ fn dict_from_values < K : ArrowDictionaryKeyType > (
753
+ values_array : ArrayRef ,
754
+ ) -> Result < ArrayRef > {
755
+ // Create a key array with `size` elements of 0..array_len for all
756
+ // non-null value elements
757
+ let key_array: PrimitiveArray < K > = ( 0 ..values_array. len ( ) )
758
+ . map ( |index| {
759
+ if values_array. is_valid ( index) {
760
+ let native_index = K :: Native :: from_usize ( index) . ok_or_else ( || {
761
+ DataFusionError :: Internal ( format ! (
762
+ "Can not create index of type {} from value {}" ,
763
+ K :: DATA_TYPE ,
764
+ index
765
+ ) )
766
+ } ) ?;
767
+ Ok ( Some ( native_index) )
768
+ } else {
769
+ Ok ( None )
770
+ }
771
+ } )
772
+ . collect :: < Result < Vec < _ > > > ( ) ?
773
+ . into_iter ( )
774
+ . collect ( ) ;
775
+
776
+ // create a new DictionaryArray
777
+ //
778
+ // Note: this path could be made faster by using the ArrayData
779
+ // APIs and skipping validation, if it every comes up in
780
+ // performance traces.
781
+ let dict_array = DictionaryArray :: < K > :: try_new ( key_array, values_array) ?;
782
+ Ok ( Arc :: new ( dict_array) )
783
+ }
784
+
751
785
macro_rules! typed_cast_tz {
752
786
( $array: expr, $index: expr, $ARRAYTYPE: ident, $SCALAR: ident, $TZ: expr) => { {
753
787
use std:: any:: type_name;
@@ -1545,6 +1579,7 @@ impl ScalarValue {
1545
1579
Ok ( Scalar :: new ( self . to_array_of_size ( 1 ) ?) )
1546
1580
}
1547
1581
1582
+
1548
1583
/// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`]
1549
1584
/// corresponding to those values. For example, an iterator of
1550
1585
/// [`ScalarValue::Int32`] would be converted to an [`Int32Array`].
@@ -1596,6 +1631,15 @@ impl ScalarValue {
1596
1631
Some ( sv) => sv. data_type ( ) ,
1597
1632
} ;
1598
1633
1634
+ Self :: iter_to_array_of_type ( scalars. collect ( ) , & data_type)
1635
+ }
1636
+
1637
+ fn iter_to_array_of_type (
1638
+ scalars : Vec < ScalarValue > ,
1639
+ data_type : & DataType ,
1640
+ ) -> Result < ArrayRef > {
1641
+ let scalars = scalars. into_iter ( ) ;
1642
+
1599
1643
/// Creates an array of $ARRAY_TY by unpacking values of
1600
1644
/// SCALAR_TY for primitive types
1601
1645
macro_rules! build_array_primitive {
@@ -1685,7 +1729,9 @@ impl ScalarValue {
1685
1729
DataType :: UInt32 => build_array_primitive ! ( UInt32Array , UInt32 ) ,
1686
1730
DataType :: UInt64 => build_array_primitive ! ( UInt64Array , UInt64 ) ,
1687
1731
DataType :: Utf8 => build_array_string ! ( StringArray , Utf8 ) ,
1732
+ DataType :: LargeUtf8 => build_array_string ! ( LargeStringArray , Utf8 ) ,
1688
1733
DataType :: Binary => build_array_string ! ( BinaryArray , Binary ) ,
1734
+ DataType :: LargeBinary => build_array_string ! ( LargeBinaryArray , Binary ) ,
1689
1735
DataType :: Date32 => build_array_primitive ! ( Date32Array , Date32 ) ,
1690
1736
DataType :: Date64 => build_array_primitive ! ( Date64Array , Date64 ) ,
1691
1737
DataType :: Time32 ( TimeUnit :: Second ) => {
@@ -1758,11 +1804,8 @@ impl ScalarValue {
1758
1804
if let Some ( DataType :: FixedSizeList ( f, l) ) = first_non_null_data_type {
1759
1805
for array in arrays. iter_mut ( ) {
1760
1806
if array. is_null ( 0 ) {
1761
- * array = Arc :: new ( FixedSizeListArray :: new_null (
1762
- Arc :: clone ( & f) ,
1763
- l,
1764
- 1 ,
1765
- ) ) ;
1807
+ * array =
1808
+ Arc :: new ( FixedSizeListArray :: new_null ( f. clone ( ) , l, 1 ) ) ;
1766
1809
}
1767
1810
}
1768
1811
}
@@ -1771,13 +1814,28 @@ impl ScalarValue {
1771
1814
}
1772
1815
DataType :: List ( _)
1773
1816
| DataType :: LargeList ( _)
1774
- | DataType :: Map ( _, _)
1775
1817
| DataType :: Struct ( _)
1776
1818
| DataType :: Union ( _, _) => {
1777
1819
let arrays = scalars. map ( |s| s. to_array ( ) ) . collect :: < Result < Vec < _ > > > ( ) ?;
1778
1820
let arrays = arrays. iter ( ) . map ( |a| a. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
1779
1821
arrow:: compute:: concat ( arrays. as_slice ( ) ) ?
1780
1822
}
1823
+ DataType :: Dictionary ( key_type, value_type) => {
1824
+ let values = Self :: iter_to_array ( scalars) ?;
1825
+ assert_eq ! ( values. data_type( ) , value_type. as_ref( ) ) ;
1826
+
1827
+ match key_type. as_ref ( ) {
1828
+ DataType :: Int8 => dict_from_values :: < Int8Type > ( values) ?,
1829
+ DataType :: Int16 => dict_from_values :: < Int16Type > ( values) ?,
1830
+ DataType :: Int32 => dict_from_values :: < Int32Type > ( values) ?,
1831
+ DataType :: Int64 => dict_from_values :: < Int64Type > ( values) ?,
1832
+ DataType :: UInt8 => dict_from_values :: < UInt8Type > ( values) ?,
1833
+ DataType :: UInt16 => dict_from_values :: < UInt16Type > ( values) ?,
1834
+ DataType :: UInt32 => dict_from_values :: < UInt32Type > ( values) ?,
1835
+ DataType :: UInt64 => dict_from_values :: < UInt64Type > ( values) ?,
1836
+ _ => unreachable ! ( "Invalid dictionary keys type: {:?}" , key_type) ,
1837
+ }
1838
+ }
1781
1839
DataType :: FixedSizeBinary ( size) => {
1782
1840
let array = scalars
1783
1841
. map ( |sv| {
@@ -1806,18 +1864,15 @@ impl ScalarValue {
1806
1864
| DataType :: Time32 ( TimeUnit :: Nanosecond )
1807
1865
| DataType :: Time64 ( TimeUnit :: Second )
1808
1866
| DataType :: Time64 ( TimeUnit :: Millisecond )
1867
+ | DataType :: Map ( _, _)
1809
1868
| DataType :: RunEndEncoded ( _, _)
1810
- | DataType :: ListView ( _)
1811
- | DataType :: LargeBinary
1812
- | DataType :: BinaryView
1813
- | DataType :: LargeUtf8
1814
1869
| DataType :: Utf8View
1815
- | DataType :: Dictionary ( _, _)
1870
+ | DataType :: BinaryView
1871
+ | DataType :: ListView ( _)
1816
1872
| DataType :: LargeListView ( _) => {
1817
1873
return _internal_err ! (
1818
- "Unsupported creation of {:?} array from ScalarValue {:?}" ,
1819
- data_type,
1820
- scalars. peek( )
1874
+ "Unsupported creation of {:?} array" ,
1875
+ data_type
1821
1876
) ;
1822
1877
}
1823
1878
} ;
@@ -1940,7 +1995,7 @@ impl ScalarValue {
1940
1995
let values = if values. is_empty ( ) {
1941
1996
new_empty_array ( data_type)
1942
1997
} else {
1943
- Self :: iter_to_array ( values. iter ( ) . cloned ( ) ) . unwrap ( )
1998
+ Self :: iter_to_array_of_type ( values. to_vec ( ) , data_type ) . unwrap ( )
1944
1999
} ;
1945
2000
Arc :: new ( array_into_list_array ( values, nullable) )
1946
2001
}
@@ -2931,6 +2986,11 @@ impl ScalarValue {
2931
2986
. map ( |sv| sv. size ( ) - std:: mem:: size_of_val ( sv) )
2932
2987
. sum :: < usize > ( )
2933
2988
}
2989
+
2990
+ pub fn supported_datatype ( data_type : & DataType ) -> Result < DataType , DataFusionError > {
2991
+ let scalar = Self :: try_from ( data_type) ?;
2992
+ Ok ( scalar. data_type ( ) )
2993
+ }
2934
2994
}
2935
2995
2936
2996
macro_rules! impl_scalar {
@@ -5456,22 +5516,23 @@ mod tests {
5456
5516
5457
5517
check_scalar_cast ( ScalarValue :: Float64 ( None ) , DataType :: Int16 ) ;
5458
5518
5459
- check_scalar_cast (
5460
- ScalarValue :: from ( "foo" ) ,
5461
- DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ,
5462
- ) ;
5463
-
5464
- check_scalar_cast (
5465
- ScalarValue :: Utf8 ( None ) ,
5466
- DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ,
5467
- ) ;
5468
-
5469
- check_scalar_cast ( ScalarValue :: Utf8 ( None ) , DataType :: Utf8View ) ;
5470
- check_scalar_cast ( ScalarValue :: from ( "foo" ) , DataType :: Utf8View ) ;
5471
- check_scalar_cast (
5472
- ScalarValue :: from ( "larger than 12 bytes string" ) ,
5473
- DataType :: Utf8View ,
5474
- ) ;
5519
+ // TODO(@notfilippo): this tests fails but it should check if logically equal
5520
+ // check_scalar_cast(
5521
+ // ScalarValue::from("foo"),
5522
+ // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
5523
+ // );
5524
+ //
5525
+ // check_scalar_cast(
5526
+ // ScalarValue::Utf8(None),
5527
+ // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
5528
+ // );
5529
+ //
5530
+ // check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View);
5531
+ // check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View);
5532
+ // check_scalar_cast(
5533
+ // ScalarValue::from("larger than 12 bytes string"),
5534
+ // DataType::Utf8View,
5535
+ // );
5475
5536
}
5476
5537
5477
5538
// mimics how casting work on scalar values by `casting` `scalar` to `desired_type`
0 commit comments