diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index 4ac993aee..73bb9ef07 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -23,7 +23,8 @@ use std::str::FromStr; use std::sync::Arc; use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene}; -use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch}; +use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar}; +use arrow_cast::cast::cast; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{ ArrowError, DataType, FieldRef, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, @@ -1103,6 +1104,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; lt(&left, literal.as_ref()) })) } else { @@ -1122,6 +1124,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; lt_eq(&left, literal.as_ref()) })) } else { @@ -1141,6 +1144,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; gt(&left, literal.as_ref()) })) } else { @@ -1160,6 +1164,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; gt_eq(&left, literal.as_ref()) })) } else { @@ -1179,6 +1184,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; eq(&left, literal.as_ref()) })) } else { @@ -1198,6 +1204,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; neq(&left, literal.as_ref()) })) } else { @@ -1217,6 +1224,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; + let literal = try_cast_literal(&literal, left.data_type())?; starts_with(&left, literal.as_ref()) })) } else { @@ -1236,7 +1244,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { let left = project_column(&batch, idx)?; - + let literal = try_cast_literal(&literal, left.data_type())?; // update here if arrow ever adds a native not_starts_with not(&starts_with(&left, literal.as_ref())?) })) @@ -1261,8 +1269,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { // update this if arrow ever adds a native is_in kernel let left = project_column(&batch, idx)?; + let mut acc = BooleanArray::from(vec![false; batch.num_rows()]); for literal in &literals { + let literal = try_cast_literal(literal, left.data_type())?; acc = or(&acc, &eq(&left, literal.as_ref())?)? } @@ -1291,6 +1301,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { let left = project_column(&batch, idx)?; let mut acc = BooleanArray::from(vec![true; batch.num_rows()]); for literal in &literals { + let literal = try_cast_literal(literal, left.data_type())?; acc = and(&acc, &neq(&left, literal.as_ref())?)? } @@ -1370,6 +1381,27 @@ impl AsyncFileReader for ArrowFileReader { } } +/// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type +/// that Iceberg uses for literals - but they are effectively the same logical type, +/// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8. +/// +/// The Arrow compute kernels that we use must match the type exactly, so first cast the literal +/// into the type of the batch we read from Parquet before sending it to the compute kernel. +fn try_cast_literal( + literal: &Arc, + column_type: &DataType, +) -> std::result::Result, ArrowError> { + let literal_array = literal.get().0; + + // No cast required + if literal_array.data_type() == column_type { + return Ok(Arc::clone(literal)); + } + + let literal_array = cast(literal_array, column_type)?; + Ok(Arc::new(Scalar::new(literal_array))) +} + #[cfg(test)] mod tests { use std::collections::{HashMap, HashSet}; @@ -1377,7 +1409,7 @@ mod tests { use std::sync::Arc; use arrow_array::cast::AsArray; - use arrow_array::{ArrayRef, RecordBatch, StringArray}; + use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit}; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; @@ -1573,7 +1605,8 @@ message schema { // Expected: [NULL, "foo"]. let expected = vec![None, Some("foo".to_string())]; - let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a); + let (file_io, schema, table_location, _temp_dir) = + setup_kleene_logic(data_for_col_a, DataType::Utf8); let reader = ArrowReaderBuilder::new(file_io).build(); let result_data = test_perform_read(predicate, schema, table_location, reader).await; @@ -1594,7 +1627,8 @@ message schema { // Expected: ["bar"]. let expected = vec![Some("bar".to_string())]; - let (file_io, schema, table_location, _temp_dir) = setup_kleene_logic(data_for_col_a); + let (file_io, schema, table_location, _temp_dir) = + setup_kleene_logic(data_for_col_a, DataType::Utf8); let reader = ArrowReaderBuilder::new(file_io).build(); let result_data = test_perform_read(predicate, schema, table_location, reader).await; @@ -1602,6 +1636,79 @@ message schema { assert_eq!(result_data, expected); } + #[tokio::test] + async fn test_predicate_cast_literal() { + let predicates = vec![ + // a == 'foo' + (Reference::new("a").equal_to(Datum::string("foo")), vec![ + Some("foo".to_string()), + ]), + // a != 'foo' + ( + Reference::new("a").not_equal_to(Datum::string("foo")), + vec![Some("bar".to_string())], + ), + // STARTS_WITH(a, 'foo') + (Reference::new("a").starts_with(Datum::string("f")), vec![ + Some("foo".to_string()), + ]), + // NOT STARTS_WITH(a, 'foo') + ( + Reference::new("a").not_starts_with(Datum::string("f")), + vec![Some("bar".to_string())], + ), + // a < 'foo' + (Reference::new("a").less_than(Datum::string("foo")), vec![ + Some("bar".to_string()), + ]), + // a <= 'foo' + ( + Reference::new("a").less_than_or_equal_to(Datum::string("foo")), + vec![Some("foo".to_string()), Some("bar".to_string())], + ), + // a > 'foo' + ( + Reference::new("a").greater_than(Datum::string("bar")), + vec![Some("foo".to_string())], + ), + // a >= 'foo' + ( + Reference::new("a").greater_than_or_equal_to(Datum::string("foo")), + vec![Some("foo".to_string())], + ), + // a IN ('foo', 'bar') + ( + Reference::new("a").is_in([Datum::string("foo"), Datum::string("baz")]), + vec![Some("foo".to_string())], + ), + // a NOT IN ('foo', 'bar') + ( + Reference::new("a").is_not_in([Datum::string("foo"), Datum::string("baz")]), + vec![Some("bar".to_string())], + ), + ]; + + // Table data: ["foo", "bar"] + let data_for_col_a = vec![Some("foo".to_string()), Some("bar".to_string())]; + + let (file_io, schema, table_location, _temp_dir) = + setup_kleene_logic(data_for_col_a, DataType::LargeUtf8); + let reader = ArrowReaderBuilder::new(file_io).build(); + + for (predicate, expected) in predicates { + println!("testing predicate {predicate}"); + let result_data = test_perform_read( + predicate.clone(), + schema.clone(), + table_location.clone(), + reader.clone(), + ) + .await; + + assert_eq!(result_data, expected, "predicate={predicate}"); + } + } + async fn test_perform_read( predicate: Predicate, schema: SchemaRef, @@ -1644,6 +1751,7 @@ message schema { fn setup_kleene_logic( data_for_col_a: Vec>, + col_a_type: DataType, ) -> (FileIO, SchemaRef, String, TempDir) { let schema = Arc::new( Schema::builder() @@ -1660,7 +1768,7 @@ message schema { let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new( "a", - DataType::Utf8, + col_a_type.clone(), true, ) .with_metadata(HashMap::from([( @@ -1673,7 +1781,11 @@ message schema { let file_io = FileIO::from_path(&table_location).unwrap().build().unwrap(); - let col = Arc::new(StringArray::from(data_for_col_a)) as ArrayRef; + let col = match col_a_type { + DataType::Utf8 => Arc::new(StringArray::from(data_for_col_a)) as ArrayRef, + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data_for_col_a)) as ArrayRef, + _ => panic!("unexpected col_a_type"), + }; let to_write = RecordBatch::try_new(arrow_schema.clone(), vec![col]).unwrap(); diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index a85f966b7..e6873e958 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -650,33 +650,33 @@ pub fn type_to_arrow_type(ty: &crate::spec::Type) -> crate::Result { } /// Convert Iceberg Datum to Arrow Datum. -pub(crate) fn get_arrow_datum(datum: &Datum) -> Result> { +pub(crate) fn get_arrow_datum(datum: &Datum) -> Result> { match (datum.data_type(), datum.literal()) { (PrimitiveType::Boolean, PrimitiveLiteral::Boolean(value)) => { - Ok(Box::new(BooleanArray::new_scalar(*value))) + Ok(Arc::new(BooleanArray::new_scalar(*value))) } (PrimitiveType::Int, PrimitiveLiteral::Int(value)) => { - Ok(Box::new(Int32Array::new_scalar(*value))) + Ok(Arc::new(Int32Array::new_scalar(*value))) } (PrimitiveType::Long, PrimitiveLiteral::Long(value)) => { - Ok(Box::new(Int64Array::new_scalar(*value))) + Ok(Arc::new(Int64Array::new_scalar(*value))) } (PrimitiveType::Float, PrimitiveLiteral::Float(value)) => { - Ok(Box::new(Float32Array::new_scalar(value.to_f32().unwrap()))) + Ok(Arc::new(Float32Array::new_scalar(value.to_f32().unwrap()))) } (PrimitiveType::Double, PrimitiveLiteral::Double(value)) => { - Ok(Box::new(Float64Array::new_scalar(value.to_f64().unwrap()))) + Ok(Arc::new(Float64Array::new_scalar(value.to_f64().unwrap()))) } (PrimitiveType::String, PrimitiveLiteral::String(value)) => { - Ok(Box::new(StringArray::new_scalar(value.as_str()))) + Ok(Arc::new(StringArray::new_scalar(value.as_str()))) } (PrimitiveType::Date, PrimitiveLiteral::Int(value)) => { - Ok(Box::new(Date32Array::new_scalar(*value))) + Ok(Arc::new(Date32Array::new_scalar(*value))) } (PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => { - Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value))) + Ok(Arc::new(TimestampMicrosecondArray::new_scalar(*value))) } - (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Box::new(Scalar::new( + (PrimitiveType::Timestamptz, PrimitiveLiteral::Long(value)) => Ok(Arc::new(Scalar::new( PrimitiveArray::::new(vec![*value; 1].into(), None) .with_timezone("UTC"), ))),