Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/python/dataloader/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.12"
license = {text = "BSD-2-Clause"}
keywords = ["openhouse", "data-loader", "lakehouse", "iceberg", "datafusion"]
dependencies = ["datafusion==51.0.0", "li-pyiceberg==0.11.2", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"]
dependencies = ["datafusion==51.0.0", "li-pyiceberg==0.11.3", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"]

[[tool.uv.index]]
url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/simple/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
filters: Filter | None = None,
context: DataLoaderContext | None = None,
max_attempts: int = 3,
batch_size: int | None = None,
):
"""
Args:
Expand All @@ -126,6 +127,10 @@ def __init__(
filters: Row filter expression, defaults to always_true() (all rows)
context: Data loader context
max_attempts: Total number of attempts including the initial try (default 3)
batch_size: Maximum number of rows per RecordBatch yielded by each split.
Passed to PyArrow's Scanner which produces batches of at most this many
rows. Smaller values reduce peak memory but increase per-batch overhead.
None uses the PyArrow default (~131K rows).
"""
if branch is not None and branch.strip() == "":
raise ValueError("branch must not be empty or whitespace")
Expand All @@ -138,6 +143,7 @@ def __init__(
self._filters = filters if filters is not None else always_true()
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts
self._batch_size = batch_size

if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None:
apply_libhdfs_opts(self._context.jvm_config.planner_args)
Expand Down Expand Up @@ -260,4 +266,5 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
scan_context=scan_context,
transform_sql=optimized_sql,
udf_registry=self._context.udf_registry,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datafusion.context import SessionContext
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import FileScanTask
from pyiceberg.table import ArrivalOrder, FileScanTask

from openhouse.dataloader._jvm import apply_libhdfs_opts
from openhouse.dataloader._table_scan_context import TableScanContext
Expand Down Expand Up @@ -53,11 +53,13 @@ def __init__(
scan_context: TableScanContext,
transform_sql: str | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._scan_context = scan_context
self._transform_sql = transform_sql
self._udf_registry = udf_registry or NoOpRegistry()
self._batch_size = batch_size

@property
def id(self) -> str:
Expand All @@ -76,7 +78,8 @@ def __iter__(self) -> Iterator[RecordBatch]:
"""Reads the file scan task and yields Arrow RecordBatches.

Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution,
delete files, and partition spec lookups.
delete files, and partition spec lookups. The number of batches loaded
into memory at once is bounded to prevent using too much memory at once.
"""
ctx = self._scan_context
if ctx.worker_jvm_args is not None:
Expand All @@ -88,7 +91,10 @@ def __iter__(self) -> Iterator[RecordBatch]:
row_filter=ctx.row_filter,
)

batches = arrow_scan.to_record_batches([self._file_scan_task])
batches = arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)

if self._transform_sql is None:
yield from batches
Expand Down
11 changes: 8 additions & 3 deletions integrations/python/dataloader/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,13 @@ def read_token() -> str:
snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id
assert snap1 is not None

# 4. Read all data
result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID))
# 4. Read all data with batch_size and verify batch count
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, batch_size=2)
batches = [batch for split in loader for batch in split]
assert len(batches) == 2, f"Expected 2 batches (3 rows, batch_size=2), got {len(batches)}"
for batch in batches:
assert batch.num_rows <= 2
result = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).sort_by(COL_ID)
finally:
os.dup2(saved_stdout, 1)
os.close(saved_stdout)
Expand All @@ -240,7 +245,7 @@ def read_token() -> str:
assert result.column(COL_ID).to_pylist() == [1, 2, 3]
assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"]
assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3]
print(f"PASS: read all {result.num_rows} rows")
print(f"PASS: read all {result.num_rows} rows in {len(batches)} batches (batch_size=2)")

# 5a. Row filter
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=col(COL_ID) > 1)
Expand Down
137 changes: 137 additions & 0 deletions integrations/python/dataloader/tests/test_arrival_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Tests verifying the ArrivalOrder API from pyiceberg PR #3046 is available and functional.

