diff --git a/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py b/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py index ae20bd9c5..1b7d3ac33 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py @@ -1,6 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass +from collections.abc import Mapping +from dataclasses import dataclass, field from pyiceberg.expressions import AlwaysTrue, BooleanExpression from pyiceberg.io import FileIO, load_file_io @@ -17,6 +18,7 @@ def _unpickle_scan_context( row_filter: BooleanExpression, table_id: TableIdentifier, worker_jvm_args: str | None = None, + metric_attributes: Mapping[str, str] | None = None, ) -> TableScanContext: return TableScanContext( table_metadata=table_metadata, @@ -25,6 +27,7 @@ def _unpickle_scan_context( row_filter=row_filter, table_id=table_id, worker_jvm_args=worker_jvm_args, + metric_attributes=metric_attributes if metric_attributes is not None else {}, ) @@ -42,6 +45,7 @@ class TableScanContext: table_id: Identifier for the table being scanned row_filter: Row-level filter expression pushed down to the scan worker_jvm_args: JVM arguments applied when the JNI JVM is created in worker processes + metric_attributes: Attributes attached to every metric emitted while iterating splits. """ table_metadata: TableMetadata @@ -50,6 +54,7 @@ class TableScanContext: table_id: TableIdentifier row_filter: BooleanExpression = AlwaysTrue() worker_jvm_args: str | None = None + metric_attributes: Mapping[str, str] = field(default_factory=dict) def __reduce__(self) -> tuple: return ( @@ -61,5 +66,6 @@ def __reduce__(self) -> tuple: self.row_filter, self.table_id, self.worker_jvm_args, + dict(self.metric_attributes), ), ) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 4d1310e17..da9a3e943 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import time import uuid from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -9,6 +10,7 @@ from types import MappingProxyType from typing import TypeVar +from opentelemetry.metrics import Counter, Histogram, get_meter from pyiceberg.catalog import Catalog from pyiceberg.table import Table from pyiceberg.table.snapshots import Snapshot @@ -28,6 +30,7 @@ _to_pyiceberg, always_true, ) +from openhouse.dataloader.metrics import METER_NAME from openhouse.dataloader.scan_optimizer import optimize_scan from openhouse.dataloader.table_identifier import TableIdentifier from openhouse.dataloader.table_transformer import TableTransformer @@ -35,6 +38,39 @@ logger = logging.getLogger(__name__) +_meter = get_meter(METER_NAME) + +_load_table_duration = _meter.create_histogram( + name="OpenHouse.DataLoader.LoadTableTime", + unit="s", + description="Time spent loading the Iceberg table from the catalog.", +) +_load_table_success = _meter.create_counter( + name="OpenHouse.DataLoader.LoadTableSuccess", + unit="1", + description="Successful loads of the Iceberg table from the catalog.", +) +_load_table_failure = _meter.create_counter( + name="OpenHouse.DataLoader.LoadTableFailure", + unit="1", + description="Failed loads of the Iceberg table from the catalog.", +) +_plan_files_duration = _meter.create_histogram( + name="OpenHouse.DataLoader.PlanFilesTime", + unit="s", + description="Time spent planning which files to scan.", +) +_plan_files_success = _meter.create_counter( + name="OpenHouse.DataLoader.PlanFilesSuccess", + unit="1", + description="Successful file-planning operations for the scan.", +) +_plan_files_failure = _meter.create_counter( + name="OpenHouse.DataLoader.PlanFilesFailure", + unit="1", + description="Failed file-planning operations for the scan.", +) + def _is_transient(exc: BaseException) -> bool: """Return True if the exception is transient and worth retrying.""" @@ -52,22 +88,38 @@ def _batched(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: yield batch -def _retry(fn: Callable[[], _T], label: str, max_attempts: int) -> _T: - """Call *fn* with retry logic, logging duration of each attempt. +def _retry( + fn: Callable[[], _T], + label: str, + max_attempts: int, + duration_histogram: Histogram, + success_counter: Counter, + failure_counter: Counter, + attributes: Mapping[str, str], +) -> _T: + """Call *fn* with retry logic, logging the duration and recording the outcome. Retries on ``OSError`` (transient network/storage I/O failures), except ``HTTPError`` which is only retried for 5xx status codes. Uses exponential backoff with up to *max_attempts* total attempts. """ - for attempt in Retrying( - retry=retry_if_exception(_is_transient), - stop=stop_after_attempt(max_attempts), - wait=wait_exponential(), - reraise=True, - ): - with attempt, log_duration(logger, "%s (attempt %d)", label, attempt.retry_state.attempt_number): - return fn() - raise AssertionError("unreachable") # pragma: no cover + overall_start = time.monotonic() + succeeded = False + try: + for attempt in Retrying( + retry=retry_if_exception(_is_transient), + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(), + reraise=True, + ): + with attempt, log_duration(logger, "%s (attempt %d)", label, attempt.retry_state.attempt_number): + result = fn() + succeeded = True + return result + raise AssertionError("unreachable") # pragma: no cover + finally: + duration_histogram.record(time.monotonic() - overall_start, attributes) + (success_counter if succeeded else failure_counter).add(1, attributes) @dataclass(frozen=True) @@ -102,6 +154,7 @@ class DataLoaderContext: Args: execution_context: Dictionary of execution context information (e.g. tenant, environment) + metric_attribute_keys: Keys from ``execution_context`` to attach as dimensions on emitted metrics. table_transformer: Transformation to apply to the table before loading (e.g. column masking) udf_registry: UDFs required for the table transformation jvm_config: JVM configuration for JNI-based storage access. Currently only HDFS is supported @@ -109,6 +162,7 @@ class DataLoaderContext: """ execution_context: Mapping[str, str] | None = None + metric_attribute_keys: Sequence[str] | None = None table_transformer: TableTransformer | None = None udf_registry: UDFRegistry | None = None jvm_config: JvmConfig | None = None @@ -169,12 +223,30 @@ def __init__( if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None: apply_libhdfs_opts(self._context.jvm_config.planner_args) + @cached_property + def _resolved_metric_attributes(self) -> Mapping[str, str]: + attrs: dict[str, str] = { + "OpenHouse.Database": self._table_id.database, + "OpenHouse.Table": self._table_id.table, + } + keys = self._context.metric_attribute_keys + if keys: + execution_context = self._context.execution_context or {} + for k in keys: + if k in execution_context: + attrs[k] = execution_context[k] + return attrs + @cached_property def _iceberg_table(self) -> Table: return _retry( lambda: self._catalog.load_table((self._table_id.database, self._table_id.table)), label=f"load_table {self._table_id}", max_attempts=self._max_attempts, + duration_histogram=_load_table_duration, + success_counter=_load_table_success, + failure_counter=_load_table_failure, + attributes=self._resolved_metric_attributes, ) @property @@ -283,12 +355,19 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: row_filter=row_filter, table_id=self._table_id, worker_jvm_args=self._context.jvm_config.worker_args if self._context.jvm_config else None, + metric_attributes=self._resolved_metric_attributes, ) # plan_files() materializes all tasks at once (PyIceberg doesn't support streaming) # Manifests are read in parallel with one thread per manifest scan_tasks = _retry( - lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts + lambda: scan.plan_files(), + label=f"plan_files {self._table_id}", + max_attempts=self._max_attempts, + duration_histogram=_plan_files_duration, + success_counter=_plan_files_success, + failure_counter=_plan_files_failure, + attributes=self._resolved_metric_attributes, ) for chunk in _batched(scan_tasks, self._files_per_split): diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index 77e8c1d81..e9bd9ae82 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -10,6 +10,7 @@ from datafusion import SessionConfig from datafusion.context import SessionContext +from opentelemetry.metrics import get_meter from pyarrow import RecordBatch from pyiceberg.io.pyarrow import ArrowScan from pyiceberg.table import ArrivalOrder, FileScanTask @@ -18,11 +19,70 @@ from openhouse.dataloader._table_scan_context import TableScanContext from openhouse.dataloader._timer import log_duration from openhouse.dataloader.filters import _quote_identifier +from openhouse.dataloader.metrics import METER_NAME from openhouse.dataloader.table_identifier import TableIdentifier from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry logger = logging.getLogger(__name__) +_meter = get_meter(METER_NAME) + +_split_duration = _meter.create_histogram( + name="OpenHouse.DataLoader.SplitTime", + unit="s", + description="Time spent iterating a split.", +) +_split_files = _meter.create_histogram( + name="OpenHouse.DataLoader.SplitFiles", + unit="1", + description="Number of files in a split.", +) +_split_rows = _meter.create_histogram( + name="OpenHouse.DataLoader.SplitRows", + unit="1", + description="Rows yielded by a split.", +) +_split_bytes = _meter.create_histogram( + name="OpenHouse.DataLoader.SplitBytes", + unit="By", + description="Bytes yielded by a split.", +) +_split_batches = _meter.create_histogram( + name="OpenHouse.DataLoader.SplitBatches", + unit="1", + description="Record batches yielded by a split.", +) +_split_errors = _meter.create_counter( + name="OpenHouse.DataLoader.SplitErrors", + unit="1", + description="Errors raised while iterating a split.", +) +_batch_duration = _meter.create_histogram( + name="OpenHouse.DataLoader.BatchTime", + unit="s", + description="Time spent reading a record batch.", +) +_batch_rows = _meter.create_histogram( + name="OpenHouse.DataLoader.BatchRows", + unit="1", + description="Rows in a record batch.", +) +_batch_bytes = _meter.create_histogram( + name="OpenHouse.DataLoader.BatchBytes", + unit="By", + description="Bytes in a record batch.", +) +_batch_errors = _meter.create_counter( + name="OpenHouse.DataLoader.BatchErrors", + unit="1", + description="Errors raised while reading a record batch.", +) +_transform_duration = _meter.create_histogram( + name="OpenHouse.DataLoader.TransformTime", + unit="s", + description="Time spent applying the transform to a record batch.", +) + def to_sql_identifier(table_id: TableIdentifier) -> str: """Return the quoted DataFusion SQL identifier, e.g. ``"db"."tbl"``.""" @@ -57,12 +117,21 @@ def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch: class _TimedBatchIter: - """Wraps a RecordBatch iterator to log the wall-clock time of each ``next()`` call.""" + """Wraps a RecordBatch iterator to log and emit metrics for each ``next()`` call.""" - def __init__(self, inner: Iterator[RecordBatch], split_id: str) -> None: + def __init__( + self, + inner: Iterator[RecordBatch], + split_id: str, + attributes: Mapping[str, str], + ) -> None: self._inner = inner self._split_id = split_id + self._attributes = attributes self._idx = 0 + self.total_rows = 0 + self.total_bytes = 0 + self.batch_count = 0 def __iter__(self) -> _TimedBatchIter: return self @@ -74,11 +143,20 @@ def __next__(self) -> RecordBatch: except StopIteration: raise except Exception: - logger.warning( - "record_batch %s [%d] failed after %.3fs", self._split_id, self._idx, time.monotonic() - start - ) + elapsed = time.monotonic() - start + logger.warning("record_batch %s [%d] failed after %.3fs", self._split_id, self._idx, elapsed) + _batch_errors.add(1, self._attributes) raise - logger.info("record_batch %s [%d] in %.3fs", self._split_id, self._idx, time.monotonic() - start) + elapsed = time.monotonic() - start + logger.info("record_batch %s [%d] in %.3fs", self._split_id, self._idx, elapsed) + rows = batch.num_rows + nbytes = batch.nbytes + _batch_duration.record(elapsed, self._attributes) + _batch_rows.record(rows, self._attributes) + _batch_bytes.record(nbytes, self._attributes) + self.total_rows += rows + self.total_bytes += nbytes + self.batch_count += 1 self._idx += 1 return batch @@ -88,11 +166,16 @@ def _timed_transform( split_id: str, session: SessionContext, apply_fn: Callable[[SessionContext, RecordBatch], Iterator[RecordBatch]], + attributes: Mapping[str, str], ) -> Iterator[RecordBatch]: - """Apply a transform to each batch, logging the wall-clock time of each.""" + """Apply a transform to each batch, logging and recording the wall-clock time of each.""" for idx, batch in enumerate(batches): - with log_duration(logger, "transform_batch %s [%d]", split_id, idx): - transformed = list(apply_fn(session, batch)) + transform_start = time.monotonic() + try: + with log_duration(logger, "transform_batch %s [%d]", split_id, idx): + transformed = list(apply_fn(session, batch)) + finally: + _transform_duration.record(time.monotonic() - transform_start, attributes) yield from transformed @@ -140,34 +223,48 @@ def __iter__(self) -> Iterator[RecordBatch]: ctx = self._scan_context if ctx.worker_jvm_args is not None: apply_libhdfs_opts(ctx.worker_jvm_args) - arrow_scan = ArrowScan( - table_metadata=ctx.table_metadata, - io=ctx.io, - projected_schema=ctx.projected_schema, - row_filter=ctx.row_filter, - ) - - split_id = self.id[:12] - - with log_duration(logger, "setup_scan %s", split_id): - batches = arrow_scan.to_record_batches( - self._file_scan_tasks, - order=ArrivalOrder(concurrent_streams=len(self._file_scan_tasks), batch_size=self._batch_size), + attributes = ctx.metric_attributes + split_start = time.monotonic() + timed: _TimedBatchIter | None = None + try: + arrow_scan = ArrowScan( + table_metadata=ctx.table_metadata, + io=ctx.io, + projected_schema=ctx.projected_schema, + row_filter=ctx.row_filter, ) - timed = _TimedBatchIter(iter(batches), split_id) - - if self._transform_sql is None: - yield from timed - else: - # Materialize the first batch before creating the transform session - # so that the HDFS JVM starts (and picks up worker_jvm_args) before - # any UDF registration code can trigger JNI. - first = next(timed, None) - if first is None: - return - session = _create_transform_session(self._scan_context.table_id, self._udf_registry, self._batch_size) - yield from _timed_transform(chain([first], timed), split_id, session, self._apply_transform) + split_id = self.id[:12] + + with log_duration(logger, "setup_scan %s", split_id): + batches = arrow_scan.to_record_batches( + self._file_scan_tasks, + order=ArrivalOrder(concurrent_streams=len(self._file_scan_tasks), batch_size=self._batch_size), + ) + + timed = _TimedBatchIter(iter(batches), split_id, attributes) + + if self._transform_sql is None: + yield from timed + else: + # Materialize the first batch before creating the transform session + # so that the HDFS JVM starts (and picks up worker_jvm_args) before + # any UDF registration code can trigger JNI. + first = next(timed, None) + if first is None: + return + session = _create_transform_session(self._scan_context.table_id, self._udf_registry, self._batch_size) + yield from _timed_transform(chain([first], timed), split_id, session, self._apply_transform, attributes) + except BaseException: + _split_errors.add(1, attributes) + raise + finally: + _split_duration.record(time.monotonic() - split_start, attributes) + _split_files.record(len(self._file_scan_tasks), attributes) + if timed is not None: + _split_rows.record(timed.total_rows, attributes) + _split_bytes.record(timed.total_bytes, attributes) + _split_batches.record(timed.batch_count, attributes) def _apply_transform(self, session: SessionContext, batch: RecordBatch) -> Iterator[RecordBatch]: """Execute the transform SQL against a single RecordBatch.""" diff --git a/integrations/python/dataloader/src/openhouse/dataloader/metrics/__init__.py b/integrations/python/dataloader/src/openhouse/dataloader/metrics/__init__.py index f180e5d27..d41e2aff1 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/metrics/__init__.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/metrics/__init__.py @@ -12,4 +12,6 @@ meter = get_meter(METER_NAME) """ -METER_NAME = "openhouse.dataloader" +METER_NAME = "OpenHouse.DataLoader" + +__all__ = ["METER_NAME"] diff --git a/integrations/python/dataloader/tests/test_metrics.py b/integrations/python/dataloader/tests/test_metrics.py index 00273ceda..6f074f127 100644 --- a/integrations/python/dataloader/tests/test_metrics.py +++ b/integrations/python/dataloader/tests/test_metrics.py @@ -1,13 +1,379 @@ -"""Tests for the OpenTelemetry metrics infrastructure.""" +"""Tests for the OpenTelemetry metrics emitted by the dataloader.""" +from __future__ import annotations + +import os +import pickle +from collections.abc import Iterator +from unittest.mock import MagicMock + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from opentelemetry import metrics as otel_metrics from opentelemetry.metrics import Meter, get_meter +from opentelemetry.metrics import _internal as otel_metrics_internal +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from pyiceberg.io import load_file_io +from pyiceberg.manifest import DataFile, FileFormat +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC +from pyiceberg.schema import Schema +from pyiceberg.table import FileScanTask +from pyiceberg.table.metadata import new_table_metadata +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER +from pyiceberg.types import LongType, NestedField +from openhouse.dataloader import DataLoaderContext, OpenHouseDataLoader +from openhouse.dataloader._table_scan_context import TableScanContext +from openhouse.dataloader.data_loader import ( + _load_table_duration, + _load_table_failure, + _load_table_success, + _plan_files_duration, + _plan_files_failure, + _plan_files_success, + _retry, +) +from openhouse.dataloader.data_loader_split import DataLoaderSplit from openhouse.dataloader.metrics import METER_NAME +from openhouse.dataloader.table_identifier import TableIdentifier + +# --- Meter / METER_NAME basics --- def test_meter_name_is_stable(): - assert METER_NAME == "openhouse.dataloader" + assert METER_NAME == "OpenHouse.DataLoader" def test_get_meter_with_meter_name_returns_a_meter(): assert isinstance(get_meter(METER_NAME), Meter) + + +# --- DataLoaderContext.metric_attribute_keys resolution --- + + +def _loader(context: DataLoaderContext) -> OpenHouseDataLoader: + return OpenHouseDataLoader(catalog=MagicMock(), database="db", table="tbl", context=context) + + +_BASE_ATTRS = {"OpenHouse.Database": "db", "OpenHouse.Table": "tbl"} + + +def test_resolved_metric_attributes_includes_table_identifier_only_by_default(): + loader = _loader(DataLoaderContext()) + assert dict(loader._resolved_metric_attributes) == _BASE_ATTRS + + +def test_resolved_metric_attributes_picks_whitelisted_keys(): + loader = _loader( + DataLoaderContext( + execution_context={"tenant": "t1", "env": "prod", "user_id": "u-42"}, + metric_attribute_keys=["tenant", "env"], + ) + ) + assert dict(loader._resolved_metric_attributes) == {**_BASE_ATTRS, "tenant": "t1", "env": "prod"} + + +def test_resolved_metric_attributes_skips_missing_keys(): + loader = _loader( + DataLoaderContext( + execution_context={"tenant": "t1"}, + metric_attribute_keys=["tenant", "env"], + ) + ) + assert dict(loader._resolved_metric_attributes) == {**_BASE_ATTRS, "tenant": "t1"} + + +def test_resolved_metric_attributes_no_extras_when_no_keys_configured(): + loader = _loader(DataLoaderContext(execution_context={"tenant": "t1"})) + assert dict(loader._resolved_metric_attributes) == _BASE_ATTRS + + +def test_resolved_metric_attributes_no_extras_when_execution_context_missing(): + loader = _loader(DataLoaderContext(metric_attribute_keys=["tenant"])) + assert dict(loader._resolved_metric_attributes) == _BASE_ATTRS + + +# --- InMemoryMetricReader harness --- + + +@pytest.fixture +def metrics_reader() -> Iterator[InMemoryMetricReader]: + """Install an SDK MeterProvider with an InMemoryMetricReader for the test. + + Resets the one-shot ``_METER_PROVIDER_SET_ONCE`` guard and restores the + prior MeterProvider on exit so other tests are not affected. + """ + reader = InMemoryMetricReader() + provider = MeterProvider(metric_readers=[reader]) + once = otel_metrics_internal._METER_PROVIDER_SET_ONCE + prior_provider = otel_metrics_internal._METER_PROVIDER + prior_done = once._done + once._done = False + otel_metrics.set_meter_provider(provider) + try: + yield reader + finally: + otel_metrics_internal._METER_PROVIDER = prior_provider + once._done = prior_done + + +def _data_points(reader: InMemoryMetricReader, metric_name: str) -> list: + """Collect and return all data points for *metric_name* across scopes. + + ``metric_name`` must be the lowercase form stored by the SDK — the + OpenTelemetry SDK lowercases instrument names at registration time + (``opentelemetry/sdk/metrics/_internal/instrument.py``), even though + the declared names are PascalCase. + """ + data = reader.get_metrics_data() + points: list = [] + if data is None: + return points + for resource_metric in data.resource_metrics: + for scope_metric in resource_metric.scope_metrics: + for metric in scope_metric.metrics: + if metric.name == metric_name: + points.extend(metric.data.data_points) + return points + + +def _attrs(point) -> dict: + return dict(point.attributes) + + +# --- _retry success / failure / duration --- + + +def test_retry_emits_success_and_duration_on_first_try(metrics_reader): + attrs = {"OpenHouse.Database": "db", "OpenHouse.Table": "tbl"} + result = _retry( + lambda: "ok", + label="load_table db.tbl", + max_attempts=3, + duration_histogram=_load_table_duration, + success_counter=_load_table_success, + failure_counter=_load_table_failure, + attributes=attrs, + ) + assert result == "ok" + + successes = _data_points(metrics_reader, "openhouse.dataloader.loadtablesuccess") + assert len(successes) == 1 + assert _attrs(successes[0]) == attrs + assert successes[0].value == 1 + + assert _data_points(metrics_reader, "openhouse.dataloader.loadtablefailure") == [] + + durations = _data_points(metrics_reader, "openhouse.dataloader.loadtabletime") + assert len(durations) == 1 + assert _attrs(durations[0]) == attrs + + +def test_retry_emits_single_success_after_transient_retry(metrics_reader): + attrs = {"OpenHouse.Database": "db", "OpenHouse.Table": "tbl", "Tenant": "t1"} + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + if calls["n"] == 1: + raise OSError("transient") + return "ok" + + result = _retry( + fn, + label="plan_files db.tbl", + max_attempts=3, + duration_histogram=_plan_files_duration, + success_counter=_plan_files_success, + failure_counter=_plan_files_failure, + attributes=attrs, + ) + assert result == "ok" + assert calls["n"] == 2 + + successes = _data_points(metrics_reader, "openhouse.dataloader.planfilessuccess") + assert len(successes) == 1 + assert successes[0].value == 1 + assert _attrs(successes[0])["Tenant"] == "t1" + + assert _data_points(metrics_reader, "openhouse.dataloader.planfilesfailure") == [] + + durations = _data_points(metrics_reader, "openhouse.dataloader.planfilestime") + assert len(durations) == 1 + + +def test_retry_emits_failure_and_duration_on_permanent_failure(metrics_reader): + attrs = {"OpenHouse.Database": "db", "OpenHouse.Table": "tbl"} + + class _NonTransient(Exception): + pass + + def fn(): + raise _NonTransient("nope") + + with pytest.raises(_NonTransient): + _retry( + fn, + label="load_table", + max_attempts=3, + duration_histogram=_load_table_duration, + success_counter=_load_table_success, + failure_counter=_load_table_failure, + attributes=attrs, + ) + + failures = _data_points(metrics_reader, "openhouse.dataloader.loadtablefailure") + assert len(failures) == 1 + assert failures[0].value == 1 + + assert _data_points(metrics_reader, "openhouse.dataloader.loadtablesuccess") == [] + + durations = _data_points(metrics_reader, "openhouse.dataloader.loadtabletime") + assert len(durations) == 1 + + +# --- DataLoaderSplit instrumentation --- + +_SPLIT_SCHEMA = Schema(NestedField(field_id=1, name="id", field_type=LongType(), required=False)) +_SPLIT_TABLE_ID = TableIdentifier("db", "tbl") + + +def _make_split( + tmp_path, + metric_attributes: dict | None = None, + transform_sql: str | None = None, +) -> DataLoaderSplit: + file_path = str(tmp_path / "data.parquet") + table = pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}) + fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)] + pq.write_table(table.cast(pa.schema(fields)), file_path) + + metadata = new_table_metadata( + schema=_SPLIT_SCHEMA, + partition_spec=UNPARTITIONED_PARTITION_SPEC, + sort_order=UNSORTED_SORT_ORDER, + location=str(tmp_path), + ) + scan_context = TableScanContext( + table_metadata=metadata, + io=load_file_io(properties={}, location=file_path), + projected_schema=_SPLIT_SCHEMA, + table_id=_SPLIT_TABLE_ID, + metric_attributes=metric_attributes or {}, + ) + data_file = DataFile.from_args( + file_path=file_path, + file_format=FileFormat.PARQUET, + record_count=table.num_rows, + file_size_in_bytes=os.path.getsize(file_path), + ) + data_file._spec_id = 0 + task = FileScanTask(data_file=data_file) + return DataLoaderSplit(file_scan_tasks=[task], scan_context=scan_context, transform_sql=transform_sql) + + +def test_split_emits_per_split_and_per_batch_metrics(tmp_path, metrics_reader): + expected_attrs = {**_BASE_ATTRS, "Tenant": "t1"} + split = _make_split(tmp_path, metric_attributes=expected_attrs) + batches = list(split) + assert sum(b.num_rows for b in batches) == 3 + + split_duration = _data_points(metrics_reader, "openhouse.dataloader.splittime") + assert len(split_duration) == 1 + assert _attrs(split_duration[0]) == expected_attrs + + split_files = _data_points(metrics_reader, "openhouse.dataloader.splitfiles") + assert len(split_files) == 1 + assert split_files[0].sum == 1 + + split_rows = _data_points(metrics_reader, "openhouse.dataloader.splitrows") + assert len(split_rows) == 1 + assert split_rows[0].sum == 3 + + split_bytes = _data_points(metrics_reader, "openhouse.dataloader.splitbytes") + assert len(split_bytes) == 1 + assert split_bytes[0].sum > 0 + + split_batches = _data_points(metrics_reader, "openhouse.dataloader.splitbatches") + assert len(split_batches) == 1 + assert split_batches[0].sum >= 1 + + batch_duration = _data_points(metrics_reader, "openhouse.dataloader.batchtime") + assert len(batch_duration) == 1 + assert _attrs(batch_duration[0]) == expected_attrs + + batch_rows = _data_points(metrics_reader, "openhouse.dataloader.batchrows") + assert len(batch_rows) == 1 + assert batch_rows[0].sum == 3 + + +def test_batch_read_failure_bumps_error_counters(tmp_path, monkeypatch, metrics_reader): + split = _make_split(tmp_path) + + class _ReaderError(Exception): + pass + + def _fake_to_record_batches(self, scan_tasks, **kwargs): + def _gen(): + raise _ReaderError("boom") + yield # pragma: no cover -- makes this a generator + + return _gen() + + monkeypatch.setattr( + "openhouse.dataloader.data_loader_split.ArrowScan.to_record_batches", + _fake_to_record_batches, + ) + + with pytest.raises(_ReaderError): + list(split) + + batch_errors = _data_points(metrics_reader, "openhouse.dataloader.batcherrors") + assert len(batch_errors) == 1 + assert batch_errors[0].value == 1 + + split_errors = _data_points(metrics_reader, "openhouse.dataloader.spliterrors") + assert len(split_errors) == 1 + assert split_errors[0].value == 1 + + # split.duration is still recorded on failure + split_duration = _data_points(metrics_reader, "openhouse.dataloader.splittime") + assert len(split_duration) == 1 + + +def test_split_with_transform_emits_transform_time(tmp_path, metrics_reader): + expected_attrs = {**_BASE_ATTRS, "Tenant": "t1"} + split = _make_split( + tmp_path, + metric_attributes=expected_attrs, + transform_sql='SELECT id FROM "db"."tbl"', + ) + list(split) + + transform_times = _data_points(metrics_reader, "openhouse.dataloader.transformtime") + assert len(transform_times) == 1 + assert _attrs(transform_times[0]) == expected_attrs + assert transform_times[0].sum > 0 + + +def test_split_without_transform_does_not_emit_transform_time(tmp_path, metrics_reader): + split = _make_split(tmp_path) + list(split) + + assert _data_points(metrics_reader, "openhouse.dataloader.transformtime") == [] + + +# --- TableScanContext.metric_attributes --- + + +def test_table_scan_context_default_metric_attributes_is_empty(tmp_path): + split = _make_split(tmp_path) + assert dict(split._scan_context.metric_attributes) == {} + + +def test_table_scan_context_pickle_preserves_metric_attributes(tmp_path): + split = _make_split(tmp_path, metric_attributes={"Tenant": "t1"}) + restored = pickle.loads(pickle.dumps(split._scan_context)) + assert dict(restored.metric_attributes) == {"Tenant": "t1"}