Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion integrations/python/dataloader/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[tool.hatch.metadata]
allow-direct-references = true

[project]
name = "openhouse.dataloader"
dynamic = ["version"]
Expand All @@ -10,7 +13,7 @@ readme = "README.md"
requires-python = ">=3.12"
license = {text = "BSD-2-Clause"}
keywords = ["openhouse", "data-loader", "lakehouse", "iceberg", "datafusion"]
dependencies = ["datafusion==51.0.0", "pyiceberg~=0.11.0", "requests>=2.31.0", "tenacity>=8.0.0"]
dependencies = ["datafusion==51.0.0", "pyiceberg @ git+https://github.com/sumedhsakdeo/iceberg-python@75ba28bfc6d8bbeac398357c6db80327632a2dc8", "requests>=2.31.0", "tenacity>=8.0.0"]

[project.optional-dependencies]
dev = ["responses>=0.25.0", "ruff>=0.9.0", "pytest>=8.0.0", "twine>=6.0.0", "mypy>=1.14.0", "types-requests>=2.31.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
filters: Filter | None = None,
context: DataLoaderContext | None = None,
max_attempts: int = 3,
batch_size: int | None = None,
):
"""
Args:
Expand All @@ -90,6 +91,10 @@ def __init__(
filters: Row filter expression, defaults to always_true() (all rows)
context: Data loader context
max_attempts: Total number of attempts including the initial try (default 3)
batch_size: Maximum number of rows per RecordBatch yielded by each split.
Passed to PyArrow's Scanner which produces batches of at most this many
rows. Smaller values reduce peak memory but increase per-batch overhead.
None uses the PyArrow default (~131K rows).
"""
self._catalog = catalog
self._table_id = TableIdentifier(database, table, branch)
Expand All @@ -98,6 +103,7 @@ def __init__(
self._filters = filters if filters is not None else always_true()
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts
self._batch_size = batch_size
Comment thread
cbb330 marked this conversation as resolved.

@cached_property
def _iceberg_table(self) -> Table:
Expand Down Expand Up @@ -163,4 +169,5 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
yield DataLoaderSplit(
file_scan_task=scan_task,
scan_context=scan_context,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datafusion.substrait import Producer
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import FileScanTask
from pyiceberg.table import ArrivalOrder, FileScanTask

from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry
Expand All @@ -25,10 +25,12 @@ def __init__(
plan: LogicalPlan | None = None,
session_context: SessionContext | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._udf_registry = udf_registry or NoOpRegistry()
self._scan_context = scan_context
self._batch_size = batch_size

if (plan is None) != (session_context is None):
raise ValueError("plan and session_context must both be provided or both be None")
Expand Down Expand Up @@ -59,7 +61,8 @@ def __iter__(self) -> Iterator[RecordBatch]:
"""Reads the file scan task and yields Arrow RecordBatches.

Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution,
delete files, and partition spec lookups.
delete files, and partition spec lookups. The number of batches loaded
into memory at once is bounded to prevent using too much memory at once.
"""
ctx = self._scan_context
arrow_scan = ArrowScan(
Expand All @@ -68,4 +71,7 @@ def __iter__(self) -> Iterator[RecordBatch]:
projected_schema=ctx.projected_schema,
row_filter=ctx.row_filter,
)
yield from arrow_scan.to_record_batches([self._file_scan_task])
yield from arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)
12 changes: 8 additions & 4 deletions integrations/python/dataloader/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,18 @@ def read_token() -> str:
snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id
assert snap1 is not None

# 4. Read all data and verify
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)
result = _read_all(loader)
# 4. Read all data with batch_size and verify batch count
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, batch_size=2)
batches = [batch for split in loader for batch in split]
assert len(batches) == 2, f"Expected 2 batches (3 rows, batch_size=2), got {len(batches)}"
for batch in batches:
assert batch.num_rows <= 2
result = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).sort_by(COL_ID)
assert result.num_rows == 3
assert result.column(COL_ID).to_pylist() == [1, 2, 3]
assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"]
assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3]
print(f"PASS: read all {result.num_rows} rows")
print(f"PASS: read all {result.num_rows} rows in {len(batches)} batches (batch_size=2)")

# 5a. Row filter
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=col(COL_ID) > 1)
Expand Down
137 changes: 137 additions & 0 deletions integrations/python/dataloader/tests/test_arrival_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Tests verifying the ArrivalOrder API from pyiceberg PR #3046 is available and functional.

