@@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
23
23
use std:: sync:: Arc ;
24
24
25
25
use super :: write:: orchestration:: stateless_multipart_put;
26
- use super :: { FileFormat , FileFormatFactory , DEFAULT_SCHEMA_INFER_MAX_RECORD } ;
26
+ use super :: {
27
+ Decoder , DecoderDeserializer , FileFormat , FileFormatFactory ,
28
+ DEFAULT_SCHEMA_INFER_MAX_RECORD ,
29
+ } ;
27
30
use crate :: datasource:: file_format:: file_compression_type:: FileCompressionType ;
28
31
use crate :: datasource:: file_format:: write:: BatchSerializer ;
29
32
use crate :: datasource:: physical_plan:: {
@@ -38,8 +41,8 @@ use crate::physical_plan::{
38
41
39
42
use arrow:: array:: RecordBatch ;
40
43
use arrow:: csv:: WriterBuilder ;
41
- use arrow:: datatypes:: SchemaRef ;
42
- use arrow :: datatypes :: { DataType , Field , Fields , Schema } ;
44
+ use arrow:: datatypes:: { DataType , Field , Fields , Schema , SchemaRef } ;
45
+ use arrow_schema :: ArrowError ;
43
46
use datafusion_common:: config:: { ConfigField , ConfigFileType , CsvOptions } ;
44
47
use datafusion_common:: file_options:: csv_writer:: CsvWriterOptions ;
45
48
use datafusion_common:: {
@@ -293,6 +296,45 @@ impl CsvFormat {
293
296
}
294
297
}
295
298
299
+ #[ derive( Debug ) ]
300
+ pub ( crate ) struct CsvDecoder {
301
+ inner : arrow:: csv:: reader:: Decoder ,
302
+ }
303
+
304
+ impl CsvDecoder {
305
+ pub ( crate ) fn new ( decoder : arrow:: csv:: reader:: Decoder ) -> Self {
306
+ Self { inner : decoder }
307
+ }
308
+ }
309
+
310
+ impl Decoder for CsvDecoder {
311
+ fn decode ( & mut self , buf : & [ u8 ] ) -> Result < usize , ArrowError > {
312
+ self . inner . decode ( buf)
313
+ }
314
+
315
+ fn flush ( & mut self ) -> Result < Option < RecordBatch > , ArrowError > {
316
+ self . inner . flush ( )
317
+ }
318
+
319
+ fn can_flush_early ( & self ) -> bool {
320
+ self . inner . capacity ( ) == 0
321
+ }
322
+ }
323
+
324
+ impl Debug for CsvSerializer {
325
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
326
+ f. debug_struct ( "CsvSerializer" )
327
+ . field ( "header" , & self . header )
328
+ . finish ( )
329
+ }
330
+ }
331
+
332
+ impl From < arrow:: csv:: reader:: Decoder > for DecoderDeserializer < CsvDecoder > {
333
+ fn from ( decoder : arrow:: csv:: reader:: Decoder ) -> Self {
334
+ DecoderDeserializer :: new ( CsvDecoder :: new ( decoder) )
335
+ }
336
+ }
337
+
296
338
#[ async_trait]
297
339
impl FileFormat for CsvFormat {
298
340
fn as_any ( & self ) -> & dyn Any {
@@ -692,23 +734,28 @@ impl DataSink for CsvSink {
692
734
mod tests {
693
735
use super :: super :: test_util:: scan_format;
694
736
use super :: * ;
695
- use crate :: arrow:: util:: pretty;
696
737
use crate :: assert_batches_eq;
697
738
use crate :: datasource:: file_format:: file_compression_type:: FileCompressionType ;
698
739
use crate :: datasource:: file_format:: test_util:: VariableStream ;
740
+ use crate :: datasource:: file_format:: {
741
+ BatchDeserializer , DecoderDeserializer , DeserializerOutput ,
742
+ } ;
699
743
use crate :: datasource:: listing:: ListingOptions ;
744
+ use crate :: execution:: session_state:: SessionStateBuilder ;
700
745
use crate :: physical_plan:: collect;
701
746
use crate :: prelude:: { CsvReadOptions , SessionConfig , SessionContext } ;
702
747
use crate :: test_util:: arrow_test_data;
703
748
704
749
use arrow:: compute:: concat_batches;
750
+ use arrow:: csv:: ReaderBuilder ;
751
+ use arrow:: util:: pretty:: pretty_format_batches;
752
+ use arrow_array:: { BooleanArray , Float64Array , Int32Array , StringArray } ;
705
753
use datafusion_common:: cast:: as_string_array;
706
754
use datafusion_common:: internal_err;
707
755
use datafusion_common:: stats:: Precision ;
708
756
use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
709
757
use datafusion_expr:: { col, lit} ;
710
758
711
- use crate :: execution:: session_state:: SessionStateBuilder ;
712
759
use chrono:: DateTime ;
713
760
use object_store:: local:: LocalFileSystem ;
714
761
use object_store:: path:: Path ;
@@ -1097,7 +1144,7 @@ mod tests {
1097
1144
) -> Result < usize > {
1098
1145
let df = ctx. sql ( & format ! ( "EXPLAIN {sql}" ) ) . await ?;
1099
1146
let result = df. collect ( ) . await ?;
1100
- let plan = format ! ( "{}" , & pretty :: pretty_format_batches( & result) ?) ;
1147
+ let plan = format ! ( "{}" , & pretty_format_batches( & result) ?) ;
1101
1148
1102
1149
let re = Regex :: new ( r"CsvExec: file_groups=\{(\d+) group" ) . unwrap ( ) ;
1103
1150
@@ -1464,4 +1511,180 @@ mod tests {
1464
1511
1465
1512
Ok ( ( ) )
1466
1513
}
1514
+
1515
+ #[ rstest]
1516
+ fn test_csv_deserializer_with_finish (
1517
+ #[ values( 1 , 5 , 17 ) ] batch_size : usize ,
1518
+ #[ values( 0 , 5 , 93 ) ] line_count : usize ,
1519
+ ) -> Result < ( ) > {
1520
+ let schema = csv_schema ( ) ;
1521
+ let generator = CsvBatchGenerator :: new ( batch_size, line_count) ;
1522
+ let mut deserializer = csv_deserializer ( batch_size, & schema) ;
1523
+
1524
+ for data in generator {
1525
+ deserializer. digest ( data) ;
1526
+ }
1527
+ deserializer. finish ( ) ;
1528
+
1529
+ let batch_count = line_count. div_ceil ( batch_size) ;
1530
+
1531
+ let mut all_batches = RecordBatch :: new_empty ( schema. clone ( ) ) ;
1532
+ for _ in 0 ..batch_count {
1533
+ let output = deserializer. next ( ) ?;
1534
+ let DeserializerOutput :: RecordBatch ( batch) = output else {
1535
+ panic ! ( "Expected RecordBatch, got {:?}" , output) ;
1536
+ } ;
1537
+ all_batches = concat_batches ( & schema, & [ all_batches, batch] ) ?;
1538
+ }
1539
+ assert_eq ! ( deserializer. next( ) ?, DeserializerOutput :: InputExhausted ) ;
1540
+
1541
+ let expected = csv_expected_batch ( schema, line_count) ?;
1542
+
1543
+ assert_eq ! (
1544
+ expected. clone( ) ,
1545
+ all_batches. clone( ) ,
1546
+ "Expected:\n {}\n Actual:\n {}" ,
1547
+ pretty_format_batches( & [ expected] ) ?,
1548
+ pretty_format_batches( & [ all_batches] ) ?,
1549
+ ) ;
1550
+
1551
+ Ok ( ( ) )
1552
+ }
1553
+
1554
+ #[ rstest]
1555
+ fn test_csv_deserializer_without_finish (
1556
+ #[ values( 1 , 5 , 17 ) ] batch_size : usize ,
1557
+ #[ values( 0 , 5 , 93 ) ] line_count : usize ,
1558
+ ) -> Result < ( ) > {
1559
+ let schema = csv_schema ( ) ;
1560
+ let generator = CsvBatchGenerator :: new ( batch_size, line_count) ;
1561
+ let mut deserializer = csv_deserializer ( batch_size, & schema) ;
1562
+
1563
+ for data in generator {
1564
+ deserializer. digest ( data) ;
1565
+ }
1566
+
1567
+ let batch_count = line_count / batch_size;
1568
+
1569
+ let mut all_batches = RecordBatch :: new_empty ( schema. clone ( ) ) ;
1570
+ for _ in 0 ..batch_count {
1571
+ let output = deserializer. next ( ) ?;
1572
+ let DeserializerOutput :: RecordBatch ( batch) = output else {
1573
+ panic ! ( "Expected RecordBatch, got {:?}" , output) ;
1574
+ } ;
1575
+ all_batches = concat_batches ( & schema, & [ all_batches, batch] ) ?;
1576
+ }
1577
+ assert_eq ! ( deserializer. next( ) ?, DeserializerOutput :: RequiresMoreData ) ;
1578
+
1579
+ let expected = csv_expected_batch ( schema, batch_count * batch_size) ?;
1580
+
1581
+ assert_eq ! (
1582
+ expected. clone( ) ,
1583
+ all_batches. clone( ) ,
1584
+ "Expected:\n {}\n Actual:\n {}" ,
1585
+ pretty_format_batches( & [ expected] ) ?,
1586
+ pretty_format_batches( & [ all_batches] ) ?,
1587
+ ) ;
1588
+
1589
+ Ok ( ( ) )
1590
+ }
1591
+
1592
+ struct CsvBatchGenerator {
1593
+ batch_size : usize ,
1594
+ line_count : usize ,
1595
+ offset : usize ,
1596
+ }
1597
+
1598
+ impl CsvBatchGenerator {
1599
+ fn new ( batch_size : usize , line_count : usize ) -> Self {
1600
+ Self {
1601
+ batch_size,
1602
+ line_count,
1603
+ offset : 0 ,
1604
+ }
1605
+ }
1606
+ }
1607
+
1608
+ impl Iterator for CsvBatchGenerator {
1609
+ type Item = Bytes ;
1610
+
1611
+ fn next ( & mut self ) -> Option < Self :: Item > {
1612
+ // Return `batch_size` rows per batch:
1613
+ let mut buffer = Vec :: new ( ) ;
1614
+ for _ in 0 ..self . batch_size {
1615
+ if self . offset >= self . line_count {
1616
+ break ;
1617
+ }
1618
+ buffer. extend_from_slice ( & csv_line ( self . offset ) ) ;
1619
+ self . offset += 1 ;
1620
+ }
1621
+
1622
+ ( !buffer. is_empty ( ) ) . then ( || buffer. into ( ) )
1623
+ }
1624
+ }
1625
+
1626
+ fn csv_expected_batch (
1627
+ schema : SchemaRef ,
1628
+ line_count : usize ,
1629
+ ) -> Result < RecordBatch , DataFusionError > {
1630
+ let mut c1 = Vec :: with_capacity ( line_count) ;
1631
+ let mut c2 = Vec :: with_capacity ( line_count) ;
1632
+ let mut c3 = Vec :: with_capacity ( line_count) ;
1633
+ let mut c4 = Vec :: with_capacity ( line_count) ;
1634
+
1635
+ for i in 0 ..line_count {
1636
+ let ( int_value, float_value, bool_value, char_value) = csv_values ( i) ;
1637
+ c1. push ( int_value) ;
1638
+ c2. push ( float_value) ;
1639
+ c3. push ( bool_value) ;
1640
+ c4. push ( char_value) ;
1641
+ }
1642
+
1643
+ let expected = RecordBatch :: try_new (
1644
+ schema. clone ( ) ,
1645
+ vec ! [
1646
+ Arc :: new( Int32Array :: from( c1) ) ,
1647
+ Arc :: new( Float64Array :: from( c2) ) ,
1648
+ Arc :: new( BooleanArray :: from( c3) ) ,
1649
+ Arc :: new( StringArray :: from( c4) ) ,
1650
+ ] ,
1651
+ ) ?;
1652
+ Ok ( expected)
1653
+ }
1654
+
1655
+ fn csv_line ( line_number : usize ) -> Bytes {
1656
+ let ( int_value, float_value, bool_value, char_value) = csv_values ( line_number) ;
1657
+ format ! (
1658
+ "{},{},{},{}\n " ,
1659
+ int_value, float_value, bool_value, char_value
1660
+ )
1661
+ . into ( )
1662
+ }
1663
+
1664
+ fn csv_values ( line_number : usize ) -> ( i32 , f64 , bool , String ) {
1665
+ let int_value = line_number as i32 ;
1666
+ let float_value = line_number as f64 ;
1667
+ let bool_value = line_number % 2 == 0 ;
1668
+ let char_value = format ! ( "{}-string" , line_number) ;
1669
+ ( int_value, float_value, bool_value, char_value)
1670
+ }
1671
+
1672
+ fn csv_schema ( ) -> Arc < Schema > {
1673
+ Arc :: new ( Schema :: new ( vec ! [
1674
+ Field :: new( "c1" , DataType :: Int32 , true ) ,
1675
+ Field :: new( "c2" , DataType :: Float64 , true ) ,
1676
+ Field :: new( "c3" , DataType :: Boolean , true ) ,
1677
+ Field :: new( "c4" , DataType :: Utf8 , true ) ,
1678
+ ] ) )
1679
+ }
1680
+
1681
+ fn csv_deserializer (
1682
+ batch_size : usize ,
1683
+ schema : & Arc < Schema > ,
1684
+ ) -> impl BatchDeserializer < Bytes > {
1685
+ let decoder = ReaderBuilder :: new ( schema. clone ( ) )
1686
+ . with_batch_size ( batch_size)
1687
+ . build_decoder ( ) ;
1688
+ DecoderDeserializer :: new ( CsvDecoder :: new ( decoder) )
1689
+ }
1467
1690
}
0 commit comments