Skip to content

Commit 06db9ed

Browse files
Deduplicate and standardize deserialization logic for streams (#13412)
* Add BatchDeserializer * Fix formatting * Remove unused enum value * Update datafusion/core/src/datasource/file_format/mod.rs --------- Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent a09814a commit 06db9ed

File tree

5 files changed

+547
-72
lines changed

5 files changed

+547
-72
lines changed

datafusion/core/src/datasource/file_format/csv.rs

Lines changed: 229 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
2323
use std::sync::Arc;
2424

2525
use super::write::orchestration::stateless_multipart_put;
26-
use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD};
26+
use super::{
27+
Decoder, DecoderDeserializer, FileFormat, FileFormatFactory,
28+
DEFAULT_SCHEMA_INFER_MAX_RECORD,
29+
};
2730
use crate::datasource::file_format::file_compression_type::FileCompressionType;
2831
use crate::datasource::file_format::write::BatchSerializer;
2932
use crate::datasource::physical_plan::{
@@ -38,8 +41,8 @@ use crate::physical_plan::{
3841

3942
use arrow::array::RecordBatch;
4043
use arrow::csv::WriterBuilder;
41-
use arrow::datatypes::SchemaRef;
42-
use arrow::datatypes::{DataType, Field, Fields, Schema};
44+
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
45+
use arrow_schema::ArrowError;
4346
use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
4447
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
4548
use datafusion_common::{
@@ -293,6 +296,45 @@ impl CsvFormat {
293296
}
294297
}
295298

299+
#[derive(Debug)]
300+
pub(crate) struct CsvDecoder {
301+
inner: arrow::csv::reader::Decoder,
302+
}
303+
304+
impl CsvDecoder {
305+
pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self {
306+
Self { inner: decoder }
307+
}
308+
}
309+
310+
impl Decoder for CsvDecoder {
311+
fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
312+
self.inner.decode(buf)
313+
}
314+
315+
fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
316+
self.inner.flush()
317+
}
318+
319+
fn can_flush_early(&self) -> bool {
320+
self.inner.capacity() == 0
321+
}
322+
}
323+
324+
impl Debug for CsvSerializer {
325+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326+
f.debug_struct("CsvSerializer")
327+
.field("header", &self.header)
328+
.finish()
329+
}
330+
}
331+
332+
impl From<arrow::csv::reader::Decoder> for DecoderDeserializer<CsvDecoder> {
333+
fn from(decoder: arrow::csv::reader::Decoder) -> Self {
334+
DecoderDeserializer::new(CsvDecoder::new(decoder))
335+
}
336+
}
337+
296338
#[async_trait]
297339
impl FileFormat for CsvFormat {
298340
fn as_any(&self) -> &dyn Any {
@@ -692,23 +734,28 @@ impl DataSink for CsvSink {
692734
mod tests {
693735
use super::super::test_util::scan_format;
694736
use super::*;
695-
use crate::arrow::util::pretty;
696737
use crate::assert_batches_eq;
697738
use crate::datasource::file_format::file_compression_type::FileCompressionType;
698739
use crate::datasource::file_format::test_util::VariableStream;
740+
use crate::datasource::file_format::{
741+
BatchDeserializer, DecoderDeserializer, DeserializerOutput,
742+
};
699743
use crate::datasource::listing::ListingOptions;
744+
use crate::execution::session_state::SessionStateBuilder;
700745
use crate::physical_plan::collect;
701746
use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
702747
use crate::test_util::arrow_test_data;
703748

704749
use arrow::compute::concat_batches;
750+
use arrow::csv::ReaderBuilder;
751+
use arrow::util::pretty::pretty_format_batches;
752+
use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray};
705753
use datafusion_common::cast::as_string_array;
706754
use datafusion_common::internal_err;
707755
use datafusion_common::stats::Precision;
708756
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
709757
use datafusion_expr::{col, lit};
710758

711-
use crate::execution::session_state::SessionStateBuilder;
712759
use chrono::DateTime;
713760
use object_store::local::LocalFileSystem;
714761
use object_store::path::Path;
@@ -1097,7 +1144,7 @@ mod tests {
10971144
) -> Result<usize> {
10981145
let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
10991146
let result = df.collect().await?;
1100-
let plan = format!("{}", &pretty::pretty_format_batches(&result)?);
1147+
let plan = format!("{}", &pretty_format_batches(&result)?);
11011148

11021149
let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap();
11031150

@@ -1464,4 +1511,180 @@ mod tests {
14641511

14651512
Ok(())
14661513
}
1514+
1515+
#[rstest]
1516+
fn test_csv_deserializer_with_finish(
1517+
#[values(1, 5, 17)] batch_size: usize,
1518+
#[values(0, 5, 93)] line_count: usize,
1519+
) -> Result<()> {
1520+
let schema = csv_schema();
1521+
let generator = CsvBatchGenerator::new(batch_size, line_count);
1522+
let mut deserializer = csv_deserializer(batch_size, &schema);
1523+
1524+
for data in generator {
1525+
deserializer.digest(data);
1526+
}
1527+
deserializer.finish();
1528+
1529+
let batch_count = line_count.div_ceil(batch_size);
1530+
1531+
let mut all_batches = RecordBatch::new_empty(schema.clone());
1532+
for _ in 0..batch_count {
1533+
let output = deserializer.next()?;
1534+
let DeserializerOutput::RecordBatch(batch) = output else {
1535+
panic!("Expected RecordBatch, got {:?}", output);
1536+
};
1537+
all_batches = concat_batches(&schema, &[all_batches, batch])?;
1538+
}
1539+
assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);
1540+
1541+
let expected = csv_expected_batch(schema, line_count)?;
1542+
1543+
assert_eq!(
1544+
expected.clone(),
1545+
all_batches.clone(),
1546+
"Expected:\n{}\nActual:\n{}",
1547+
pretty_format_batches(&[expected])?,
1548+
pretty_format_batches(&[all_batches])?,
1549+
);
1550+
1551+
Ok(())
1552+
}
1553+
1554+
#[rstest]
1555+
fn test_csv_deserializer_without_finish(
1556+
#[values(1, 5, 17)] batch_size: usize,
1557+
#[values(0, 5, 93)] line_count: usize,
1558+
) -> Result<()> {
1559+
let schema = csv_schema();
1560+
let generator = CsvBatchGenerator::new(batch_size, line_count);
1561+
let mut deserializer = csv_deserializer(batch_size, &schema);
1562+
1563+
for data in generator {
1564+
deserializer.digest(data);
1565+
}
1566+
1567+
let batch_count = line_count / batch_size;
1568+
1569+
let mut all_batches = RecordBatch::new_empty(schema.clone());
1570+
for _ in 0..batch_count {
1571+
let output = deserializer.next()?;
1572+
let DeserializerOutput::RecordBatch(batch) = output else {
1573+
panic!("Expected RecordBatch, got {:?}", output);
1574+
};
1575+
all_batches = concat_batches(&schema, &[all_batches, batch])?;
1576+
}
1577+
assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);
1578+
1579+
let expected = csv_expected_batch(schema, batch_count * batch_size)?;
1580+
1581+
assert_eq!(
1582+
expected.clone(),
1583+
all_batches.clone(),
1584+
"Expected:\n{}\nActual:\n{}",
1585+
pretty_format_batches(&[expected])?,
1586+
pretty_format_batches(&[all_batches])?,
1587+
);
1588+
1589+
Ok(())
1590+
}
1591+
1592+
struct CsvBatchGenerator {
1593+
batch_size: usize,
1594+
line_count: usize,
1595+
offset: usize,
1596+
}
1597+
1598+
impl CsvBatchGenerator {
1599+
fn new(batch_size: usize, line_count: usize) -> Self {
1600+
Self {
1601+
batch_size,
1602+
line_count,
1603+
offset: 0,
1604+
}
1605+
}
1606+
}
1607+
1608+
impl Iterator for CsvBatchGenerator {
1609+
type Item = Bytes;
1610+
1611+
fn next(&mut self) -> Option<Self::Item> {
1612+
// Return `batch_size` rows per batch:
1613+
let mut buffer = Vec::new();
1614+
for _ in 0..self.batch_size {
1615+
if self.offset >= self.line_count {
1616+
break;
1617+
}
1618+
buffer.extend_from_slice(&csv_line(self.offset));
1619+
self.offset += 1;
1620+
}
1621+
1622+
(!buffer.is_empty()).then(|| buffer.into())
1623+
}
1624+
}
1625+
1626+
fn csv_expected_batch(
1627+
schema: SchemaRef,
1628+
line_count: usize,
1629+
) -> Result<RecordBatch, DataFusionError> {
1630+
let mut c1 = Vec::with_capacity(line_count);
1631+
let mut c2 = Vec::with_capacity(line_count);
1632+
let mut c3 = Vec::with_capacity(line_count);
1633+
let mut c4 = Vec::with_capacity(line_count);
1634+
1635+
for i in 0..line_count {
1636+
let (int_value, float_value, bool_value, char_value) = csv_values(i);
1637+
c1.push(int_value);
1638+
c2.push(float_value);
1639+
c3.push(bool_value);
1640+
c4.push(char_value);
1641+
}
1642+
1643+
let expected = RecordBatch::try_new(
1644+
schema.clone(),
1645+
vec![
1646+
Arc::new(Int32Array::from(c1)),
1647+
Arc::new(Float64Array::from(c2)),
1648+
Arc::new(BooleanArray::from(c3)),
1649+
Arc::new(StringArray::from(c4)),
1650+
],
1651+
)?;
1652+
Ok(expected)
1653+
}
1654+
1655+
fn csv_line(line_number: usize) -> Bytes {
1656+
let (int_value, float_value, bool_value, char_value) = csv_values(line_number);
1657+
format!(
1658+
"{},{},{},{}\n",
1659+
int_value, float_value, bool_value, char_value
1660+
)
1661+
.into()
1662+
}
1663+
1664+
fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
1665+
let int_value = line_number as i32;
1666+
let float_value = line_number as f64;
1667+
let bool_value = line_number % 2 == 0;
1668+
let char_value = format!("{}-string", line_number);
1669+
(int_value, float_value, bool_value, char_value)
1670+
}
1671+
1672+
fn csv_schema() -> Arc<Schema> {
1673+
Arc::new(Schema::new(vec![
1674+
Field::new("c1", DataType::Int32, true),
1675+
Field::new("c2", DataType::Float64, true),
1676+
Field::new("c3", DataType::Boolean, true),
1677+
Field::new("c4", DataType::Utf8, true),
1678+
]))
1679+
}
1680+
1681+
fn csv_deserializer(
1682+
batch_size: usize,
1683+
schema: &Arc<Schema>,
1684+
) -> impl BatchDeserializer<Bytes> {
1685+
let decoder = ReaderBuilder::new(schema.clone())
1686+
.with_batch_size(batch_size)
1687+
.build_decoder();
1688+
DecoderDeserializer::new(CsvDecoder::new(decoder))
1689+
}
14671690
}

0 commit comments

Comments
 (0)