Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 133 additions & 49 deletions cpp/csp/python/adapters/ArrowInputAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class RecordBatchIterator
{
public:
RecordBatchIterator() {}
RecordBatchIterator( PyObjectPtr iter, std::shared_ptr<::arrow::Schema> schema ): m_iter( std::move( iter ) ), m_schema( schema )

RecordBatchIterator( PyObjectPtr iter, std::shared_ptr<::arrow::Schema> schema ): m_iter( std::move( iter ) ), m_schema( std::move( schema ) )
{
}

Expand All @@ -36,7 +37,7 @@ class RecordBatchIterator

if( py_tuple.get() == NULL )
{
// No more data in the input steam
// No more data in the input stream
return nullptr;
}

Expand All @@ -47,6 +48,20 @@ class RecordBatchIterator
if( num_elems != 2 )
CSP_THROW( csp::TypeError, "Invalid arrow data, expected tuple (using the PyCapsule C interface) with 2 elements got " << num_elems );

// Extract schema from first batch if not provided upfront
if( !m_schema )
{
PyObject * py_schema = PyTuple_GetItem( py_tuple.get(), 0 );
if( !PyCapsule_IsValid( py_schema, "arrow_schema" ) )
CSP_THROW( csp::TypeError, "Invalid arrow data, expected schema capsule from the PyCapsule C interface" );

ArrowSchema * c_schema = reinterpret_cast<ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
auto schema_result = ::arrow::ImportSchema( c_schema );
if( !schema_result.ok() )
CSP_THROW( ValueError, "Failed to load schema through PyCapsule C Data interface: " << schema_result.status().ToString() );
m_schema = std::move( schema_result.ValueUnsafe() );
}

// Extract the record batch
PyObject * py_array = PyTuple_GetItem( py_tuple.get(), 1 );
if( !PyCapsule_IsValid( py_array, "arrow_array" ) )
Expand All @@ -60,20 +75,22 @@ class RecordBatchIterator
return result.ValueUnsafe();
}

std::shared_ptr<::arrow::Schema> schema() const { return m_schema; }

private:
PyObjectPtr m_iter;
std::shared_ptr<::arrow::Schema> m_schema;
};

void ReleaseArrowSchemaPyCapsule( PyObject * capsule )
inline void ReleaseArrowSchemaPyCapsule( PyObject * capsule )
{
ArrowSchema * schema = reinterpret_cast<ArrowSchema*>( PyCapsule_GetPointer( capsule, "arrow_schema" ) );
if ( schema -> release != NULL )
schema -> release( schema );
free( schema );
}