These tests confirm that the openhouse dataloader can access the new ScanOrder class hierarchy
added upstream (apache/iceberg-python#3046) and that ArrowScan.to_record_batches accepts the
order parameter.
"""

import os

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pyiceberg.expressions import AlwaysTrue
from pyiceberg.io import load_file_io
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
from pyiceberg.schema import Schema
from pyiceberg.table import ArrivalOrder, FileScanTask, ScanOrder, TaskOrder
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
from pyiceberg.types import LongType, NestedField, StringType

_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
NestedField(field_id=2, name="name", field_type=StringType(), required=False),
)


def _write_parquet(tmp_path: object, table: pa.Table) -> str:
"""Write a parquet file with Iceberg field IDs and return its path."""
file_path = str(tmp_path / "test.parquet") # type: ignore[operator]
fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)]
pq.write_table(table.cast(pa.schema(fields)), file_path)
return file_path


def _make_arrow_scan(tmp_path: object, file_path: str) -> ArrowScan:
metadata = new_table_metadata(
schema=_SCHEMA,
partition_spec=UNPARTITIONED_PARTITION_SPEC,
sort_order=UNSORTED_SORT_ORDER,
location=str(tmp_path),
properties={},
)
return ArrowScan(
table_metadata=metadata,
io=load_file_io(properties={}, location=file_path),
projected_schema=_SCHEMA,
row_filter=AlwaysTrue(),
)


def _make_file_scan_task(file_path: str, table: pa.Table) -> FileScanTask:
data_file = DataFile.from_args(
file_path=file_path,
file_format=FileFormat.PARQUET,
record_count=table.num_rows,
file_size_in_bytes=os.path.getsize(file_path),
)
data_file._spec_id = 0
return FileScanTask(data_file=data_file)


def _sample_table() -> pa.Table:
return pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()),
"name": pa.array(["alice", "bob", "charlie"], type=pa.string()),
}
)


class TestScanOrderImports:
"""Verify the ScanOrder class hierarchy is importable from pyiceberg.table."""

def test_scan_order_base_class_exists(self) -> None:
assert ScanOrder is not None

def test_task_order_is_scan_order(self) -> None:
assert issubclass(TaskOrder, ScanOrder)

def test_arrival_order_is_scan_order(self) -> None:
assert issubclass(ArrivalOrder, ScanOrder)

def test_arrival_order_default_params(self) -> None:
ao = ArrivalOrder()
assert ao.concurrent_streams == 8
assert ao.batch_size is None
assert ao.max_buffered_batches == 16

def test_arrival_order_custom_params(self) -> None:
ao = ArrivalOrder(concurrent_streams=4, batch_size=32768, max_buffered_batches=8)
assert ao.concurrent_streams == 4
assert ao.batch_size == 32768
assert ao.max_buffered_batches == 8

def test_arrival_order_rejects_invalid_concurrent_streams(self) -> None:
with pytest.raises(ValueError, match="concurrent_streams"):
ArrivalOrder(concurrent_streams=0)

def test_arrival_order_rejects_invalid_max_buffered_batches(self) -> None:
with pytest.raises(ValueError, match="max_buffered_batches"):
ArrivalOrder(max_buffered_batches=0)


class TestToRecordBatchesOrder:
Comment thread
cbb330 marked this conversation as resolved.
"""Verify ArrowScan.to_record_batches accepts the order parameter and returns correct data."""

def test_default_order_returns_all_rows(self, tmp_path: object) -> None:
"""Default (TaskOrder) still works — backward compatible."""
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task]))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]

def test_explicit_task_order_returns_all_rows(self, tmp_path: object) -> None:
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task], order=TaskOrder()))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]

def test_arrival_order_returns_all_rows(self, tmp_path: object) -> None:
table = _sample_table()
file_path = _write_parquet(tmp_path, table)
arrow_scan = _make_arrow_scan(tmp_path, file_path)
task = _make_file_scan_task(file_path, table)
batches = list(arrow_scan.to_record_batches([task], order=ArrivalOrder(concurrent_streams=2)))
result = pa.Table.from_batches(batches).sort_by("id")
assert result.column("id").to_pylist() == [1, 2, 3]
assert result.column("name").to_pylist() == ["alice", "bob", "charlie"]
27 changes: 27 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,30 @@ def test_snapshot_id_with_columns_and_filters(tmp_path):
assert scan_kwargs["snapshot_id"] == 99
assert scan_kwargs["selected_fields"] == (COL_ID,)
assert "row_filter" in scan_kwargs


# --- batch_size tests ---


def test_batch_size_forwarded_to_splits(tmp_path):
"""batch_size is correctly passed through to each DataLoaderSplit."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=32768)
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size == 32768


def test_batch_size_default_is_none(tmp_path):
"""Omitting batch_size defaults to None in each split."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")
splits = list(loader)

assert len(splits) >= 1
for split in splits:
assert split._batch_size is None
46 changes: 46 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _create_test_split(
iceberg_schema: Schema,
io_properties: dict[str, str] | None = None,
filename: str | None = None,
batch_size: int | None = None,
) -> DataLoaderSplit:
"""Create a DataLoaderSplit for testing by writing data to disk.

Expand Down Expand Up @@ -91,6 +92,7 @@ def _create_test_split(
session_context=ctx,
file_scan_task=task,
scan_context=scan_context,
batch_size=batch_size,
)


Expand Down Expand Up @@ -345,3 +347,47 @@ def _to_substrait(plan, ctx):
mock_udf_registry.register_udfs.assert_called_once_with(session_context)
producer.assert_called_once_with(mock_plan, session_context)
assert split._plan_substrait_bytes == b"serialized-plan"


# --- batch_size tests ---

_BATCH_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
)


def _make_table(num_rows: int) -> pa.Table:
return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())})


def test_split_batch_size_limits_rows_per_batch(tmp_path):
"""When batch_size is set, each RecordBatch has at most that many rows."""
table = _make_table(100)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10)

batches = list(split)

assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows"
for batch in batches:
assert batch.num_rows <= 10
assert sum(b.num_rows for b in batches) == 100


def test_split_batch_size_none_returns_all_rows(tmp_path):
"""Default batch_size (None) returns all data correctly."""
table = _make_table(50)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 50
assert sorted(result.column("id").to_pylist()) == list(range(50))


def test_split_batch_size_preserves_data(tmp_path):
"""batch_size controls chunking but all data is preserved."""
table = _make_table(25)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 25
assert sorted(result.column("id").to_pylist()) == list(range(25))
Loading