Skip to content

Commit abb25ed

Browse files
committed
Add node to convert csp.Structs into record batches and cleanup
Signed-off-by: Arham Chopra <[email protected]>
1 parent 01efd9f commit abb25ed

7 files changed

+218
-31
lines changed

cpp/csp/adapters/parquet/ParquetDictBasketOutputWriter.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ void ParquetDictBasketOutputWriter::start()
2020
m_indexFileWriterContainer = std::make_unique<MultipleFileWriterWrapperContainer>(
2121
arrow::schema( { arrow::field( m_cycleIndexOutputAdapter -> getColumnArrayBuilder( 0 ) -> getColumnName(),
2222
m_cycleIndexOutputAdapter -> getColumnArrayBuilder( 0 ) -> getDataType() ) } ),
23-
m_adapterMgr.isWriteArrowBinary() );
24-
if( !m_adapterMgr.getFileName().empty() )
23+
m_adapterMgr -> isWriteArrowBinary() );
24+
if( !m_adapterMgr -> getFileName().empty() )
2525
{
26-
m_indexFileWriterContainer -> open( m_adapterMgr.getFileName(),
27-
m_adapterMgr.getCompression(), m_adapterMgr.isAllowOverwrite() );
26+
m_indexFileWriterContainer -> open( m_adapterMgr -> getFileName(),
27+
m_adapterMgr -> getCompression(), m_adapterMgr -> isAllowOverwrite() );
2828

2929
}
3030
}
@@ -45,7 +45,7 @@ void ParquetDictBasketOutputWriter::stop()
4545

4646
void ParquetDictBasketOutputWriter::writeValue( const std::string &valueKey, const TimeSeriesProvider *ts )
4747
{
48-
m_adapterMgr.scheduleEndCycle();
48+
m_adapterMgr -> scheduleEndCycle();
4949
m_symbolOutputAdapter -> writeValue<std::string, StringArrayBuilder>( valueKey );
5050
ParquetWriter::onEndCycle();
5151
++m_nextCycleIndex;
@@ -86,7 +86,7 @@ void ParquetDictBasketOutputWriter::onFileNameChange( const std::string &fileNam
8686
if(!fileName.empty())
8787
{
8888
m_indexFileWriterContainer
89-
-> open( fileName, m_adapterMgr.getCompression(), m_adapterMgr.isAllowOverwrite() );
89+
-> open( fileName, m_adapterMgr -> getCompression(), m_adapterMgr -> isAllowOverwrite() );
9090
}
9191

9292
}

cpp/csp/adapters/parquet/ParquetOutputAdapter.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ StructParquetOutputHandler::StructParquetOutputHandler( Engine *engine, ParquetW
192192
}
193193
}
194194

195+
void StructParquetOutputHandler::writeValueFromArgs( const StructPtr input )
196+
{
197+
const Struct *structData = input.get();
198+
199+
for( auto &&valueHandler: m_valueHandlers )
200+
{
201+
valueHandler( structData );
202+
}
203+
}
204+
195205
void StructParquetOutputHandler::writeValueFromTs( const TimeSeriesProvider *input )
196206
{
197207
const Struct *structData = input -> lastValueTyped<StructPtr>().get();

cpp/csp/adapters/parquet/ParquetOutputAdapter.h

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class StructParquetOutputHandler : public ParquetOutputHandler
151151
}
152152

153153
void writeValueFromTs( const TimeSeriesProvider *input ) override final;
154+
void writeValueFromArgs( const StructPtr input );
154155

155156
private:
156157
using ValueHandler = std::function<void( const Struct * )>;

cpp/csp/adapters/parquet/ParquetWriter.cpp

