diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 062bdbd55..c64f5d690 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass from functools import cached_property +from itertools import batched from types import MappingProxyType from pyiceberg.catalog import Catalog @@ -115,6 +116,7 @@ def __init__( context: DataLoaderContext | None = None, max_attempts: int = 3, batch_size: int | None = None, + files_per_split: int = 1, ): """ Args: @@ -131,11 +133,15 @@ def __init__( Passed to PyArrow's Scanner which produces batches of at most this many rows. Smaller values reduce peak memory but increase per-batch overhead. None uses the PyArrow default (~131K rows). + files_per_split: Number of files each split reads concurrently. + Default is 1 (one file per split). """ if branch is not None and branch.strip() == "": raise ValueError("branch must not be empty or whitespace") if branch is not None and snapshot_id is not None: raise ValueError("Cannot specify both branch and snapshot_id") + if files_per_split < 1: + raise ValueError("files_per_split must be at least 1") self._catalog = catalog self._table_id = TableIdentifier(database, table, branch) self._snapshot_id = snapshot_id @@ -144,6 +150,7 @@ def __init__( self._context = context or DataLoaderContext() self._max_attempts = max_attempts self._batch_size = batch_size + self._files_per_split = files_per_split if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None: apply_libhdfs_opts(self._context.jvm_config.planner_args) @@ -260,9 +267,9 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts ) - for scan_task in scan_tasks: + for chunk in batched(scan_tasks, self._files_per_split): yield DataLoaderSplit( - file_scan_task=scan_task, + file_scan_tasks=chunk, scan_context=scan_context, transform_sql=optimized_sql, udf_registry=self._context.udf_registry, diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index 59649536e..b07de881a 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib -from collections.abc import Iterator, Mapping +from collections.abc import Iterator, Mapping, Sequence from types import MappingProxyType from datafusion.context import SessionContext @@ -45,17 +45,19 @@ def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch: class DataLoaderSplit: - """A single data split""" + """A data split that reads one or more files.""" def __init__( self, - file_scan_task: FileScanTask, + file_scan_tasks: Sequence[FileScanTask], scan_context: TableScanContext, transform_sql: str | None = None, udf_registry: UDFRegistry | None = None, batch_size: int | None = None, ): - self._file_scan_task = file_scan_task + self._file_scan_tasks = list(file_scan_tasks) + if not self._file_scan_tasks: + raise ValueError("file_scan_tasks must not be empty") self._scan_context = scan_context self._transform_sql = transform_sql self._udf_registry = udf_registry or NoOpRegistry() @@ -66,8 +68,9 @@ def id(self) -> str: """Unique ID for the split. This is stable across executions for a given snapshot and split size. """ - file_path = self._file_scan_task.file.file_path - return hashlib.sha256(file_path.encode("utf-8")).hexdigest() + paths = sorted(t.file.file_path for t in self._file_scan_tasks) + combined = "\0".join(paths) + return hashlib.sha256(combined.encode("utf-8")).hexdigest() @property def table_properties(self) -> Mapping[str, str]: @@ -75,11 +78,9 @@ def table_properties(self) -> Mapping[str, str]: return MappingProxyType(self._scan_context.table_metadata.properties) def __iter__(self) -> Iterator[RecordBatch]: - """Reads the file scan task and yields Arrow RecordBatches. + """Reads the file scan tasks and yields Arrow RecordBatches. - Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution, - delete files, and partition spec lookups. The number of batches loaded - into memory at once is bounded to prevent using too much memory at once. + When the split contains multiple files, they are read concurrently. """ ctx = self._scan_context if ctx.worker_jvm_args is not None: @@ -92,8 +93,8 @@ def __iter__(self) -> Iterator[RecordBatch]: ) batches = arrow_scan.to_record_batches( - [self._file_scan_task], - order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size), + self._file_scan_tasks, + order=ArrivalOrder(concurrent_streams=len(self._file_scan_tasks), batch_size=self._batch_size), ) if self._transform_sql is None: diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index 8816ab1c0..cf10fa2c1 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -559,12 +559,12 @@ def fake_scan(**kwargs): # Without branch: splits come from main snapshot main_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")) assert len(main_splits) == 1 - assert main_splits[0]._file_scan_task.file.file_path == "main.parquet" + assert main_splits[0]._file_scan_tasks[0].file.file_path == "main.parquet" # With branch: splits come from branch snapshot branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch")) assert len(branch_splits) == 1 - assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet" + assert branch_splits[0]._file_scan_tasks[0].file.file_path == "branch.parquet" # --- batch_size tests --- @@ -594,6 +594,55 @@ def test_batch_size_default_is_none(tmp_path): assert split._batch_size is None +# --- files_per_split tests --- + + +def _add_file_tasks(catalog, num_tasks: int) -> None: + """Override plan_files on a catalog from _make_real_catalog to return multiple mock tasks.""" + mock_table = catalog.load_table.return_value + original_scan = mock_table.scan.side_effect + + def multi_file_scan(**kwargs): + scan = original_scan(**kwargs) + scan.plan_files.return_value = [ + MagicMock(file=MagicMock(file_path=f"file_{i}.parquet")) for i in range(num_tasks) + ] + return scan + + mock_table.scan.side_effect = multi_file_scan + + +def test_files_per_split_groups_tasks(tmp_path): + """files_per_split=2 groups 4 files into 2 splits of 2 files each.""" + catalog = _make_real_catalog(tmp_path) + _add_file_tasks(catalog, 4) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2) + splits = list(loader) + + assert len(splits) == 2 + for split in splits: + assert len(split._file_scan_tasks) == 2 + + +def test_files_per_split_remainder_split(tmp_path): + """When files don't divide evenly, the last split gets the remainder.""" + catalog = _make_real_catalog(tmp_path) + _add_file_tasks(catalog, 5) + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=3) + splits = list(loader) + + assert len(splits) == 2 + assert len(splits[0]._file_scan_tasks) == 3 + assert len(splits[1]._file_scan_tasks) == 2 + + +def test_files_per_split_invalid_raises(): + """files_per_split < 1 raises ValueError.""" + catalog = MagicMock() + with pytest.raises(ValueError, match="files_per_split must be at least 1"): + OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=0) + + # --- Predicate pushdown with transformer tests --- diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index afdea471d..eca583bff 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -100,7 +100,7 @@ def _create_test_split( task = FileScanTask(data_file=data_file) return DataLoaderSplit( - file_scan_task=task, + file_scan_tasks=[task], scan_context=scan_context, transform_sql=transform_sql, udf_registry=udf_registry, @@ -468,3 +468,30 @@ def test_split_batch_size_preserves_data(tmp_path): result = pa.Table.from_batches(list(split)) assert result.num_rows == 25 assert sorted(result.column("id").to_pylist()) == list(range(25)) + + +# --- multi-file split tests --- + + +def test_multi_file_split_returns_all_rows(tmp_path): + """A split with multiple files yields rows from all files.""" + schema = _BATCH_SCHEMA + table_a = pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}) + table_b = pa.table({"id": pa.array([4, 5, 6], type=pa.int64())}) + split_a = _create_test_split(tmp_path, table_a, FileFormat.PARQUET, schema, filename="a.parquet") + split_b = _create_test_split(tmp_path, table_b, FileFormat.PARQUET, schema, filename="b.parquet") + + combined = DataLoaderSplit( + file_scan_tasks=split_a._file_scan_tasks + split_b._file_scan_tasks, + scan_context=split_a._scan_context, + ) + result = pa.Table.from_batches(list(combined)) + + assert result.num_rows == 6 + assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6] + + reversed_split = DataLoaderSplit( + file_scan_tasks=split_b._file_scan_tasks + split_a._file_scan_tasks, + scan_context=split_a._scan_context, + ) + assert reversed_split.id == combined.id