Skip to content

Commit 7ebce5c

Browse files
committed
Add node to convert record batches into csp.Structs
Signed-off-by: Arham Chopra <[email protected]>
1 parent f70df59 commit 7ebce5c

File tree

6 files changed

+238
-104
lines changed

6 files changed

+238
-104
lines changed

cpp/csp/adapters/parquet/ParquetReader.cpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,17 @@ void SingleTableParquetReader::setColumnAdaptersFromCurrentTable()
192192
columnAdapter = createColumnAdapter( *this, *field, getCurFileOrTableName(), &getStructColumnMeta() );
193193
auto &fieldInfo = fieldsInfo[ index ];
194194

195-
for( std::size_t i = 0; i < fieldInfo.m_width; ++i )
195+
if( isArrowIPC() )
196+
{
197+
// Needed for all memory tables
198+
m_neededColumnIndices.push_back( index );
199+
}
200+
else
196201
{
197-
m_neededColumnIndices.push_back( fieldInfo.m_startColumnIndex + i );
202+
for( std::size_t i = 0; i < fieldInfo.m_width; ++i )
203+
{
204+
m_neededColumnIndices.push_back( fieldInfo.m_startColumnIndex + i );
205+
}
198206
}
199207
}
200208
else
@@ -382,10 +390,13 @@ void SingleFileParquetReader::clear()
382390
}
383391

384392
InMemoryTableParquetReader::InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
385-
bool allowMissingColumns, std::optional<std::string> symbolColumnName )
393+
bool allowMissingColumns, std::optional<std::string> symbolColumnName, bool call_init )
386394
: SingleTableParquetReader( columns, true, allowMissingColumns, symbolColumnName ), m_generatorPtr( generatorPtr )
387395
{
388-
init();
396+
if( call_init )
397+
{
398+
init();
399+
}
389400
}
390401

391402
bool InMemoryTableParquetReader::openNextFile()

cpp/csp/adapters/parquet/ParquetReader.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -424,21 +424,28 @@ class SingleFileParquetReader final : public SingleTableParquetReader
424424
bool m_allowMissingFiles;
425425
};
426426

427-
class InMemoryTableParquetReader final : public SingleTableParquetReader
427+
class InMemoryTableParquetReader : public SingleTableParquetReader
428428
{
429429
public:
430430
using GeneratorPtr = csp::Generator<std::shared_ptr<arrow::Table>, csp::DateTime, csp::DateTime>::Ptr;
431431

432432
InMemoryTableParquetReader( GeneratorPtr generatorPtr, std::vector<std::string> columns,
433433
bool allowMissingColumns,
434-
std::optional<std::string> symbolColumnName = {} );
434+
std::optional<std::string> symbolColumnName = {},
435+
bool call_init = true);
435436
std::string getCurFileOrTableName() const override{ return "IN_MEMORY_TABLE"; }
436437

437438
protected:
438-
bool openNextFile() override;
439-
bool readNextRowGroup() override;
439+
virtual bool openNextFile() override;
440+
virtual bool readNextRowGroup() override;
441+
void setTable( std::shared_ptr<arrow::Table> table )
442+
{
443+
m_fullTable = table;
444+
m_nextChunkIndex = 0;
445+
m_curTable = nullptr;
446+
}
440447

441-
void clear() override;
448+
virtual void clear() override;
442449

443450
private:
444451
GeneratorPtr m_generatorPtr;

cpp/csp/python/CMakeLists.txt