+17-13
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
namespace csp::adapters::parquet
1414
{
1515

16+
ParquetWriter::ParquetWriter()
17+
: m_adapterMgr( nullptr ), m_engine( nullptr ), m_curChunkSize( 0 ), m_writeTimestampColumn( std::optional<bool>{} )
18+
{}
19+
1620
ParquetWriter::ParquetWriter( ParquetOutputAdapterManager *mgr, std::optional<bool> writeTimestampColumn )
17-
: m_adapterMgr( *mgr ), m_engine( mgr -> engine() ), m_curChunkSize( 0 ), m_writeTimestampColumn( writeTimestampColumn )
21+
: m_adapterMgr( mgr ), m_engine( mgr -> engine() ), m_curChunkSize( 0 ), m_writeTimestampColumn( writeTimestampColumn )
1822
{}
1923

2024
ParquetWriter::ParquetWriter( ParquetOutputAdapterManager *mgr, const Dictionary & properties ) : ParquetWriter( mgr, std::optional<bool>{} )
@@ -128,20 +132,20 @@ PushInputAdapter *ParquetWriter::getStatusAdapter()
128132
void ParquetWriter::start()
129133
{
130134
std::vector<std::shared_ptr<arrow::Field>> arrowFields;
131-
if( !m_writeTimestampColumn.has_value() && !m_adapterMgr.getTimestampColumnName().empty() )
135+
if( !m_writeTimestampColumn.has_value() && !m_adapterMgr -> getTimestampColumnName().empty() )
132136
{
133137
m_writeTimestampColumn = true;
134-
m_columnBuilders.push_back( std::make_shared<DatetimeArrayBuilder>( m_adapterMgr.getTimestampColumnName(), getChunkSize() ) );
138+
m_columnBuilders.push_back( std::make_shared<DatetimeArrayBuilder>( m_adapterMgr -> getTimestampColumnName(), getChunkSize() ) );
135139
std::shared_ptr<arrow::KeyValueMetadata> colMetaData;
136-
auto colMetaIt = m_columnMetaData.find( m_adapterMgr.getTimestampColumnName() );
140+
auto colMetaIt = m_columnMetaData.find( m_adapterMgr -> getTimestampColumnName() );
137141
if( colMetaIt != m_columnMetaData.end() )
138142
{
139143
colMetaData = colMetaIt -> second;
140144
m_columnMetaData.erase( colMetaIt );
141145
}
142146

143147
arrowFields.push_back(
144-
arrow::field( m_adapterMgr.getTimestampColumnName(), m_columnBuilders.back() -> getDataType(), colMetaData ) );
148+
arrow::field( m_adapterMgr -> getTimestampColumnName(), m_columnBuilders.back() -> getDataType(), colMetaData ) );
145149
}
146150
else
147151
{
@@ -194,7 +198,7 @@ void ParquetWriter::onEndCycle()
194198
if( m_writeTimestampColumn.value() )
195199
{
196200
// Set the timestamp value it's always the first
197-
now = m_adapterMgr.rootEngine() -> now();
201+
now = m_adapterMgr -> rootEngine() -> now();
198202
static_cast<DatetimeArrayBuilder *>(m_columnBuilders[ 0 ].get()) -> setValue( now );
199203
}
200204
for( auto &&columnBuilder:m_columnBuilders )
@@ -221,7 +225,7 @@ void ParquetWriter::onFileNameChange( const std::string &fileName )
221225
if( !fileName.empty() )
222226
{
223227
m_fileWriterWrapperContainer
224-
-> open( fileName, m_adapterMgr.getCompression(), m_adapterMgr.isAllowOverwrite() );
228+
-> open( fileName, m_adapterMgr -> getCompression(), m_adapterMgr -> isAllowOverwrite() );
225229
}
226230
}
227231

@@ -245,20 +249,20 @@ StructParquetOutputHandler *ParquetWriter::createStructOutputHandler( CspTypePtr
245249

246250
void ParquetWriter::initFileWriterContainer( std::shared_ptr<arrow::Schema> schema )
247251
{
248-
if( m_adapterMgr.isSplitColumnsToFiles() )
252+
if( m_adapterMgr -> isSplitColumnsToFiles() )
249253
{
250254
m_fileWriterWrapperContainer = std::make_unique<MultipleFileWriterWrapperContainer>( schema,
251-
m_adapterMgr.isWriteArrowBinary() );
255+
m_adapterMgr -> isWriteArrowBinary() );
252256
}
253257
else
254258
{
255259
m_fileWriterWrapperContainer = std::make_unique<SingleFileWriterWrapperContainer>( schema,
256-
m_adapterMgr.isWriteArrowBinary() );
260+
m_adapterMgr -> isWriteArrowBinary() );
257261
}
258-
if( !m_adapterMgr.getFileName().empty() )
262+
if( !m_adapterMgr -> getFileName().empty() )
259263
{
260-
m_fileWriterWrapperContainer -> open( m_adapterMgr.getFileName(),
261-
m_adapterMgr.getCompression(), m_adapterMgr.isAllowOverwrite() );
264+
m_fileWriterWrapperContainer -> open( m_adapterMgr -> getFileName(),
265+
m_adapterMgr -> getCompression(), m_adapterMgr -> isAllowOverwrite() );
262266
}
263267
}
264268