void ReleaseArrowArrayPyCapsule( PyObject * capsule )
inline void ReleaseArrowArrayPyCapsule( PyObject * capsule )
{
ArrowArray * array = reinterpret_cast<ArrowArray*>( PyCapsule_GetPointer( capsule, "arrow_array" ) );
if ( array -> release != NULL )
Expand All @@ -88,36 +105,33 @@ class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGeneri
: PullInputAdapter<std::vector<DialectGenericType>>( engine, type, PushMode::LAST_VALUE ),
m_tsColName( tsColName ),
m_expectSmallBatches( expectSmallBatches ),
m_finished( false )
m_finished( false ),
m_initialized( false ),
m_multiplier( 0 ),
m_arrayLastTime( 0 ),
m_numRows( 0 ),
m_startTime( 0 ),
m_endTime( 0 ),
m_curBatchIdx( 0 )
{
// Extract the arrow schema
ArrowSchema * c_schema = reinterpret_cast<ArrowSchema*>( PyCapsule_GetPointer( pySchema.get(), "arrow_schema" ) );
auto result = ::arrow::ImportSchema( c_schema );
if( !result.ok() )
CSP_THROW( ValueError, "Failed to load schema for record batches through the PyCapsule C Data interface: " << result.status().ToString() );
m_schema = std::move( result.ValueUnsafe() );
// Check if schema is provided or should be deferred to first batch
if( pySchema.get() != Py_None )
{
// Extract the arrow schema upfront
ArrowSchema * c_schema = reinterpret_cast<ArrowSchema*>( PyCapsule_GetPointer( pySchema.get(), "arrow_schema" ) );
auto result = ::arrow::ImportSchema( c_schema );
if( !result.ok() )
CSP_THROW( ValueError, "Failed to load schema for record batches through the PyCapsule C Data interface: " << result.status().ToString() );

auto tsField = m_schema -> GetFieldByName( m_tsColName );
auto timestampType = std::static_pointer_cast<::arrow::TimestampType>( tsField -> type() );
switch( timestampType -> unit() )
m_source = RecordBatchIterator( source, std::move( result.ValueUnsafe() ) );
initializeTimestampMultiplier();
m_initialized = true;
}
else
{
case ::arrow::TimeUnit::SECOND:
m_multiplier = csp::NANOS_PER_SECOND;
break;
case ::arrow::TimeUnit::MILLI:
m_multiplier = csp::NANOS_PER_MILLISECOND;
break;
case ::arrow::TimeUnit::MICRO:
m_multiplier = csp::NANOS_PER_MICROSECOND;
break;
case ::arrow::TimeUnit::NANO:
m_multiplier = 1;
break;
default:
CSP_THROW( ValueError, "Unsupported unit type for arrow timestamp column" );
// Schema will be extracted lazily from first batch
m_source = RecordBatchIterator( source, nullptr );
}

m_source = RecordBatchIterator( source, m_schema );
}

int64_t findFirstMatchingIndex()
Expand Down Expand Up @@ -174,6 +188,16 @@ class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGeneri
while( ( rb = m_source.next() ) && ( rb -> num_rows() == 0 ) ) {}
if( rb )
{
// Lazy schema initialization on first batch (when schema was not provided upfront)
if( !m_initialized ) [[unlikely]]
{
if( !m_source.schema() )
CSP_THROW( ValueError, "Failed to extract schema from first record batch" );
initializeTimestampMultiplier();
computeTimeRange();
m_initialized = true;
}

auto array = rb -> GetColumnByName( m_tsColName );
if( !array )
CSP_THROW( ValueError, "Failed to get timestamp column " << m_tsColName << " from record batch " << rb -> ToString() );
Expand All @@ -189,24 +213,10 @@ class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGeneri

void start( DateTime start, DateTime end ) override
{
// start and end as multiples of the unit in timestamp column
auto start_nanos = start.asNanoseconds();
m_startTime = ( start_nanos % m_multiplier == 0 ) ? start_nanos / m_multiplier : start_nanos / m_multiplier + 1;
m_endTime = end.asNanoseconds() / m_multiplier;
m_rawStartTime = start;
m_rawEndTime = end;

// Find the starting index where time >= start
while( !m_finished )
{
updateStateFromNextRecordBatch();
if( !m_curRecordBatch )
{
m_finished = true;
break;
}
m_curBatchIdx = findFirstMatchingIndex();
if( m_curBatchIdx < m_numRows )
break;
}
initializeStartPosition();
PullInputAdapter<std::vector<DialectGenericType>>::start( start, end );
}

Expand All @@ -215,6 +225,12 @@ class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGeneri
ArrowSchema* rb_schema = ( ArrowSchema* )malloc( sizeof( ArrowSchema ) );
ArrowArray* rb_array = ( ArrowArray* )malloc( sizeof( ArrowArray ) );
::arrow::Status st = ::arrow::ExportRecordBatch( *rb, rb_array, rb_schema );
if( !st.ok() )
{
free( rb_array );
free( rb_schema );
CSP_THROW( ValueError, "Failed to export record batch through C Data interface: " << st.ToString() );
}
auto py_schema = csp::python::PyObjectPtr::own( PyCapsule_New( rb_schema, "arrow_schema", ReleaseArrowSchemaPyCapsule ) );
auto py_array = csp::python::PyObjectPtr::own( PyCapsule_New( rb_array, "arrow_array", ReleaseArrowArrayPyCapsule ) );
auto py_tuple = csp::python::PyObjectPtr::own( PyTuple_Pack( 2, py_schema.get(), py_array.get() ) );
Expand Down Expand Up @@ -269,17 +285,85 @@ class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGeneri
}

private:
void initializeTimestampMultiplier()
{
auto schema = m_source.schema();
auto tsField = schema -> GetFieldByName( m_tsColName );
if( !tsField )
CSP_THROW( ValueError, "Timestamp column '" << m_tsColName << "' not found in schema" );

if( tsField -> type() -> id() != ::arrow::Type::TIMESTAMP )
CSP_THROW( ValueError, "Column '" << m_tsColName << "' is not a valid timestamp column" );

auto timestampType = std::static_pointer_cast<::arrow::TimestampType>( tsField -> type() );
switch( timestampType -> unit() )
{
case ::arrow::TimeUnit::SECOND:
m_multiplier = csp::NANOS_PER_SECOND;
break;
case ::arrow::TimeUnit::MILLI:
m_multiplier = csp::NANOS_PER_MILLISECOND;
break;
case ::arrow::TimeUnit::MICRO:
m_multiplier = csp::NANOS_PER_MICROSECOND;
break;
case ::arrow::TimeUnit::NANO:
m_multiplier = 1;
break;
default:
CSP_THROW( ValueError, "Unsupported unit type for arrow timestamp column" );
}
}

void computeTimeRange()
{
auto start_nanos = m_rawStartTime.asNanoseconds();
m_startTime = ( start_nanos % m_multiplier == 0 )
? start_nanos / m_multiplier
: start_nanos / m_multiplier + 1;
m_endTime = m_rawEndTime.asNanoseconds() / m_multiplier;
}

void initializeStartPosition()
{
// If schema was provided upfront, compute time range now
// Otherwise, time range will be computed lazily in updateStateFromNextRecordBatch()
if( m_initialized )
computeTimeRange();

// Find the starting index where time >= start
while( !m_finished )
{
updateStateFromNextRecordBatch();
if( !m_curRecordBatch )
{
m_finished = true;
break;
}
m_curBatchIdx = findFirstMatchingIndex();
if( m_curBatchIdx < m_numRows )
break;
m_curRecordBatch = nullptr;
}
}

std::string m_tsColName;
RecordBatchIterator m_source;

int m_expectSmallBatches;
bool m_finished;
std::shared_ptr<::arrow::Schema> m_schema;
bool m_initialized; // Whether schema and timestamp multiplier have been resolved
int64_t m_multiplier;
std::shared_ptr<::arrow::RecordBatch> m_curRecordBatch;
std::shared_ptr<::arrow::TimestampArray> m_tsArray;
::arrow::TimestampArray::IteratorType m_endIt;
int64_t m_arrayLastTime;
int64_t m_multiplier, m_numRows, m_startTime, m_endTime, m_curBatchIdx;
int64_t m_numRows;
int64_t m_startTime;
int64_t m_endTime;
int64_t m_curBatchIdx;
DateTime m_rawStartTime;
DateTime m_rawEndTime;
};

};
Expand Down
23 changes: 15 additions & 8 deletions csp/adapters/arrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple
from typing import Iterable, List, Optional, Tuple

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -52,25 +52,32 @@ def __arrow_c_array__(self, requested_schema=None):

