Skip to content

Commit 7ba36b0

Browse files
Jefffreytustvold
andauthored
Parquet: read/write f16 for Arrow (#5003)
* Support for read/write f16 Parquet to Arrow * Update parquet/src/arrow/arrow_writer/mod.rs Co-authored-by: Raphael Taylor-Davies <[email protected]> * Update parquet/src/arrow/arrow_reader/mod.rs Co-authored-by: Raphael Taylor-Davies <[email protected]> * Update test with null version * Fix schema tests and parsing for f16 * f16 for record api * Handle NaN for f16 statistics writing * Revert formatting changes * Fix num trait * Fix half feature * Handle writing signed zero statistics * Bump parquet-testing and read new f16 files for test --------- Co-authored-by: Raphael Taylor-Davies <[email protected]>
1 parent 924b6e9 commit 7ba36b0

File tree

18 files changed

+646
-25
lines changed

18 files changed

+646
-25
lines changed

parquet/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ tokio = { version = "1.0", optional = true, default-features = false, features =
6666
hashbrown = { version = "0.14", default-features = false }
6767
twox-hash = { version = "1.6", default-features = false }
6868
paste = { version = "1.0" }
69+
half = { version = "2.1", default-features = false, features = ["num-traits"] }
6970

7071
[dev-dependencies]
7172
base64 = { version = "0.21", default-features = false, features = ["std"] }

parquet/regen.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919

20-
REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37
20+
REVISION=46cc3a0647d301bb9579ca8dd2cc356caf2a72d2
2121

2222
SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"
2323

parquet/src/arrow/array_reader/fixed_len_byte_array.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ use crate::column::reader::decoder::{ColumnValueDecoder, ValuesBufferSlice};
2727
use crate::errors::{ParquetError, Result};
2828
use crate::schema::types::ColumnDescPtr;
2929
use arrow_array::{
30-
ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray,
30+
ArrayRef, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array,
3131
IntervalDayTimeArray, IntervalYearMonthArray,
3232
};
3333
use arrow_buffer::{i256, Buffer};
3434
use arrow_data::ArrayDataBuilder;
3535
use arrow_schema::{DataType as ArrowType, IntervalUnit};
3636
use bytes::Bytes;
37+
use half::f16;
3738
use std::any::Any;
3839
use std::ops::Range;
3940
use std::sync::Arc;
@@ -88,6 +89,14 @@ pub fn make_fixed_len_byte_array_reader(
8889
));
8990
}
9091
}
92+
ArrowType::Float16 => {
93+
if byte_length != 2 {
94+
return Err(general_err!(
95+
"float 16 type must be 2 bytes, got {}",
96+
byte_length
97+
));
98+
}
99+
}
91100
_ => {
92101
return Err(general_err!(
93102
"invalid data type for fixed length byte array reader - {}",
@@ -208,6 +217,12 @@ impl ArrayReader for FixedLenByteArrayReader {
208217
}
209218
}
210219
}
220+
ArrowType::Float16 => Arc::new(
221+
binary
222+
.iter()
223+
.map(|o| o.map(|b| f16::from_le_bytes(b[..2].try_into().unwrap())))
224+
.collect::<Float16Array>(),
225+
) as ArrayRef,
211226
_ => Arc::new(binary) as ArrayRef,
212227
};
213228

parquet/src/arrow/arrow_reader/mod.rs

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,13 +712,14 @@ mod tests {
712712
use std::sync::Arc;
713713

714714
use bytes::Bytes;
715+
use half::f16;
715716
use num::PrimInt;
716717
use rand::{thread_rng, Rng, RngCore};
717718
use tempfile::tempfile;
718719

719720
use arrow_array::builder::*;
720721
use arrow_array::cast::AsArray;
721-
use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType};
722+
use arrow_array::types::{Decimal128Type, Decimal256Type, DecimalType, Float16Type};
722723
use arrow_array::*;
723724
use arrow_array::{RecordBatch, RecordBatchReader};
724725
use arrow_buffer::{i256, ArrowNativeType, Buffer};
@@ -924,6 +925,66 @@ mod tests {
924925
.unwrap();
925926
}
926927