cpp/csp/adapters/parquet/ParquetWriter.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class FileWriterWrapperContainer;
3131
class ParquetWriter : public EndCycleListener
3232
{
3333
public:
34+
ParquetWriter();
3435
ParquetWriter( ParquetOutputAdapterManager *mgr, std::optional<bool> writeTimestampColumn = {} );
3536
ParquetWriter( ParquetOutputAdapterManager *mgr, const Dictionary & properties );
3637

@@ -53,11 +54,11 @@ class ParquetWriter : public EndCycleListener
5354

5455
void onEndCycle() override;
5556

56-
std::uint32_t getChunkSize() const{ return m_adapterMgr.getBatchSize(); }
57+
virtual std::uint32_t getChunkSize() const{ return m_adapterMgr -> getBatchSize(); }
5758

5859
virtual void scheduleEndCycleEvent()
5960
{
60-
m_adapterMgr.scheduleEndCycle();
61+
m_adapterMgr -> scheduleEndCycle();
6162
}
6263

6364
bool isFileOpen() const;
@@ -76,7 +77,7 @@ class ParquetWriter : public EndCycleListener
7677
using Adapters = std::vector<ParquetOutputHandler *>;
7778
using PublishedColumnNames = std::unordered_set<std::string>;
7879

79-
ParquetOutputAdapterManager &m_adapterMgr;
80+
ParquetOutputAdapterManager *m_adapterMgr;
8081
Engine *m_engine;
8182
private:
8283
Adapters m_adapters;

cpp/csp/python/cspbaselibimpl.cpp

+143-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
#include <csp/adapters/parquet/ParquetReader.h>
1414
#include <csp/adapters/utils/StructAdapterInfo.h>
15+
#include <csp/adapters/parquet/ParquetOutputAdapter.h>
16+
#include <csp/adapters/parquet/ParquetWriter.h>
17+
#include <csp/python/PyObjectPtr.h>
1518

1619
static void * init_nparray()
1720
{
@@ -21,6 +24,24 @@ static void * init_nparray()
2124
}
2225
static void * s_init_array = init_nparray();
2326

27+
void ReleaseArrowSchemaPyCapsule( PyObject* capsule ) {
28+
struct ArrowSchema* schema =
29+
( struct ArrowSchema* )PyCapsule_GetPointer( capsule, "arrow_schema" );
30+
if ( schema->release != NULL ) {
31+
schema->release( schema );
32+
}
33+
free( schema );
34+
}
35+
36+
void ReleaseArrowArrayPyCapsule( PyObject* capsule ) {
37+
struct ArrowArray* array =
38+
( struct ArrowArray* )PyCapsule_GetPointer( capsule, "arrow_array");
39+
if ( array->release != NULL ) {
40+
array->release( array );
41+
}
42+
free( array );
43+
}
44+
2445
namespace csp::cppnodes
2546
{
2647
DECLARE_CPPNODE( exprtk_impl )
@@ -403,9 +424,9 @@ DECLARE_CPPNODE( record_batches_to_struct )
403424
START()
404425
{
405426
// Create Adapters for Schema
406-
PyObject* capsule = csp::python::toPythonBorrowed(schema_ptr);
407-
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer(capsule, "arrow_schema") );
408-
auto result = arrow::ImportSchema(c_schema);
427+
PyObject* capsule = csp::python::toPythonBorrowed( schema_ptr );
428+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( capsule, "arrow_schema") );
429+
auto result = arrow::ImportSchema( c_schema );
409430
if( !result.ok() )
410431
CSP_THROW( ValueError, "Failed to load the arrow schema: " << result.status().ToString() );
411432
std::shared_ptr<arrow::Schema> schema = result.ValueUnsafe();
@@ -414,15 +435,15 @@ DECLARE_CPPNODE( record_batches_to_struct )
414435
for( auto it = field_map -> begin(); it != field_map -> end(); ++it )
415436
{
416437
if( schema -> GetFieldByName( it.key() ) )
417-
columns.push_back(it.key());
438+
columns.push_back( it.key() );
418439
else
419440
CSP_THROW( ValueError, "column " << it.key() << " not found in schema" );
420441
}
421442
reader = std::make_shared<RecordBatchReader>( columns, schema );
422443
reader -> initialize();
423444

424445
CspTypePtr outType = std::make_shared<csp::CspStructType>( cls.value() );
425-
csp::adapters::utils::StructAdapterInfo key{ std::move(outType), std::move(field_map) };
446+
csp::adapters::utils::StructAdapterInfo key{ std::move( outType ), std::move( field_map ) };
426447
auto& struct_adapter = reader -> getStructAdapter( key );
427448
struct_adapter.addSubscriber( [this]( StructPtr * s )
428449
{
@@ -444,10 +465,10 @@ DECLARE_CPPNODE( record_batches_to_struct )
444465
PyObject* py_array = PyTuple_GET_ITEM( py_tuple, 1 );
445466
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
446467
struct ArrowArray* c_array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( py_array, "arrow_array" ) );
447-
auto result = arrow::ImportRecordBatch(c_array, c_schema);
468+
auto result = arrow::ImportRecordBatch( c_array, c_schema );
448469
if( !result.ok() )
449470
CSP_THROW( ValueError, "Failed to load record batches through PyCapsule C Data interface: " << result.status().ToString() );
450-
batches.emplace_back(result.ValueUnsafe());
471+
batches.emplace_back( result.ValueUnsafe() );
451472
}
452473
std::vector<StructPtr> & out = unnamed_output().reserveSpace<std::vector<StructPtr>>();
453474
out.clear();
@@ -460,6 +481,120 @@ DECLARE_CPPNODE( record_batches_to_struct )
460481

461482
EXPORT_CPPNODE( record_batches_to_struct );
462483

484+
DECLARE_CPPNODE( struct_to_record_batches )
485+
{
486+
SCALAR_INPUT( DialectGenericType, schema_ptr );
487+
SCALAR_INPUT( StructMetaPtr, cls );
488+
SCALAR_INPUT( DictionaryPtr, properties );
489+
SCALAR_INPUT( int64_t, chunk_size );
490+
TS_INPUT( Generic, data );
491+
492+
TS_OUTPUT( Generic );
493+
494+
using StructParquetOutputHandler = csp::adapters::parquet::StructParquetOutputHandler;
495+
using ParquetWriter = csp::adapters::parquet::ParquetWriter;
496+
class MyParquetWriter : public ParquetWriter
497+
{
498+
public:
499+
MyParquetWriter( int64_t chunk_size ): ParquetWriter(), m_chunkSize( chunk_size )
500+
{
501+
if( m_chunkSize <= 0 )
502+
{
503+
CSP_THROW( ValueError, "Chunk size should be >= 0" );
504+
}
505+
}
506+
std::uint32_t getChunkSize() const override{ return m_chunkSize; }
507+
private:
508+
int64_t m_chunkSize = 0;
509+
};
510+
511+
std::shared_ptr<StructParquetOutputHandler> m_handler;
512+
CspTypePtr m_cspType;
513+
std::shared_ptr<MyParquetWriter> m_writer;
514+
std::shared_ptr<arrow::Schema> m_schema;
515+
516+
INIT_CPPNODE( struct_to_record_batches )
517+
{
518+
auto & input_def = tsinputDef( "data" );
519+
if( input_def.type -> type() != CspType::Type::ARRAY )
520+
CSP_THROW( TypeError, "struct_to_record_batches expected ts array type, got " << input_def.type -> type() );
521+
522+
auto * aType = static_cast<const CspArrayType *>( input_def.type.get() );
523+
CspTypePtr elemType = aType -> elemType();
524+
if( elemType -> type() != CspType::Type::STRUCT )
525+
CSP_THROW( TypeError, "struct_to_record_batches expected ts array of structs type, got " << elemType -> type() );
526+
527+
auto & output_def = tsoutputDef( "" );
528+
if( output_def.type -> type() != CspType::Type::ARRAY )
529+
CSP_THROW( TypeError, "struct_to_record_batches expected ts array type, got " << output_def.type -> type() );
530+
}
531+
532+
START()
533+
{
534+
// Create Adapters for Schema
535+
auto field_map = properties.value() -> get<DictionaryPtr>( "field_map" );
536+
m_writer = std::make_shared<MyParquetWriter>( chunk_size.value() );
537+
m_cspType = std::make_shared<csp::CspStructType>( cls.value() );
538+
m_handler = std::make_shared<StructParquetOutputHandler>( engine(), *m_writer, m_cspType, field_map );
539+
std::vector<std::shared_ptr<arrow::Field>> arrowFields;
540+
for( unsigned i = 0; i < m_handler -> getNumColumns(); i++ )
541+
{
542+
arrowFields.push_back( arrow::field( m_handler -> getColumnArrayBuilder( i ) -> getColumnName(),
543+
m_handler -> getColumnArrayBuilder( i ) -> getDataType() ) );
544+
}
545+
m_schema = arrow::schema( arrowFields );
546+
}
547+
548+
DialectGenericType getData( std::shared_ptr<StructParquetOutputHandler> handler, int num_rows )
549+
{
550+
std::vector<std::shared_ptr<arrow::Array>> columns;
551+
columns.reserve( handler -> getNumColumns() );
552+
for( unsigned i = 0; i < handler -> getNumColumns(); i++ )
553+
{
554+
columns.push_back( handler -> getColumnArrayBuilder( i ) -> buildArray() );
555+
}
556+
auto rb_ptr = arrow::RecordBatch::Make( m_schema, num_rows, columns );
557+
const arrow::RecordBatch& rb = *rb_ptr;
558+
struct ArrowSchema* rb_schema = ( struct ArrowSchema* )malloc( sizeof( struct ArrowSchema ) );
559+
struct ArrowArray* rb_array = ( struct ArrowArray* )malloc( sizeof( struct ArrowArray ) );
560+
arrow::Status st = arrow::ExportRecordBatch( rb, rb_array, rb_schema );
561+
auto py_schema = csp::python::PyObjectPtr::own( PyCapsule_New( rb_schema, "arrow_schema", ReleaseArrowSchemaPyCapsule ) );
562+
auto py_array = csp::python::PyObjectPtr::own( PyCapsule_New( rb_array, "arrow_array", ReleaseArrowArrayPyCapsule ) );
563+
auto py_tuple = csp::python::PyObjectPtr::own( PyTuple_Pack( 2, py_schema.get(), py_array.get() ) );
564+
return csp::python::fromPython<DialectGenericType>( py_tuple.get() );
565+
}
566+
567+
INVOKE()
568+
{
569+
if( csp.ticked( data ) )
570+
{
571+
std::vector<DialectGenericType> & out = unnamed_output().reserveSpace<std::vector<DialectGenericType>>();
572+
out.clear();
573+
auto & structs = data.lastValue<std::vector<StructPtr>>();
574+
uint32_t cur_chunk_size = 0;
575+
for( auto& st: structs )
576+
{
577+
m_handler -> writeValueFromArgs( st );
578+
for( unsigned i = 0; i < m_handler -> getNumColumns(); i++ )
579+
{
580+
m_handler -> getColumnArrayBuilder( i ) -> handleRowFinished();
581+
}
582+
if( ++cur_chunk_size >= m_writer -> getChunkSize() )
583+
{
584+
out.emplace_back( getData( m_handler, cur_chunk_size ) );
585+
cur_chunk_size = 0;
586+
}
587+
}
588+
if( cur_chunk_size > 0)
589+
{
590+
out.emplace_back( getData( m_handler, cur_chunk_size ) );
591+
}
592+
}
593+
}
594+
};
595+
596+
EXPORT_CPPNODE( struct_to_record_batches );
597+
463598
}
464599

465600
// Base nodes
@@ -486,6 +621,7 @@ REGISTER_CPPNODE( csp::cppnodes, struct_collectts );
486621

487622
REGISTER_CPPNODE( csp::cppnodes, exprtk_impl );
488623
REGISTER_CPPNODE( csp::cppnodes, record_batches_to_struct );
624+
REGISTER_CPPNODE( csp::cppnodes, struct_to_record_batches );
489625

490626
static PyModuleDef _cspbaselibimpl_module = {
491627
PyModuleDef_HEAD_INIT,

0 commit comments

Comments
 (0)