Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from collections.abc import Mapping
from dataclasses import dataclass, field

from pyiceberg.expressions import AlwaysTrue, BooleanExpression
from pyiceberg.io import FileIO, load_file_io
Expand All @@ -17,6 +18,7 @@ def _unpickle_scan_context(
row_filter: BooleanExpression,
table_id: TableIdentifier,
worker_jvm_args: str | None = None,
metric_attributes: Mapping[str, str] | None = None,
) -> TableScanContext:
return TableScanContext(
table_metadata=table_metadata,
Expand All @@ -25,6 +27,7 @@ def _unpickle_scan_context(
row_filter=row_filter,
table_id=table_id,
worker_jvm_args=worker_jvm_args,
metric_attributes=metric_attributes if metric_attributes is not None else {},
)


Expand All @@ -42,6 +45,7 @@ class TableScanContext:
table_id: Identifier for the table being scanned
row_filter: Row-level filter expression pushed down to the scan
worker_jvm_args: JVM arguments applied when the JNI JVM is created in worker processes
metric_attributes: Attributes attached to every metric emitted while iterating splits.
"""

table_metadata: TableMetadata
Expand All @@ -50,6 +54,7 @@ class TableScanContext:
table_id: TableIdentifier
row_filter: BooleanExpression = AlwaysTrue()
worker_jvm_args: str | None = None
metric_attributes: Mapping[str, str] = field(default_factory=dict)

def __reduce__(self) -> tuple:
return (
Expand All @@ -61,5 +66,6 @@ def __reduce__(self) -> tuple:
self.row_filter,
self.table_id,
self.worker_jvm_args,
dict(self.metric_attributes),
),
)
103 changes: 91 additions & 12 deletions integrations/python/dataloader/src/openhouse/dataloader/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import time
import uuid
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
Expand All @@ -9,6 +10,7 @@
from types import MappingProxyType
from typing import TypeVar

from opentelemetry.metrics import Counter, Histogram, get_meter
from pyiceberg.catalog import Catalog
from pyiceberg.table import Table
from pyiceberg.table.snapshots import Snapshot
Expand All @@ -28,13 +30,47 @@
_to_pyiceberg,
always_true,
)
from openhouse.dataloader.metrics import METER_NAME
from openhouse.dataloader.scan_optimizer import optimize_scan
from openhouse.dataloader.table_identifier import TableIdentifier
from openhouse.dataloader.table_transformer import TableTransformer
from openhouse.dataloader.udf_registry import UDFRegistry

logger = logging.getLogger(__name__)

_meter = get_meter(METER_NAME)

_load_table_duration = _meter.create_histogram(
name="OpenHouse.DataLoader.LoadTableTime",
unit="s",
description="Time spent loading the Iceberg table from the catalog.",
)
_load_table_success = _meter.create_counter(
name="OpenHouse.DataLoader.LoadTableSuccess",
unit="1",
description="Successful loads of the Iceberg table from the catalog.",
)
_load_table_failure = _meter.create_counter(
name="OpenHouse.DataLoader.LoadTableFailure",
unit="1",
description="Failed loads of the Iceberg table from the catalog.",
)
_plan_files_duration = _meter.create_histogram(
name="OpenHouse.DataLoader.PlanFilesTime",
unit="s",
description="Time spent planning which files to scan.",
)
_plan_files_success = _meter.create_counter(
name="OpenHouse.DataLoader.PlanFilesSuccess",
unit="1",
description="Successful file-planning operations for the scan.",
)
_plan_files_failure = _meter.create_counter(
name="OpenHouse.DataLoader.PlanFilesFailure",
unit="1",
description="Failed file-planning operations for the scan.",
)


def _is_transient(exc: BaseException) -> bool:
"""Return True if the exception is transient and worth retrying."""
Expand All @@ -52,22 +88,38 @@ def _batched(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]:
yield batch


def _retry(fn: Callable[[], _T], label: str, max_attempts: int) -> _T:
"""Call *fn* with retry logic, logging duration of each attempt.
def _retry(
fn: Callable[[], _T],
label: str,
max_attempts: int,
duration_histogram: Histogram,
success_counter: Counter,
failure_counter: Counter,
attributes: Mapping[str, str],
) -> _T:
"""Call *fn* with retry logic, logging the duration and recording the outcome.

Retries on ``OSError`` (transient network/storage I/O failures),
except ``HTTPError`` which is only retried for 5xx status codes.
Uses exponential backoff with up to *max_attempts* total attempts.
"""
for attempt in Retrying(
retry=retry_if_exception(_is_transient),
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(),
reraise=True,
):
with attempt, log_duration(logger, "%s (attempt %d)", label, attempt.retry_state.attempt_number):
return fn()
raise AssertionError("unreachable") # pragma: no cover
overall_start = time.monotonic()
succeeded = False
try:
for attempt in Retrying(
retry=retry_if_exception(_is_transient),
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(),
reraise=True,
):
with attempt, log_duration(logger, "%s (attempt %d)", label, attempt.retry_state.attempt_number):
result = fn()
succeeded = True
return result
raise AssertionError("unreachable") # pragma: no cover
finally:
duration_histogram.record(time.monotonic() - overall_start, attributes)
(success_counter if succeeded else failure_counter).add(1, attributes)


@dataclass(frozen=True)
Expand Down Expand Up @@ -102,13 +154,15 @@ class DataLoaderContext:

Args:
execution_context: Dictionary of execution context information (e.g. tenant, environment)
metric_attribute_keys: Keys from ``execution_context`` to attach as dimensions on emitted metrics.
table_transformer: Transformation to apply to the table before loading (e.g. column masking)
udf_registry: UDFs required for the table transformation
jvm_config: JVM configuration for JNI-based storage access. Currently only HDFS is supported
via the ``LIBHDFS_OPTS`` environment variable. See :class:`JvmConfig`.
"""

execution_context: Mapping[str, str] | None = None
metric_attribute_keys: Sequence[str] | None = None
table_transformer: TableTransformer | None = None
udf_registry: UDFRegistry | None = None
jvm_config: JvmConfig | None = None
Expand Down Expand Up @@ -169,12 +223,30 @@ def __init__(
if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None:
apply_libhdfs_opts(self._context.jvm_config.planner_args)

@cached_property
def _resolved_metric_attributes(self) -> Mapping[str, str]:
attrs: dict[str, str] = {
"OpenHouse.Database": self._table_id.database,
"OpenHouse.Table": self._table_id.table,
}
keys = self._context.metric_attribute_keys
if keys:
execution_context = self._context.execution_context or {}
for k in keys:
if k in execution_context:
attrs[k] = execution_context[k]
return attrs

@cached_property
def _iceberg_table(self) -> Table:
return _retry(
lambda: self._catalog.load_table((self._table_id.database, self._table_id.table)),
label=f"load_table {self._table_id}",
max_attempts=self._max_attempts,
duration_histogram=_load_table_duration,
success_counter=_load_table_success,
failure_counter=_load_table_failure,
attributes=self._resolved_metric_attributes,
)

@property
Expand Down Expand Up @@ -283,12 +355,19 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
row_filter=row_filter,
table_id=self._table_id,
worker_jvm_args=self._context.jvm_config.worker_args if self._context.jvm_config else None,
metric_attributes=self._resolved_metric_attributes,
)

# plan_files() materializes all tasks at once (PyIceberg doesn't support streaming)
# Manifests are read in parallel with one thread per manifest
scan_tasks = _retry(
lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts
lambda: scan.plan_files(),
label=f"plan_files {self._table_id}",
max_attempts=self._max_attempts,
duration_histogram=_plan_files_duration,
success_counter=_plan_files_success,
failure_counter=_plan_files_failure,
attributes=self._resolved_metric_attributes,
)

for chunk in _batched(scan_tasks, self._files_per_split):
Expand Down
Loading