928+
#[test]
929+
fn test_float16_roundtrip() -> Result<()> {
930+
let schema = Arc::new(Schema::new(vec![
931+
Field::new("float16", ArrowDataType::Float16, false),
932+
Field::new("float16-nullable", ArrowDataType::Float16, true),
933+
]));
934+
935+
let mut buf = Vec::with_capacity(1024);
936+
let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None)?;
937+
938+
let original = RecordBatch::try_new(
939+
schema,
940+
vec![
941+
Arc::new(Float16Array::from_iter_values([
942+
f16::EPSILON,
943+
f16::MIN,
944+
f16::MAX,
945+
f16::NAN,
946+
f16::INFINITY,
947+
f16::NEG_INFINITY,
948+
f16::ONE,
949+
f16::NEG_ONE,
950+
f16::ZERO,
951+
f16::NEG_ZERO,
952+
f16::E,
953+
f16::PI,
954+
f16::FRAC_1_PI,
955+
])),
956+
Arc::new(Float16Array::from(vec![
957+
None,
958+
None,
959+
None,
960+
Some(f16::NAN),
961+
Some(f16::INFINITY),
962+
Some(f16::NEG_INFINITY),
963+
None,
964+
None,
965+
None,
966+
None,
967+
None,
968+
None,
969+
Some(f16::FRAC_1_PI),
970+
])),
971+
],
972+
)?;
973+
974+
writer.write(&original)?;
975+
writer.close()?;
976+
977+
let mut reader = ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024)?;
978+
let ret = reader.next().unwrap()?;
979+
assert_eq!(ret, original);
980+
981+
// Ensure can be downcast to the correct type
982+
ret.column(0).as_primitive::<Float16Type>();
983+
ret.column(1).as_primitive::<Float16Type>();
984+
985+
Ok(())
986+
}
987+
927988
struct RandFixedLenGen {}
928989

929990
impl RandGen<FixedLenByteArrayType> for RandFixedLenGen {
@@ -1255,6 +1316,62 @@ mod tests {
12551316
}
12561317
}
12571318

