From 768266d5792b91eec936792ced801a5274c3c2ae Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sun, 8 Mar 2026 20:51:18 -0700 Subject: [PATCH 01/14] feat: add async 'for' loop support to LogScanner (#424) --- bindings/python/fluss/__init__.pyi | 8 +- bindings/python/src/table.rs | 154 ++++++++++++++++++++++--- bindings/python/test/test_log_table.py | 49 ++++++++ 3 files changed, 188 insertions(+), 23 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 417ac9b2..2534f638 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -125,7 +125,9 @@ class ScanRecords: def __getitem__(self, index: slice) -> List[ScanRecord]: ... @overload def __getitem__(self, bucket: TableBucket) -> List[ScanRecord]: ... - def __getitem__(self, key: Union[int, slice, TableBucket]) -> Union[ScanRecord, List[ScanRecord]]: ... + def __getitem__( + self, key: Union[int, slice, TableBucket] + ) -> Union[ScanRecord, List[ScanRecord]]: ... def __contains__(self, bucket: TableBucket) -> bool: ... def __iter__(self) -> Iterator[ScanRecord]: ... def __str__(self) -> str: ... @@ -369,7 +371,6 @@ class FlussAdmin: ... def __repr__(self) -> str: ... - class DatabaseDescriptor: """Descriptor for a Fluss database (comment and custom properties).""" @@ -383,7 +384,6 @@ class DatabaseDescriptor: def get_custom_properties(self) -> Dict[str, str]: ... def __repr__(self) -> str: ... - class DatabaseInfo: """Information about a Fluss database.""" @@ -604,7 +604,6 @@ class UpsertWriter: ... def __repr__(self) -> str: ... - class WriteResultHandle: """Handle for a pending write (append/upsert/delete). Ignore for fire-and-forget, or await handle.wait() for ack.""" @@ -613,7 +612,6 @@ class WriteResultHandle: ... def __repr__(self) -> str: ... - class Lookuper: """Lookuper for performing primary key lookups on a Fluss table.""" diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 660cd6be..b68492e1 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -30,6 +30,9 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; +use pyo3::{ + Bound, IntoPyObjectExt, Py, PyAny, PyClassInitializer, PyErr, PyRef, PyRefMut, PyResult, Python, +}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; @@ -1863,6 +1866,13 @@ enum ScannerKind { Batch(fcore::client::RecordBatchLogScanner), } +/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing +struct ScannerState { + kind: ScannerKind, + /// A buffer to hold records polled from the network before yielding them one-by-one to Python + pending_records: std::collections::VecDeque>, +} + impl ScannerKind { fn as_record(&self) -> PyResult<&fcore::client::LogScanner> { match self { @@ -1901,7 +1911,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - scanner: ScannerKind, + state: Arc>, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1922,7 +1932,8 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe(bucket_id, start_offset)) + let state = self.state.lock().await; + with_scanner!(&state.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1935,7 +1946,8 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets)) + let state = self.state.lock().await; + with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1956,8 +1968,9 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { + let state = self.state.lock().await; with_scanner!( - &self.scanner, + &state.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1976,8 +1989,9 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { + let state = self.state.lock().await; with_scanner!( - &self.scanner, + &state.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1992,7 +2006,8 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, unsubscribe(bucket_id)) + let state = self.state.lock().await; + with_scanner!(&state.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2006,11 +2021,9 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!( - &self.scanner, - unsubscribe_partition(partition_id, bucket_id) - ) - .map_err(|e| FlussError::from_core_error(&e)) + let state = self.state.lock().await; + with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id)) + .map_err(|e| FlussError::from_core_error(&e)) }) }) } @@ -2030,7 +2043,10 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner = self.scanner.as_record()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2079,7 +2095,10 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2114,7 +2133,10 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2167,7 +2189,10 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let subscribed = scanner.get_subscribed_buckets(); if subscribed.is_empty() { return Err(FlussError::new_err( @@ -2199,6 +2224,90 @@ impl LogScanner { Ok(df) } + fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + let py = slf.py(); + let code = pyo3::ffi::c_str!( + r#" +async def _adapter(obj): + while True: + try: + yield await obj.__anext__() + except StopAsyncIteration: + break +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None)?; + let adapter = globals.get_item("_adapter")?.unwrap(); + // Return adapt(self) + adapter.call1((slf.into_bound_py_any(py)?,)) + } + + fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult>> { + let state_arc = slf.state.clone(); + let projected_row_type = slf.projected_row_type.clone(); + let py = slf.py(); + + let future = future_into_py(py, async move { + let mut state = state_arc.lock().await; + + // 1. If we already have buffered records, pop and return immediately + if let Some(record) = state.pending_records.pop_front() { + return Ok(record.into_any()); + } + + // 2. Buffer is empty, we must poll the network for the next batch + // The underlying kind must be a Record-based scanner. + let scanner = match state.kind.as_record() { + Ok(s) => s, + Err(_) => { + return Err(pyo3::exceptions::PyStopAsyncIteration::new_err( + "Stream Ended", + )); + } + }; + + // Poll with a reasonable internal timeout before unblocking the event loop + let timeout = core::time::Duration::from_millis(5000); + + let mut current_records = scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // If it's a real timeout with zero records, loop or throw StopAsyncIteration? + // Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future. + while current_records.is_empty() { + current_records = scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + } + + // Now we have records. + Python::attach(|py| { + for (_, records) in current_records.into_records_by_buckets() { + for core_record in records { + let scan_record = + ScanRecord::from_core(py, &core_record, &projected_row_type)?; + state.pending_records.push_back(Py::new(py, scan_record)?); + } + } + + // Pop the very first one to return right now + if let Some(record) = state.pending_records.pop_front() { + Ok(record.into_any()) + } else { + Err(pyo3::exceptions::PyStopAsyncIteration::new_err( + "Stream Ended", + )) + } + }) + })?; + + Ok(Some(future)) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } @@ -2213,7 +2322,10 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - scanner, + state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState { + kind: scanner, + pending_records: std::collections::VecDeque::new(), + })), admin, table_info, projected_schema, @@ -2264,7 +2376,10 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2367,7 +2482,10 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner_ref = + unsafe { &*(&self.state as *const std::sync::Arc>) }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index dd1a4d4f..2f9588b0 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -729,6 +729,55 @@ async def test_scan_records_indexing_and_slicing(connection, admin): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on LogScanner.""" + table_path = fluss.TablePath("fluss", "py_test_async_iterator") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + # Write 5 records + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"async{i}" for i in range(1, 6)])], + schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)}) + + collected = [] + + # Here is the magical Issue #424 async iterator logic at work: + async def consume_scanner(): + async for record in scanner: + collected.append(record) + if len(collected) == 5: + break + + # We must race the consumption against a timeout so the test doesn't hang if the iterator is broken + await asyncio.wait_for(consume_scanner(), timeout=10.0) + + assert len(collected) == 5, f"Expected 5 records, got {len(collected)}" + + collected.sort(key=lambda r: r.row["id"]) + for i, record in enumerate(collected): + assert record.row["id"] == i + 1 + assert record.row["val"] == f"async{i + 1}" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 0e01b8b7a453a691f49d5382cb308b71fa83d650 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sun, 8 Mar 2026 21:22:53 -0700 Subject: [PATCH 02/14] chore: revert formatting changes to __init__.pyi --- bindings/python/fluss/__init__.pyi | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 2534f638..417ac9b2 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -125,9 +125,7 @@ class ScanRecords: def __getitem__(self, index: slice) -> List[ScanRecord]: ... @overload def __getitem__(self, bucket: TableBucket) -> List[ScanRecord]: ... - def __getitem__( - self, key: Union[int, slice, TableBucket] - ) -> Union[ScanRecord, List[ScanRecord]]: ... + def __getitem__(self, key: Union[int, slice, TableBucket]) -> Union[ScanRecord, List[ScanRecord]]: ... def __contains__(self, bucket: TableBucket) -> bool: ... def __iter__(self) -> Iterator[ScanRecord]: ... def __str__(self) -> str: ... @@ -371,6 +369,7 @@ class FlussAdmin: ... def __repr__(self) -> str: ... + class DatabaseDescriptor: """Descriptor for a Fluss database (comment and custom properties).""" @@ -384,6 +383,7 @@ class DatabaseDescriptor: def get_custom_properties(self) -> Dict[str, str]: ... def __repr__(self) -> str: ... + class DatabaseInfo: """Information about a Fluss database.""" @@ -604,6 +604,7 @@ class UpsertWriter: ... def __repr__(self) -> str: ... + class WriteResultHandle: """Handle for a pending write (append/upsert/delete). Ignore for fire-and-forget, or await handle.wait() for ack.""" @@ -612,6 +613,7 @@ class WriteResultHandle: ... def __repr__(self) -> str: ... + class Lookuper: """Lookuper for performing primary key lookups on a Fluss table.""" From 3aa067b2fbda018c5b19ce972c4c03aafa01e1a4 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 9 Mar 2026 07:12:28 -0700 Subject: [PATCH 03/14] fix: remove unused PyClassInitializer and PyErr imports --- bindings/python/src/table.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index b68492e1..10f77463 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -31,7 +31,7 @@ use pyo3::types::{ PyTzInfo, }; use pyo3::{ - Bound, IntoPyObjectExt, Py, PyAny, PyClassInitializer, PyErr, PyRef, PyRefMut, PyResult, Python, + Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; From 1065665717ec19494e17a13de1415c8606bf9e70 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 9 Mar 2026 19:37:43 -0700 Subject: [PATCH 04/14] style: apply cargo fmt --- bindings/python/src/table.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 10f77463..1d66a6e3 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -30,9 +30,7 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; -use pyo3::{ - Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python, -}; +use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; From 195ec7cfc1fbefa50e56ab1c760b345fdea4eaca Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Tue, 10 Mar 2026 10:34:43 -0700 Subject: [PATCH 05/14] refactor: release scanner lock earlier by cloning subscribed buckets within a local scope in `to_arrow` --- bindings/python/src/table.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 1d66a6e3..28ccbc76 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2187,16 +2187,20 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; - let subscribed = scanner.get_subscribed_buckets(); - if subscribed.is_empty() { - return Err(FlussError::new_err( - "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", - )); - } + let subscribed = { + let scanner_ref = unsafe { + &*(&self.state as *const std::sync::Arc>) + }; + let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); + let scanner = lock.kind.as_batch()?; + let subs = scanner.get_subscribed_buckets(); + if subs.is_empty() { + return Err(FlussError::new_err( + "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", + )); + } + subs.clone() + }; // 2. Query latest offsets for all subscribed buckets let stopping_offsets = self.query_latest_offsets(py, &subscribed)?; From 4ad2fd86fa710f3cd66f6e840ab2eef02168df6c Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Wed, 11 Mar 2026 16:45:51 -0700 Subject: [PATCH 06/14] refactor: Remove Mutex and utilize __aiter__ with _async_poll(timeout_ms) instead --- bindings/python/src/table.rs | 180 +++++-------- bindings/python/test/test_log_table.py | 341 +++++++++++++++++++++++++ 2 files changed, 412 insertions(+), 109 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 28ccbc76..64c06d30 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -1864,13 +1864,6 @@ enum ScannerKind { Batch(fcore::client::RecordBatchLogScanner), } -/// The internal state of the scanner, protected by a Tokio Mutex for async cross-thread sharing -struct ScannerState { - kind: ScannerKind, - /// A buffer to hold records polled from the network before yielding them one-by-one to Python - pending_records: std::collections::VecDeque>, -} - impl ScannerKind { fn as_record(&self) -> PyResult<&fcore::client::LogScanner> { match self { @@ -1895,7 +1888,7 @@ impl ScannerKind { /// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface. macro_rules! with_scanner { ($scanner:expr, $method:ident($($arg:expr),*)) => { - match $scanner { + match $scanner.as_ref() { ScannerKind::Record(s) => s.$method($($arg),*).await, ScannerKind::Batch(s) => s.$method($($arg),*).await, } @@ -1909,7 +1902,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - state: Arc>, + kind: Arc, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1930,8 +1923,7 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, subscribe(bucket_id, start_offset)) + with_scanner!(&self.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1944,8 +1936,7 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, subscribe_buckets(&bucket_offsets)) + with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1966,9 +1957,8 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; with_scanner!( - &state.kind, + &self.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1987,9 +1977,8 @@ impl LogScanner { ) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; with_scanner!( - &state.kind, + &self.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -2004,8 +1993,7 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, unsubscribe(bucket_id)) + with_scanner!(&self.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2019,8 +2007,7 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - let state = self.state.lock().await; - with_scanner!(&state.kind, unsubscribe_partition(partition_id, bucket_id)) + with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2041,10 +2028,7 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_record()?; + let scanner = self.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2093,10 +2077,7 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2131,10 +2112,7 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2188,11 +2166,7 @@ impl LogScanner { /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { let subscribed = { - let scanner_ref = unsafe { - &*(&self.state as *const std::sync::Arc>) - }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let subs = scanner.get_subscribed_buckets(); if subs.is_empty() { return Err(FlussError::new_err( @@ -2227,87 +2201,84 @@ impl LogScanner { } fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); let py = slf.py(); - let code = pyo3::ffi::c_str!( - r#" -async def _adapter(obj): + let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_scan(scanner, timeout_ms=1000): while True: - try: - yield await obj.__anext__() - except StopAsyncIteration: - break + batch = await scanner._async_poll(timeout_ms) + if batch: + for record in batch: + yield record "# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None)?; - let adapter = globals.get_item("_adapter")?.unwrap(); - // Return adapt(self) - adapter.call1((slf.into_bound_py_any(py)?,)) + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan").unwrap().unwrap().unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) } - fn __anext__<'py>(slf: PyRefMut<'py, Self>) -> PyResult>> { - let state_arc = slf.state.clone(); - let projected_row_type = slf.projected_row_type.clone(); - let py = slf.py(); - - let future = future_into_py(py, async move { - let mut state = state_arc.lock().await; + /// Perform a single bounded poll and return a list of ScanRecord objects. + /// + /// This is the async building block used by `__aiter__` to implement + /// `async for`. Each call does exactly one network poll (bounded by + /// `timeout_ms`), converts any results to Python objects, and returns + /// them as a list. An empty list signals a timeout (no data yet), not + /// end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of ScanRecord objects + fn _async_poll<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } - // 1. If we already have buffered records, pop and return immediately - if let Some(record) = state.pending_records.pop_front() { - return Ok(record.into_any()); - } + let scanner = Arc::clone(&self.kind); + let projected_row_type = self.projected_row_type.clone(); + let timeout = Duration::from_millis(timeout_ms as u64); - // 2. Buffer is empty, we must poll the network for the next batch - // The underlying kind must be a Record-based scanner. - let scanner = match state.kind.as_record() { - Ok(s) => s, - Err(_) => { - return Err(pyo3::exceptions::PyStopAsyncIteration::new_err( - "Stream Ended", + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Record(s) => s, + ScannerKind::Batch(_) => { + return Err(PyTypeError::new_err( + "Async iteration is only supported for record scanners; \ + use create_log_scanner() instead.", )); } }; - // Poll with a reasonable internal timeout before unblocking the event loop - let timeout = core::time::Duration::from_millis(5000); - - let mut current_records = scanner + let scan_records = core_scanner .poll(timeout) .await .map_err(|e| FlussError::from_core_error(&e))?; - // If it's a real timeout with zero records, loop or throw StopAsyncIteration? - // Since it's a streaming log, we can yield None or block. Blocking requires a loop in the future. - while current_records.is_empty() { - current_records = scanner - .poll(timeout) - .await - .map_err(|e| FlussError::from_core_error(&e))?; - } - - // Now we have records. + // Convert to Python list Python::attach(|py| { - for (_, records) in current_records.into_records_by_buckets() { + let mut result: Vec> = Vec::new(); + for (_, records) in scan_records.into_records_by_buckets() { for core_record in records { let scan_record = ScanRecord::from_core(py, &core_record, &projected_row_type)?; - state.pending_records.push_back(Py::new(py, scan_record)?); + result.push(Py::new(py, scan_record)?); } } - - // Pop the very first one to return right now - if let Some(record) = state.pending_records.pop_front() { - Ok(record.into_any()) - } else { - Err(pyo3::exceptions::PyStopAsyncIteration::new_err( - "Stream Ended", - )) - } + Ok(result) }) - })?; - - Ok(Some(future)) + }) } fn __repr__(&self) -> String { @@ -2324,10 +2295,7 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - state: std::sync::Arc::new(tokio::sync::Mutex::new(ScannerState { - kind: scanner, - pending_records: std::collections::VecDeque::new(), - })), + kind: Arc::new(scanner), admin, table_info, projected_schema, @@ -2378,10 +2346,7 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2484,10 +2449,7 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner_ref = - unsafe { &*(&self.state as *const std::sync::Arc>) }; - let lock = TOKIO_RUNTIME.block_on(async { scanner_ref.lock().await }); - let scanner = lock.kind.as_batch()?; + let scanner = self.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index 2f9588b0..8cf43fb4 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -778,6 +778,347 @@ async def consume_scanner(): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + `poll()` calls. If the old implementation's tokio::spawn'd task + were still alive, it would hold the Mutex and cause `poll()` to + deadlock or error. + """ + table_path = fluss.TablePath("fluss", "py_test_async_break_leak") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"v{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect only 3 of 10) + collected_async = [] + + async def consume_and_break(): + async for record in scanner: + collected_async.append(record) + if len(collected_async) >= 3: + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert len(collected_async) == 3, ( + f"Expected 3 records from async for, got {len(collected_async)}" + ) + + # Phase 2: sync poll() must still work — proves no leaked task / lock. + # With small data and few buckets, _async_poll may have fetched all + # records in one batch. After break, the un-yielded records from that + # batch are lost. So sync poll may return 0 records — the key assertion + # is that poll() completes without deadlock (returns within timeout). + remaining = scanner.poll(2000) + assert remaining is not None, "poll() should return (not deadlock)" + + # If we got records, verify no duplicates + async_ids = {r.row["id"] for r in collected_async} + sync_ids = {r.row["id"] for r in remaining} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs between async and sync: {async_ids & sync_ids}" + ) + + # All IDs must be from the original 1-10 range + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 11))), ( + f"Unexpected IDs: {all_ids - set(range(1, 11))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_multiple_batches(connection, admin): + """Verify async iteration works across multiple network poll cycles. + + _async_poll does a single bounded poll per call. Writing 20 records + to multiple buckets ensures the Python generator must loop through + several _async_poll calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_async_multi_batch") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"multi{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected = [] + + async def consume_all(): + async for record in scanner: + collected.append(record) + if len(collected) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(collected) == num_records, ( + f"Expected {num_records} records, got {len(collected)}" + ) + + # Verify all IDs are present (order may vary due to bucketing) + ids = sorted(r.row["id"] for r in collected) + assert ids == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_batch_scanner_raises_type_error( + connection, admin +): + """Verify that using `async for` on a batch scanner raises TypeError.""" + table_path = fluss.TablePath("fluss", "py_test_async_batch_error") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + + # Write some data so there's something to iterate + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + # Create a BATCH scanner (not a record scanner) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + # Attempting async for on a batch scanner must raise TypeError + import pytest + + with pytest.raises(TypeError): + + async def try_iterate(): + async for _ in batch_scanner: + pass + + await asyncio.wait_for(try_iterate(), timeout=5.0) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_negative_timeout(connection, admin): + """Verify _async_poll rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + scanner = await table.new_scan().create_log_scanner() + scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(Exception, match="non-negative"): + await scanner._async_poll(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_returns_list(connection, admin): + """Verify _async_poll returns a Python list of ScanRecord objects.""" + table_path = fluss.TablePath("fluss", "py_test_async_poll_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await scanner._async_poll(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a ScanRecord with .row, .offset, .timestamp + for record in result: + assert hasattr(record, "row"), "ScanRecord should have .row" + assert hasattr(record, "offset"), "ScanRecord should have .offset" + assert hasattr(record, "timestamp"), ( + "ScanRecord should have .timestamp" + ) + assert "id" in record.row + + # An empty poll (no new data) should return an empty list, not None + empty_result = await scanner._async_poll(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_async_iteration(connection, admin): + """Verify sync poll() works correctly interleaved with async iteration. + + This proves there is no lock contention between the async and sync + code paths — the removed Mutex would have caused deadlocks here if + the lock were held across the async poll boundary. + """ + table_path = fluss.TablePath( + "fluss", "py_test_sync_after_async" + ) + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"s{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 4 records via async for + async_records = [] + + async def partial_consume(): + async for record in scanner: + async_records.append(record) + if len(async_records) >= 4: + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert len(async_records) == 4 + + # Step 2: Collect remaining records via sync poll(). + # With small data, _async_poll may have fetched all records in one + # batch. After break, the un-yielded records are lost. The key + # assertion is that poll() works (no deadlock from a held lock). + sync_records = scanner.poll(2000) + assert sync_records is not None, "poll() should return (not deadlock)" + + # Step 3: Verify no duplicates and all IDs are valid + async_ids = {r.row["id"] for r in async_records} + sync_ids = {r.row["id"] for r in sync_records} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs: {async_ids & sync_ids}" + ) + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 9))), ( + f"Unexpected IDs: {all_ids - set(range(1, 9))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 08eef133d39b0bccf076f5909042423833ff7250 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 12 Mar 2026 22:31:03 -0700 Subject: [PATCH 07/14] feat: add create_record_batch_log_scanner() --- bindings/python/src/table.rs | 104 ++++- bindings/python/test/test_log_table.py | 556 ++++++++++++++++++++++++- 2 files changed, 636 insertions(+), 24 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 64c06d30..1dddddbd 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2201,11 +2201,14 @@ impl LogScanner { } fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { - static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); let py = slf.py(); - let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { - let code = pyo3::ffi::c_str!( - r#" + + match slf.kind.as_ref() { + ScannerKind::Record(_) => { + static RECORD_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" async def _async_scan(scanner, timeout_ms=1000): while True: batch = await scanner._async_poll(timeout_ms) @@ -2213,12 +2216,37 @@ async def _async_scan(scanner, timeout_ms=1000): for record in batch: yield record "# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None).unwrap(); - globals.get_item("_async_scan").unwrap().unwrap().unbind() - }); - gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan").unwrap().unwrap().unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + ScannerKind::Batch(_) => { + static BATCH_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_batch_scan(scanner, timeout_ms=1000): + while True: + batches = await scanner._async_poll_batches(timeout_ms) + if batches: + for rb in batches: + yield rb +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals + .get_item("_async_batch_scan") + .unwrap() + .unwrap() + .unbind() + }); + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) + } + } } /// Perform a single bounded poll and return a list of ScanRecord objects. @@ -2281,6 +2309,62 @@ async def _async_scan(scanner, timeout_ms=1000): }) } + /// Perform a single bounded poll and return a list of RecordBatch objects. + /// + /// This is the async building block used by `__aiter__` (batch mode) to + /// implement `async for`. Each call does exactly one network poll (bounded + /// by `timeout_ms`), converts any results to Python RecordBatch objects, + /// and returns them as a list. An empty list signals a timeout (no data + /// yet), not end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of RecordBatch objects + fn _async_poll_batches<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Batch(s) => s, + ScannerKind::Record(_) => { + return Err(PyTypeError::new_err( + "Batch async iteration is only supported for batch scanners; \ + use create_record_batch_log_scanner() instead.", + )); + } + }; + + let scan_batches = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list of RecordBatch objects + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for scan_batch in scan_batches { + let rb = RecordBatch::from_scan_batch(scan_batch); + result.push(Py::new(py, rb)?); + } + Ok(result) + }) + }) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index 8cf43fb4..970a516f 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -916,11 +916,14 @@ async def consume_all(): await admin.drop_table(table_path, ignore_if_not_exists=False) -async def test_async_iterator_batch_scanner_raises_type_error( - connection, admin -): - """Verify that using `async for` on a batch scanner raises TypeError.""" - table_path = fluss.TablePath("fluss", "py_test_async_batch_error") +async def test_batch_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on a batch LogScanner. + + With our __aiter__ dispatch, a batch-based scanner should yield RecordBatch + objects (not ScanRecord). Each yielded item has .batch (PyArrow RecordBatch), + .bucket, .base_offset, .last_offset. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_iter") await admin.drop_table(table_path, ignore_if_not_exists=True) schema = fluss.Schema( @@ -929,14 +932,148 @@ async def test_async_iterator_batch_scanner_raises_type_error( await admin.create_table(table_path, fluss.TableDescriptor(schema)) table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 7)), type=pa.int32()), + pa.array([f"bv{i}" for i in range(1, 7)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume_batches(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 6: + break + + await asyncio.wait_for(consume_batches(), timeout=15.0) + + assert total_rows >= 6, f"Expected >=6 total rows, got {total_rows}" + assert len(collected_batches) > 0 + + # Verify each yielded item is a RecordBatch with expected attributes + for rb in collected_batches: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + # .batch should be a PyArrow RecordBatch + arrow_batch = rb.batch + assert isinstance(arrow_batch, pa.RecordBatch), ( + f"Expected PyArrow RecordBatch, got {type(arrow_batch).__name__}" + ) + assert arrow_batch.num_columns == 2 + assert set(arrow_batch.schema.names) == {"id", "val"} + + # Verify all 6 IDs are present + all_ids = [] + for rb in collected_batches: + all_ids.extend(rb.batch.column("id").to_pylist()) + assert sorted(all_ids[:6]) == [1, 2, 3, 4, 5, 6] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of batch `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + poll_record_batch() calls, proving no leaked task or lock. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_break") + await admin.drop_table(table_path, ignore_if_not_exists=True) - # Write some data so there's something to iterate + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) writer = table.new_append().create_writer() writer.write_arrow_batch( pa.RecordBatch.from_arrays( [ - pa.array([1, 2, 3], type=pa.int32()), - pa.array(["a", "b", "c"]), + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"bl{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect just 1 batch) + first_batch = None + + async def consume_and_break(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert first_batch is not None, "Should have received at least 1 batch" + assert first_batch.batch.num_rows > 0 + + # Phase 2: sync poll_record_batch() must still work — proves no leak + remaining = batch_scanner.poll_record_batch(2000) + assert remaining is not None, "poll_record_batch() should return (not deadlock)" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_multiple_batches(connection, admin): + """Verify batch async iteration works across multiple network poll cycles. + + Writing 20 records to 3 buckets ensures the generator must loop through + several _async_poll_batches calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_multi") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"bm{i}" for i in range(1, num_records + 1)]), ], schema=pa.schema( [pa.field("id", pa.int32()), pa.field("val", pa.string())] @@ -945,20 +1082,410 @@ async def test_async_iterator_batch_scanner_raises_type_error( ) await writer.flush() - # Create a BATCH scanner (not a record scanner) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_ids = [] + + async def consume_all(): + async for rb in batch_scanner: + all_ids.extend(rb.batch.column("id").to_pylist()) + if len(all_ids) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(all_ids) >= num_records, ( + f"Expected >={num_records} IDs, got {len(all_ids)}" + ) + assert sorted(all_ids[:num_records]) == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_wrong_scanner_type(connection, admin): + """Verify _async_poll_batches raises TypeError on a record scanner.""" + table_path = fluss.TablePath("fluss", "py_test_apb_wrong_type") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + # Create a RECORD scanner (not batch) + record_scanner = await table.new_scan().create_log_scanner() + record_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest + + with pytest.raises(TypeError): + await record_scanner._async_poll_batches(1000) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_on_batch_scanner_raises_type_error( + connection, admin +): + """Verify _async_poll (record method) raises TypeError on a batch scanner. + + This is the inverse: _async_poll is for records only, _async_poll_batches + is for batches only. Calling the wrong one should raise TypeError. + """ + table_path = fluss.TablePath("fluss", "py_test_apoll_batch_err") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) batch_scanner = await table.new_scan().create_record_batch_log_scanner() batch_scanner.subscribe(bucket_id=0, start_offset=0) - # Attempting async for on a batch scanner must raise TypeError import pytest with pytest.raises(TypeError): + await batch_scanner._async_poll(1000) - async def try_iterate(): - async for _ in batch_scanner: - pass + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_negative_timeout(connection, admin): + """Verify _async_poll_batches rejects a negative timeout_ms with an error.""" + table_path = fluss.TablePath("fluss", "py_test_apb_neg_timeout") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + batch_scanner.subscribe(bucket_id=0, start_offset=0) + + import pytest - await asyncio.wait_for(try_iterate(), timeout=5.0) + with pytest.raises(Exception, match="non-negative"): + await batch_scanner._async_poll_batches(-1) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_poll_batches_returns_list(connection, admin): + """Verify _async_poll_batches returns a Python list of RecordBatch objects.""" + table_path = fluss.TablePath("fluss", "py_test_apb_returns_list") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Poll until we get a non-empty result + result = None + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + result = await batch_scanner._async_poll_batches(2000) + if result: + break + + assert result is not None, "Expected non-None result from _async_poll_batches" + assert isinstance(result, list), ( + f"Expected list, got {type(result).__name__}" + ) + assert len(result) > 0, "Expected non-empty list" + + # Each element must be a RecordBatch with .batch, .bucket, .base_offset, .last_offset + for rb in result: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + assert isinstance(rb.batch, pa.RecordBatch) + assert rb.base_offset >= 0 + assert rb.last_offset >= rb.base_offset + + # An empty poll (no new data) should return an empty list, not None + empty_result = await batch_scanner._async_poll_batches(100) + assert isinstance(empty_result, list), ( + f"Empty poll should return list, got {type(empty_result).__name__}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_record_batch_metadata(connection, admin): + """Verify that RecordBatch objects yielded by async for contain correct metadata. + + Each RecordBatch must have: + - .bucket with a valid bucket_id + - .base_offset >= 0 + - .last_offset = base_offset + num_rows - 1 (for non-empty batches) + - .batch.num_rows > 0 + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_meta") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"m{i}" for i in range(1, 6)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 5: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 5 + + for rb in collected_batches: + assert rb.batch.num_rows > 0, "Yielded batch should not be empty" + assert rb.base_offset >= 0, "base_offset should be non-negative" + expected_last = rb.base_offset + rb.batch.num_rows - 1 + assert rb.last_offset == expected_last, ( + f"last_offset should be {expected_last}, got {rb.last_offset}" + ) + assert rb.bucket.bucket_id >= 0, "bucket_id should be non-negative" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_to_pandas(connection, admin): + """Verify end-to-end: async for → RecordBatch → .batch.to_pandas().""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_pandas") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("name", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array(["alice", "bob", "charlie"]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("name", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_dfs = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in batch_scanner: + df = rb.batch.to_pandas() + all_dfs.append(df) + total_rows += len(df) + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + import pandas as pd + combined = pd.concat(all_dfs, ignore_index=True).sort_values("id").reset_index(drop=True) + assert list(combined.columns) == ["id", "name"] + assert combined["id"].tolist()[:3] == [10, 20, 30] + assert combined["name"].tolist()[:3] == ["alice", "bob", "charlie"] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_projected_columns(connection, admin): + """Verify batch async for respects column projection.""" + table_path = fluss.TablePath("fluss", "py_test_batch_async_proj") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["x", "y", "z"]), + pa.array([10, 20, 30], type=pa.int32()), + ], + schema=pa.schema( + [ + pa.field("col_a", pa.int32()), + pa.field("col_b", pa.string()), + pa.field("col_c", pa.int32()), + ] + ), + ) + ) + await writer.flush() + + # Project only col_b and col_c + proj_scanner = ( + await table.new_scan() + .project_by_name(["col_b", "col_c"]) + .create_record_batch_log_scanner() + ) + num_buckets = (await admin.get_table_info(table_path)).num_buckets + proj_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_batches = [] + total_rows = 0 + + async def consume(): + nonlocal total_rows + async for rb in proj_scanner: + all_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 3: + break + + await asyncio.wait_for(consume(), timeout=15.0) + assert total_rows >= 3 + + # Verify projected schema: only col_b and col_c, no col_a + for rb in all_batches: + assert set(rb.batch.schema.names) == {"col_b", "col_c"}, ( + f"Projected schema should only have col_b and col_c, " + f"got {rb.batch.schema.names}" + ) + assert rb.batch.num_columns == 2 + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_sync_methods_after_batch_async_iteration(connection, admin): + """Verify sync poll_record_batch() works after batch async iteration. + + This proves no lock contention between the batch async and sync paths. + """ + table_path = fluss.TablePath("fluss", "py_test_sync_after_batch_async") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 9)), type=pa.int32()), + pa.array([f"sv{i}" for i in range(1, 9)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Step 1: Collect 1 batch via async for then break + first_batch = None + + async def partial_consume(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(partial_consume(), timeout=10.0) + assert first_batch is not None + + # Step 2: Sync poll_record_batch() must work (no deadlock) + sync_batches = batch_scanner.poll_record_batch(2000) + assert sync_batches is not None, "poll_record_batch() should return (not deadlock)" + + # Step 3: Sync poll_arrow() must also work + arrow_table = batch_scanner.poll_arrow(2000) + assert arrow_table is not None, "poll_arrow() should return (not deadlock)" await admin.drop_table(table_path, ignore_if_not_exists=False) @@ -1145,3 +1672,4 @@ def _poll_arrow_ids(scanner, expected_count, timeout_s=10): if arrow_table.num_rows > 0: all_ids.extend(arrow_table.column("id").to_pylist()) return all_ids + From 68426a073090a8982f0e440612def37c295a8ba2 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 12 Mar 2026 22:38:03 -0700 Subject: [PATCH 08/14] chore: update error message for _async_poll and _async_poll_batches so that they match when talking about Record vs Batch --- bindings/python/src/table.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 1dddddbd..130c1cad 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2283,8 +2283,8 @@ async def _async_batch_scan(scanner, timeout_ms=1000): ScannerKind::Record(s) => s, ScannerKind::Batch(_) => { return Err(PyTypeError::new_err( - "Async iteration is only supported for record scanners; \ - use create_log_scanner() instead.", + "This internal method only supports record-based scanners. \ + For batch-based scanners, use 'async for' or 'poll_record_batch' instead.", )); } }; @@ -2342,8 +2342,8 @@ async def _async_batch_scan(scanner, timeout_ms=1000): ScannerKind::Batch(s) => s, ScannerKind::Record(_) => { return Err(PyTypeError::new_err( - "Batch async iteration is only supported for batch scanners; \ - use create_record_batch_log_scanner() instead.", + "This internal method only supports batch-based scanners. \ + For record-based scanners, use 'async for' or 'poll' instead.", )); } }; From d619b13d12d51b1f36bbe895c8ca2422683ca43e Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sat, 14 Mar 2026 15:57:03 -0700 Subject: [PATCH 09/14] refactor: separate `IntoPyObjectExt` import from grouped `pyo3` imports. --- bindings/python/src/table.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 130c1cad..896685cb 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -23,6 +23,7 @@ use arrow_schema::SchemaRef; use fluss::record::to_arrow_schema; use fluss::rpc::message::OffsetSpec; use indexmap::IndexMap; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyTypeError}; use pyo3::sync::PyOnceLock; use pyo3::types::{ @@ -30,7 +31,6 @@ use pyo3::types::{ PyDeltaAccess, PyDict, PyList, PySequence, PySlice, PyTime, PyTimeAccess, PyTuple, PyType, PyTzInfo, }; -use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyRef, PyRefMut, PyResult, Python}; use pyo3_async_runtimes::tokio::future_into_py; use std::collections::HashMap; use std::sync::Arc; From 134e56b543395d7f110ffa7b36edfe6fdab4e729 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sat, 14 Mar 2026 16:20:03 -0700 Subject: [PATCH 10/14] feat: Add asynchronous iteration support to `ScanIterator` with `__aiter__` and async polling methods. --- bindings/python/fluss/__init__.pyi | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 417ac9b2..8d46bc63 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -19,7 +19,17 @@ from enum import IntEnum from types import TracebackType -from typing import Dict, Iterator, List, Optional, Tuple, Union, overload +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + overload, +) import pandas as pd import pyarrow as pa @@ -765,8 +775,12 @@ class LogScanner: You must call subscribe(), subscribe_buckets(), or subscribe_partition() first. """ - ... def __repr__(self) -> str: ... + def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ... + async def _async_poll(self, timeout_ms: Optional[int] = ...) -> List[ScanRecord]: ... + async def _async_poll_batches( + self, timeout_ms: Optional[int] = ... + ) -> List[RecordBatch]: ... class Schema: def __init__( From efbcb8cd3e390558fc6242f7a7aed184b5c2257b Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sat, 14 Mar 2026 16:50:10 -0700 Subject: [PATCH 11/14] refactor: Remove extra tests that check for _async_poll / _async_poll_batches directly - they're internal methods, tested implicitly through async for and tests that re-verify existing features (already covered by sync tests). --- bindings/python/test/test_log_table.py | 541 ------------------------- 1 file changed, 541 deletions(-) diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index 970a516f..eb118748 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -1105,547 +1105,6 @@ async def consume_all(): await admin.drop_table(table_path, ignore_if_not_exists=False) -async def test_async_poll_batches_wrong_scanner_type(connection, admin): - """Verify _async_poll_batches raises TypeError on a record scanner.""" - table_path = fluss.TablePath("fluss", "py_test_apb_wrong_type") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - # Create a RECORD scanner (not batch) - record_scanner = await table.new_scan().create_log_scanner() - record_scanner.subscribe(bucket_id=0, start_offset=0) - - import pytest - - with pytest.raises(TypeError): - await record_scanner._async_poll_batches(1000) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_async_poll_on_batch_scanner_raises_type_error( - connection, admin -): - """Verify _async_poll (record method) raises TypeError on a batch scanner. - - This is the inverse: _async_poll is for records only, _async_poll_batches - is for batches only. Calling the wrong one should raise TypeError. - """ - table_path = fluss.TablePath("fluss", "py_test_apoll_batch_err") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - batch_scanner.subscribe(bucket_id=0, start_offset=0) - - import pytest - - with pytest.raises(TypeError): - await batch_scanner._async_poll(1000) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_async_poll_batches_negative_timeout(connection, admin): - """Verify _async_poll_batches rejects a negative timeout_ms with an error.""" - table_path = fluss.TablePath("fluss", "py_test_apb_neg_timeout") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - batch_scanner.subscribe(bucket_id=0, start_offset=0) - - import pytest - - with pytest.raises(Exception, match="non-negative"): - await batch_scanner._async_poll_batches(-1) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_async_poll_batches_returns_list(connection, admin): - """Verify _async_poll_batches returns a Python list of RecordBatch objects.""" - table_path = fluss.TablePath("fluss", "py_test_apb_returns_list") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array([1, 2, 3], type=pa.int32()), - pa.array(["x", "y", "z"]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("val", pa.string())] - ), - ) - ) - await writer.flush() - - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - batch_scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - # Poll until we get a non-empty result - result = None - deadline = time.monotonic() + 10 - while time.monotonic() < deadline: - result = await batch_scanner._async_poll_batches(2000) - if result: - break - - assert result is not None, "Expected non-None result from _async_poll_batches" - assert isinstance(result, list), ( - f"Expected list, got {type(result).__name__}" - ) - assert len(result) > 0, "Expected non-empty list" - - # Each element must be a RecordBatch with .batch, .bucket, .base_offset, .last_offset - for rb in result: - assert hasattr(rb, "batch"), "RecordBatch should have .batch" - assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" - assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" - assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" - assert isinstance(rb.batch, pa.RecordBatch) - assert rb.base_offset >= 0 - assert rb.last_offset >= rb.base_offset - - # An empty poll (no new data) should return an empty list, not None - empty_result = await batch_scanner._async_poll_batches(100) - assert isinstance(empty_result, list), ( - f"Empty poll should return list, got {type(empty_result).__name__}" - ) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_batch_async_iterator_record_batch_metadata(connection, admin): - """Verify that RecordBatch objects yielded by async for contain correct metadata. - - Each RecordBatch must have: - - .bucket with a valid bucket_id - - .base_offset >= 0 - - .last_offset = base_offset + num_rows - 1 (for non-empty batches) - - .batch.num_rows > 0 - """ - table_path = fluss.TablePath("fluss", "py_test_batch_async_meta") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array(list(range(1, 6)), type=pa.int32()), - pa.array([f"m{i}" for i in range(1, 6)]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("val", pa.string())] - ), - ) - ) - await writer.flush() - - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - batch_scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - collected_batches = [] - total_rows = 0 - - async def consume(): - nonlocal total_rows - async for rb in batch_scanner: - collected_batches.append(rb) - total_rows += rb.batch.num_rows - if total_rows >= 5: - break - - await asyncio.wait_for(consume(), timeout=15.0) - assert total_rows >= 5 - - for rb in collected_batches: - assert rb.batch.num_rows > 0, "Yielded batch should not be empty" - assert rb.base_offset >= 0, "base_offset should be non-negative" - expected_last = rb.base_offset + rb.batch.num_rows - 1 - assert rb.last_offset == expected_last, ( - f"last_offset should be {expected_last}, got {rb.last_offset}" - ) - assert rb.bucket.bucket_id >= 0, "bucket_id should be non-negative" - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_batch_async_iterator_to_pandas(connection, admin): - """Verify end-to-end: async for → RecordBatch → .batch.to_pandas().""" - table_path = fluss.TablePath("fluss", "py_test_batch_async_pandas") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("name", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array([10, 20, 30], type=pa.int32()), - pa.array(["alice", "bob", "charlie"]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("name", pa.string())] - ), - ) - ) - await writer.flush() - - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - batch_scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - all_dfs = [] - total_rows = 0 - - async def consume(): - nonlocal total_rows - async for rb in batch_scanner: - df = rb.batch.to_pandas() - all_dfs.append(df) - total_rows += len(df) - if total_rows >= 3: - break - - await asyncio.wait_for(consume(), timeout=15.0) - assert total_rows >= 3 - - import pandas as pd - combined = pd.concat(all_dfs, ignore_index=True).sort_values("id").reset_index(drop=True) - assert list(combined.columns) == ["id", "name"] - assert combined["id"].tolist()[:3] == [10, 20, 30] - assert combined["name"].tolist()[:3] == ["alice", "bob", "charlie"] - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_batch_async_iterator_projected_columns(connection, admin): - """Verify batch async for respects column projection.""" - table_path = fluss.TablePath("fluss", "py_test_batch_async_proj") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema( - [ - pa.field("col_a", pa.int32()), - pa.field("col_b", pa.string()), - pa.field("col_c", pa.int32()), - ] - ) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array([1, 2, 3], type=pa.int32()), - pa.array(["x", "y", "z"]), - pa.array([10, 20, 30], type=pa.int32()), - ], - schema=pa.schema( - [ - pa.field("col_a", pa.int32()), - pa.field("col_b", pa.string()), - pa.field("col_c", pa.int32()), - ] - ), - ) - ) - await writer.flush() - - # Project only col_b and col_c - proj_scanner = ( - await table.new_scan() - .project_by_name(["col_b", "col_c"]) - .create_record_batch_log_scanner() - ) - num_buckets = (await admin.get_table_info(table_path)).num_buckets - proj_scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - all_batches = [] - total_rows = 0 - - async def consume(): - nonlocal total_rows - async for rb in proj_scanner: - all_batches.append(rb) - total_rows += rb.batch.num_rows - if total_rows >= 3: - break - - await asyncio.wait_for(consume(), timeout=15.0) - assert total_rows >= 3 - - # Verify projected schema: only col_b and col_c, no col_a - for rb in all_batches: - assert set(rb.batch.schema.names) == {"col_b", "col_c"}, ( - f"Projected schema should only have col_b and col_c, " - f"got {rb.batch.schema.names}" - ) - assert rb.batch.num_columns == 2 - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_sync_methods_after_batch_async_iteration(connection, admin): - """Verify sync poll_record_batch() works after batch async iteration. - - This proves no lock contention between the batch async and sync paths. - """ - table_path = fluss.TablePath("fluss", "py_test_sync_after_batch_async") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array(list(range(1, 9)), type=pa.int32()), - pa.array([f"sv{i}" for i in range(1, 9)]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("val", pa.string())] - ), - ) - ) - await writer.flush() - - batch_scanner = await table.new_scan().create_record_batch_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - batch_scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - # Step 1: Collect 1 batch via async for then break - first_batch = None - - async def partial_consume(): - nonlocal first_batch - async for rb in batch_scanner: - first_batch = rb - break - - await asyncio.wait_for(partial_consume(), timeout=10.0) - assert first_batch is not None - - # Step 2: Sync poll_record_batch() must work (no deadlock) - sync_batches = batch_scanner.poll_record_batch(2000) - assert sync_batches is not None, "poll_record_batch() should return (not deadlock)" - - # Step 3: Sync poll_arrow() must also work - arrow_table = batch_scanner.poll_arrow(2000) - assert arrow_table is not None, "poll_arrow() should return (not deadlock)" - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_async_poll_negative_timeout(connection, admin): - """Verify _async_poll rejects a negative timeout_ms with an error.""" - table_path = fluss.TablePath("fluss", "py_test_async_poll_neg_timeout") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - scanner = await table.new_scan().create_log_scanner() - scanner.subscribe(bucket_id=0, start_offset=0) - - import pytest - - with pytest.raises(Exception, match="non-negative"): - await scanner._async_poll(-1) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_async_poll_returns_list(connection, admin): - """Verify _async_poll returns a Python list of ScanRecord objects.""" - table_path = fluss.TablePath("fluss", "py_test_async_poll_returns_list") - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array([1, 2, 3], type=pa.int32()), - pa.array(["x", "y", "z"]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("val", pa.string())] - ), - ) - ) - await writer.flush() - - scanner = await table.new_scan().create_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - # Poll until we get a non-empty result - result = None - deadline = time.monotonic() + 10 - while time.monotonic() < deadline: - result = await scanner._async_poll(2000) - if result: - break - - assert result is not None, "Expected non-None result from _async_poll" - assert isinstance(result, list), ( - f"Expected list, got {type(result).__name__}" - ) - assert len(result) > 0, "Expected non-empty list" - - # Each element must be a ScanRecord with .row, .offset, .timestamp - for record in result: - assert hasattr(record, "row"), "ScanRecord should have .row" - assert hasattr(record, "offset"), "ScanRecord should have .offset" - assert hasattr(record, "timestamp"), ( - "ScanRecord should have .timestamp" - ) - assert "id" in record.row - - # An empty poll (no new data) should return an empty list, not None - empty_result = await scanner._async_poll(100) - assert isinstance(empty_result, list), ( - f"Empty poll should return list, got {type(empty_result).__name__}" - ) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - -async def test_sync_methods_after_async_iteration(connection, admin): - """Verify sync poll() works correctly interleaved with async iteration. - - This proves there is no lock contention between the async and sync - code paths — the removed Mutex would have caused deadlocks here if - the lock were held across the async poll boundary. - """ - table_path = fluss.TablePath( - "fluss", "py_test_sync_after_async" - ) - await admin.drop_table(table_path, ignore_if_not_exists=True) - - schema = fluss.Schema( - pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) - ) - await admin.create_table(table_path, fluss.TableDescriptor(schema)) - - table = await connection.get_table(table_path) - writer = table.new_append().create_writer() - writer.write_arrow_batch( - pa.RecordBatch.from_arrays( - [ - pa.array(list(range(1, 9)), type=pa.int32()), - pa.array([f"s{i}" for i in range(1, 9)]), - ], - schema=pa.schema( - [pa.field("id", pa.int32()), pa.field("val", pa.string())] - ), - ) - ) - await writer.flush() - - scanner = await table.new_scan().create_log_scanner() - num_buckets = (await admin.get_table_info(table_path)).num_buckets - scanner.subscribe_buckets( - {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} - ) - - # Step 1: Collect 4 records via async for - async_records = [] - - async def partial_consume(): - async for record in scanner: - async_records.append(record) - if len(async_records) >= 4: - break - - await asyncio.wait_for(partial_consume(), timeout=10.0) - assert len(async_records) == 4 - - # Step 2: Collect remaining records via sync poll(). - # With small data, _async_poll may have fetched all records in one - # batch. After break, the un-yielded records are lost. The key - # assertion is that poll() works (no deadlock from a held lock). - sync_records = scanner.poll(2000) - assert sync_records is not None, "poll() should return (not deadlock)" - - # Step 3: Verify no duplicates and all IDs are valid - async_ids = {r.row["id"] for r in async_records} - sync_ids = {r.row["id"] for r in sync_records} - assert async_ids.isdisjoint(sync_ids), ( - f"Duplicate IDs: {async_ids & sync_ids}" - ) - all_ids = async_ids | sync_ids - assert all_ids.issubset(set(range(1, 9))), ( - f"Unexpected IDs: {all_ids - set(range(1, 9))}" - ) - - await admin.drop_table(table_path, ignore_if_not_exists=False) - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From db23dd6833eef59e21e258bd80575c6d1c419a0a Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 16 Mar 2026 17:09:42 -0700 Subject: [PATCH 12/14] chore: Remove _async_poll and _async_poll_batches from .pyi as they are private implementation details --- bindings/python/fluss/__init__.pyi | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 8d46bc63..3b8e532d 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -777,10 +777,6 @@ class LogScanner: """ def __repr__(self) -> str: ... def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ... - async def _async_poll(self, timeout_ms: Optional[int] = ...) -> List[ScanRecord]: ... - async def _async_poll_batches( - self, timeout_ms: Optional[int] = ... - ) -> List[RecordBatch]: ... class Schema: def __init__( From 6ad8cab9171da38e07df03bf0cd7197d387bb93b Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 16 Mar 2026 19:22:43 -0700 Subject: [PATCH 13/14] refactor: Remove scoping block and subs.clone() as they are not needed with Mutex no longer being used --- bindings/python/src/table.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 896685cb..999827d1 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2165,16 +2165,14 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let subscribed = { - let scanner = self.kind.as_batch()?; - let subs = scanner.get_subscribed_buckets(); - if subs.is_empty() { - return Err(FlussError::new_err( - "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", - )); - } - subs.clone() - }; + let scanner = self.kind.as_batch()?; + let subscribed = scanner.get_subscribed_buckets(); + + if subscribed.is_empty() { + return Err(FlussError::new_err( + "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", + )); + } // 2. Query latest offsets for all subscribed buckets let stopping_offsets = self.query_latest_offsets(py, &subscribed)?; From 3981fff33714680629fed5503a165d9589c007fe Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Mon, 16 Mar 2026 20:01:02 -0700 Subject: [PATCH 14/14] refactor: Combine the two _async_scan and _async_batch_scan into a single PyOnceLock and generator that takes a callable --- bindings/python/src/table.rs | 69 ++++++++++++++---------------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 999827d1..9b21101d 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -2201,50 +2201,35 @@ impl LogScanner { fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { let py = slf.py(); - match slf.kind.as_ref() { - ScannerKind::Record(_) => { - static RECORD_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); - let gen_fn = RECORD_ASYNC_GEN_FN.get_or_init(py, || { - let code = pyo3::ffi::c_str!( - r#" -async def _async_scan(scanner, timeout_ms=1000): + // Single lock for the generic async generator + static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + + let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_scan_generic(scanner, method_name, timeout_ms=1000): + # Dynamically resolve the polling method (e.g., _async_poll or _async_poll_batches) + poll_method = getattr(scanner, method_name) while True: - batch = await scanner._async_poll(timeout_ms) - if batch: - for record in batch: - yield record + items = await poll_method(timeout_ms) + if items: + for item in items: + yield item "# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None).unwrap(); - globals.get_item("_async_scan").unwrap().unwrap().unbind() - }); - gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) - } - ScannerKind::Batch(_) => { - static BATCH_ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); - let gen_fn = BATCH_ASYNC_GEN_FN.get_or_init(py, || { - let code = pyo3::ffi::c_str!( - r#" -async def _async_batch_scan(scanner, timeout_ms=1000): - while True: - batches = await scanner._async_poll_batches(timeout_ms) - if batches: - for rb in batches: - yield rb -"# - ); - let globals = pyo3::types::PyDict::new(py); - py.run(code, Some(&globals), None).unwrap(); - globals - .get_item("_async_batch_scan") - .unwrap() - .unwrap() - .unbind() - }); - gen_fn.bind(py).call1((slf.into_bound_py_any(py)?,)) - } - } + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan_generic").unwrap().unwrap().unbind() + }); + + // Determine which internal method to call based on the scanner kind + let method_name = match slf.kind.as_ref() { + ScannerKind::Record(_) => "_async_poll", + ScannerKind::Batch(_) => "_async_poll_batches", + }; + + // Instantiate the generator with the scanner instance and the target method name + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?, method_name)) } /// Perform a single bounded poll and return a list of ScanRecord objects.