@@ -23,7 +23,8 @@ use std::str::FromStr;
23
23
use std:: sync:: Arc ;
24
24
25
25
use arrow_arith:: boolean:: { and, and_kleene, is_not_null, is_null, not, or, or_kleene} ;
26
- use arrow_array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
26
+ use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
27
+ use arrow_cast:: cast:: cast;
27
28
use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
28
29
use arrow_schema:: {
29
30
ArrowError , DataType , FieldRef , Schema as ArrowSchema , SchemaRef as ArrowSchemaRef ,
@@ -1103,6 +1104,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1103
1104
1104
1105
Ok ( Box :: new ( move |batch| {
1105
1106
let left = project_column ( & batch, idx) ?;
1107
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1106
1108
lt ( & left, literal. as_ref ( ) )
1107
1109
} ) )
1108
1110
} else {
@@ -1122,6 +1124,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1122
1124
1123
1125
Ok ( Box :: new ( move |batch| {
1124
1126
let left = project_column ( & batch, idx) ?;
1127
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1125
1128
lt_eq ( & left, literal. as_ref ( ) )
1126
1129
} ) )
1127
1130
} else {
@@ -1141,6 +1144,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1141
1144
1142
1145
Ok ( Box :: new ( move |batch| {
1143
1146
let left = project_column ( & batch, idx) ?;
1147
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1144
1148
gt ( & left, literal. as_ref ( ) )
1145
1149
} ) )
1146
1150
} else {
@@ -1160,6 +1164,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1160
1164
1161
1165
Ok ( Box :: new ( move |batch| {
1162
1166
let left = project_column ( & batch, idx) ?;
1167
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1163
1168
gt_eq ( & left, literal. as_ref ( ) )
1164
1169
} ) )
1165
1170
} else {
@@ -1179,6 +1184,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1179
1184
1180
1185
Ok ( Box :: new ( move |batch| {
1181
1186
let left = project_column ( & batch, idx) ?;
1187
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1182
1188
eq ( & left, literal. as_ref ( ) )
1183
1189
} ) )
1184
1190
} else {
@@ -1198,6 +1204,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1198
1204
1199
1205
Ok ( Box :: new ( move |batch| {
1200
1206
let left = project_column ( & batch, idx) ?;
1207
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1201
1208
neq ( & left, literal. as_ref ( ) )
1202
1209
} ) )
1203
1210
} else {
@@ -1217,6 +1224,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1217
1224
1218
1225
Ok ( Box :: new ( move |batch| {
1219
1226
let left = project_column ( & batch, idx) ?;
1227
+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
1220
1228
starts_with ( & left, literal. as_ref ( ) )
1221
1229
} ) )
1222
1230
} else {
@@ -1236,7 +1244,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1236
1244
1237
1245
Ok ( Box :: new ( move |batch| {
1238
1246
let left = project_column ( & batch, idx) ?;
1239
-
1247
+ let literal = try_cast_literal ( & literal , left . data_type ( ) ) ? ;
1240
1248
// update here if arrow ever adds a native not_starts_with
1241
1249
not ( & starts_with ( & left, literal. as_ref ( ) ) ?)
1242
1250
} ) )
@@ -1261,8 +1269,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1261
1269
Ok ( Box :: new ( move |batch| {
1262
1270
// update this if arrow ever adds a native is_in kernel
1263
1271
let left = project_column ( & batch, idx) ?;
1272
+
1264
1273
let mut acc = BooleanArray :: from ( vec ! [ false ; batch. num_rows( ) ] ) ;
1265
1274
for literal in & literals {
1275
+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
1266
1276
acc = or ( & acc, & eq ( & left, literal. as_ref ( ) ) ?) ?
1267
1277
}
1268
1278
@@ -1291,6 +1301,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
1291
1301
let left = project_column ( & batch, idx) ?;
1292
1302
let mut acc = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
1293
1303
for literal in & literals {
1304
+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
1294
1305
acc = and ( & acc, & neq ( & left, literal. as_ref ( ) ) ?) ?
1295
1306
}
1296
1307
@@ -1370,14 +1381,35 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
1370
1381
}
1371
1382
}
1372
1383
1384
+ /// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1385
+ /// that Iceberg uses for literals - but they are effectively the same logical type,
1386
+ /// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1387
+ ///
1388
+ /// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1389
+ /// into the type of the batch we read from Parquet before sending it to the compute kernel.
1390
+ fn try_cast_literal (
1391
+ literal : & Arc < dyn ArrowDatum + Send + Sync > ,
1392
+ column_type : & DataType ,
1393
+ ) -> std:: result:: Result < Arc < dyn ArrowDatum + Send + Sync > , ArrowError > {
1394
+ let literal_array = literal. get ( ) . 0 ;
1395
+
1396
+ // No cast required
1397
+ if literal_array. data_type ( ) == column_type {
1398
+ return Ok ( Arc :: clone ( literal) ) ;
1399
+ }
1400
+
1401
+ let literal_array = cast ( literal_array, column_type) ?;
1402
+ Ok ( Arc :: new ( Scalar :: new ( literal_array) ) )
1403
+ }
1404
+
1373
1405
#[ cfg( test) ]
1374
1406
mod tests {
1375
1407
use std:: collections:: { HashMap , HashSet } ;
1376
1408
use std:: fs:: File ;
1377
1409
use std:: sync:: Arc ;
1378
1410
1379
1411
use arrow_array:: cast:: AsArray ;
1380
- use arrow_array:: { ArrayRef , RecordBatch , StringArray } ;
1412
+ use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
1381
1413
use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
1382
1414
use futures:: TryStreamExt ;
1383
1415
use parquet:: arrow:: arrow_reader:: { RowSelection , RowSelector } ;
@@ -1573,7 +1605,8 @@ message schema {
1573
1605
// Expected: [NULL, "foo"].
1574
1606
let expected = vec ! [ None , Some ( "foo" . to_string( ) ) ] ;
1575
1607
1576
- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1608
+ let ( file_io, schema, table_location, _temp_dir) =
1609
+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
1577
1610
let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1578
1611
1579
1612
let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
@@ -1594,14 +1627,88 @@ message schema {
1594
1627
// Expected: ["bar"].
1595
1628
let expected = vec ! [ Some ( "bar" . to_string( ) ) ] ;
1596
1629
1597
- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1630
+ let ( file_io, schema, table_location, _temp_dir) =
1631
+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
1598
1632
let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1599
1633
1600
1634
let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
1601
1635
1602
1636
assert_eq ! ( result_data, expected) ;
1603
1637
}
1604
1638
1639
+ #[ tokio:: test]
1640
+ async fn test_predicate_cast_literal ( ) {
1641
+ let predicates = vec ! [
1642
+ // a == 'foo'
1643
+ ( Reference :: new( "a" ) . equal_to( Datum :: string( "foo" ) ) , vec![
1644
+ Some ( "foo" . to_string( ) ) ,
1645
+ ] ) ,
1646
+ // a != 'foo'
1647
+ (
1648
+ Reference :: new( "a" ) . not_equal_to( Datum :: string( "foo" ) ) ,
1649
+ vec![ Some ( "bar" . to_string( ) ) ] ,
1650
+ ) ,
1651
+ // STARTS_WITH(a, 'foo')
1652
+ ( Reference :: new( "a" ) . starts_with( Datum :: string( "f" ) ) , vec![
1653
+ Some ( "foo" . to_string( ) ) ,
1654
+ ] ) ,
1655
+ // NOT STARTS_WITH(a, 'foo')
1656
+ (
1657
+ Reference :: new( "a" ) . not_starts_with( Datum :: string( "f" ) ) ,
1658
+ vec![ Some ( "bar" . to_string( ) ) ] ,
1659
+ ) ,
1660
+ // a < 'foo'
1661
+ ( Reference :: new( "a" ) . less_than( Datum :: string( "foo" ) ) , vec![
1662
+ Some ( "bar" . to_string( ) ) ,
1663
+ ] ) ,
1664
+ // a <= 'foo'
1665
+ (
1666
+ Reference :: new( "a" ) . less_than_or_equal_to( Datum :: string( "foo" ) ) ,
1667
+ vec![ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ,
1668
+ ) ,
1669
+ // a > 'foo'
1670
+ (
1671
+ Reference :: new( "a" ) . greater_than( Datum :: string( "bar" ) ) ,
1672
+ vec![ Some ( "foo" . to_string( ) ) ] ,
1673
+ ) ,
1674
+ // a >= 'foo'
1675
+ (
1676
+ Reference :: new( "a" ) . greater_than_or_equal_to( Datum :: string( "foo" ) ) ,
1677
+ vec![ Some ( "foo" . to_string( ) ) ] ,
1678
+ ) ,
1679
+ // a IN ('foo', 'bar')
1680
+ (
1681
+ Reference :: new( "a" ) . is_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1682
+ vec![ Some ( "foo" . to_string( ) ) ] ,
1683
+ ) ,
1684
+ // a NOT IN ('foo', 'bar')
1685
+ (
1686
+ Reference :: new( "a" ) . is_not_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1687
+ vec![ Some ( "bar" . to_string( ) ) ] ,
1688
+ ) ,
1689
+ ] ;
1690
+
1691
+ // Table data: ["foo", "bar"]
1692
+ let data_for_col_a = vec ! [ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ;
1693
+
1694
+ let ( file_io, schema, table_location, _temp_dir) =
1695
+ setup_kleene_logic ( data_for_col_a, DataType :: LargeUtf8 ) ;
1696
+ let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1697
+
1698
+ for ( predicate, expected) in predicates {
1699
+ println ! ( "testing predicate {predicate}" ) ;
1700
+ let result_data = test_perform_read (
1701
+ predicate. clone ( ) ,
1702
+ schema. clone ( ) ,
1703
+ table_location. clone ( ) ,
1704
+ reader. clone ( ) ,
1705
+ )
1706
+ . await ;
1707
+
1708
+ assert_eq ! ( result_data, expected, "predicate={predicate}" ) ;
1709
+ }
1710
+ }
1711
+
1605
1712
async fn test_perform_read (
1606
1713
predicate : Predicate ,
1607
1714
schema : SchemaRef ,
@@ -1644,6 +1751,7 @@ message schema {
1644
1751
1645
1752
fn setup_kleene_logic (
1646
1753
data_for_col_a : Vec < Option < String > > ,
1754
+ col_a_type : DataType ,
1647
1755
) -> ( FileIO , SchemaRef , String , TempDir ) {
1648
1756
let schema = Arc :: new (
1649
1757
Schema :: builder ( )
@@ -1660,7 +1768,7 @@ message schema {
1660
1768
1661
1769
let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
1662
1770
"a" ,
1663
- DataType :: Utf8 ,
1771
+ col_a_type . clone ( ) ,
1664
1772
true ,
1665
1773
)
1666
1774
. with_metadata( HashMap :: from( [ (
@@ -1673,7 +1781,11 @@ message schema {
1673
1781
1674
1782
let file_io = FileIO :: from_path ( & table_location) . unwrap ( ) . build ( ) . unwrap ( ) ;
1675
1783
1676
- let col = Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ;
1784
+ let col = match col_a_type {
1785
+ DataType :: Utf8 => Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ,
1786
+ DataType :: LargeUtf8 => Arc :: new ( LargeStringArray :: from ( data_for_col_a) ) as ArrayRef ,
1787
+ _ => panic ! ( "unexpected col_a_type" ) ,
1788
+ } ;
1677
1789
1678
1790
let to_write = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
1679
1791
0 commit comments