1319+
#[test]
1320+
fn test_read_float16_nonzeros_file() {
1321+
use arrow_array::Float16Array;
1322+
let testdata = arrow::util::test_util::parquet_test_data();
1323+
// see https://github.com/apache/parquet-testing/pull/40
1324+
let path = format!("{testdata}/float16_nonzeros_and_nans.parquet");
1325+
let file = File::open(path).unwrap();
1326+
let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap();
1327+
1328+
let batch = record_reader.next().unwrap().unwrap();
1329+
assert_eq!(batch.num_rows(), 8);
1330+
let col = batch
1331+
.column(0)
1332+
.as_any()
1333+
.downcast_ref::<Float16Array>()
1334+
.unwrap();
1335+
1336+
let f16_two = f16::ONE + f16::ONE;
1337+
1338+
assert_eq!(col.null_count(), 1);
1339+
assert!(col.is_null(0));
1340+
assert_eq!(col.value(1), f16::ONE);
1341+
assert_eq!(col.value(2), -f16_two);
1342+
assert!(col.value(3).is_nan());
1343+
assert_eq!(col.value(4), f16::ZERO);
1344+
assert!(col.value(4).is_sign_positive());
1345+
assert_eq!(col.value(5), f16::NEG_ONE);
1346+
assert_eq!(col.value(6), f16::NEG_ZERO);
1347+
assert!(col.value(6).is_sign_negative());
1348+
assert_eq!(col.value(7), f16_two);
1349+
}
1350+
1351+
#[test]
1352+
fn test_read_float16_zeros_file() {
1353+
use arrow_array::Float16Array;
1354+
let testdata = arrow::util::test_util::parquet_test_data();
1355+
// see https://github.com/apache/parquet-testing/pull/40
1356+
let path = format!("{testdata}/float16_zeros_and_nans.parquet");
1357+
let file = File::open(path).unwrap();
1358+
let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap();
1359+
1360+
let batch = record_reader.next().unwrap().unwrap();
1361+
assert_eq!(batch.num_rows(), 3);
1362+
let col = batch
1363+
.column(0)
1364+
.as_any()
1365+
.downcast_ref::<Float16Array>()
1366+
.unwrap();
1367+
1368+
assert_eq!(col.null_count(), 1);
1369+
assert!(col.is_null(0));
1370+
assert_eq!(col.value(1), f16::ZERO);
1371+
assert!(col.value(1).is_sign_positive());
1372+
assert!(col.value(2).is_nan());
1373+
}
1374+
12581375
/// Parameters for single_column_reader_test
12591376
#[derive(Clone)]
12601377
struct TestOptions {

parquet/src/arrow/arrow_writer/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,10 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result<usi
771771
.unwrap();
772772
get_decimal_256_array_slice(array, indices)
773773
}
774+
ArrowDataType::Float16 => {
775+
let array = column.as_primitive::<Float16Type>();
776+
get_float_16_array_slice(array, indices)
777+
}
774778
_ => {
775779
return Err(ParquetError::NYI(
776780
"Attempting to write an Arrow type that is not yet implemented".to_string(),
@@ -867,6 +871,18 @@ fn get_decimal_256_array_slice(
867871
values
868872
}
869873

874+
fn get_float_16_array_slice(
875+
array: &arrow_array::Float16Array,
876+
indices: &[usize],
877+
) -> Vec<FixedLenByteArray> {
878+
let mut values = Vec::with_capacity(indices.len());
879+
for i in indices {
880+
let value = array.value(*i).to_le_bytes().to_vec();
881+
values.push(FixedLenByteArray::from(ByteArray::from(value)));
882+
}
883+
values
884+
}
885+
870886
fn get_fsb_array_slice(
871887
array: &arrow_array::FixedSizeBinaryArray,
872888
indices: &[usize],

parquet/src/arrow/schema/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,12 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
373373
.with_repetition(repetition)
374374
.with_id(id)
375375
.build(),
376-
DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")),
376+
DataType::Float16 => Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY)
377+
.with_repetition(repetition)
378+
.with_id(id)
379+
.with_logical_type(Some(LogicalType::Float16))
380+
.with_length(2)
381+
.build(),
377382
DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT)
378383
.with_repetition(repetition)
379384
.with_id(id)
@@ -604,9 +609,10 @@ mod tests {
604609
REQUIRED INT32 uint8 (INTEGER(8,false));
605610
REQUIRED INT32 uint16 (INTEGER(16,false));
606611
REQUIRED INT32 int32;
607-
REQUIRED INT64 int64 ;
612+
REQUIRED INT64 int64;
608613
OPTIONAL DOUBLE double;
609614
OPTIONAL FLOAT float;
615+
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
610616
OPTIONAL BINARY string (UTF8);
611617
OPTIONAL BINARY string_2 (STRING);
612618
OPTIONAL BINARY json (JSON);
@@ -628,6 +634,7 @@ mod tests {
628634
Field::new("int64", DataType::Int64, false),
629635
Field::new("double", DataType::Float64, true),
630636
Field::new("float", DataType::Float32, true),
637+
Field::new("float16", DataType::Float16, true),
631638
Field::new("string", DataType::Utf8, true),
632639
Field::new("string_2", DataType::Utf8, true),
633640
Field::new("json", DataType::Utf8, true),
@@ -1303,6 +1310,7 @@ mod tests {
13031310
REQUIRED INT64 int64;
13041311
OPTIONAL DOUBLE double;
13051312
OPTIONAL FLOAT float;
1313+
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
13061314
OPTIONAL BINARY string (UTF8);
13071315
REPEATED BOOLEAN bools;
13081316
OPTIONAL INT32 date (DATE);
@@ -1339,6 +1347,7 @@ mod tests {
13391347
Field::new("int64", DataType::Int64, false),
13401348
Field::new("double", DataType::Float64, true),
13411349
Field::new("float", DataType::Float32, true),
1350+
Field::new("float16", DataType::Float16, true),
13421351
Field::new("string", DataType::Utf8, true),
13431352
Field::new_list(
13441353
"bools",
@@ -1398,6 +1407,7 @@ mod tests {
13981407
REQUIRED INT64 int64;
13991408
OPTIONAL DOUBLE double;
14001409
OPTIONAL FLOAT float;
1410+
OPTIONAL FIXED_LEN_BYTE_ARRAY (2) float16 (FLOAT16);
14011411
OPTIONAL BINARY string (STRING);
14021412
OPTIONAL GROUP bools (LIST) {
14031413
REPEATED GROUP list {
@@ -1448,6 +1458,7 @@ mod tests {
14481458
Field::new("int64", DataType::Int64, false),
14491459
Field::new("double", DataType::Float64, true),
14501460
Field::new("float", DataType::Float32, true),
1461+
Field::new("float16", DataType::Float16, true),
14511462
Field::new("string", DataType::Utf8, true),
14521463
Field::new_list(
14531464
"bools",
@@ -1661,6 +1672,8 @@ mod tests {
16611672
vec![
16621673
Field::new("a", DataType::Int16, true),
16631674
Field::new("b", DataType::Float64, false),
1675+
Field::new("c", DataType::Float32, false),
1676+
Field::new("d", DataType::Float16, false),
16641677
]
16651678
.into(),
16661679
),

parquet/src/arrow/schema/primitive.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ fn from_fixed_len_byte_array(
304304
// would be incorrect if all 12 bytes of the interval are populated
305305
Ok(DataType::Interval(IntervalUnit::DayTime))
306306
}
307+
(Some(LogicalType::Float16), _) => {
308+
if type_length == 2 {
309+
Ok(DataType::Float16)
310+
} else {
311+
Err(ParquetError::General(
312+
"FLOAT16 logical type must be Fixed Length Byte Array with length 2"
313+
.to_string(),
314+
))
315+
}
316+
}
307317
_ => Ok(DataType::FixedSizeBinary(type_length)),
308318
}
309319
}

0 commit comments

Comments
 (0)