Skip to content

Commit b81e824

Browse files
fix: predicates not matching the Arrow type of columns read from parquet files (#1308)
## Which issue does this PR close? - Closes #1307 ## What changes are included in this PR? I check the type of the literal scalar against the value we read from the parquet file and convert the literal to match the Parquet Arrow data type. ## Are these changes tested? Tested with a new unit test to cover the different cases.
1 parent 71b1307 commit b81e824

File tree

2 files changed

+129
-17
lines changed

2 files changed

+129
-17
lines changed

crates/iceberg/src/arrow/reader.rs

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use std::str::FromStr;
2323
use std::sync::Arc;
2424

2525
use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene};
26-
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
26+
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
27+
use arrow_cast::cast::cast;
2728
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
2829
use arrow_schema::{
2930
ArrowError, DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
@@ -1103,6 +1104,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11031104

11041105
Ok(Box::new(move |batch| {
11051106
let left = project_column(&batch, idx)?;
1107+
let literal = try_cast_literal(&literal, left.data_type())?;
11061108
lt(&left, literal.as_ref())
11071109
}))
11081110
} else {
@@ -1122,6 +1124,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11221124

11231125
Ok(Box::new(move |batch| {
11241126
let left = project_column(&batch, idx)?;
1127+
let literal = try_cast_literal(&literal, left.data_type())?;
11251128
lt_eq(&left, literal.as_ref())
11261129
}))
11271130
} else {
@@ -1141,6 +1144,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11411144

11421145
Ok(Box::new(move |batch| {
11431146
let left = project_column(&batch, idx)?;
1147+
let literal = try_cast_literal(&literal, left.data_type())?;
11441148
gt(&left, literal.as_ref())
11451149
}))
11461150
} else {
@@ -1160,6 +1164,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11601164

11611165
Ok(Box::new(move |batch| {
11621166
let left = project_column(&batch, idx)?;
1167+
let literal = try_cast_literal(&literal, left.data_type())?;
11631168
gt_eq(&left, literal.as_ref())
11641169
}))
11651170
} else {
@@ -1179,6 +1184,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11791184

11801185
Ok(Box::new(move |batch| {
11811186
let left = project_column(&batch, idx)?;
1187+
let literal = try_cast_literal(&literal, left.data_type())?;
11821188
eq(&left, literal.as_ref())
11831189
}))
11841190
} else {
@@ -1198,6 +1204,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11981204

11991205
Ok(Box::new(move |batch| {
12001206
let left = project_column(&batch, idx)?;
1207+
let literal = try_cast_literal(&literal, left.data_type())?;
12011208
neq(&left, literal.as_ref())
12021209
}))
12031210
} else {
@@ -1217,6 +1224,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12171224

12181225
Ok(Box::new(move |batch| {
12191226
let left = project_column(&batch, idx)?;
1227+
let literal = try_cast_literal(&literal, left.data_type())?;
12201228
starts_with(&left, literal.as_ref())
12211229
}))
12221230
} else {
@@ -1236,7 +1244,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12361244

12371245
Ok(Box::new(move |batch| {
12381246
let left = project_column(&batch, idx)?;
1239-
1247+
let literal = try_cast_literal(&literal, left.data_type())?;
12401248
// update here if arrow ever adds a native not_starts_with
12411249
not(&starts_with(&left, literal.as_ref())?)
12421250
}))
@@ -1261,8 +1269,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12611269
Ok(Box::new(move |batch| {
12621270
// update this if arrow ever adds a native is_in kernel
12631271
let left = project_column(&batch, idx)?;
1272+
12641273
let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
12651274
for literal in &literals {
1275+
let literal = try_cast_literal(literal, left.data_type())?;
12661276
acc = or(&acc, &eq(&left, literal.as_ref())?)?
12671277
}
12681278

@@ -1291,6 +1301,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12911301
let left = project_column(&batch, idx)?;
12921302
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
12931303
for literal in &literals {
1304+
let literal = try_cast_literal(literal, left.data_type())?;
12941305
acc = and(&acc, &neq(&left, literal.as_ref())?)?
12951306
}
12961307

@@ -1370,14 +1381,35 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
13701381
}
13711382
}
13721383

1384+
/// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1385+
/// that Iceberg uses for literals - but they are effectively the same logical type,
1386+
/// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1387+
///
1388+
/// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1389+
/// into the type of the batch we read from Parquet before sending it to the compute kernel.
1390+
fn try_cast_literal(
1391+
literal: &Arc<dyn ArrowDatum + Send + Sync>,
1392+
column_type: &DataType,
1393+
) -> std::result::Result<Arc<dyn ArrowDatum + Send + Sync>, ArrowError> {
1394+
let literal_array = literal.get().0;
1395+
1396+
// No cast required
1397+
if literal_array.data_type() == column_type {
1398+
return Ok(Arc::clone(literal));
1399+
}
1400+
1401+
let literal_array = cast(literal_array, column_type)?;
1402+
Ok(Arc::new(Scalar::new(literal_array)))
1403+
}
1404+
13731405
#[cfg(test)]
13741406
mod tests {
13751407
use std::collections::{HashMap, HashSet};
13761408
use std::fs::File;
13771409
use std::sync::Arc;
13781410

13791411
use arrow_array::cast::AsArray;
1380-
use arrow_array::{ArrayRef, RecordBatch, StringArray};
1412+
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
13811413
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
13821414
use futures::TryStreamExt;
13831415
use parquet::arrow::arrow_reader::{RowSelection, RowSelector};
@@ -1573,7 +1605,8 @@ message schema {
15731605
// Expected: [NULL, "foo"].
15741606
let expected = vec![None, Some("foo".to_string())];
15751607

1576-
let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a);
1608+
let (file_io, schema, table_location, _temp_dir) =
1609+
setup_kleene_logic(data_for_col_a, DataType::Utf8);
15771610
let reader = ArrowReaderBuilder::new(file_io).build();
15781611

15791612
let result_data = test_perform_read(predicate, schema, table_location, reader).await;
@@ -1594,14 +1627,88 @@ message schema {
15941627
// Expected: ["bar"].
15951628
let expected = vec![Some("bar".to_string())];
15961629

1597-
let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a);
1630+
let (file_io, schema, table_location, _temp_dir) =
1631+
setup_kleene_logic(data_for_col_a, DataType::Utf8);
15981632
let reader = ArrowReaderBuilder::new(file_io).build();
15991633

16001634
let result_data = test_perform_read(predicate, schema, table_location, reader).await;
16011635

16021636
assert_eq!(result_data, expected);
16031637
}
16041638

1639+
#[tokio::test]
1640+
async fn test_predicate_cast_literal() {
1641+
let predicates = vec![
1642+
// a == 'foo'
1643+
(Reference::new("a").equal_to(Datum::string("foo")), vec![
1644+
Some("foo".to_string()),
1645+
]),
1646+
// a != 'foo'
1647+
(
1648+
Reference::new("a").not_equal_to(Datum::string("foo")),
1649+
vec![Some("bar".to_string())],
1650+
),
1651+
// STARTS_WITH(a, 'foo')
1652+
(Reference::new("a").starts_with(Datum::string("f")), vec![
1653+
Some("foo".to_string()),
1654+
]),
1655+
// NOT STARTS_WITH(a, 'foo')
1656+
(
1657+
Reference::new("a").not_starts_with(Datum::string("f")),
1658+
vec![Some("bar".to_string())],
1659+
),
1660+
// a < 'foo'
1661+
(Reference::new("a").less_than(Datum::string("foo")), vec![
1662+
Some("bar".to_string()),
1663+
]),
1664+
// a <= 'foo'
1665+
(
1666+
Reference::new("a").less_than_or_equal_to(Datum::string("foo")),
1667+
vec![Some("foo".to_string()), Some("bar".to_string())],
1668+
),
1669+
// a > 'foo'
1670+
(
1671+
Reference::new("a").greater_than(Datum::string("bar")),
1672+
vec![Some("foo".to_string())],
1673+
),
1674+
// a >= 'foo'
1675+
(
1676+
Reference::new("a").greater_than_or_equal_to(Datum::string("foo")),
1677+
vec![Some("foo".to_string())],
1678+
),
1679+
// a IN ('foo', 'bar')
1680+
(
1681+
Reference::new("a").is_in([Datum::string("foo"), Datum::string("baz")]),
1682+
vec![Some("foo".to_string())],
1683+
),
1684+
// a NOT IN ('foo', 'bar')
1685+
(
1686+
Reference::new("a").is_not_in([Datum::string("foo"), Datum::string("baz")]),
1687+
vec![Some("bar".to_string())],
1688+
),
1689+
];
1690+
1691+
// Table data: ["foo", "bar"]
1692+
let data_for_col_a = vec![Some("foo".to_string()), Some("bar".to_string())];
1693+
1694+
let (file_io, schema, table_location, _temp_dir) =
1695+
setup_kleene_logic(data_for_col_a, DataType::LargeUtf8);
1696+
let reader = ArrowReaderBuilder::new(file_io).build();
1697+
1698+
for (predicate, expected) in predicates {
1699+
println!("testing predicate {predicate}");
1700+
let result_data = test_perform_read(
1701+
predicate.clone(),
1702+
schema.clone(),
1703+
table_location.clone(),
1704+
reader.clone(),
1705+
)
1706+
.await;
1707+
1708+
assert_eq!(result_data, expected, "predicate={predicate}");
1709+
}
1710+
}
1711+
16051712
async fn test_perform_read(
16061713
predicate: Predicate,
16071714
schema: SchemaRef,
@@ -1644,6 +1751,7 @@ message schema {
16441751

16451752
fn setup_kleene_logic(
16461753
data_for_col_a: Vec<Option<String>>,
1754+
col_a_type: DataType,
16471755
) -> (FileIO, SchemaRef, String, TempDir) {
16481756
let schema = Arc::new(
16491757
Schema::builder()
@@ -1660,7 +1768,7 @@ message schema {
16601768

16611769
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
16621770
"a",
1663-
DataType::Utf8,
1771+
col_a_type.clone(),
16641772
true,
16651773
)
16661774
.with_metadata(HashMap::from([(
@@ -1673,7 +1781,11 @@ message schema {
16731781

16741782
let file_io = FileIO::from_path(&table_location).unwrap().build().unwrap();
16751783

1676-
let col = Arc::new(StringArray::from(data_for_col_a)) as ArrayRef;
1784+
let col = match col_a_type {
1785+
DataType::Utf8 => Arc::new(StringArray::from(data_for_col_a)) as ArrayRef,
1786+
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data_for_col_a)) as ArrayRef,
1787+
_ => panic!("unexpected col_a_type"),
1788+
};
16771789

16781790
let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![col]).unwrap();
16791791

crates/iceberg/src/arrow/schema.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -650,33 +650,33 @@ pub fn type_to_arrow_type(ty: &crate::spec::Type) -> crate::Result<DataType> {
650650
}
651651

652652
/// Convert Iceberg Datum to Arrow Datum.
653-
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send>> {
653+
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Arc<dyn ArrowDatum + Send + Sync>> {
654654
match (datum.data_type(), datum.literal()) {
655655
(PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => {
656-
Ok(Box::new(BooleanArray::new_scalar(*value)))
656+
Ok(Arc::new(BooleanArray::new_scalar(*value)))
657657
}
658658
(PrimitiveType::Int, PrimitiveLiteral::Int(value)) => {
659-
Ok(Box::new(Int32Array::new_scalar(*value)))
659+
Ok(Arc::new(Int32Array::new_scalar(*value)))
660660
}
661661
(PrimitiveType::Long, PrimitiveLiteral::Long(value)) => {
662-
Ok(Box::new(Int64Array::new_scalar(*value)))
662+
Ok(Arc::new(Int64Array::new_scalar(*value)))
663663
}
664664
(PrimitiveType::Float, PrimitiveLiteral::Float(value)) => {
665-
Ok(Box::new(Float32Array::new_scalar(value.to_f32().unwrap())))
665+
Ok(Arc::new(Float32Array::new_scalar(value.to_f32().unwrap())))
666666
}
667667
(PrimitiveType::Double, PrimitiveLiteral::Double(value)) => {
668-
Ok(Box::new(Float64Array::new_scalar(value.to_f64().unwrap())))
668+
Ok(Arc::new(Float64Array::new_scalar(value.to_f64().unwrap())))
669669
}
670670
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
671-
Ok(Box::new(StringArray::new_scalar(value.as_str())))
671+
Ok(Arc::new(StringArray::new_scalar(value.as_str())))
672672
}
673673
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
674-
Ok(Box::new(Date32Array::new_scalar(*value)))
674+
Ok(Arc::new(Date32Array::new_scalar(*value)))
675675
}
676676
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
677-
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
677+
Ok(Arc::new(TimestampMicrosecondArray::new_scalar(*value)))
678678
}
679-
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new(
679+
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Arc::new(Scalar::new(
680680
PrimitiveArray::<TimestampMicrosecondType>::new(vec![*value; 1].into(), None)
681681
.with_timezone("UTC"),
682682
))),

0 commit comments

Comments
 (0)