These tests confirm that the openhouse dataloader can access the new ScanOrder class hierarchy
added upstream (apache/iceberg-python#3046) and that ArrowScan.to_record_batches accepts the
order parameter.
"""

import os

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io import load_file_io
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
from pyiceberg.schema import Schema
from pyiceberg.table import ArrivalOrder, FileScanTask, ScanOrder, TaskOrder
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
from pyiceberg.types import LongType, NestedField, StringType

_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
NestedField(field_id=2, name="name", field_type=StringType(), required=False),
)


def _write_parquet(tmp_path: object, table: pa.Table) -> str:
"""Write a parquet file with Iceberg field IDs and return its path."""
file_path = str(tmp_path / "test.parquet") # type: ignore[operator]
fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)]
pq.write_table(table.cast(pa.schema(fields)), file_path)
return file_path


def _make_arrow_scan(tmp_path: object, file_path: str) -> ArrowScan:
metadata = new_table_metadata(
schema=_SCHEMA,
partition_spec=UNPARTITIONED_PARTITION_SPEC,
sort_order=UNSORTED_SORT_ORDER,
location=str(tmp_path),
properties={},
)
return ArrowScan(
table_metadata=metadata,
io=load_file_io(properties={}, location=file_path),
projected_schema=_SCHEMA,
row_filter=AlwaysTrue(),
)


def _make_file_scan_task(file_path: str, table: pa.Table) -> FileScanTask:
data_file = DataFile.from_args(
file_path=file_path,
file_format=FileFormat.PARQUET,
record_count=table.num_rows,
file_size_in_bytes=os.path.getsize(file_path),
)
data_file._spec_id = 0
return FileScanTask(data_file=data_file)


def _sample_table() -> pa.Table:
return pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()),
"name": pa.array(["alice", "bob", "charlie"], type=pa.string()),
}
)


class TestScanOrderImports:
"""Verify the ScanOrder class hierarchy is importable from pyiceberg.table."""

def test_scan_order_base_class_exists(self) -> None:
assert ScanOrder is not None

def test_task_order_is_scan_order(self) -> None:
assert issubclass(TaskOrder, ScanOrder)

def test_arrival_order_is_scan_order(self) -> None:
assert issubclass(ArrivalOrder, ScanOrder)

def test_arrival_order_default_params(self) -> None:
ao = ArrivalOrder()
assert ao.concurrent_streams == 8
assert ao.batch_size is None
assert ao.max_buffered_batches == 16

def test_arrival_order_custom_params(self) -> None:
ao = ArrivalOrder(concurrent_streams=4, batch_size=32768, max_buffered_batches=8)
assert ao.concurrent_streams == 4
assert ao.batch_size == 32768
assert ao.max_buffered_batches == 8

def test_arrival_order_rejects_invalid_concurrent_streams(self) -> None:
with pytest.raises(ValueError, match="concurrent_streams"):
ArrivalOrder(concurrent_streams=0)

def test_arrival_order_rejects_invalid_max_buffered_batches(self) -> None:
with pytest.raises(ValueError, match="max_buffered_batches"):
ArrivalOrder(max_buffered_batches=0)


class TestToRecordBatchesOrder:
"""Verify ArrowScan.to_record_batches accepts the order parameter and returns correct data."""

def test_default_order_returns_all_rows(self, tmp_path: object) -> None:
"""Default (TaskOrder) still works — backward compatible."""
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task]))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]

def test_explicit_task_order_returns_all_rows(self, tmp_path: object) -> None:
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task], order=TaskOrder()))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]

def test_arrival_order_returns_all_rows(self, tmp_path: object) -> None:
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task], order=ArrivalOrder(concurrent_streams=2)))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]
assert result.column("name").to_pylist() == ["alice", "bob", "charlie"]
27 changes: 27 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,33 @@ def fake_scan(**kwargs):
assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet"


# --- batch_size tests ---


def test_batch_size_forwarded_to_splits(tmp_path):
"""batch_size is correctly passed through to each DataLoaderSplit."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=32768)
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size == 32768


def test_batch_size_default_is_none(tmp_path):
"""Omitting batch_size defaults to None in each split."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size is None


# --- Predicate pushdown with transformer tests ---


Expand Down
46 changes: 46 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _create_test_split(
transform_sql: str | None = None,
table_id: TableIdentifier = _DEFAULT_TABLE_ID,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
) -> DataLoaderSplit:
"""Create a DataLoaderSplit for testing by writing data to disk.

Expand Down Expand Up @@ -103,6 +104,7 @@ def _create_test_split(
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=udf_registry,
batch_size=batch_size,
)


Expand Down Expand Up @@ -422,3 +424,47 @@ def test_worker_jvm_args_sets_libhdfs_opts(tmp_path, monkeypatch):
list(split)

assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx512m"


# --- batch_size tests ---

_BATCH_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
)


def _make_table(num_rows: int) -> pa.Table:
return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())})


def test_split_batch_size_limits_rows_per_batch(tmp_path):
"""When batch_size is set, each RecordBatch has at most that many rows."""
table = _make_table(100)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10)

batches = list(split)

assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows"
for batch in batches:
assert batch.num_rows <= 10
assert sum(b.num_rows for b in batches) == 100


def test_split_batch_size_none_returns_all_rows(tmp_path):
"""Default batch_size (None) returns all data correctly."""
table = _make_table(50)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 50
assert sorted(result.column("id").to_pylist()) == list(range(50))


def test_split_batch_size_preserves_data(tmp_path):
"""batch_size controls chunking but all data is preserved."""
table = _make_table(25)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 25
assert sorted(result.column("id").to_pylist()) == list(range(25))
Loading