diff --git a/cpp/csp/python/adapters/ArrowInputAdapter.h b/cpp/csp/python/adapters/ArrowInputAdapter.h index 9d3c8f059..f9a1b22f1 100644 --- a/cpp/csp/python/adapters/ArrowInputAdapter.h +++ b/cpp/csp/python/adapters/ArrowInputAdapter.h @@ -24,7 +24,8 @@ class RecordBatchIterator { public: RecordBatchIterator() {} - RecordBatchIterator( PyObjectPtr iter, std::shared_ptr<::arrow::Schema> schema ): m_iter( std::move( iter ) ), m_schema( schema ) + + RecordBatchIterator( PyObjectPtr iter, std::shared_ptr<::arrow::Schema> schema ): m_iter( std::move( iter ) ), m_schema( std::move( schema ) ) { } @@ -36,7 +37,7 @@ class RecordBatchIterator if( py_tuple.get() == NULL ) { - // No more data in the input steam + // No more data in the input stream return nullptr; } @@ -47,6 +48,20 @@ class RecordBatchIterator if( num_elems != 2 ) CSP_THROW( csp::TypeError, "Invalid arrow data, expected tuple (using the PyCapsule C interface) with 2 elements got " << num_elems ); + // Extract schema from first batch if not provided upfront + if( !m_schema ) + { + PyObject * py_schema = PyTuple_GetItem( py_tuple.get(), 0 ); + if( !PyCapsule_IsValid( py_schema, "arrow_schema" ) ) + CSP_THROW( csp::TypeError, "Invalid arrow data, expected schema capsule from the PyCapsule C interface" ); + + ArrowSchema * c_schema = reinterpret_cast( PyCapsule_GetPointer( py_schema, "arrow_schema" ) ); + auto schema_result = ::arrow::ImportSchema( c_schema ); + if( !schema_result.ok() ) + CSP_THROW( ValueError, "Failed to load schema through PyCapsule C Data interface: " << schema_result.status().ToString() ); + m_schema = std::move( schema_result.ValueUnsafe() ); + } + // Extract the record batch PyObject * py_array = PyTuple_GetItem( py_tuple.get(), 1 ); if( !PyCapsule_IsValid( py_array, "arrow_array" ) ) @@ -60,12 +75,14 @@ class RecordBatchIterator return result.ValueUnsafe(); } + std::shared_ptr<::arrow::Schema> schema() const { return m_schema; } + private: PyObjectPtr m_iter; std::shared_ptr<::arrow::Schema> m_schema; }; -void ReleaseArrowSchemaPyCapsule( PyObject * capsule ) +inline void ReleaseArrowSchemaPyCapsule( PyObject * capsule ) { ArrowSchema * schema = reinterpret_cast( PyCapsule_GetPointer( capsule, "arrow_schema" ) ); if ( schema -> release != NULL ) @@ -73,7 +90,7 @@ void ReleaseArrowSchemaPyCapsule( PyObject * capsule ) free( schema ); } -void ReleaseArrowArrayPyCapsule( PyObject * capsule ) +inline void ReleaseArrowArrayPyCapsule( PyObject * capsule ) { ArrowArray * array = reinterpret_cast( PyCapsule_GetPointer( capsule, "arrow_array" ) ); if ( array -> release != NULL ) @@ -88,36 +105,33 @@ class RecordBatchInputAdapter: public PullInputAdapter>( engine, type, PushMode::LAST_VALUE ), m_tsColName( tsColName ), m_expectSmallBatches( expectSmallBatches ), - m_finished( false ) + m_finished( false ), + m_initialized( false ), + m_multiplier( 0 ), + m_arrayLastTime( 0 ), + m_numRows( 0 ), + m_startTime( 0 ), + m_endTime( 0 ), + m_curBatchIdx( 0 ) { - // Extract the arrow schema - ArrowSchema * c_schema = reinterpret_cast( PyCapsule_GetPointer( pySchema.get(), "arrow_schema" ) ); - auto result = ::arrow::ImportSchema( c_schema ); - if( !result.ok() ) - CSP_THROW( ValueError, "Failed to load schema for record batches through the PyCapsule C Data interface: " << result.status().ToString() ); - m_schema = std::move( result.ValueUnsafe() ); + // Check if schema is provided or should be deferred to first batch + if( pySchema.get() != Py_None ) + { + // Extract the arrow schema upfront + ArrowSchema * c_schema = reinterpret_cast( PyCapsule_GetPointer( pySchema.get(), "arrow_schema" ) ); + auto result = ::arrow::ImportSchema( c_schema ); + if( !result.ok() ) + CSP_THROW( ValueError, "Failed to load schema for record batches through the PyCapsule C Data interface: " << result.status().ToString() ); - auto tsField = m_schema -> GetFieldByName( m_tsColName ); - auto timestampType = std::static_pointer_cast<::arrow::TimestampType>( tsField -> type() ); - switch( timestampType -> unit() ) + m_source = RecordBatchIterator( source, std::move( result.ValueUnsafe() ) ); + initializeTimestampMultiplier(); + m_initialized = true; + } + else { - case ::arrow::TimeUnit::SECOND: - m_multiplier = csp::NANOS_PER_SECOND; - break; - case ::arrow::TimeUnit::MILLI: - m_multiplier = csp::NANOS_PER_MILLISECOND; - break; - case ::arrow::TimeUnit::MICRO: - m_multiplier = csp::NANOS_PER_MICROSECOND; - break; - case ::arrow::TimeUnit::NANO: - m_multiplier = 1; - break; - default: - CSP_THROW( ValueError, "Unsupported unit type for arrow timestamp column" ); + // Schema will be extracted lazily from first batch + m_source = RecordBatchIterator( source, nullptr ); } - - m_source = RecordBatchIterator( source, m_schema ); } int64_t findFirstMatchingIndex() @@ -174,6 +188,16 @@ class RecordBatchInputAdapter: public PullInputAdapter num_rows() == 0 ) ) {} if( rb ) { + // Lazy schema initialization on first batch (when schema was not provided upfront) + if( !m_initialized ) [[unlikely]] + { + if( !m_source.schema() ) + CSP_THROW( ValueError, "Failed to extract schema from first record batch" ); + initializeTimestampMultiplier(); + computeTimeRange(); + m_initialized = true; + } + auto array = rb -> GetColumnByName( m_tsColName ); if( !array ) CSP_THROW( ValueError, "Failed to get timestamp column " << m_tsColName << " from record batch " << rb -> ToString() ); @@ -189,24 +213,10 @@ class RecordBatchInputAdapter: public PullInputAdapter= start - while( !m_finished ) - { - updateStateFromNextRecordBatch(); - if( !m_curRecordBatch ) - { - m_finished = true; - break; - } - m_curBatchIdx = findFirstMatchingIndex(); - if( m_curBatchIdx < m_numRows ) - break; - } + initializeStartPosition(); PullInputAdapter>::start( start, end ); } @@ -215,6 +225,12 @@ class RecordBatchInputAdapter: public PullInputAdapter GetFieldByName( m_tsColName ); + if( !tsField ) + CSP_THROW( ValueError, "Timestamp column '" << m_tsColName << "' not found in schema" ); + + if( tsField -> type() -> id() != ::arrow::Type::TIMESTAMP ) + CSP_THROW( ValueError, "Column '" << m_tsColName << "' is not a valid timestamp column" ); + + auto timestampType = std::static_pointer_cast<::arrow::TimestampType>( tsField -> type() ); + switch( timestampType -> unit() ) + { + case ::arrow::TimeUnit::SECOND: + m_multiplier = csp::NANOS_PER_SECOND; + break; + case ::arrow::TimeUnit::MILLI: + m_multiplier = csp::NANOS_PER_MILLISECOND; + break; + case ::arrow::TimeUnit::MICRO: + m_multiplier = csp::NANOS_PER_MICROSECOND; + break; + case ::arrow::TimeUnit::NANO: + m_multiplier = 1; + break; + default: + CSP_THROW( ValueError, "Unsupported unit type for arrow timestamp column" ); + } + } + + void computeTimeRange() + { + auto start_nanos = m_rawStartTime.asNanoseconds(); + m_startTime = ( start_nanos % m_multiplier == 0 ) + ? start_nanos / m_multiplier + : start_nanos / m_multiplier + 1; + m_endTime = m_rawEndTime.asNanoseconds() / m_multiplier; + } + + void initializeStartPosition() + { + // If schema was provided upfront, compute time range now + // Otherwise, time range will be computed lazily in updateStateFromNextRecordBatch() + if( m_initialized ) + computeTimeRange(); + + // Find the starting index where time >= start + while( !m_finished ) + { + updateStateFromNextRecordBatch(); + if( !m_curRecordBatch ) + { + m_finished = true; + break; + } + m_curBatchIdx = findFirstMatchingIndex(); + if( m_curBatchIdx < m_numRows ) + break; + m_curRecordBatch = nullptr; + } + } + std::string m_tsColName; RecordBatchIterator m_source; int m_expectSmallBatches; bool m_finished; - std::shared_ptr<::arrow::Schema> m_schema; + bool m_initialized; // Whether schema and timestamp multiplier have been resolved + int64_t m_multiplier; std::shared_ptr<::arrow::RecordBatch> m_curRecordBatch; std::shared_ptr<::arrow::TimestampArray> m_tsArray; ::arrow::TimestampArray::IteratorType m_endIt; int64_t m_arrayLastTime; - int64_t m_multiplier, m_numRows, m_startTime, m_endTime, m_curBatchIdx; + int64_t m_numRows; + int64_t m_startTime; + int64_t m_endTime; + int64_t m_curBatchIdx; + DateTime m_rawStartTime; + DateTime m_rawEndTime; }; }; diff --git a/csp/adapters/arrow.py b/csp/adapters/arrow.py index e01dd7106..e9e7f63f3 100644 --- a/csp/adapters/arrow.py +++ b/csp/adapters/arrow.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Tuple +from typing import Iterable, List, Optional, Tuple import pyarrow as pa import pyarrow.parquet as pq @@ -52,25 +52,32 @@ def __arrow_c_array__(self, requested_schema=None): @csp.graph def RecordBatchPullInputAdapter( - ts_col_name: str, source: Iterable[pa.RecordBatch], schema: pa.Schema, expect_small_batches: bool = False + ts_col_name: str, + source: Iterable[pa.RecordBatch], + schema: Optional[pa.Schema] = None, + expect_small_batches: bool = False, ) -> csp.ts[[pa.RecordBatch]]: """Stream record batches from an iterator/generator into csp Args: ts_col_name: Name of the timestamp column containing timestamps in ascending order source: Iterator/generator of record batches - schema: The schema of the record batches + schema: Schema of the record batches. If None, extracted from first batch at runtime. expect_small_batches: Optional flag to optimize performance for scenarios where there are few rows (<10) per timestamp NOTE: The ascending order of the timestamp column must be enforced by the caller """ - # Safety checks - ts_col = schema.field(ts_col_name) - if not pa.types.is_timestamp(ts_col.type): - raise ValueError(f"{ts_col_name} is not a valid timestamp column in the schema") + # Validate only if schema provided upfront + if schema is not None: + ts_col = schema.field(ts_col_name) + if not pa.types.is_timestamp(ts_col.type): + raise ValueError(f"{ts_col_name} is not a valid timestamp column in the schema") + c_schema = schema.__arrow_c_schema__() + else: + c_schema = None # C++ will extract from first batch c_source = map(lambda rb: rb.__arrow_c_array__(), source) - c_data = CRecordBatchPullInputAdapter(ts_col_name, c_source, schema.__arrow_c_schema__(), expect_small_batches) + c_data = CRecordBatchPullInputAdapter(ts_col_name, c_source, c_schema, expect_small_batches) return csp.apply( c_data, lambda c_tups: [pa.record_batch(_RecordBatchCSource(c_tup)) for c_tup in c_tups], List[pa.RecordBatch] ) diff --git a/csp/tests/adapters/test_arrow.py b/csp/tests/adapters/test_arrow.py index 8f59846ea..e146033da 100644 --- a/csp/tests/adapters/test_arrow.py +++ b/csp/tests/adapters/test_arrow.py @@ -11,9 +11,38 @@ _STARTTIME = datetime(2020, 1, 1, 9, 0, 0) +def _make_record_batch(ts_col_name: str, row_size: int, ts: datetime) -> pa.RecordBatch: + data = { + ts_col_name: pa.array([ts] * row_size, type=pa.timestamp("ms")), + "name": pa.array([chr(ord("A") + idx % 26) for idx in range(row_size)]), + } + schema = pa.schema([(ts_col_name, pa.timestamp("ms")), ("name", pa.string())]) + return pa.RecordBatch.from_pydict(data, schema=schema) + + +def _make_data(ts_col_name: str, row_sizes: list[int], start: datetime = _STARTTIME, interval: int = 1): + res = [ + _make_record_batch(ts_col_name, row_size, start + timedelta(seconds=interval * idx)) + for idx, row_size in enumerate(row_sizes) + ] + return res, start, start + timedelta(seconds=interval * (len(row_sizes) - 1)) + + +def _make_data_with_schema(ts_col_name: str, row_sizes: list[int], start: datetime = _STARTTIME, interval: int = 1): + res, dt_start, dt_end = _make_data(ts_col_name, row_sizes, start, interval) + return res[0].schema, res, dt_start, dt_end + + @csp.graph def G(ts_col_name: str, schema: pa.Schema, batches: object, expect_small: bool): - data = RecordBatchPullInputAdapter(ts_col_name, batches, schema, expect_small) + data = RecordBatchPullInputAdapter(ts_col_name, batches, schema, expect_small_batches=expect_small) + csp.add_graph_output("data", data) + + +@csp.graph +def G_lazy_schema(ts_col_name: str, batches: object, expect_small: bool): + """Graph that passes schema=None for lazy schema extraction.""" + data = RecordBatchPullInputAdapter(ts_col_name, batches, schema=None, expect_small_batches=expect_small) csp.add_graph_output("data", data) @@ -31,30 +60,15 @@ def _concat_batches(batches: list[pa.RecordBatch]) -> pa.RecordBatch: class TestArrow: - def make_record_batch(self, ts_col_name: str, row_size: int, ts: datetime) -> pa.RecordBatch: - data = { - ts_col_name: pa.array([ts] * row_size, type=pa.timestamp("ms")), - "name": pa.array([chr(ord("A") + idx % 26) for idx in range(row_size)]), - } - schema = pa.schema([(ts_col_name, pa.timestamp("ms")), ("name", pa.string())]) - return pa.RecordBatch.from_pydict(data, schema=schema) - - def make_data(self, ts_col_name: str, row_sizes: [int], start: datetime = _STARTTIME, interval: int = 1): - res = [ - self.make_record_batch(ts_col_name, row_size, start + timedelta(seconds=interval * idx)) - for idx, row_size in enumerate(row_sizes) - ] - return res[0].schema, res, start, start + timedelta(seconds=interval * (len(row_sizes) - 1)) - @pytest.mark.parametrize("small_batches", (True, False)) def test_bad_ts_col_name(self, small_batches: bool): - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[1]) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[1]) with pytest.raises(KeyError): results = csp.run(G, "NotTsCol", schema, rbs, small_batches, starttime=_STARTTIME) @pytest.mark.parametrize("small_batches", (True, False)) def test_bad_ts_col_type(self, small_batches: bool): - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[1]) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[1]) with pytest.raises(ValueError): results = csp.run(G, "name", schema, rbs, small_batches, starttime=_STARTTIME) @@ -66,19 +80,19 @@ def test_bad_source(self, small_batches: bool): @pytest.mark.parametrize("small_batches", (True, False)) def test_empty_rb(self, small_batches: bool): - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[0] * 1) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[0] * 1) results = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=_STARTTIME) assert len(results["data"]) == 0 - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[0] * 3) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[0] * 3) results = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=_STARTTIME) assert len(results["data"]) == 0 - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[0] * 4) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[0] * 4) results = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=_STARTTIME) assert len(results["data"]) == 0 - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[0] * 1024) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[0] * 1024) results = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=_STARTTIME) assert len(results["data"]) == 0 @@ -86,7 +100,7 @@ def test_empty_rb(self, small_batches: bool): @pytest.mark.parametrize("row_sizes", ([10], [100, 10], [100, 10, 1, 0, 0, 1, 2, 3, 4])) @pytest.mark.parametrize("delta", (timedelta(microseconds=1), timedelta(seconds=1), timedelta(days=1))) def test_start_not_found(self, small_batches: bool, row_sizes: [int], delta: timedelta): - schema, rbs, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=[10]) + schema, rbs, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[10]) results = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=dt_start + delta) assert len(results["data"]) == 0 @@ -96,8 +110,8 @@ def test_start_not_found(self, small_batches: bool, row_sizes: [int], delta: tim @pytest.mark.parametrize("delta", (timedelta(microseconds=1), timedelta(seconds=1), timedelta(days=1))) def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev: [int], delta: timedelta): clean_row_sizes = [r for r in row_sizes if r != 0] - schema, rbs_prev, _, old_dt_end = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) - schema, rbs, dt_start, dt_end = self.make_data( + schema, rbs_prev, _, old_dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) + schema, rbs, dt_start, dt_end = _make_data_with_schema( ts_col_name="TsCol", row_sizes=row_sizes, start=old_dt_end + timedelta(days=10) ) clean_rbs = [rb for rb in rbs if len(rb) != 0] @@ -117,12 +131,14 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev @pytest.mark.parametrize("repeat", (1, 10, 100)) @pytest.mark.parametrize("dt_count", (1, 5)) def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_count: int): - schema, _, dt_start, dt_end = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + schema, _, dt_start, dt_end = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) rbs_indivs = [[]] * dt_count rbs_full = [] for idx in range(dt_count): _data = [ - self.make_data(ts_col_name="TsCol", row_sizes=row_sizes, start=dt_start + timedelta(seconds=idx))[1] + _make_data_with_schema( + ts_col_name="TsCol", row_sizes=row_sizes, start=dt_start + timedelta(seconds=idx) + )[1] for i in range(repeat) ] rbs_indivs[idx] = [item for sublist in _data for item in sublist] @@ -141,7 +157,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun @pytest.mark.parametrize("small_batches", (True, False)) @pytest.mark.parametrize("row_sizes", ([10, 0, 0, 1], [0, 1, 0, 10])) def test_end_time_early(self, small_batches: bool, row_sizes: [int]): - schema, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + schema, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) results = csp.run( G, "TsCol", @@ -161,7 +177,7 @@ def test_different_size_rbs(self, small_batches: bool, seed: int): random.seed(seed) row_sizes = [random.randint(0, 100) for i in range(10000)] clean_row_sizes = [r for r in row_sizes if r != 0] - schema, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + schema, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) clean_rbs = [rb for rb in rbs if len(rb) != 0] results = csp.run( G, @@ -179,7 +195,7 @@ def test_different_size_rbs(self, small_batches: bool, seed: int): @pytest.mark.parametrize("row_sizes", ([1], [10], [1, 2, 3, 4, 5])) @pytest.mark.parametrize("batch_size", (1, 5, 10)) def test_write_record_batches(self, row_sizes: [int], concat: bool, batch_size: int): - _, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + _, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) if not concat: rbs_ts = [[rb] for rb in rbs] else: @@ -194,7 +210,7 @@ def test_write_record_batches(self, row_sizes: [int], concat: bool, batch_size: @pytest.mark.parametrize("concat", (False, True)) @pytest.mark.parametrize("row_sizes", ([1], [10], [1, 2, 3, 4, 5])) def test_write_record_batches_concat(self, row_sizes: [int], concat: bool): - _, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + _, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) if not concat: rbs_ts = [[rb] for rb in rbs] else: @@ -213,7 +229,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool): def test_write_record_batches_batch_sizes(self): row_sizes = [10] * 10 - _, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + _, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) rbs_ts = [rbs] with tempfile.NamedTemporaryFile(prefix="csp_unit_tests", mode="w") as temp_file: temp_file.close() @@ -225,7 +241,7 @@ def test_write_record_batches_batch_sizes(self): assert rbs_ts_expected == res.to_batches() row_sizes = [10] * 10 - _, rbs, _, _ = self.make_data(ts_col_name="TsCol", row_sizes=row_sizes) + _, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=row_sizes) rbs_ts = [rbs] with tempfile.NamedTemporaryFile(prefix="csp_unit_tests", mode="w") as temp_file: temp_file.close() @@ -235,3 +251,157 @@ def test_write_record_batches_batch_sizes(self): assert res.equals(orig) rbs_ts_expected = [_concat_batches(rbs[3 * i : 3 * i + 3]) for i in range(4)] assert rbs_ts_expected == res.to_batches() + + +class TestArrowLazySchema: + """Tests for lazy schema initialization (schema=None). + + These tests verify that RecordBatchPullInputAdapter correctly extracts + the schema from the first record batch when schema=None is passed. + """ + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_basic(self, small_batches: bool): + """Test basic lazy schema extraction from first batch.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[5]) + results = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + assert len(results["data"]) == 1 + assert len(results["data"][0][1][0]) == 5 + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_multiple_batches(self, small_batches: bool): + """Test lazy schema with multiple record batches.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[5, 10, 3]) + results = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + assert len(results["data"]) == 3 + assert [len(r[1][0]) for r in results["data"]] == [5, 10, 3] + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_empty_batches_before_data(self, small_batches: bool): + """Test lazy schema extraction skips empty batches to find first non-empty.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[0, 0, 5, 10]) + results = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + # Should get 2 results (the non-empty batches with rows 5 and 10) + assert len(results["data"]) == 2 + assert [len(r[1][0]) for r in results["data"]] == [5, 10] + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_all_empty(self, small_batches: bool): + """Test lazy schema with all empty batches.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[0, 0, 0]) + results = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + assert len(results["data"]) == 0 + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_bad_ts_col_name(self, small_batches: bool): + """Test error handling for wrong timestamp column with lazy schema.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[5]) + with pytest.raises(ValueError): + csp.run(G_lazy_schema, "NotTsCol", rbs, small_batches, starttime=_STARTTIME) + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_bad_ts_col_type(self, small_batches: bool): + """Test error handling for non-timestamp column with lazy schema.""" + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=[5]) + with pytest.raises(ValueError): + csp.run(G_lazy_schema, "name", rbs, small_batches, starttime=_STARTTIME) + + @pytest.mark.parametrize("small_batches", (True, False)) + @pytest.mark.parametrize("seed", (1, 42, 100)) + def test_lazy_schema_random_sizes(self, small_batches: bool, seed: int): + """Test lazy schema with random sized batches.""" + import random + + random.seed(seed) + row_sizes = [random.randint(0, 50) for i in range(100)] + clean_row_sizes = [r for r in row_sizes if r != 0] + rbs, _, _ = _make_data(ts_col_name="TsCol", row_sizes=row_sizes) + clean_rbs = [rb for rb in rbs if len(rb) != 0] + results = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + assert len(results["data"]) == len(clean_row_sizes) + assert [len(r[1][0]) for r in results["data"]] == clean_row_sizes + assert [r[1][0] for r in results["data"]] == clean_rbs + + @pytest.mark.parametrize("small_batches", (True, False)) + def test_lazy_schema_matches_explicit_schema(self, small_batches: bool): + """Test that lazy schema produces same results as explicit schema.""" + schema, rbs, _, _ = _make_data_with_schema(ts_col_name="TsCol", row_sizes=[5, 0, 10, 3]) + + # Run with explicit schema + results_explicit = csp.run(G, "TsCol", schema, rbs, small_batches, starttime=_STARTTIME) + + # Run with lazy schema + results_lazy = csp.run(G_lazy_schema, "TsCol", rbs, small_batches, starttime=_STARTTIME) + + # Results should match + assert len(results_explicit["data"]) == len(results_lazy["data"]) + for explicit, lazy in zip(results_explicit["data"], results_lazy["data"]): + assert explicit[0] == lazy[0] # timestamps match + assert explicit[1] == lazy[1] # data matches + + +class TestDeferredIterator: + """Tests for deferred iterator behavior. + + These tests verify that the iterator is not consumed at graph build time, + allowing lazy iterators to be set up after graph construction. + """ + + def test_deferred_iterator_not_consumed_at_build_time(self): + """Test that iterator is not consumed during graph build, but is consumed at run time.""" + iteration_started = [] + + def tracking_generator(): + iteration_started.append(True) + yield _make_record_batch("TsCol", 5, _STARTTIME) + + gen = tracking_generator() + + @csp.graph + def test_graph(): + data = RecordBatchPullInputAdapter("TsCol", gen, schema=None, expect_small_batches=False) + csp.add_graph_output("data", data) + + # Graph definition should NOT consume the iterator + assert len(iteration_started) == 0 + + # Running the graph SHOULD consume it + results = csp.run(test_graph, starttime=_STARTTIME) + assert len(iteration_started) == 1 + assert len(results["data"]) == 1 + assert len(results["data"][0][1][0]) == 5 + + def test_lazy_iterator_pattern(self): + """Test the lazy iterator pattern used by LazyParquetIterator.""" + + class LazyIterator: + """Simulates LazyParquetIterator behavior.""" + + def __init__(self): + self._data = None + + def set_data(self, data): + self._data = data + + def __iter__(self): + if self._data is None: + raise RuntimeError("Data not set") + for item in self._data: + yield item + + # Create lazy iterator without data + lazy_iter = LazyIterator() + + # Create batches + rbs = [ + _make_record_batch("TsCol", 5, _STARTTIME), + _make_record_batch("TsCol", 3, _STARTTIME + timedelta(seconds=1)), + ] + + # Set data before running (simulates what GraphComputeSimManager.start() does) + lazy_iter.set_data(rbs) + + # Run with lazy schema + results = csp.run(G_lazy_schema, "TsCol", lazy_iter, False, starttime=_STARTTIME) + assert len(results["data"]) == 2 + assert [len(r[1][0]) for r in results["data"]] == [5, 3]