diff --git a/bindings/python/fluss/__init__.pyi b/bindings/python/fluss/__init__.pyi index 417ac9b2..3b8e532d 100644 --- a/bindings/python/fluss/__init__.pyi +++ b/bindings/python/fluss/__init__.pyi @@ -19,7 +19,17 @@ from enum import IntEnum from types import TracebackType -from typing import Dict, Iterator, List, Optional, Tuple, Union, overload +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + overload, +) import pandas as pd import pyarrow as pa @@ -765,8 +775,8 @@ class LogScanner: You must call subscribe(), subscribe_buckets(), or subscribe_partition() first. """ - ... def __repr__(self) -> str: ... + def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ... class Schema: def __init__( diff --git a/bindings/python/src/table.rs b/bindings/python/src/table.rs index 660cd6be..9b21101d 100644 --- a/bindings/python/src/table.rs +++ b/bindings/python/src/table.rs @@ -23,6 +23,7 @@ use arrow_schema::SchemaRef; use fluss::record::to_arrow_schema; use fluss::rpc::message::OffsetSpec; use indexmap::IndexMap; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyTypeError}; use pyo3::sync::PyOnceLock; use pyo3::types::{ @@ -1887,7 +1888,7 @@ impl ScannerKind { /// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface. macro_rules! with_scanner { ($scanner:expr, $method:ident($($arg:expr),*)) => { - match $scanner { + match $scanner.as_ref() { ScannerKind::Record(s) => s.$method($($arg),*).await, ScannerKind::Batch(s) => s.$method($($arg),*).await, } @@ -1901,7 +1902,7 @@ macro_rules! with_scanner { /// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches #[pyclass] pub struct LogScanner { - scanner: ScannerKind, + kind: Arc, admin: fcore::client::FlussAdmin, table_info: fcore::metadata::TableInfo, /// The projected Arrow schema to use for empty table creation @@ -1922,7 +1923,7 @@ impl LogScanner { fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe(bucket_id, start_offset)) + with_scanner!(&self.kind, subscribe(bucket_id, start_offset)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1935,7 +1936,7 @@ impl LogScanner { fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets)) + with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -1957,7 +1958,7 @@ impl LogScanner { py.detach(|| { TOKIO_RUNTIME.block_on(async { with_scanner!( - &self.scanner, + &self.kind, subscribe_partition(partition_id, bucket_id, start_offset) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1977,7 +1978,7 @@ impl LogScanner { py.detach(|| { TOKIO_RUNTIME.block_on(async { with_scanner!( - &self.scanner, + &self.kind, subscribe_partition_buckets(&partition_bucket_offsets) ) .map_err(|e| FlussError::from_core_error(&e)) @@ -1992,7 +1993,7 @@ impl LogScanner { fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!(&self.scanner, unsubscribe(bucket_id)) + with_scanner!(&self.kind, unsubscribe(bucket_id)) .map_err(|e| FlussError::from_core_error(&e)) }) }) @@ -2006,11 +2007,8 @@ impl LogScanner { fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> { py.detach(|| { TOKIO_RUNTIME.block_on(async { - with_scanner!( - &self.scanner, - unsubscribe_partition(partition_id, bucket_id) - ) - .map_err(|e| FlussError::from_core_error(&e)) + with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id)) + .map_err(|e| FlussError::from_core_error(&e)) }) }) } @@ -2030,7 +2028,7 @@ impl LogScanner { /// - Returns an empty ScanRecords if no records are available /// - When timeout expires, returns an empty ScanRecords (NOT an error) fn poll(&self, py: Python, timeout_ms: i64) -> PyResult { - let scanner = self.scanner.as_record()?; + let scanner = self.kind.as_record()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2079,7 +2077,7 @@ impl LogScanner { /// - Returns an empty list if no batches are available /// - When timeout expires, returns an empty list (NOT an error) fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2114,7 +2112,7 @@ impl LogScanner { /// - Returns an empty table (with correct schema) if no records are available /// - When timeout expires, returns an empty table (NOT an error) fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; if timeout_ms < 0 { return Err(FlussError::new_err(format!( @@ -2167,8 +2165,9 @@ impl LogScanner { /// Returns: /// PyArrow Table containing all data from subscribed buckets fn to_arrow(&self, py: Python) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; let subscribed = scanner.get_subscribed_buckets(); + if subscribed.is_empty() { return Err(FlussError::new_err( "No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.", @@ -2199,6 +2198,156 @@ impl LogScanner { Ok(df) } + fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult> { + let py = slf.py(); + + // Single lock for the generic async generator + static ASYNC_GEN_FN: PyOnceLock> = PyOnceLock::new(); + + let gen_fn = ASYNC_GEN_FN.get_or_init(py, || { + let code = pyo3::ffi::c_str!( + r#" +async def _async_scan_generic(scanner, method_name, timeout_ms=1000): + # Dynamically resolve the polling method (e.g., _async_poll or _async_poll_batches) + poll_method = getattr(scanner, method_name) + while True: + items = await poll_method(timeout_ms) + if items: + for item in items: + yield item +"# + ); + let globals = pyo3::types::PyDict::new(py); + py.run(code, Some(&globals), None).unwrap(); + globals.get_item("_async_scan_generic").unwrap().unwrap().unbind() + }); + + // Determine which internal method to call based on the scanner kind + let method_name = match slf.kind.as_ref() { + ScannerKind::Record(_) => "_async_poll", + ScannerKind::Batch(_) => "_async_poll_batches", + }; + + // Instantiate the generator with the scanner instance and the target method name + gen_fn.bind(py).call1((slf.into_bound_py_any(py)?, method_name)) + } + + /// Perform a single bounded poll and return a list of ScanRecord objects. + /// + /// This is the async building block used by `__aiter__` to implement + /// `async for`. Each call does exactly one network poll (bounded by + /// `timeout_ms`), converts any results to Python objects, and returns + /// them as a list. An empty list signals a timeout (no data yet), not + /// end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of ScanRecord objects + fn _async_poll<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let projected_row_type = self.projected_row_type.clone(); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Record(s) => s, + ScannerKind::Batch(_) => { + return Err(PyTypeError::new_err( + "This internal method only supports record-based scanners. \ + For batch-based scanners, use 'async for' or 'poll_record_batch' instead.", + )); + } + }; + + let scan_records = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for (_, records) in scan_records.into_records_by_buckets() { + for core_record in records { + let scan_record = + ScanRecord::from_core(py, &core_record, &projected_row_type)?; + result.push(Py::new(py, scan_record)?); + } + } + Ok(result) + }) + }) + } + + /// Perform a single bounded poll and return a list of RecordBatch objects. + /// + /// This is the async building block used by `__aiter__` (batch mode) to + /// implement `async for`. Each call does exactly one network poll (bounded + /// by `timeout_ms`), converts any results to Python RecordBatch objects, + /// and returns them as a list. An empty list signals a timeout (no data + /// yet), not end-of-stream. + /// + /// Args: + /// timeout_ms: Timeout in milliseconds for the network poll (default: 1000) + /// + /// Returns: + /// Awaitable that resolves to a list of RecordBatch objects + fn _async_poll_batches<'py>( + &self, + py: Python<'py>, + timeout_ms: Option, + ) -> PyResult> { + let timeout_ms = timeout_ms.unwrap_or(1000); + if timeout_ms < 0 { + return Err(FlussError::new_err(format!( + "timeout_ms must be non-negative, got: {timeout_ms}" + ))); + } + + let scanner = Arc::clone(&self.kind); + let timeout = Duration::from_millis(timeout_ms as u64); + + future_into_py(py, async move { + let core_scanner = match scanner.as_ref() { + ScannerKind::Batch(s) => s, + ScannerKind::Record(_) => { + return Err(PyTypeError::new_err( + "This internal method only supports batch-based scanners. \ + For record-based scanners, use 'async for' or 'poll' instead.", + )); + } + }; + + let scan_batches = core_scanner + .poll(timeout) + .await + .map_err(|e| FlussError::from_core_error(&e))?; + + // Convert to Python list of RecordBatch objects + Python::attach(|py| { + let mut result: Vec> = Vec::new(); + for scan_batch in scan_batches { + let rb = RecordBatch::from_scan_batch(scan_batch); + result.push(Py::new(py, rb)?); + } + Ok(result) + }) + }) + } + fn __repr__(&self) -> String { format!("LogScanner(table={})", self.table_info.table_path) } @@ -2213,7 +2362,7 @@ impl LogScanner { projected_row_type: fcore::metadata::RowType, ) -> Self { Self { - scanner, + kind: Arc::new(scanner), admin, table_info, projected_schema, @@ -2264,7 +2413,7 @@ impl LogScanner { py: Python, subscribed: &[(fcore::metadata::TableBucket, i64)], ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; let is_partitioned = scanner.is_partitioned(); let table_path = &self.table_info.table_path; @@ -2367,7 +2516,7 @@ impl LogScanner { py: Python, mut stopping_offsets: HashMap, ) -> PyResult> { - let scanner = self.scanner.as_batch()?; + let scanner = self.kind.as_batch()?; let mut all_batches = Vec::new(); while !stopping_offsets.is_empty() { diff --git a/bindings/python/test/test_log_table.py b/bindings/python/test/test_log_table.py index dd1a4d4f..eb118748 100644 --- a/bindings/python/test/test_log_table.py +++ b/bindings/python/test/test_log_table.py @@ -729,6 +729,382 @@ async def test_scan_records_indexing_and_slicing(connection, admin): await admin.drop_table(table_path, ignore_if_not_exists=False) +async def test_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on LogScanner.""" + table_path = fluss.TablePath("fluss", "py_test_async_iterator") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + # Write 5 records + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [pa.array(list(range(1, 6)), type=pa.int32()), + pa.array([f"async{i}" for i in range(1, 6)])], + schema=pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets({i: fluss.EARLIEST_OFFSET for i in range(num_buckets)}) + + collected = [] + + # Here is the magical Issue #424 async iterator logic at work: + async def consume_scanner(): + async for record in scanner: + collected.append(record) + if len(collected) == 5: + break + + # We must race the consumption against a timeout so the test doesn't hang if the iterator is broken + await asyncio.wait_for(consume_scanner(), timeout=10.0) + + assert len(collected) == 5, f"Expected 5 records, got {len(collected)}" + + collected.sort(key=lambda r: r.row["id"]) + for i, record in enumerate(collected): + assert record.row["id"] == i + 1 + assert record.row["val"] == f"async{i + 1}" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + `poll()` calls. If the old implementation's tokio::spawn'd task + were still alive, it would hold the Mutex and cause `poll()` to + deadlock or error. + """ + table_path = fluss.TablePath("fluss", "py_test_async_break_leak") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"v{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect only 3 of 10) + collected_async = [] + + async def consume_and_break(): + async for record in scanner: + collected_async.append(record) + if len(collected_async) >= 3: + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert len(collected_async) == 3, ( + f"Expected 3 records from async for, got {len(collected_async)}" + ) + + # Phase 2: sync poll() must still work — proves no leaked task / lock. + # With small data and few buckets, _async_poll may have fetched all + # records in one batch. After break, the un-yielded records from that + # batch are lost. So sync poll may return 0 records — the key assertion + # is that poll() completes without deadlock (returns within timeout). + remaining = scanner.poll(2000) + assert remaining is not None, "poll() should return (not deadlock)" + + # If we got records, verify no duplicates + async_ids = {r.row["id"] for r in collected_async} + sync_ids = {r.row["id"] for r in remaining} + assert async_ids.isdisjoint(sync_ids), ( + f"Duplicate IDs between async and sync: {async_ids & sync_ids}" + ) + + # All IDs must be from the original 1-10 range + all_ids = async_ids | sync_ids + assert all_ids.issubset(set(range(1, 11))), ( + f"Unexpected IDs: {all_ids - set(range(1, 11))}" + ) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_async_iterator_multiple_batches(connection, admin): + """Verify async iteration works across multiple network poll cycles. + + _async_poll does a single bounded poll per call. Writing 20 records + to multiple buckets ensures the Python generator must loop through + several _async_poll calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_async_multi_batch") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"multi{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + scanner = await table.new_scan().create_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected = [] + + async def consume_all(): + async for record in scanner: + collected.append(record) + if len(collected) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(collected) == num_records, ( + f"Expected {num_records} records, got {len(collected)}" + ) + + # Verify all IDs are present (order may vary due to bucketing) + ids = sorted(r.row["id"] for r in collected) + assert ids == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator(connection, admin): + """Test the Python asynchronous iterator loop (`async for`) on a batch LogScanner. + + With our __aiter__ dispatch, a batch-based scanner should yield RecordBatch + objects (not ScanRecord). Each yielded item has .batch (PyArrow RecordBatch), + .bucket, .base_offset, .last_offset. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_iter") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 7)), type=pa.int32()), + pa.array([f"bv{i}" for i in range(1, 7)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + collected_batches = [] + total_rows = 0 + + async def consume_batches(): + nonlocal total_rows + async for rb in batch_scanner: + collected_batches.append(rb) + total_rows += rb.batch.num_rows + if total_rows >= 6: + break + + await asyncio.wait_for(consume_batches(), timeout=15.0) + + assert total_rows >= 6, f"Expected >=6 total rows, got {total_rows}" + assert len(collected_batches) > 0 + + # Verify each yielded item is a RecordBatch with expected attributes + for rb in collected_batches: + assert hasattr(rb, "batch"), "RecordBatch should have .batch" + assert hasattr(rb, "bucket"), "RecordBatch should have .bucket" + assert hasattr(rb, "base_offset"), "RecordBatch should have .base_offset" + assert hasattr(rb, "last_offset"), "RecordBatch should have .last_offset" + # .batch should be a PyArrow RecordBatch + arrow_batch = rb.batch + assert isinstance(arrow_batch, pa.RecordBatch), ( + f"Expected PyArrow RecordBatch, got {type(arrow_batch).__name__}" + ) + assert arrow_batch.num_columns == 2 + assert set(arrow_batch.schema.names) == {"id", "val"} + + # Verify all 6 IDs are present + all_ids = [] + for rb in collected_batches: + all_ids.extend(rb.batch.column("id").to_pylist()) + assert sorted(all_ids[:6]) == [1, 2, 3, 4, 5, 6] + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_break_no_leak(connection, admin): + """Verify that breaking out of batch `async for` does not leak resources. + + After breaking, the scanner must still be usable for synchronous + poll_record_batch() calls, proving no leaked task or lock. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_break") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + await admin.create_table(table_path, fluss.TableDescriptor(schema)) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, 11)), type=pa.int32()), + pa.array([f"bl{i}" for i in range(1, 11)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + # Phase 1: async for with early break (collect just 1 batch) + first_batch = None + + async def consume_and_break(): + nonlocal first_batch + async for rb in batch_scanner: + first_batch = rb + break + + await asyncio.wait_for(consume_and_break(), timeout=10.0) + assert first_batch is not None, "Should have received at least 1 batch" + assert first_batch.batch.num_rows > 0 + + # Phase 2: sync poll_record_batch() must still work — proves no leak + remaining = batch_scanner.poll_record_batch(2000) + assert remaining is not None, "poll_record_batch() should return (not deadlock)" + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + +async def test_batch_async_iterator_multiple_batches(connection, admin): + """Verify batch async iteration works across multiple network poll cycles. + + Writing 20 records to 3 buckets ensures the generator must loop through + several _async_poll_batches calls to collect them all. + """ + table_path = fluss.TablePath("fluss", "py_test_batch_async_multi") + await admin.drop_table(table_path, ignore_if_not_exists=True) + + schema = fluss.Schema( + pa.schema([pa.field("id", pa.int32()), pa.field("val", pa.string())]) + ) + table_descriptor = fluss.TableDescriptor( + schema, bucket_count=3, bucket_keys=["id"] + ) + await admin.create_table( + table_path, table_descriptor, ignore_if_exists=False + ) + + table = await connection.get_table(table_path) + writer = table.new_append().create_writer() + + num_records = 20 + writer.write_arrow_batch( + pa.RecordBatch.from_arrays( + [ + pa.array(list(range(1, num_records + 1)), type=pa.int32()), + pa.array([f"bm{i}" for i in range(1, num_records + 1)]), + ], + schema=pa.schema( + [pa.field("id", pa.int32()), pa.field("val", pa.string())] + ), + ) + ) + await writer.flush() + + batch_scanner = await table.new_scan().create_record_batch_log_scanner() + num_buckets = (await admin.get_table_info(table_path)).num_buckets + batch_scanner.subscribe_buckets( + {i: fluss.EARLIEST_OFFSET for i in range(num_buckets)} + ) + + all_ids = [] + + async def consume_all(): + async for rb in batch_scanner: + all_ids.extend(rb.batch.column("id").to_pylist()) + if len(all_ids) >= num_records: + break + + await asyncio.wait_for(consume_all(), timeout=15.0) + assert len(all_ids) >= num_records, ( + f"Expected >={num_records} IDs, got {len(all_ids)}" + ) + assert sorted(all_ids[:num_records]) == list(range(1, num_records + 1)) + + await admin.drop_table(table_path, ignore_if_not_exists=False) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -755,3 +1131,4 @@ def _poll_arrow_ids(scanner, expected_count, timeout_s=10): if arrow_table.num_rows > 0: all_ids.extend(arrow_table.column("id").to_pylist()) return all_ids +