+27-1
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,35 @@ target_compile_definitions(cspimpl PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERS
9090
target_compile_definitions(cspimpl PRIVATE CSPIMPL_EXPORTS=1)
9191

9292

93+
find_package(Arrow REQUIRED)
94+
find_package(Parquet REQUIRED)
95+
96+
if(WIN32)
97+
if(CSP_USE_VCPKG)
98+
set(ARROW_PACKAGES_TO_LINK Arrow::arrow_static Parquet::parquet_static )
99+
target_compile_definitions(csp_parquet_adapter PUBLIC ARROW_STATIC)
100+
target_compile_definitions(csp_parquet_adapter PUBLIC PARQUET_STATIC)
101+
else()
102+
# use dynamic variants
103+
# Until we manage to get the fix for ws3_32.dll in arrow-16 into conda, manually fix the error here
104+
get_target_property(LINK_LIBS Arrow::arrow_shared INTERFACE_LINK_LIBRARIES)
105+
string(REPLACE "ws2_32.dll" "ws2_32" FIXED_LINK_LIBS "${LINK_LIBS}")
106+
set_target_properties(Arrow::arrow_shared PROPERTIES INTERFACE_LINK_LIBRARIES "${FIXED_LINK_LIBS}")
107+
set(ARROW_PACKAGES_TO_LINK parquet_shared arrow_shared)
108+
endif()
109+
else()
110+
if(CSP_USE_VCPKG)
111+
# use static variants
112+
set(ARROW_PACKAGES_TO_LINK parquet_static arrow_static)
113+
else()
114+
# use dynamic variants
115+
set(ARROW_PACKAGES_TO_LINK parquet arrow)
116+
endif()
117+
endif()
118+
93119
## Baselib c++ module
94120
add_library(cspbaselibimpl SHARED cspbaselibimpl.cpp)
95-
target_link_libraries(cspbaselibimpl cspimpl baselibimpl)
121+
target_link_libraries(cspbaselibimpl cspimpl baselibimpl csp_parquet_adapter ${ARROW_PACKAGES_TO_LINK})
96122

97123
# Include exprtk include directory for exprtk node
98124
target_include_directories(cspbaselibimpl PRIVATE ${EXPRTK_INCLUDE_DIRS})

cpp/csp/python/cspbaselibimpl.cpp

+140
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
#include <exprtk.hpp>
66
#include <numpy/ndarrayobject.h>
77

8+
#include <arrow/type.h>
9+
#include <arrow/table.h>
10+
#include <arrow/c/abi.h>
11+
#include <arrow/c/bridge.h>
12+
13+
#include <csp/adapters/parquet/ParquetReader.h>
14+
#include <csp/adapters/utils/StructAdapterInfo.h>
15+
#include <csp/adapters/utils/ValueDispatcher.h>
16+
817
static void * init_nparray()
918
{
1019
csp::python::AcquireGIL gil;
@@ -325,6 +334,136 @@ DECLARE_CPPNODE( exprtk_impl )
325334

326335
EXPORT_CPPNODE( exprtk_impl );
327336

337+
DECLARE_CPPNODE( record_batches_to_struct )
338+
{
339+
using InMemoryTableParquetReader = csp::adapters::parquet::InMemoryTableParquetReader;
340+
class MyTableReader : public InMemoryTableParquetReader
341+
{
342+
public:
343+
MyTableReader( std::vector<std::string> columns, std::shared_ptr<arrow::Schema> schema ):
344+
InMemoryTableParquetReader( nullptr, columns, false, {}, false )
345+
{
346+
m_schema = schema;
347+
}
348+
std::string getCurFileOrTableName() const override{ return "IN_RECORD_BATCH"; }
349+
void initialize() { setColumnAdaptersFromCurrentTable(); }
350+
void parseBatches( std::vector<std::shared_ptr<arrow::RecordBatch>> record_batches )
351+
{
352+
// TODO: Check if the schema has not changed
353+
auto table_result = arrow::Table::FromRecordBatches(record_batches);
354+
if( !table_result.ok() )
355+
CSP_THROW( NotImplemented, "Unable to make table from record batches" );
356+
357+
setTable( table_result.ValueUnsafe() );
358+
359+
if( !readNextRowGroup() )
360+
CSP_THROW( NotImplemented, "Unable to read row group from table" );
361+
362+
while( readNextRow() )
363+
{
364+
for( auto& adapter: getStructAdapters() )
365+
{
366+
adapter -> dispatchValue( nullptr );
367+
}
368+
}
369+
}
370+
void stop()
371+
{
372+
InMemoryTableParquetReader::clear();
373+
}
374+
protected:
375+
bool openNextFile() override { return false; }
376+
void clear() override { setTable( nullptr ); }
377+
};
378+
379+
SCALAR_INPUT( DialectGenericType, schema_ptr );
380+
SCALAR_INPUT( StructMetaPtr, cls );
381+
SCALAR_INPUT( DictionaryPtr, properties );
382+
TS_INPUT( Generic, data );
383+
384+
TS_OUTPUT( Generic );
385+
386+
std::shared_ptr<MyTableReader> reader;
387+
CspTypePtr outType;
388+
std::vector<StructPtr>* m_structsVecPtr;
389+
390+
using StructAdapterInfo = csp::adapters::utils::StructAdapterInfo;
391+
using ValueDispatcher = csp::adapters::utils::ValueDispatcher<StructPtr &>;
392+
393+
INIT_CPPNODE( record_batches_to_struct )
394+
{
395+
auto & input_def = tsinputDef( "data" );
396+
if( input_def.type -> type() != CspType::Type::ARRAY )
397+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array type, got " << input_def.type -> type() );
398+
399+
auto * aType = static_cast<const CspArrayType *>( input_def.type.get() );
400+
CspTypePtr elemType = aType -> elemType();
401+
if( elemType -> type() != CspType::Type::DIALECT_GENERIC )
402+
CSP_THROW( TypeError, "record_batches_to_struct expected ts array of DIALECT_GENERIC type, got " << elemType -> type() );
403+
404+
auto & output_def = tsoutputDef( "" );
405+
if( output_def.type -> type() != CspType::Type::ARRAY )
406+
CSP_THROW( NotImplemented, "record_batches_to_struct expected ts array type, got " << output_def.type -> type() );
407+
}
408+
409+
START()
410+
{
411+
// Create Adapters for Schema
412+
PyObject* capsule = csp::python::toPythonBorrowed(schema_ptr);
413+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer(capsule, "arrow_schema") );
414+
auto result = arrow::ImportSchema(c_schema);
415+
if( !result.ok() )
416+
CSP_THROW( NotImplemented, "Unable to import schema" );
417+
std::shared_ptr<arrow::Schema> schema = result.ValueUnsafe();
418+
std::vector<std::string> columns;
419+
auto field_map = properties.value() -> get<DictionaryPtr>( "field_map" );
420+
for( auto it = field_map -> begin(); it != field_map -> end(); ++it )
421+
{
422+
// TODO: Check if the column exists in the table
423+
columns.push_back(it.key());
424+
}
425+
reader = std::make_shared<MyTableReader>( columns, schema );
426+
reader -> initialize();
427+
428+
outType = std::make_shared<csp::CspStructType>( cls.value() );
429+
StructAdapterInfo key{ outType, field_map };
430+
auto& struct_adapter = reader -> getStructAdapter( key );
431+
struct_adapter.addSubscriber( [this]( StructPtr * s )
432+
{
433+
if( s ) this -> m_structsVecPtr -> push_back( *s );
434+
else CSP_THROW( NotImplemented, "StructPtr was null" );
435+
}, {} );
436+
}
437+
438+
INVOKE()
439+
{
440+
if( csp.ticked( data ) )
441+
{
442+
auto & py_batches = data.lastValue<std::vector<DialectGenericType>>();
443+
std::vector<std::shared_ptr<arrow::RecordBatch>> batches;
444+
for( auto& py_batch: py_batches )
445+
{
446+
PyObject* py_tuple = csp::python::toPythonBorrowed( py_batch );
447+
PyObject* py_schema = PyTuple_GET_ITEM( py_tuple, 0 );
448+
PyObject* py_array = PyTuple_GET_ITEM( py_tuple, 1 );
449+
struct ArrowSchema* c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
450+
struct ArrowArray* c_array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( py_array, "arrow_array" ) );
451+
auto result = arrow::ImportRecordBatch(c_array, c_schema);
452+
if( !result.ok() )
453+
CSP_THROW( NotImplemented, "Unable to import record batch from c interface" );
454+
batches.emplace_back(result.ValueUnsafe());
455+
}
456+
std::vector<StructPtr> & out = unnamed_output().reserveSpace<std::vector<StructPtr>>();
457+
out.clear();
458+
m_structsVecPtr = &out;
459+
reader -> parseBatches( batches );
460+
m_structsVecPtr = nullptr;
461+
}
462+
}
463+
};
464+
465+
EXPORT_CPPNODE( record_batches_to_struct );
466+
328467
}
329468

330469
// Base nodes
@@ -350,6 +489,7 @@ REGISTER_CPPNODE( csp::cppnodes, struct_fromts );
350489
REGISTER_CPPNODE( csp::cppnodes, struct_collectts );
351490

352491
REGISTER_CPPNODE( csp::cppnodes, exprtk_impl );
492+
REGISTER_CPPNODE( csp::cppnodes, record_batches_to_struct );
353493

354494
static PyModuleDef _cspbaselibimpl_module = {
355495
PyModuleDef_HEAD_INIT,

0 commit comments

Comments
 (0)