Skip to content

Fix predicates not matching the Arrow type of columns read from parquet files #1308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 119 additions & 7 deletions crates/iceberg/src/arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use std::str::FromStr;
use std::sync::Arc;

use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
use arrow_cast::cast::cast;
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_schema::{
ArrowError, DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
Expand Down Expand Up @@ -1103,6 +1104,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
lt(&left, literal.as_ref())
}))
} else {
Expand All @@ -1122,6 +1124,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
lt_eq(&left, literal.as_ref())
}))
} else {
Expand All @@ -1141,6 +1144,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
gt(&left, literal.as_ref())
}))
} else {
Expand All @@ -1160,6 +1164,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
gt_eq(&left, literal.as_ref())
}))
} else {
Expand All @@ -1179,6 +1184,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
eq(&left, literal.as_ref())
}))
} else {
Expand All @@ -1198,6 +1204,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
neq(&left, literal.as_ref())
}))
} else {
Expand All @@ -1217,6 +1224,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let literal = try_cast_literal(&literal, left.data_type())?;
starts_with(&left, literal.as_ref())
}))
} else {
Expand All @@ -1236,7 +1244,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;

let literal = try_cast_literal(&literal, left.data_type())?;
// update here if arrow ever adds a native not_starts_with
not(&starts_with(&left, literal.as_ref())?)
}))
Expand All @@ -1261,8 +1269,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
Ok(Box::new(move |batch| {
// update this if arrow ever adds a native is_in kernel
let left = project_column(&batch, idx)?;

let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
for literal in &literals {
let literal = try_cast_literal(literal, left.data_type())?;
acc = or(&acc, &eq(&left, literal.as_ref())?)?
}

Expand Down Expand Up @@ -1291,6 +1301,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
let left = project_column(&batch, idx)?;
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
for literal in &literals {
let literal = try_cast_literal(literal, left.data_type())?;
acc = and(&acc, &neq(&left, literal.as_ref())?)?
}

Expand Down Expand Up @@ -1370,14 +1381,35 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
}
}

/// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
/// that Iceberg uses for literals - but they are effectively the same logical type,
/// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
///
/// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
/// into the type of the batch we read from Parquet before sending it to the compute kernel.
fn try_cast_literal(
literal: &Arc<dyn ArrowDatum + Send + Sync>,
column_type: &DataType,
) -> std::result::Result<Arc<dyn ArrowDatum + Send + Sync>, ArrowError> {
let literal_array = literal.get().0;

// No cast required
if literal_array.data_type() == column_type {
return Ok(Arc::clone(literal));
}

let literal_array = cast(literal_array, column_type)?;
Ok(Arc::new(Scalar::new(literal_array)))
}

#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::{ArrayRef, RecordBatch, StringArray};
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use futures::TryStreamExt;
use parquet::arrow::arrow_reader::{RowSelection, RowSelector};
Expand Down Expand Up @@ -1573,7 +1605,8 @@ message schema {
// Expected: [NULL, "foo"].
let expected = vec![None, Some("foo".to_string())];

let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a);
let (file_io, schema, table_location, _temp_dir) =
setup_kleene_logic(data_for_col_a, DataType::Utf8);
let reader = ArrowReaderBuilder::new(file_io).build();

let result_data = test_perform_read(predicate, schema, table_location, reader).await;
Expand All @@ -1594,14 +1627,88 @@ message schema {
// Expected: ["bar"].
let expected = vec![Some("bar".to_string())];

let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a);
let (file_io, schema, table_location, _temp_dir) =
setup_kleene_logic(data_for_col_a, DataType::Utf8);
let reader = ArrowReaderBuilder::new(file_io).build();

let result_data = test_perform_read(predicate, schema, table_location, reader).await;

assert_eq!(result_data, expected);
}