@csp.graph
def RecordBatchPullInputAdapter(
ts_col_name: str, source: Iterable[pa.RecordBatch], schema: pa.Schema, expect_small_batches: bool = False
ts_col_name: str,
source: Iterable[pa.RecordBatch],
schema: Optional[pa.Schema] = None,
expect_small_batches: bool = False,
) -> csp.ts[[pa.RecordBatch]]:
"""Stream record batches from an iterator/generator into csp

Args:
ts_col_name: Name of the timestamp column containing timestamps in ascending order
source: Iterator/generator of record batches
schema: The schema of the record batches
schema: Schema of the record batches. If None, extracted from first batch at runtime.
expect_small_batches: Optional flag to optimize performance for scenarios where there are few rows (<10) per timestamp

NOTE: The ascending order of the timestamp column must be enforced by the caller
"""
# Safety checks
ts_col = schema.field(ts_col_name)
if not pa.types.is_timestamp(ts_col.type):
raise ValueError(f"{ts_col_name} is not a valid timestamp column in the schema")
# Validate only if schema provided upfront
if schema is not None:
ts_col = schema.field(ts_col_name)
if not pa.types.is_timestamp(ts_col.type):
raise ValueError(f"{ts_col_name} is not a valid timestamp column in the schema")
c_schema = schema.__arrow_c_schema__()
else:
c_schema = None # C++ will extract from first batch

c_source = map(lambda rb: rb.__arrow_c_array__(), source)
c_data = CRecordBatchPullInputAdapter(ts_col_name, c_source, schema.__arrow_c_schema__(), expect_small_batches)
c_data = CRecordBatchPullInputAdapter(ts_col_name, c_source, c_schema, expect_small_batches)
return csp.apply(
c_data, lambda c_tups: [pa.record_batch(_RecordBatchCSource(c_tup)) for c_tup in c_tups], List[pa.RecordBatch]
)
Expand Down
Loading
Loading