Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions bindings/python/fluss/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@

from enum import IntEnum
from types import TracebackType
from typing import Dict, Iterator, List, Optional, Tuple, Union, overload
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
overload,
)

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -765,8 +775,8 @@ class LogScanner:

You must call subscribe(), subscribe_buckets(), or subscribe_partition() first.
"""
...
def __repr__(self) -> str: ...
def __aiter__(self) -> AsyncIterator[Union[ScanRecord, RecordBatch]]: ...

class Schema:
def __init__(
Expand Down
187 changes: 168 additions & 19 deletions bindings/python/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_schema::SchemaRef;
use fluss::record::to_arrow_schema;
use fluss::rpc::message::OffsetSpec;
use indexmap::IndexMap;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyTypeError};
use pyo3::sync::PyOnceLock;
use pyo3::types::{
Expand Down Expand Up @@ -1887,7 +1888,7 @@ impl ScannerKind {
/// Both `LogScanner` and `RecordBatchLogScanner` share the same subscribe interface.
macro_rules! with_scanner {
($scanner:expr, $method:ident($($arg:expr),*)) => {
match $scanner {
match $scanner.as_ref() {
ScannerKind::Record(s) => s.$method($($arg),*).await,
ScannerKind::Batch(s) => s.$method($($arg),*).await,
}
Expand All @@ -1901,7 +1902,7 @@ macro_rules! with_scanner {
/// - Batch-based scanning via `poll_arrow()` / `poll_record_batch()` - returns Arrow batches
#[pyclass]
pub struct LogScanner {
scanner: ScannerKind,
kind: Arc<ScannerKind>,
admin: fcore::client::FlussAdmin,
table_info: fcore::metadata::TableInfo,
/// The projected Arrow schema to use for empty table creation
Expand All @@ -1922,7 +1923,7 @@ impl LogScanner {
fn subscribe(&self, py: Python, bucket_id: i32, start_offset: i64) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe(bucket_id, start_offset))
with_scanner!(&self.kind, subscribe(bucket_id, start_offset))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1935,7 +1936,7 @@ impl LogScanner {
fn subscribe_buckets(&self, py: Python, bucket_offsets: HashMap<i32, i64>) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, subscribe_buckets(&bucket_offsets))
with_scanner!(&self.kind, subscribe_buckets(&bucket_offsets))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -1957,7 +1958,7 @@ impl LogScanner {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
&self.kind,
subscribe_partition(partition_id, bucket_id, start_offset)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1977,7 +1978,7 @@ impl LogScanner {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
&self.kind,
subscribe_partition_buckets(&partition_bucket_offsets)
)
.map_err(|e| FlussError::from_core_error(&e))
Expand All @@ -1992,7 +1993,7 @@ impl LogScanner {
fn unsubscribe(&self, py: Python, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(&self.scanner, unsubscribe(bucket_id))
with_scanner!(&self.kind, unsubscribe(bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
Expand All @@ -2006,11 +2007,8 @@ impl LogScanner {
fn unsubscribe_partition(&self, py: Python, partition_id: i64, bucket_id: i32) -> PyResult<()> {
py.detach(|| {
TOKIO_RUNTIME.block_on(async {
with_scanner!(
&self.scanner,
unsubscribe_partition(partition_id, bucket_id)
)
.map_err(|e| FlussError::from_core_error(&e))
with_scanner!(&self.kind, unsubscribe_partition(partition_id, bucket_id))
.map_err(|e| FlussError::from_core_error(&e))
})
})
}
Expand All @@ -2030,7 +2028,7 @@ impl LogScanner {
/// - Returns an empty ScanRecords if no records are available
/// - When timeout expires, returns an empty ScanRecords (NOT an error)
fn poll(&self, py: Python, timeout_ms: i64) -> PyResult<ScanRecords> {
let scanner = self.scanner.as_record()?;
let scanner = self.kind.as_record()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2079,7 +2077,7 @@ impl LogScanner {
/// - Returns an empty list if no batches are available
/// - When timeout expires, returns an empty list (NOT an error)
fn poll_record_batch(&self, py: Python, timeout_ms: i64) -> PyResult<Vec<RecordBatch>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2114,7 +2112,7 @@ impl LogScanner {
/// - Returns an empty table (with correct schema) if no records are available
/// - When timeout expires, returns an empty table (NOT an error)
fn poll_arrow(&self, py: Python, timeout_ms: i64) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;

if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
Expand Down Expand Up @@ -2167,8 +2165,9 @@ impl LogScanner {
/// Returns:
/// PyArrow Table containing all data from subscribed buckets
fn to_arrow(&self, py: Python) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;
let subscribed = scanner.get_subscribed_buckets();

if subscribed.is_empty() {
return Err(FlussError::new_err(
"No buckets subscribed. Call subscribe(), subscribe_buckets(), subscribe_partition(), or subscribe_partition_buckets() first.",
Expand Down Expand Up @@ -2199,6 +2198,156 @@ impl LogScanner {
Ok(df)
}

fn __aiter__<'py>(slf: PyRef<'py, Self>) -> PyResult<Bound<'py, PyAny>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add this method to .pyi stubs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, just added __aiter__ to __init__.pyi here 134e56b along with with _async_poll and _async_poll_batches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to leave _async_poll and _async_poll_batches out of .pyi bc these methods ideally should be private implementation details.
So exposing __aiter__ makes sense to just signal IDE that we support async for, but the rest of underscore methods added - we don't want to encourage users to use them directly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fresh-borzoni, thanks for the clarification, removed those two entries in the .pyi file in db23dd6.

let py = slf.py();

// Single lock for the generic async generator
static ASYNC_GEN_FN: PyOnceLock<Py<PyAny>> = PyOnceLock::new();

let gen_fn = ASYNC_GEN_FN.get_or_init(py, || {
let code = pyo3::ffi::c_str!(
r#"
async def _async_scan_generic(scanner, method_name, timeout_ms=1000):
# Dynamically resolve the polling method (e.g., _async_poll or _async_poll_batches)
poll_method = getattr(scanner, method_name)
while True:
items = await poll_method(timeout_ms)
if items:
for item in items:
yield item
"#
);
let globals = pyo3::types::PyDict::new(py);
py.run(code, Some(&globals), None).unwrap();
globals.get_item("_async_scan_generic").unwrap().unwrap().unbind()
});

// Determine which internal method to call based on the scanner kind
let method_name = match slf.kind.as_ref() {
ScannerKind::Record(_) => "_async_poll",
ScannerKind::Batch(_) => "_async_poll_batches",
};

// Instantiate the generator with the scanner instance and the target method name
gen_fn.bind(py).call1((slf.into_bound_py_any(py)?, method_name))
}

/// Perform a single bounded poll and return a list of ScanRecord objects.
///
/// This is the async building block used by `__aiter__` to implement
/// `async for`. Each call does exactly one network poll (bounded by
/// `timeout_ms`), converts any results to Python objects, and returns
/// them as a list. An empty list signals a timeout (no data yet), not
/// end-of-stream.
///
/// Args:
/// timeout_ms: Timeout in milliseconds for the network poll (default: 1000)
///
/// Returns:
/// Awaitable that resolves to a list of ScanRecord objects
fn _async_poll<'py>(
&self,
py: Python<'py>,
timeout_ms: Option<i64>,
) -> PyResult<Bound<'py, PyAny>> {
let timeout_ms = timeout_ms.unwrap_or(1000);
if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
"timeout_ms must be non-negative, got: {timeout_ms}"
)));
}

let scanner = Arc::clone(&self.kind);
let projected_row_type = self.projected_row_type.clone();
let timeout = Duration::from_millis(timeout_ms as u64);

future_into_py(py, async move {
let core_scanner = match scanner.as_ref() {
ScannerKind::Record(s) => s,
ScannerKind::Batch(_) => {
return Err(PyTypeError::new_err(
"This internal method only supports record-based scanners. \
For batch-based scanners, use 'async for' or 'poll_record_batch' instead.",
));
}
};

let scan_records = core_scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;

// Convert to Python list
Python::attach(|py| {
let mut result: Vec<Py<ScanRecord>> = Vec::new();
for (_, records) in scan_records.into_records_by_buckets() {
for core_record in records {
let scan_record =
ScanRecord::from_core(py, &core_record, &projected_row_type)?;
result.push(Py::new(py, scan_record)?);
}
}
Ok(result)
})
})
}

/// Perform a single bounded poll and return a list of RecordBatch objects.
///
/// This is the async building block used by `__aiter__` (batch mode) to
/// implement `async for`. Each call does exactly one network poll (bounded
/// by `timeout_ms`), converts any results to Python RecordBatch objects,
/// and returns them as a list. An empty list signals a timeout (no data
/// yet), not end-of-stream.
///
/// Args:
/// timeout_ms: Timeout in milliseconds for the network poll (default: 1000)
///
/// Returns:
/// Awaitable that resolves to a list of RecordBatch objects
fn _async_poll_batches<'py>(
&self,
py: Python<'py>,
timeout_ms: Option<i64>,
) -> PyResult<Bound<'py, PyAny>> {
let timeout_ms = timeout_ms.unwrap_or(1000);
if timeout_ms < 0 {
return Err(FlussError::new_err(format!(
"timeout_ms must be non-negative, got: {timeout_ms}"
)));
}

let scanner = Arc::clone(&self.kind);
let timeout = Duration::from_millis(timeout_ms as u64);

future_into_py(py, async move {
let core_scanner = match scanner.as_ref() {
ScannerKind::Batch(s) => s,
ScannerKind::Record(_) => {
return Err(PyTypeError::new_err(
"This internal method only supports batch-based scanners. \
For record-based scanners, use 'async for' or 'poll' instead.",
));
}
};

let scan_batches = core_scanner
.poll(timeout)
.await
.map_err(|e| FlussError::from_core_error(&e))?;

// Convert to Python list of RecordBatch objects
Python::attach(|py| {
let mut result: Vec<Py<RecordBatch>> = Vec::new();
for scan_batch in scan_batches {
let rb = RecordBatch::from_scan_batch(scan_batch);
result.push(Py::new(py, rb)?);
}
Ok(result)
})
})
}

fn __repr__(&self) -> String {
format!("LogScanner(table={})", self.table_info.table_path)
}
Expand All @@ -2213,7 +2362,7 @@ impl LogScanner {
projected_row_type: fcore::metadata::RowType,
) -> Self {
Self {
scanner,
kind: Arc::new(scanner),
admin,
table_info,
projected_schema,
Expand Down Expand Up @@ -2264,7 +2413,7 @@ impl LogScanner {
py: Python,
subscribed: &[(fcore::metadata::TableBucket, i64)],
) -> PyResult<HashMap<fcore::metadata::TableBucket, i64>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;
let is_partitioned = scanner.is_partitioned();
let table_path = &self.table_info.table_path;

Expand Down Expand Up @@ -2367,7 +2516,7 @@ impl LogScanner {
py: Python,
mut stopping_offsets: HashMap<fcore::metadata::TableBucket, i64>,
) -> PyResult<Py<PyAny>> {
let scanner = self.scanner.as_batch()?;
let scanner = self.kind.as_batch()?;
let mut all_batches = Vec::new();

while !stopping_offsets.is_empty() {
Expand Down
Loading