#[tokio::test]
async fn test_predicate_cast_literal() {
let predicates = vec![
// a == 'foo'
(Reference::new("a").equal_to(Datum::string("foo")), vec![
Some("foo".to_string()),
]),
// a != 'foo'
(
Reference::new("a").not_equal_to(Datum::string("foo")),
vec![Some("bar".to_string())],
),
// STARTS_WITH(a, 'foo')
(Reference::new("a").starts_with(Datum::string("f")), vec![
Some("foo".to_string()),
]),
// NOT STARTS_WITH(a, 'foo')
(
Reference::new("a").not_starts_with(Datum::string("f")),
vec![Some("bar".to_string())],
),
// a < 'foo'
(Reference::new("a").less_than(Datum::string("foo")), vec![
Some("bar".to_string()),
]),
// a <= 'foo'
(
Reference::new("a").less_than_or_equal_to(Datum::string("foo")),
vec![Some("foo".to_string()), Some("bar".to_string())],
),
// a > 'foo'
(
Reference::new("a").greater_than(Datum::string("bar")),
vec![Some("foo".to_string())],
),
// a >= 'foo'
(
Reference::new("a").greater_than_or_equal_to(Datum::string("foo")),
vec![Some("foo".to_string())],
),
// a IN ('foo', 'bar')
(
Reference::new("a").is_in([Datum::string("foo"), Datum::string("baz")]),
vec![Some("foo".to_string())],
),
// a NOT IN ('foo', 'bar')
(
Reference::new("a").is_not_in([Datum::string("foo"), Datum::string("baz")]),
vec![Some("bar".to_string())],
),
];

// Table data: ["foo", "bar"]
let data_for_col_a = vec![Some("foo".to_string()), Some("bar".to_string())];

let (file_io, schema, table_location, _temp_dir) =
setup_kleene_logic(data_for_col_a, DataType::LargeUtf8);
let reader = ArrowReaderBuilder::new(file_io).build();

for (predicate, expected) in predicates {
println!("testing predicate {predicate}");
let result_data = test_perform_read(
predicate.clone(),
schema.clone(),
table_location.clone(),
reader.clone(),
)
.await;

assert_eq!(result_data, expected, "predicate={predicate}");
}
}

async fn test_perform_read(
predicate: Predicate,
schema: SchemaRef,
Expand Down Expand Up @@ -1644,6 +1751,7 @@ message schema {

fn setup_kleene_logic(
data_for_col_a: Vec<Option<String>>,
col_a_type: DataType,
) -> (FileIO, SchemaRef, String, TempDir) {
let schema = Arc::new(
Schema::builder()
Expand All @@ -1660,7 +1768,7 @@ message schema {

let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"a",
DataType::Utf8,
col_a_type.clone(),
true,
)
.with_metadata(HashMap::from([(
Expand All @@ -1673,7 +1781,11 @@ message schema {

let file_io = FileIO::from_path(&table_location).unwrap().build().unwrap();

let col = Arc::new(StringArray::from(data_for_col_a)) as ArrayRef;
let col = match col_a_type {
DataType::Utf8 => Arc::new(StringArray::from(data_for_col_a)) as ArrayRef,
DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data_for_col_a)) as ArrayRef,
_ => panic!("unexpected col_a_type"),
};

let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![col]).unwrap();

Expand Down
20 changes: 10 additions & 10 deletions crates/iceberg/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,33 +650,33 @@ pub fn type_to_arrow_type(ty: &crate::spec::Type) -> crate::Result<DataType> {
}

/// Convert Iceberg Datum to Arrow Datum.
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send>> {
pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Arc<dyn ArrowDatum + Send + Sync>> {
match (datum.data_type(), datum.literal()) {
(PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => {
Ok(Box::new(BooleanArray::new_scalar(*value)))
Ok(Arc::new(BooleanArray::new_scalar(*value)))
}
(PrimitiveType::Int, PrimitiveLiteral::Int(value)) => {
Ok(Box::new(Int32Array::new_scalar(*value)))
Ok(Arc::new(Int32Array::new_scalar(*value)))
}
(PrimitiveType::Long, PrimitiveLiteral::Long(value)) => {
Ok(Box::new(Int64Array::new_scalar(*value)))
Ok(Arc::new(Int64Array::new_scalar(*value)))
}
(PrimitiveType::Float, PrimitiveLiteral::Float(value)) => {
Ok(Box::new(Float32Array::new_scalar(value.to_f32().unwrap())))
Ok(Arc::new(Float32Array::new_scalar(value.to_f32().unwrap())))
}
(PrimitiveType::Double, PrimitiveLiteral::Double(value)) => {
Ok(Box::new(Float64Array::new_scalar(value.to_f64().unwrap())))
Ok(Arc::new(Float64Array::new_scalar(value.to_f64().unwrap())))
}
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
Ok(Box::new(StringArray::new_scalar(value.as_str())))
Ok(Arc::new(StringArray::new_scalar(value.as_str())))
}
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
Ok(Box::new(Date32Array::new_scalar(*value)))
Ok(Arc::new(Date32Array::new_scalar(*value)))
}
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
Ok(Arc::new(TimestampMicrosecondArray::new_scalar(*value)))
}
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new(
(PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Arc::new(Scalar::new(
PrimitiveArray::<TimestampMicrosecondType>::new(vec![*value; 1].into(), None)
.with_timezone("UTC"),
))),
Expand Down
Loading