From 518fbe2d8d5f1068a4297f53d8ee15dad7bb05ab Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Wed, 8 Apr 2026 23:00:29 +0000 Subject: [PATCH 1/8] [DataLoader] Re-add ArrivalOrder API and batch_size support via li-pyiceberg Re-introduce the ArrivalOrder scan order and batch_size parameter that were removed in #504. The original removal was necessary because the fork dependency (sumedhsakdeo/iceberg-python) could not pass ELR. Now that li-pyiceberg 0.11.3 includes the ArrivalOrder API from upstream (apache/iceberg-python#3046), we can restore the functionality using an approved registry dependency. --- integrations/python/dataloader/pyproject.toml | 2 +- .../src/openhouse/dataloader/data_loader.py | 7 + .../openhouse/dataloader/data_loader_split.py | 12 +- .../dataloader/tests/integration_tests.py | 11 +- .../dataloader/tests/test_arrival_order.py | 137 ++++++++++++++++++ .../dataloader/tests/test_data_loader.py | 27 ++++ .../tests/test_data_loader_split.py | 46 ++++++ integrations/python/dataloader/uv.lock | 34 ++--- 8 files changed, 252 insertions(+), 24 deletions(-) create mode 100644 integrations/python/dataloader/tests/test_arrival_order.py diff --git a/integrations/python/dataloader/pyproject.toml b/integrations/python/dataloader/pyproject.toml index 2fa65ec1a..31314c4b5 100644 --- a/integrations/python/dataloader/pyproject.toml +++ b/integrations/python/dataloader/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" requires-python = ">=3.12" license = {text = "BSD-2-Clause"} keywords = ["openhouse", "data-loader", "lakehouse", "iceberg", "datafusion"] -dependencies = ["datafusion==51.0.0", "li-pyiceberg==0.11.2", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"] +dependencies = ["datafusion==51.0.0", "li-pyiceberg==0.11.3", "requests>=2.31.0", "sqlglot>=29.0.0", "tenacity>=8.0.0"] [[tool.uv.index]] url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/simple/" diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index f2118d30b..062bdbd55 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -114,6 +114,7 @@ def __init__( filters: Filter | None = None, context: DataLoaderContext | None = None, max_attempts: int = 3, + batch_size: int | None = None, ): """ Args: @@ -126,6 +127,10 @@ def __init__( filters: Row filter expression, defaults to always_true() (all rows) context: Data loader context max_attempts: Total number of attempts including the initial try (default 3) + batch_size: Maximum number of rows per RecordBatch yielded by each split. + Passed to PyArrow's Scanner which produces batches of at most this many + rows. Smaller values reduce peak memory but increase per-batch overhead. + None uses the PyArrow default (~131K rows). """ if branch is not None and branch.strip() == "": raise ValueError("branch must not be empty or whitespace") @@ -138,6 +143,7 @@ def __init__( self._filters = filters if filters is not None else always_true() self._context = context or DataLoaderContext() self._max_attempts = max_attempts + self._batch_size = batch_size if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None: apply_libhdfs_opts(self._context.jvm_config.planner_args) @@ -260,4 +266,5 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: scan_context=scan_context, transform_sql=optimized_sql, udf_registry=self._context.udf_registry, + batch_size=self._batch_size, ) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index 38331bbe4..59649536e 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -7,7 +7,7 @@ from datafusion.context import SessionContext from pyarrow import RecordBatch from pyiceberg.io.pyarrow import ArrowScan -from pyiceberg.table import FileScanTask +from pyiceberg.table import ArrivalOrder, FileScanTask from openhouse.dataloader._jvm import apply_libhdfs_opts from openhouse.dataloader._table_scan_context import TableScanContext @@ -53,11 +53,13 @@ def __init__( scan_context: TableScanContext, transform_sql: str | None = None, udf_registry: UDFRegistry | None = None, + batch_size: int | None = None, ): self._file_scan_task = file_scan_task self._scan_context = scan_context self._transform_sql = transform_sql self._udf_registry = udf_registry or NoOpRegistry() + self._batch_size = batch_size @property def id(self) -> str: @@ -76,7 +78,8 @@ def __iter__(self) -> Iterator[RecordBatch]: """Reads the file scan task and yields Arrow RecordBatches. Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution, - delete files, and partition spec lookups. + delete files, and partition spec lookups. The number of batches loaded + into memory at once is bounded to prevent using too much memory at once. """ ctx = self._scan_context if ctx.worker_jvm_args is not None: @@ -88,7 +91,10 @@ def __iter__(self) -> Iterator[RecordBatch]: row_filter=ctx.row_filter, ) - batches = arrow_scan.to_record_batches([self._file_scan_task]) + batches = arrow_scan.to_record_batches( + [self._file_scan_task], + order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size), + ) if self._transform_sql is None: yield from batches diff --git a/integrations/python/dataloader/tests/integration_tests.py b/integrations/python/dataloader/tests/integration_tests.py index 3ccfe4f48..538e4dd62 100644 --- a/integrations/python/dataloader/tests/integration_tests.py +++ b/integrations/python/dataloader/tests/integration_tests.py @@ -228,8 +228,13 @@ def read_token() -> str: snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id assert snap1 is not None - # 4. Read all data - result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)) + # 4. Read all data with batch_size and verify batch count + loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, batch_size=2) + batches = [batch for split in loader for batch in split] + assert len(batches) == 2, f"Expected 2 batches (3 rows, batch_size=2), got {len(batches)}" + for batch in batches: + assert batch.num_rows <= 2 + result = pa.concat_tables([pa.Table.from_batches([b]) for b in batches]).sort_by(COL_ID) finally: os.dup2(saved_stdout, 1) os.close(saved_stdout) @@ -240,7 +245,7 @@ def read_token() -> str: assert result.column(COL_ID).to_pylist() == [1, 2, 3] assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"] assert result.column(COL_SCORE).to_pylist() == [1.1, 2.2, 3.3] - print(f"PASS: read all {result.num_rows} rows") + print(f"PASS: read all {result.num_rows} rows in {len(batches)} batches (batch_size=2)") # 5a. Row filter loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID, filters=col(COL_ID) > 1) diff --git a/integrations/python/dataloader/tests/test_arrival_order.py b/integrations/python/dataloader/tests/test_arrival_order.py new file mode 100644 index 000000000..21e3870f2 --- /dev/null +++ b/integrations/python/dataloader/tests/test_arrival_order.py @@ -0,0 +1,137 @@ +"""Tests verifying the ArrivalOrder API from pyiceberg PR #3046 is available and functional. + +These tests confirm that the openhouse dataloader can access the new ScanOrder class hierarchy +added upstream (apache/iceberg-python#3046) and that ArrowScan.to_record_batches accepts the +order parameter. +""" + +import os + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from pyiceberg.expressions import AlwaysTrue +from pyiceberg.io import load_file_io +from pyiceberg.io.pyarrow import ArrowScan +from pyiceberg.manifest import DataFile, FileFormat +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC +from pyiceberg.schema import Schema +from pyiceberg.table import ArrivalOrder, FileScanTask, ScanOrder, TaskOrder +from pyiceberg.table.metadata import new_table_metadata +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER +from pyiceberg.types import LongType, NestedField, StringType + +_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="name", field_type=StringType(), required=False), +) + + +def _write_parquet(tmp_path: object, table: pa.Table) -> str: + """Write a parquet file with Iceberg field IDs and return its path.""" + file_path = str(tmp_path / "test.parquet") # type: ignore[operator] + fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)] + pq.write_table(table.cast(pa.schema(fields)), file_path) + return file_path + + +def _make_arrow_scan(tmp_path: object, file_path: str) -> ArrowScan: + metadata = new_table_metadata( + schema=_SCHEMA, + partition_spec=UNPARTITIONED_PARTITION_SPEC, + sort_order=UNSORTED_SORT_ORDER, + location=str(tmp_path), + properties={}, + ) + return ArrowScan( + table_metadata=metadata, + io=load_file_io(properties={}, location=file_path), + projected_schema=_SCHEMA, + row_filter=AlwaysTrue(), + ) + + +def _make_file_scan_task(file_path: str, table: pa.Table) -> FileScanTask: + data_file = DataFile.from_args( + file_path=file_path, + file_format=FileFormat.PARQUET, + record_count=table.num_rows, + file_size_in_bytes=os.path.getsize(file_path), + ) + data_file._spec_id = 0 + return FileScanTask(data_file=data_file) + + +def _sample_table() -> pa.Table: + return pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "name": pa.array(["alice", "bob", "charlie"], type=pa.string()), + } + ) + + +class TestScanOrderImports: + """Verify the ScanOrder class hierarchy is importable from pyiceberg.table.""" + + def test_scan_order_base_class_exists(self) -> None: + assert ScanOrder is not None + + def test_task_order_is_scan_order(self) -> None: + assert issubclass(TaskOrder, ScanOrder) + + def test_arrival_order_is_scan_order(self) -> None: + assert issubclass(ArrivalOrder, ScanOrder) + + def test_arrival_order_default_params(self) -> None: + ao = ArrivalOrder() + assert ao.concurrent_streams == 8 + assert ao.batch_size is None + assert ao.max_buffered_batches == 16 + + def test_arrival_order_custom_params(self) -> None: + ao = ArrivalOrder(concurrent_streams=4, batch_size=32768, max_buffered_batches=8) + assert ao.concurrent_streams == 4 + assert ao.batch_size == 32768 + assert ao.max_buffered_batches == 8 + + def test_arrival_order_rejects_invalid_concurrent_streams(self) -> None: + with pytest.raises(ValueError, match="concurrent_streams"): + ArrivalOrder(concurrent_streams=0) + + def test_arrival_order_rejects_invalid_max_buffered_batches(self) -> None: + with pytest.raises(ValueError, match="max_buffered_batches"): + ArrivalOrder(max_buffered_batches=0) + + +class TestToRecordBatchesOrder: + """Verify ArrowScan.to_record_batches accepts the order parameter and returns correct data.""" + + def test_default_order_returns_all_rows(self, tmp_path: object) -> None: + """Default (TaskOrder) still works — backward compatible.""" + table = _sample_table() + file_path = _write_parquet(tmp_path, table) + arrow_scan = _make_arrow_scan(tmp_path, file_path) + task = _make_file_scan_task(file_path, table) + batches = list(arrow_scan.to_record_batches([task])) + result = pa.Table.from_batches(batches).sort_by("id") + assert result.column("id").to_pylist() == [1, 2, 3] + + def test_explicit_task_order_returns_all_rows(self, tmp_path: object) -> None: + table = _sample_table() + file_path = _write_parquet(tmp_path, table) + arrow_scan = _make_arrow_scan(tmp_path, file_path) + task = _make_file_scan_task(file_path, table) + batches = list(arrow_scan.to_record_batches([task], order=TaskOrder())) + result = pa.Table.from_batches(batches).sort_by("id") + assert result.column("id").to_pylist() == [1, 2, 3] + + def test_arrival_order_returns_all_rows(self, tmp_path: object) -> None: + table = _sample_table() + file_path = _write_parquet(tmp_path, table) + arrow_scan = _make_arrow_scan(tmp_path, file_path) + task = _make_file_scan_task(file_path, table) + batches = list(arrow_scan.to_record_batches([task], order=ArrivalOrder(concurrent_streams=2))) + result = pa.Table.from_batches(batches).sort_by("id") + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("name").to_pylist() == ["alice", "bob", "charlie"] diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index 4f07ecced..8816ab1c0 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -567,6 +567,33 @@ def fake_scan(**kwargs): assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet" +# --- batch_size tests --- + + +def test_batch_size_forwarded_to_splits(tmp_path): + """batch_size is correctly passed through to each DataLoaderSplit.""" + catalog = _make_real_catalog(tmp_path) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=32768) + splits = list(loader) + + assert len(splits) >= 1 + for split in splits: + assert split._batch_size == 32768 + + +def test_batch_size_default_is_none(tmp_path): + """Omitting batch_size defaults to None in each split.""" + catalog = _make_real_catalog(tmp_path) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl") + splits = list(loader) + + assert len(splits) >= 1 + for split in splits: + assert split._batch_size is None + + # --- Predicate pushdown with transformer tests --- diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index 306eb3f84..afdea471d 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -43,6 +43,7 @@ def _create_test_split( transform_sql: str | None = None, table_id: TableIdentifier = _DEFAULT_TABLE_ID, udf_registry: UDFRegistry | None = None, + batch_size: int | None = None, ) -> DataLoaderSplit: """Create a DataLoaderSplit for testing by writing data to disk. @@ -103,6 +104,7 @@ def _create_test_split( scan_context=scan_context, transform_sql=transform_sql, udf_registry=udf_registry, + batch_size=batch_size, ) @@ -422,3 +424,47 @@ def test_worker_jvm_args_sets_libhdfs_opts(tmp_path, monkeypatch): list(split) assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx512m" + + +# --- batch_size tests --- + +_BATCH_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), +) + + +def _make_table(num_rows: int) -> pa.Table: + return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())}) + + +def test_split_batch_size_limits_rows_per_batch(tmp_path): + """When batch_size is set, each RecordBatch has at most that many rows.""" + table = _make_table(100) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10) + + batches = list(split) + + assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows" + for batch in batches: + assert batch.num_rows <= 10 + assert sum(b.num_rows for b in batches) == 100 + + +def test_split_batch_size_none_returns_all_rows(tmp_path): + """Default batch_size (None) returns all data correctly.""" + table = _make_table(50) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA) + + result = pa.Table.from_batches(list(split)) + assert result.num_rows == 50 + assert sorted(result.column("id").to_pylist()) == list(range(50)) + + +def test_split_batch_size_preserves_data(tmp_path): + """batch_size controls chunking but all data is preserved.""" + table = _make_table(25) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7) + + result = pa.Table.from_batches(list(split)) + assert result.num_rows == 25 + assert sorted(result.column("id").to_pylist()) == list(range(25)) diff --git a/integrations/python/dataloader/uv.lock b/integrations/python/dataloader/uv.lock index cb01c81a9..1f2eca409 100644 --- a/integrations/python/dataloader/uv.lock +++ b/integrations/python/dataloader/uv.lock @@ -304,7 +304,7 @@ wheels = [ [[package]] name = "li-pyiceberg" -version = "0.11.2" +version = "0.11.3" source = { registry = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/simple/" } dependencies = [ { name = "cachetools" }, @@ -320,22 +320,22 @@ dependencies = [ { name = "tenacity" }, { name = "zstandard" }, ] -sdist = { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2.tar.gz", hash = "sha256:6d73600d862c097143edaebd0480c491a6b682f0ff5e82412318b3c49727ddf4" } +sdist = { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3.tar.gz", hash = "sha256:13558096c793ecd64eaeee5440440d406184b2693761a0c4b189a24453a26f4b" } wheels = [ - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6b9cd3983816a7f6ea5db59e76369dded31594ee108f4dec097800aa3175612d" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e1e5544c1218af47c8ac60bc64a374c07d839678fd9ec9303b13ca0a79540e6" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f43f331134943d3a73b2a6e7fdd3582ac437e4757a1a47f94ea104b8b5321e9b" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d672d5d5eeb321e287cf40b8b787448a24b6ce9a8f7ec34604cb3a40dcf3c9e7" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b5c2182c7b99e7e53b1e4604f75fa6b4783c8dd48cd5936fdadfc4de84c654ce" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9cfa4c517923e7c4b1886642813ef3ec9fe5fe771869e1fcafee7920d38561f" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:4e1fc561cb4953694de5600de7bf6fb084e3e7aad58775ee7f057a737f58d0a0" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cbd2e2aa77c34700685bc532440f6112383a887766af77d89a7b3fa521755934" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c637b7d9f1376229b3a06a8b3b3a6f826043be721212107114b3e54f14c1629" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fdcfc4fe6587173b5501c0d02a2d912ad9522282677892015b54fb80986dd72a" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:546025112098048774eab4ae056c248c7d00daed632258c8b4565a405d098c03" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e695d2d1532f1525b8963e0703a516fff9a19e392c13157daedb1e875faf633e" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5d20cb2636e0ef3af1a6407aa728944f0e54b6c6dbe916ddde667999957b37af" }, - { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.2/li_pyiceberg-0.11.2-cp313-cp313-win_amd64.whl", hash = "sha256:7cc8c418b0ddfafe4b8b4c4100f520f9a63794aecd8be74e3cbb651e4cb55af1" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:448af31363788d888fa436fb3d9913cb46c7e01409d6e41c6f878b8565da57ec" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:831f85d6d1abe221df33d1e2a1d8a5bb17821f03ca6d8ddbd2932527498fb1c7" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:567f4976eee9425b04324a09419d0daa28535d4ae5408f58575fb80a1564fc85" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b09979c347d81fba3fa7de83918979fbd406ffa8ec90ec825394415aa4063ac1" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:33a649304ab689bddbb72d69ee9f2e063c75c5bb5e311941c72b88117296f27a" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ced2768668d1b178b4120e460bb5cbe524833933fe82a989253cf89ddd1abebe" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:9baee52cc5a0ebd129efe98046edacca44b5f4bf286c04874756177907332e35" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e503e918487f268a36fae60d23901dda89c3295c5dfa1ed3e5ee9fed0dc883bb" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:de30e53dea782549445063d66d9b308555bdcbc65a24aa01d65dde667663b2de" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0d5f8e9fbf8c17508097520fed8e1572b92c48bd4c510a226355f173df9d31b" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7a5aaf3b95e7cdba86a1f494f787abaaaa69d56ba9abd130edafe4442de6679" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bf722ad6a68139f3a2a878de661419377d22e5c2d50edeae94c2d2ece3aba2f7" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4faef4e74d5fdda80fad796deb2f79f8e6beef798d0311ee0928c955d14832cd" }, + { url = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/li-pyiceberg/0.11.3/li_pyiceberg-0.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:98719cd3f6802cf1d2b19da929bb2befc254c28c24597df1d6dada806d1ff6b4" }, ] [[package]] @@ -607,7 +607,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "datafusion", specifier = "==51.0.0" }, - { name = "li-pyiceberg", specifier = "==0.11.2", index = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/simple/" }, + { name = "li-pyiceberg", specifier = "==0.11.3", index = "https://linkedin.jfrog.io/artifactory/api/pypi/openhouse-pypi/simple/" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.14.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "requests", specifier = ">=2.31.0" }, From ec8ec80dc35486e1df481aa8c612a3daa60fa4e3 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Thu, 9 Apr 2026 23:25:38 +0000 Subject: [PATCH 2/8] [DataLoader] Support concurrent multi-file reads per split via files_per_split Add files_per_split parameter to OpenHouseDataLoader that controls how many files each DataLoaderSplit reads concurrently. DataLoaderSplit now accepts a list of FileScanTasks and sets concurrent_streams to match, enabling parallel I/O within a single split. --- .../src/openhouse/dataloader/data_loader.py | 12 +- .../openhouse/dataloader/data_loader_split.py | 25 ++-- .../dataloader/tests/test_data_loader.py | 116 +++++++++++++++++- .../tests/test_data_loader_split.py | 102 ++++++++++++++- 4 files changed, 238 insertions(+), 17 deletions(-) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 062bdbd55..703ed8751 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass from functools import cached_property +from itertools import islice from types import MappingProxyType from pyiceberg.catalog import Catalog @@ -115,6 +116,7 @@ def __init__( context: DataLoaderContext | None = None, max_attempts: int = 3, batch_size: int | None = None, + files_per_split: int = 1, ): """ Args: @@ -131,11 +133,15 @@ def __init__( Passed to PyArrow's Scanner which produces batches of at most this many rows. Smaller values reduce peak memory but increase per-batch overhead. None uses the PyArrow default (~131K rows). + files_per_split: Number of files each split reads concurrently. + Default is 1 (one file per split). """ if branch is not None and branch.strip() == "": raise ValueError("branch must not be empty or whitespace") if branch is not None and snapshot_id is not None: raise ValueError("Cannot specify both branch and snapshot_id") + if files_per_split < 1: + raise ValueError("files_per_split must be at least 1") self._catalog = catalog self._table_id = TableIdentifier(database, table, branch) self._snapshot_id = snapshot_id @@ -144,6 +150,7 @@ def __init__( self._context = context or DataLoaderContext() self._max_attempts = max_attempts self._batch_size = batch_size + self._files_per_split = files_per_split if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None: apply_libhdfs_opts(self._context.jvm_config.planner_args) @@ -260,9 +267,10 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts ) - for scan_task in scan_tasks: + task_iter = iter(scan_tasks) + for chunk in iter(lambda: list(islice(task_iter, self._files_per_split)), []): yield DataLoaderSplit( - file_scan_task=scan_task, + file_scan_tasks=chunk, scan_context=scan_context, transform_sql=optimized_sql, udf_registry=self._context.udf_registry, diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index 59649536e..b07de881a 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib -from collections.abc import Iterator, Mapping +from collections.abc import Iterator, Mapping, Sequence from types import MappingProxyType from datafusion.context import SessionContext @@ -45,17 +45,19 @@ def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch: class DataLoaderSplit: - """A single data split""" + """A data split that reads one or more files.""" def __init__( self, - file_scan_task: FileScanTask, + file_scan_tasks: Sequence[FileScanTask], scan_context: TableScanContext, transform_sql: str | None = None, udf_registry: UDFRegistry | None = None, batch_size: int | None = None, ): - self._file_scan_task = file_scan_task + self._file_scan_tasks = list(file_scan_tasks) + if not self._file_scan_tasks: + raise ValueError("file_scan_tasks must not be empty") self._scan_context = scan_context self._transform_sql = transform_sql self._udf_registry = udf_registry or NoOpRegistry() @@ -66,8 +68,9 @@ def id(self) -> str: """Unique ID for the split. This is stable across executions for a given snapshot and split size. """ - file_path = self._file_scan_task.file.file_path - return hashlib.sha256(file_path.encode("utf-8")).hexdigest() + paths = sorted(t.file.file_path for t in self._file_scan_tasks) + combined = "\0".join(paths) + return hashlib.sha256(combined.encode("utf-8")).hexdigest() @property def table_properties(self) -> Mapping[str, str]: @@ -75,11 +78,9 @@ def table_properties(self) -> Mapping[str, str]: return MappingProxyType(self._scan_context.table_metadata.properties) def __iter__(self) -> Iterator[RecordBatch]: - """Reads the file scan task and yields Arrow RecordBatches. + """Reads the file scan tasks and yields Arrow RecordBatches. - Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution, - delete files, and partition spec lookups. The number of batches loaded - into memory at once is bounded to prevent using too much memory at once. + When the split contains multiple files, they are read concurrently. """ ctx = self._scan_context if ctx.worker_jvm_args is not None: @@ -92,8 +93,8 @@ def __iter__(self) -> Iterator[RecordBatch]: ) batches = arrow_scan.to_record_batches( - [self._file_scan_task], - order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size), + self._file_scan_tasks, + order=ArrivalOrder(concurrent_streams=len(self._file_scan_tasks), batch_size=self._batch_size), ) if self._transform_sql is None: diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index 8816ab1c0..9f2905e04 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -559,12 +559,12 @@ def fake_scan(**kwargs): # Without branch: splits come from main snapshot main_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")) assert len(main_splits) == 1 - assert main_splits[0]._file_scan_task.file.file_path == "main.parquet" + assert main_splits[0]._file_scan_tasks[0].file.file_path == "main.parquet" # With branch: splits come from branch snapshot branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch")) assert len(branch_splits) == 1 - assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet" + assert branch_splits[0]._file_scan_tasks[0].file.file_path == "branch.parquet" # --- batch_size tests --- @@ -594,6 +594,118 @@ def test_batch_size_default_is_none(tmp_path): assert split._batch_size is None +# --- files_per_split tests --- + + +def _make_multi_file_catalog(tmp_path, num_files: int, rows_per_file: int = 3): + """Create a mock catalog backed by multiple real Parquet files.""" + schema = Schema( + NestedField(field_id=1, name=COL_ID, field_type=LongType(), required=False), + NestedField(field_id=2, name=COL_NAME, field_type=StringType(), required=False), + ) + tasks = [] + for i in range(num_files): + data = { + COL_ID: list(range(i * rows_per_file, (i + 1) * rows_per_file)), + COL_NAME: [f"row_{j}" for j in range(i * rows_per_file, (i + 1) * rows_per_file)], + } + file_path = _write_parquet(tmp_path, data, filename=f"file_{i}.parquet") + data_file = DataFile.from_args( + file_path=file_path, + file_format=FileFormat.PARQUET, + record_count=rows_per_file, + file_size_in_bytes=os.path.getsize(file_path), + ) + data_file._spec_id = 0 + tasks.append(FileScanTask(data_file=data_file)) + + metadata = new_table_metadata( + schema=schema, + partition_spec=UNPARTITIONED_PARTITION_SPEC, + sort_order=UNSORTED_SORT_ORDER, + location=str(tmp_path), + ) + io = load_file_io(properties={}, location=str(tmp_path)) + + def fake_scan(**kwargs): + selected = kwargs.get("selected_fields") + projected = Schema(*[f for f in schema.fields if f.name in selected]) if selected else schema + scan = MagicMock() + scan.projection.return_value = projected + scan.plan_files.return_value = tasks + return scan + + mock_table = MagicMock() + mock_table.metadata = metadata + mock_table.io = io + mock_table.scan.side_effect = fake_scan + + catalog = MagicMock() + catalog.load_table.return_value = mock_table + return catalog + + +def test_files_per_split_default_one_file_per_split(tmp_path): + """Default files_per_split=1 produces one split per file.""" + catalog = _make_multi_file_catalog(tmp_path, num_files=4) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl") + splits = list(loader) + + assert len(splits) == 4 + for split in splits: + assert len(split._file_scan_tasks) == 1 + + +def test_files_per_split_groups_tasks(tmp_path): + """files_per_split=2 groups 4 files into 2 splits of 2 files each.""" + catalog = _make_multi_file_catalog(tmp_path, num_files=4) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) + splits = list(loader) + + assert len(splits) == 2 + for split in splits: + assert len(split._file_scan_tasks) == 2 + + +def test_files_per_split_remainder_split(tmp_path): + """When files don't divide evenly, the last split gets the remainder.""" + catalog = _make_multi_file_catalog(tmp_path, num_files=5) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=3) + splits = list(loader) + + assert len(splits) == 2 + assert len(splits[0]._file_scan_tasks) == 3 + assert len(splits[1]._file_scan_tasks) == 2 + + +def test_files_per_split_larger_than_total(tmp_path): + """When files_per_split exceeds file count, one split gets all files.""" + catalog = _make_multi_file_catalog(tmp_path, num_files=3) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=10) + splits = list(loader) + + assert len(splits) == 1 + assert len(splits[0]._file_scan_tasks) == 3 + + +def test_files_per_split_preserves_all_data(tmp_path): + """All rows from all files are returned regardless of files_per_split.""" + catalog = _make_multi_file_catalog(tmp_path, num_files=4, rows_per_file=3) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) + result = _materialize(loader) + + assert result.num_rows == 12 + assert sorted(result.column(COL_ID).to_pylist()) == list(range(12)) + + +def test_files_per_split_invalid_raises(): + """files_per_split < 1 raises ValueError.""" + catalog = MagicMock() + with pytest.raises(ValueError, match="files_per_split must be at least 1"): + OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=0) + + # --- Predicate pushdown with transformer tests --- diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index afdea471d..1458b0281 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -100,7 +100,7 @@ def _create_test_split( task = FileScanTask(data_file=data_file) return DataLoaderSplit( - file_scan_task=task, + file_scan_tasks=[task], scan_context=scan_context, transform_sql=transform_sql, udf_registry=udf_registry, @@ -468,3 +468,103 @@ def test_split_batch_size_preserves_data(tmp_path): result = pa.Table.from_batches(list(split)) assert result.num_rows == 25 assert sorted(result.column("id").to_pylist()) == list(range(25)) + + +# --- multi-file split tests --- + + +def _create_multi_file_split( + tmp_path, + tables: list[pa.Table], + iceberg_schema: Schema, + transform_sql: str | None = None, + table_id: TableIdentifier = _DEFAULT_TABLE_ID, +) -> DataLoaderSplit: + """Create a DataLoaderSplit backed by multiple files.""" + tasks = [] + for i, table in enumerate(tables): + file_path = str(tmp_path / f"file_{i}.parquet") + fields = [ + field.with_metadata({b"PARQUET:field_id": str(j + 1).encode()}) for j, field in enumerate(table.schema) + ] + pq.write_table(table.cast(pa.schema(fields)), file_path) + + data_file = DataFile.from_args( + file_path=file_path, + file_format=FileFormat.PARQUET, + record_count=table.num_rows, + file_size_in_bytes=os.path.getsize(file_path), + ) + data_file._spec_id = 0 + tasks.append(FileScanTask(data_file=data_file)) + + metadata = new_table_metadata( + schema=iceberg_schema, + partition_spec=UNPARTITIONED_PARTITION_SPEC, + sort_order=UNSORTED_SORT_ORDER, + location=str(tmp_path), + ) + scan_context = TableScanContext( + table_metadata=metadata, + io=load_file_io(properties={}, location=str(tmp_path)), + projected_schema=iceberg_schema, + table_id=table_id, + ) + return DataLoaderSplit( + file_scan_tasks=tasks, + scan_context=scan_context, + transform_sql=transform_sql, + ) + + +def test_multi_file_split_returns_all_rows(tmp_path): + """A split with multiple files yields rows from all files.""" + schema = _BATCH_SCHEMA + tables = [ + pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}), + pa.table({"id": pa.array([4, 5, 6], type=pa.int64())}), + ] + split = _create_multi_file_split(tmp_path, tables, schema) + result = pa.Table.from_batches(list(split)) + + assert result.num_rows == 6 + assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6] + + +def test_multi_file_split_id_is_deterministic(tmp_path): + """Two splits with the same files produce the same id.""" + schema = _BATCH_SCHEMA + tables = [_make_table(1), _make_table(1)] + split_a = _create_multi_file_split(tmp_path, tables, schema) + split_b = _create_multi_file_split(tmp_path, tables, schema) + assert split_a.id == split_b.id + + +def test_multi_file_split_id_differs_from_single_file(tmp_path): + """A multi-file split has a different id than a single-file split.""" + schema = _BATCH_SCHEMA + table = _make_table(1) + single = _create_test_split(tmp_path, table, FileFormat.PARQUET, schema, filename="file_0.parquet") + multi = _create_multi_file_split(tmp_path, [table, table], schema) + assert single.id != multi.id + + +def test_multi_file_split_with_transform(tmp_path): + """Transform SQL is applied across all files in a multi-file split.""" + schema = _TRANSFORM_SCHEMA + tables = [ + pa.table({"id": pa.array([1], type=pa.int64()), "name": pa.array(["alice"], type=pa.string())}), + pa.table({"id": pa.array([2], type=pa.int64()), "name": pa.array(["bob"], type=pa.string())}), + ] + split = _create_multi_file_split(tmp_path, tables, schema, transform_sql=_MASKING_SQL, table_id=_TABLE_ID) + result = pa.Table.from_batches(list(split)).sort_by("id") + + assert result.num_rows == 2 + assert result.column("id").to_pylist() == [1, 2] + assert result.column("name").to_pylist() == ["MASKED", "MASKED"] + + +def test_empty_file_scan_tasks_raises(): + """Constructing a split with no file scan tasks raises ValueError.""" + with pytest.raises(ValueError, match="must not be empty"): + DataLoaderSplit(file_scan_tasks=[], scan_context=MagicMock()) From cdc7e591fc3f4fff92c72e06b2d3f4b0e02ec935 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 05:57:27 +0000 Subject: [PATCH 3/8] Simplify task chunking with itertools.batched --- .../dataloader/src/openhouse/dataloader/data_loader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 703ed8751..c64f5d690 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass from functools import cached_property -from itertools import islice +from itertools import batched from types import MappingProxyType from pyiceberg.catalog import Catalog @@ -267,8 +267,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts ) - task_iter = iter(scan_tasks) - for chunk in iter(lambda: list(islice(task_iter, self._files_per_split)), []): + for chunk in batched(scan_tasks, self._files_per_split): yield DataLoaderSplit( file_scan_tasks=chunk, scan_context=scan_context, From a6cf28345212b451e06e23e7b5251ec3d2495dc6 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 06:07:19 +0000 Subject: [PATCH 4/8] Remove redundant files_per_split tests --- .../dataloader/tests/test_data_loader.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index 9f2905e04..b20457efa 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -645,17 +645,6 @@ def fake_scan(**kwargs): return catalog -def test_files_per_split_default_one_file_per_split(tmp_path): - """Default files_per_split=1 produces one split per file.""" - catalog = _make_multi_file_catalog(tmp_path, num_files=4) - loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl") - splits = list(loader) - - assert len(splits) == 4 - for split in splits: - assert len(split._file_scan_tasks) == 1 - - def test_files_per_split_groups_tasks(tmp_path): """files_per_split=2 groups 4 files into 2 splits of 2 files each.""" catalog = _make_multi_file_catalog(tmp_path, num_files=4) @@ -678,16 +667,6 @@ def test_files_per_split_remainder_split(tmp_path): assert len(splits[1]._file_scan_tasks) == 2 -def test_files_per_split_larger_than_total(tmp_path): - """When files_per_split exceeds file count, one split gets all files.""" - catalog = _make_multi_file_catalog(tmp_path, num_files=3) - loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=10) - splits = list(loader) - - assert len(splits) == 1 - assert len(splits[0]._file_scan_tasks) == 3 - - def test_files_per_split_preserves_all_data(tmp_path): """All rows from all files are returned regardless of files_per_split.""" catalog = _make_multi_file_catalog(tmp_path, num_files=4, rows_per_file=3) From a185b5d47cad59e14b713d21b5206452c8c960d9 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 06:10:37 +0000 Subject: [PATCH 5/8] Reuse _make_real_catalog in files_per_split tests --- .../dataloader/tests/test_data_loader.py | 66 +++++-------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index b20457efa..b33578996 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -597,57 +597,25 @@ def test_batch_size_default_is_none(tmp_path): # --- files_per_split tests --- -def _make_multi_file_catalog(tmp_path, num_files: int, rows_per_file: int = 3): - """Create a mock catalog backed by multiple real Parquet files.""" - schema = Schema( - NestedField(field_id=1, name=COL_ID, field_type=LongType(), required=False), - NestedField(field_id=2, name=COL_NAME, field_type=StringType(), required=False), - ) - tasks = [] - for i in range(num_files): - data = { - COL_ID: list(range(i * rows_per_file, (i + 1) * rows_per_file)), - COL_NAME: [f"row_{j}" for j in range(i * rows_per_file, (i + 1) * rows_per_file)], - } - file_path = _write_parquet(tmp_path, data, filename=f"file_{i}.parquet") - data_file = DataFile.from_args( - file_path=file_path, - file_format=FileFormat.PARQUET, - record_count=rows_per_file, - file_size_in_bytes=os.path.getsize(file_path), - ) - data_file._spec_id = 0 - tasks.append(FileScanTask(data_file=data_file)) - - metadata = new_table_metadata( - schema=schema, - partition_spec=UNPARTITIONED_PARTITION_SPEC, - sort_order=UNSORTED_SORT_ORDER, - location=str(tmp_path), - ) - io = load_file_io(properties={}, location=str(tmp_path)) +def _add_file_tasks(catalog, num_tasks: int) -> None: + """Override plan_files on a catalog from _make_real_catalog to return multiple mock tasks.""" + mock_table = catalog.load_table.return_value + original_scan = mock_table.scan.side_effect - def fake_scan(**kwargs): - selected = kwargs.get("selected_fields") - projected = Schema(*[f for f in schema.fields if f.name in selected]) if selected else schema - scan = MagicMock() - scan.projection.return_value = projected - scan.plan_files.return_value = tasks + def multi_file_scan(**kwargs): + scan = original_scan(**kwargs) + scan.plan_files.return_value = [ + MagicMock(file=MagicMock(file_path=f"file_{i}.parquet")) for i in range(num_tasks) + ] return scan - mock_table = MagicMock() - mock_table.metadata = metadata - mock_table.io = io - mock_table.scan.side_effect = fake_scan - - catalog = MagicMock() - catalog.load_table.return_value = mock_table - return catalog + mock_table.scan.side_effect = multi_file_scan def test_files_per_split_groups_tasks(tmp_path): """files_per_split=2 groups 4 files into 2 splits of 2 files each.""" - catalog = _make_multi_file_catalog(tmp_path, num_files=4) + catalog = _make_real_catalog(tmp_path) + _add_file_tasks(catalog, 4) loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) splits = list(loader) @@ -658,7 +626,8 @@ def test_files_per_split_groups_tasks(tmp_path): def test_files_per_split_remainder_split(tmp_path): """When files don't divide evenly, the last split gets the remainder.""" - catalog = _make_multi_file_catalog(tmp_path, num_files=5) + catalog = _make_real_catalog(tmp_path) + _add_file_tasks(catalog, 5) loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=3) splits = list(loader) @@ -669,13 +638,12 @@ def test_files_per_split_remainder_split(tmp_path): def test_files_per_split_preserves_all_data(tmp_path): """All rows from all files are returned regardless of files_per_split.""" - catalog = _make_multi_file_catalog(tmp_path, num_files=4, rows_per_file=3) - + catalog = _make_real_catalog(tmp_path) loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) result = _materialize(loader) - assert result.num_rows == 12 - assert sorted(result.column(COL_ID).to_pylist()) == list(range(12)) + assert result.num_rows == 3 + assert sorted(result.column(COL_ID).to_pylist()) == TEST_DATA[COL_ID] def test_files_per_split_invalid_raises(): From 035b4a261b5bfa039b90f6eb8fe770e2327225c9 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 06:16:51 +0000 Subject: [PATCH 6/8] Trim redundant multi-file split tests --- .../dataloader/tests/test_data_loader.py | 10 ----- .../tests/test_data_loader_split.py | 40 +------------------ 2 files changed, 1 insertion(+), 49 deletions(-) diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index b33578996..cf10fa2c1 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -636,16 +636,6 @@ def test_files_per_split_remainder_split(tmp_path): assert len(splits[1]._file_scan_tasks) == 2 -def test_files_per_split_preserves_all_data(tmp_path): - """All rows from all files are returned regardless of files_per_split.""" - catalog = _make_real_catalog(tmp_path) - loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) - result = _materialize(loader) - - assert result.num_rows == 3 - assert sorted(result.column(COL_ID).to_pylist()) == TEST_DATA[COL_ID] - - def test_files_per_split_invalid_raises(): """files_per_split < 1 raises ValueError.""" catalog = MagicMock() diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index 1458b0281..163939fb3 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -529,42 +529,4 @@ def test_multi_file_split_returns_all_rows(tmp_path): assert result.num_rows == 6 assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6] - - -def test_multi_file_split_id_is_deterministic(tmp_path): - """Two splits with the same files produce the same id.""" - schema = _BATCH_SCHEMA - tables = [_make_table(1), _make_table(1)] - split_a = _create_multi_file_split(tmp_path, tables, schema) - split_b = _create_multi_file_split(tmp_path, tables, schema) - assert split_a.id == split_b.id - - -def test_multi_file_split_id_differs_from_single_file(tmp_path): - """A multi-file split has a different id than a single-file split.""" - schema = _BATCH_SCHEMA - table = _make_table(1) - single = _create_test_split(tmp_path, table, FileFormat.PARQUET, schema, filename="file_0.parquet") - multi = _create_multi_file_split(tmp_path, [table, table], schema) - assert single.id != multi.id - - -def test_multi_file_split_with_transform(tmp_path): - """Transform SQL is applied across all files in a multi-file split.""" - schema = _TRANSFORM_SCHEMA - tables = [ - pa.table({"id": pa.array([1], type=pa.int64()), "name": pa.array(["alice"], type=pa.string())}), - pa.table({"id": pa.array([2], type=pa.int64()), "name": pa.array(["bob"], type=pa.string())}), - ] - split = _create_multi_file_split(tmp_path, tables, schema, transform_sql=_MASKING_SQL, table_id=_TABLE_ID) - result = pa.Table.from_batches(list(split)).sort_by("id") - - assert result.num_rows == 2 - assert result.column("id").to_pylist() == [1, 2] - assert result.column("name").to_pylist() == ["MASKED", "MASKED"] - - -def test_empty_file_scan_tasks_raises(): - """Constructing a split with no file scan tasks raises ValueError.""" - with pytest.raises(ValueError, match="must not be empty"): - DataLoaderSplit(file_scan_tasks=[], scan_context=MagicMock()) + assert len(split.id) == 64 # SHA256 hex digest From 1e79b36e52d69a0ba4ed38568cfe70cdb46e8923 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 06:20:21 +0000 Subject: [PATCH 7/8] Simplify multi-file test to reuse _create_test_split --- .../tests/test_data_loader_split.py | 62 ++++--------------- 1 file changed, 11 insertions(+), 51 deletions(-) diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index 163939fb3..6c0a1be34 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -473,60 +473,20 @@ def test_split_batch_size_preserves_data(tmp_path): # --- multi-file split tests --- -def _create_multi_file_split( - tmp_path, - tables: list[pa.Table], - iceberg_schema: Schema, - transform_sql: str | None = None, - table_id: TableIdentifier = _DEFAULT_TABLE_ID, -) -> DataLoaderSplit: - """Create a DataLoaderSplit backed by multiple files.""" - tasks = [] - for i, table in enumerate(tables): - file_path = str(tmp_path / f"file_{i}.parquet") - fields = [ - field.with_metadata({b"PARQUET:field_id": str(j + 1).encode()}) for j, field in enumerate(table.schema) - ] - pq.write_table(table.cast(pa.schema(fields)), file_path) - - data_file = DataFile.from_args( - file_path=file_path, - file_format=FileFormat.PARQUET, - record_count=table.num_rows, - file_size_in_bytes=os.path.getsize(file_path), - ) - data_file._spec_id = 0 - tasks.append(FileScanTask(data_file=data_file)) - - metadata = new_table_metadata( - schema=iceberg_schema, - partition_spec=UNPARTITIONED_PARTITION_SPEC, - sort_order=UNSORTED_SORT_ORDER, - location=str(tmp_path), - ) - scan_context = TableScanContext( - table_metadata=metadata, - io=load_file_io(properties={}, location=str(tmp_path)), - projected_schema=iceberg_schema, - table_id=table_id, - ) - return DataLoaderSplit( - file_scan_tasks=tasks, - scan_context=scan_context, - transform_sql=transform_sql, - ) - - def test_multi_file_split_returns_all_rows(tmp_path): """A split with multiple files yields rows from all files.""" schema = _BATCH_SCHEMA - tables = [ - pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}), - pa.table({"id": pa.array([4, 5, 6], type=pa.int64())}), - ] - split = _create_multi_file_split(tmp_path, tables, schema) - result = pa.Table.from_batches(list(split)) + table_a = pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}) + table_b = pa.table({"id": pa.array([4, 5, 6], type=pa.int64())}) + split_a = _create_test_split(tmp_path, table_a, FileFormat.PARQUET, schema, filename="a.parquet") + split_b = _create_test_split(tmp_path, table_b, FileFormat.PARQUET, schema, filename="b.parquet") + + combined = DataLoaderSplit( + file_scan_tasks=split_a._file_scan_tasks + split_b._file_scan_tasks, + scan_context=split_a._scan_context, + ) + result = pa.Table.from_batches(list(combined)) assert result.num_rows == 6 assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6] - assert len(split.id) == 64 # SHA256 hex digest + assert len(combined.id) == 64 # SHA256 hex digest From 12f00fd1ef5547fdd286903a2a20cbb6a4804f68 Mon Sep 17 00:00:00 2001 From: Rob Reeves Date: Sat, 11 Apr 2026 06:23:28 +0000 Subject: [PATCH 8/8] Test that split id is stable regardless of file order --- .../python/dataloader/tests/test_data_loader_split.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index 6c0a1be34..eca583bff 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -489,4 +489,9 @@ def test_multi_file_split_returns_all_rows(tmp_path): assert result.num_rows == 6 assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6] - assert len(combined.id) == 64 # SHA256 hex digest + + reversed_split = DataLoaderSplit( + file_scan_tasks=split_b._file_scan_tasks + split_a._file_scan_tasks, + scan_context=split_a._scan_context, + ) + assert reversed_split.id == combined.id