diff --git a/airbyte/_local_sync_progress.py b/airbyte/_local_sync_progress.py new file mode 100644 index 000000000..3a79e7d9d --- /dev/null +++ b/airbyte/_local_sync_progress.py @@ -0,0 +1,258 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Local-sync progress tracking via direct observation of state messages. + +When PyAirbyte runs a sync locally (`Source.read()`, or `Source` -> `Destination` +via `tally_pending_writes` / `tally_confirmed_writes`), it acts as the in-process +intermediary that buffers every `AirbyteMessage` on its way from the source to the +cache/destination. This gives PyAirbyte the ability to directly observe both: + +- **Source-side cursors**: state messages emitted by the source as it advances + through its records (tracked in `ProgressTracker.tally_records_read`). +- **Destination-committed cursors**: state messages acknowledged by the + destination after batches are committed (tracked in + `ProgressTracker.tally_confirmed_writes`). + +This module provides small helpers used by `airbyte.progress.ProgressTracker` +to extract cursor values from state messages, compute a simple progress +percentage for datetime cursors, and serialize per-stream progress snapshots +for JSONL audit logging. + +This is intentionally distinct from `airbyte.cloud._sync_progress`, which +reconstructs progress from snapshots returned by the Airbyte Platform Config +API. The local-sync path has direct access to every state message and does +not require an external API call. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from airbyte_cdk.utils.datetime_helpers import ab_datetime_try_parse + + +if TYPE_CHECKING: + from airbyte_protocol.models import AirbyteStateMessage + + +# Field names commonly used as datetime cursors by Airbyte connectors. +# Checked in order when multiple keys exist in `stream_state`. +_COMMON_CURSOR_FIELDS: tuple[str, ...] = ( + "updatedAt", + "updated_at", + "createdAt", + "created_at", + "timestamp", + "cursor", + "date", + "modified", + "modified_at", + "last_modified", + "lastModified", +) + + +def _try_parse_datetime_cursor(value: str) -> datetime | None: + """Attempt to parse a string as a datetime. + + Delegates to the CDK's `ab_datetime_try_parse` and rejects pure numeric + strings (which the CDK parser would otherwise interpret as epoch timestamps). + Returns `None` when the value cannot be parsed as a datetime. + """ + stripped = value.strip() + if not stripped: + return None + + try: + float(stripped) + except ValueError: + pass + else: + return None + + dt = ab_datetime_try_parse(stripped) + if dt is None: + return None + + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +def _normalize_stream_state(stream_state: object) -> dict[str, Any] | None: + """Normalize a `stream_state` value to a plain `dict`. + + The Airbyte protocol models typically represent `stream_state` as an + `AirbyteStateBlob` (a Pydantic model with `extra="allow"`) rather than a + raw dict. This helper coerces either form into a dict so downstream + cursor extraction can treat them uniformly. Returns `None` when the + value cannot be represented as a dict. + """ + if stream_state is None: + return None + if isinstance(stream_state, dict): + return stream_state + if hasattr(stream_state, "model_dump"): + dumped = stream_state.model_dump() + if isinstance(dumped, dict): + return dumped + return None + + +_MAX_STATE_RECURSION_DEPTH = 4 + + +def _find_known_cursor( + state_dict: dict[str, Any], + depth: int, +) -> tuple[str | None, str | None]: + """Depth-first search for a `_COMMON_CURSOR_FIELDS` key in nested dicts. + + Some connectors (notably `source-github`) nest per-stream state under a + partition key (e.g. the repo name), so the cursor lives at + `stream_state[][updated_at]` rather than `stream_state[updated_at]`. + This helper walks up to `_MAX_STATE_RECURSION_DEPTH` levels, checking + known cursor field names at each level before descending. Non-`dict` and + `None` values are skipped. + """ + for candidate in _COMMON_CURSOR_FIELDS: + if candidate in state_dict: + raw = state_dict[candidate] + if raw is not None and not isinstance(raw, dict): + return candidate, str(raw) + + if depth >= _MAX_STATE_RECURSION_DEPTH: + return None, None + + for raw in state_dict.values(): + nested = _normalize_stream_state(raw) + if not nested: + continue + cursor_field, cursor_value = _find_known_cursor(nested, depth + 1) + if cursor_field is not None: + return cursor_field, cursor_value + + return None, None + + +def _find_datetime_fallback( + state_dict: dict[str, Any], + depth: int, +) -> tuple[str | None, str | None]: + """Depth-first search for the first datetime-parseable scalar value. + + Used only when no well-known cursor field name (`_COMMON_CURSOR_FIELDS`) + is present anywhere in the state tree. Recurses up to + `_MAX_STATE_RECURSION_DEPTH` levels into nested dicts. + """ + for key, raw in state_dict.items(): + if raw is None: + continue + if isinstance(raw, (str, int, float)): + value = str(raw) + if _try_parse_datetime_cursor(value) is not None: + return key, value + + if depth >= _MAX_STATE_RECURSION_DEPTH: + return None, None + + for raw in state_dict.values(): + nested = _normalize_stream_state(raw) + if not nested: + continue + cursor_field, cursor_value = _find_datetime_fallback(nested, depth + 1) + if cursor_field is not None: + return cursor_field, cursor_value + + return None, None + + +def _extract_cursor_from_stream_state( + stream_state: object, +) -> tuple[str | None, str | None]: + """Return `(cursor_field, cursor_value)` from a `stream_state` blob. + + Recursively searches nested state dicts so per-partition state (e.g. + `source-github`'s `{"": {"updated_at": "..."}}`) is handled. + + The search prefers well-known cursor field names (`updatedAt`, + `created_at`, …) at any depth up to `_MAX_STATE_RECURSION_DEPTH`. If + none of those are present anywhere in the tree, falls back to the first + datetime-parseable scalar. Returns `(None, None)` when no usable cursor + can be extracted. + """ + state_dict = _normalize_stream_state(stream_state) + if not state_dict: + return None, None + + cursor_field, cursor_value = _find_known_cursor(state_dict, depth=0) + if cursor_field is not None: + return cursor_field, cursor_value + + return _find_datetime_fallback(state_dict, depth=0) + + +def extract_cursor_from_state_message( + state_message: AirbyteStateMessage, +) -> tuple[str | None, str | None, str | None]: + """Return `(stream_name, cursor_field, cursor_value)` from a state message. + + Handles `STREAM`-type state messages. `GLOBAL` and `LEGACY` state + messages are not per-stream and return `(None, None, None)` -- callers + should fall back to other tracking strategies for those. + """ + stream = getattr(state_message, "stream", None) + if stream is None: + return None, None, None + + descriptor = getattr(stream, "stream_descriptor", None) + if descriptor is None: + return None, None, None + + stream_name = getattr(descriptor, "name", None) + if not stream_name: + return None, None, None + + cursor_field, cursor_value = _extract_cursor_from_stream_state( + getattr(stream, "stream_state", None) + ) + return stream_name, cursor_field, cursor_value + + +def compute_stream_progress_pct( + *, + baseline_cursor: str | None, + latest_cursor: str | None, + now: datetime | None = None, +) -> float | None: + """Compute a progress percentage for a single stream's datetime cursor. + + The formula is: + + progress = (latest - baseline) / (now - baseline) + + Returns a value clamped to `[0.0, 1.0]`, or `None` when either cursor is + missing, cannot be parsed as a datetime, or when the denominator is not + positive (e.g. a historical backfill where `now` equals the baseline). + """ + if baseline_cursor is None or latest_cursor is None: + return None + + baseline_dt = _try_parse_datetime_cursor(baseline_cursor) + latest_dt = _try_parse_datetime_cursor(latest_cursor) + if baseline_dt is None or latest_dt is None: + return None + + now_dt = now or datetime.now(timezone.utc) + if now_dt.tzinfo is None: + now_dt = now_dt.replace(tzinfo=timezone.utc) + + denominator = (now_dt - baseline_dt).total_seconds() + if denominator <= 0: + return None + + numerator = (latest_dt - baseline_dt).total_seconds() + if numerator < 0: + return 0.0 + + return round(max(0.0, min(1.0, numerator / denominator)), 4) diff --git a/airbyte/_util/api_util.py b/airbyte/_util/api_util.py index 650de345a..c45b3ac36 100644 --- a/airbyte/_util/api_util.py +++ b/airbyte/_util/api_util.py @@ -47,7 +47,7 @@ ) -JOB_WAIT_INTERVAL_SECS = 2.0 +JOB_WAIT_INTERVAL_SECS = 5.0 JOB_WAIT_TIMEOUT_SECS_DEFAULT = 60 * 60 # 1 hour # Job ordering constants for list_jobs API @@ -2104,6 +2104,51 @@ def get_workspace_organization_info( ) +def get_job_debug_info( + job_id: int, + *, + api_root: str, + client_id: SecretString | None, + client_secret: SecretString | None, + bearer_token: SecretString | None, +) -> dict[str, Any]: + """Get debug info for a job, including per-stream records/bytes stats. + + Uses the Config API endpoint: `POST /v1/jobs/get_debug_info`. + + The `streamStats` entries returned for the latest attempt update in + real-time during a running sync, making this call a useful + "proof-of-life" progress signal even when cursor-based progress + cannot be computed (e.g. because the state API returns frozen + cursors mid-sync). + + The response shape looks roughly like: + + ``` + { + "job": {...}, + "attempts": [ + {"attempt": {"streamStats": [ + {"streamName": "contacts", + "stats": {"recordsEmitted": N, "bytesEmitted": N, ...}} + ]}} + ] + } + ``` + + Returns: + The decoded JSON payload from the Config API. + """ + return _make_config_api_request( + path="/jobs/get_debug_info", + json={"id": job_id}, + api_root=api_root, + client_id=client_id, + client_secret=client_secret, + bearer_token=bearer_token, + ) + + def get_connection_state( connection_id: str, *, diff --git a/airbyte/cloud/_sync_progress.py b/airbyte/cloud/_sync_progress.py new file mode 100644 index 000000000..41ed7d702 --- /dev/null +++ b/airbyte/cloud/_sync_progress.py @@ -0,0 +1,398 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Sync progress estimation for datetime-cursor-based incremental streams. + +This module provides functions to estimate per-stream sync progress by +comparing the current cursor value against a known previous bookmark +and the current time (`now`), which serves as the estimated sync target. + +Formula (when previous bookmark is available): + + progress = (cursor_dt - previous_bookmark_dt) / (now - previous_bookmark_dt) + +Where: + +- `previous_bookmark_dt` is the cursor value from the last completed sync. +- `cursor_dt` is the latest committed cursor value parsed as a datetime. +- `now` is the current UTC time (estimated completion target). + +Because the Airbyte state API returns only the current (advancing) +cursor, the previous bookmark is not directly available from a single +snapshot. Callers should supply `previous_state_data` (the state +from the previous completed sync) when available. When it is not +supplied, the module falls back to using `sync_start_time` as the +range anchor, which works for real-time incremental syncs but yields +`progress_pct = None` for historical back-fills where the cursor is +behind `sync_start_time`. + +Each per-stream result always includes the raw factors that went into +the estimate so callers can compute their own progress or track it +across multiple calls: + +- `cursor_value` / `cursor_datetime` -- the current cursor position +- `previous_cursor_value` / `previous_cursor_datetime` -- the + baseline, if known +- `target_datetime` -- the estimated target (`now`) +- `sync_start_time` -- the wall-clock job start + +Only streams with datetime-based cursors are supported. Non-datetime +cursors (integers, opaque tokens, etc.) are skipped. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Any + +from airbyte_cdk.utils.datetime_helpers import ab_datetime_try_parse + +from airbyte.cloud._connection_state import ( + ConnectionStateResponse, + StreamState, + _get_stream_list, +) + + +logger = logging.getLogger(__name__) + + +def _try_parse_datetime_cursor(value: str) -> datetime | None: + """Attempt to parse a string as a datetime. + + Delegates to the CDK's `ab_datetime_try_parse` for the actual parsing. + Returns `None` if the value is numeric or cannot be parsed. + """ + stripped = value.strip() + if not stripped: + return None + + # Reject pure numeric strings — the CDK parser interprets them as + # epoch timestamps, but cursor values like "12345" are opaque tokens. + try: + float(stripped) + except ValueError: + pass + else: + return None + + dt = ab_datetime_try_parse(stripped) + if dt is None: + return None + + # Ensure timezone-aware (assume UTC if naive) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +def _extract_cursor_field_from_catalog( + catalog: dict[str, Any], + stream_name: str, + stream_namespace: str | None, +) -> str | None: + """Extract the cursor field name for a stream from the configured catalog. + + Returns the cursor field name, or `None` if the stream is not found + or does not have a cursor field configured. + """ + # Handle both raw catalog ({"streams": [...]}) and full connection + # response ({"syncCatalog": {"streams": [...]}}) from the Config API. + streams = catalog.get("streams", []) + if not streams and "syncCatalog" in catalog: + streams = catalog["syncCatalog"].get("streams", []) + + for stream_entry in streams: + config = stream_entry.get("config", {}) + stream_info = stream_entry.get("stream", {}) + + entry_name = stream_info.get("name") or config.get("aliasName") + entry_namespace = stream_info.get("namespace") + + # Normalize namespace comparison (None and "" are equivalent) + if entry_name != stream_name: + continue + if (entry_namespace or None) != (stream_namespace or None): + continue + + # cursor_field may be a list of path segments or a plain string + cursor_field = config.get("cursorField") + if isinstance(cursor_field, str) and cursor_field: + return cursor_field + if isinstance(cursor_field, list) and cursor_field: + return ".".join(str(segment) for segment in cursor_field) + + return None + + +def _find_cursor_value_in_state( + stream_state: dict[str, Any] | None, + cursor_field: str | None, +) -> str | None: + """Find a cursor value in a stream state blob. + + Requires `cursor_field` (from the configured catalog) to locate the + cursor. Returns `None` when the cursor field is unknown or absent. + """ + if not stream_state: + return None + + if not cursor_field: + return None + + # Traverse dot-delimited paths (e.g. "metadata.updated_at") + current: Any = stream_state + for segment in cursor_field.split("."): + if isinstance(current, dict) and segment in current: + current = current[segment] + else: + return None + + if current is not None: + return str(current) + + return None + + +def _build_previous_cursor_map( + previous_state_data: dict[str, Any], + catalog_data: dict[str, Any] | None, +) -> dict[tuple[str, str | None], str | None]: + """Build a map of `(stream_name, namespace)` to previous cursor value. + + Parses the previous state snapshot and extracts cursor values for + each stream, returning a lookup dict. + """ + result: dict[tuple[str, str | None], str | None] = {} + if not isinstance(previous_state_data, dict): + return result + + prev_state = ConnectionStateResponse(**previous_state_data) + prev_streams: list[StreamState] = _get_stream_list(prev_state) + + for stream in prev_streams: + s_name = stream.stream_descriptor.name + s_ns = stream.stream_descriptor.namespace + + cursor_field: str | None = None + if catalog_data: + cursor_field = _extract_cursor_field_from_catalog(catalog_data, s_name, s_ns) + + cursor_val = _find_cursor_value_in_state(stream.stream_state, cursor_field) + result[s_name, s_ns] = cursor_val + + return result + + +def _compute_single_stream_progress( + stream: StreamState, + catalog_data: dict[str, Any] | None, + sync_start_time: datetime, + now: datetime, + prev_cursor_map: dict[tuple[str, str | None], str | None], + first_seen_cursors: dict[tuple[str, str | None], str] | None = None, +) -> dict[str, Any]: + """Compute progress for a single stream. + + Returns a dict containing raw factors and the computed `progress_pct`. + """ + stream_name = stream.stream_descriptor.name + stream_namespace = stream.stream_descriptor.namespace + + # Look up cursor field from catalog + cursor_field: str | None = None + if catalog_data: + cursor_field = _extract_cursor_field_from_catalog( + catalog_data, stream_name, stream_namespace + ) + + # Find cursor value in current (advancing) state + cursor_value_str = _find_cursor_value_in_state(stream.stream_state, cursor_field) + + # Find previous cursor value + prev_cursor_str: str | None = prev_cursor_map.get((stream_name, stream_namespace)) + prev_cursor_dt: datetime | None = None + if prev_cursor_str: + prev_cursor_dt = _try_parse_datetime_cursor(prev_cursor_str) + + entry: dict[str, Any] = { + "stream_name": stream_name, + "stream_namespace": stream_namespace, + "cursor_field": cursor_field, + "cursor_value": cursor_value_str, + "cursor_datetime": None, + "previous_cursor_value": prev_cursor_str, + "previous_cursor_datetime": prev_cursor_dt.isoformat() if prev_cursor_dt else None, + "sync_start_time": sync_start_time.isoformat(), + "target_datetime": now.isoformat(), + "progress_pct": None, + "reason": None, + } + + if cursor_value_str is None: + entry["reason"] = "No cursor value found in state." + return entry + + cursor_dt = _try_parse_datetime_cursor(cursor_value_str) + if cursor_dt is None: + entry["reason"] = f"Cursor value '{cursor_value_str}' is not a recognized datetime format." + return entry + + entry["cursor_datetime"] = cursor_dt.isoformat() + + # Resolve the first-observed cursor for this stream (tier 2 baseline). + first_seen_dt: datetime | None = None + if first_seen_cursors: + first_seen_val = first_seen_cursors.get((stream_name, stream_namespace)) + if first_seen_val is not None: + first_seen_dt = _try_parse_datetime_cursor(first_seen_val) + + _compute_progress_pct( + entry=entry, + cursor_dt=cursor_dt, + prev_cursor_dt=prev_cursor_dt, + first_seen_dt=first_seen_dt, + sync_start_time=sync_start_time, + now=now, + ) + return entry + + +def _compute_progress_pct( + *, + entry: dict[str, Any], + cursor_dt: datetime, + prev_cursor_dt: datetime | None, + first_seen_dt: datetime | None, + sync_start_time: datetime, + now: datetime, +) -> None: + """Populate `progress_pct` and `reason` on `entry` in-place. + + Handles two modes: + + 1. *Historical backfill* — the cursor is at or behind the previous + bookmark (e.g. GA4 re-processing from an earlier date). Progress is + measured as `(cursor - first_seen) / (prev_bookmark - first_seen)`. + + 2. *Standard forward progress* — the cursor advances beyond the + baseline toward `now`. Progress is + `(cursor - range_start) / (now - range_start)`. + """ + # --- Historical backfill path --- + if ( + prev_cursor_dt is not None + and cursor_dt <= prev_cursor_dt + and first_seen_dt is not None + and first_seen_dt < prev_cursor_dt + ): + denominator = (prev_cursor_dt - first_seen_dt).total_seconds() + if denominator <= 0: + entry["reason"] = "Range anchor is not before target time." + return + numerator = (cursor_dt - first_seen_dt).total_seconds() + entry["progress_pct"] = round(max(0.0, min(1.0, numerator / denominator)), 4) + if entry["progress_pct"] == 0.0: + entry["reason"] = "Cursor has not yet advanced past the first observed value." + return + + # --- Standard forward-progress path --- + range_start: datetime | None = prev_cursor_dt + if range_start is None and first_seen_dt is not None: + range_start = first_seen_dt + if range_start is None: + range_start = sync_start_time + + denominator = (now - range_start).total_seconds() + if denominator <= 0: + entry["reason"] = "Range anchor is not before target time." + return + + numerator = (cursor_dt - range_start).total_seconds() + if numerator < 0: + _set_negative_progress_reason(entry, prev_cursor_dt) + return + + entry["progress_pct"] = round(max(0.0, min(1.0, numerator / denominator)), 4) + + +def _set_negative_progress_reason( + entry: dict[str, Any], + prev_cursor_dt: datetime | None, +) -> None: + """Set progress and reason when cursor is behind the range anchor.""" + if prev_cursor_dt is not None: + entry["progress_pct"] = 0.0 + entry["reason"] = "Cursor has not advanced past the previous bookmark." + else: + entry["progress_pct"] = None + entry["reason"] = ( + "Cursor is behind sync start time (historical back-fill). " + "Previous bookmark not available; progress indeterminate. " + "Use the raw factor fields to compute progress across " + "multiple calls." + ) + + +def compute_stream_progress( + state_data: dict[str, Any], + catalog_data: dict[str, Any] | None, + sync_start_time: datetime, + now: datetime | None = None, + previous_state_data: dict[str, Any] | None = None, + first_seen_cursors: dict[tuple[str, str | None], str] | None = None, +) -> list[dict[str, Any]]: + """Compute per-stream sync progress for datetime-cursor-based streams. + + Returns a list of per-stream dicts, each containing the raw factors + that went into the progress estimate as well as the computed + `progress_pct` (or `None` when indeterminate). + + Progress formula (when previous bookmark is available): + + progress = (cursor_dt - previous_bookmark_dt) / (now - previous_bookmark_dt) + + Baseline selection uses a tiered fallback: + + 1. Previous completed sync's cursor from `previous_state_data`. + 2. First-observed cursor during polling (`first_seen_cursors`). + 3. `sync_start_time` (wall-clock job start). + + Tier 3 works for real-time incremental syncs where cursors advance + near wall-clock time, but yields `progress_pct = None` for + historical back-fills where the cursor is behind `sync_start_time`. + """ + if now is None: + now = datetime.now(timezone.utc) + + # Ensure times are timezone-aware + if sync_start_time.tzinfo is None: + sync_start_time = sync_start_time.replace(tzinfo=timezone.utc) + if now.tzinfo is None: + now = now.replace(tzinfo=timezone.utc) + + # Build previous cursor lookup if previous state was provided + prev_cursor_map: dict[tuple[str, str | None], str | None] = {} + if previous_state_data: + prev_cursor_map = _build_previous_cursor_map(previous_state_data, catalog_data) + + if not isinstance(state_data, dict): + logger.warning( + "Expected dict for state_data, got %s; returning empty progress.", + type(state_data).__name__, + ) + return [] + + state = ConnectionStateResponse(**state_data) + streams: list[StreamState] = _get_stream_list(state) + + return [ + _compute_single_stream_progress( + stream=stream, + catalog_data=catalog_data, + sync_start_time=sync_start_time, + now=now, + prev_cursor_map=prev_cursor_map, + first_seen_cursors=first_seen_cursors, + ) + for stream in streams + ] diff --git a/airbyte/cloud/connections.py b/airbyte/cloud/connections.py index cfa3149a1..87376c2ed 100644 --- a/airbyte/cloud/connections.py +++ b/airbyte/cloud/connections.py @@ -15,7 +15,8 @@ _match_stream, ) from airbyte.cloud.connectors import CloudDestination, CloudSource -from airbyte.cloud.sync_results import SyncResult +from airbyte.cloud.constants import JobStatusEnum +from airbyte.cloud.sync_results import SyncResult, _extract_stream_stats from airbyte.exceptions import AirbyteWorkspaceMismatchError, PyAirbyteInputError @@ -262,8 +263,40 @@ def run_sync( *, wait: bool = True, wait_timeout: int = 300, + with_rich_status_updates: bool | int = False, + progress_log_path: str | None = None, ) -> SyncResult: - """Run a sync.""" + """Run a sync. + + When `with_rich_status_updates` is truthy, a Rich Live table + showing per-stream progress is displayed while waiting for + completion. Requires `wait=True`; passing `wait=False` with a + truthy `with_rich_status_updates` raises `ValueError`. + + When `progress_log_path` is set, each Rich polling iteration + appends a JSONL line to the given file with timestamped + per-stream progress data for auditing. + """ + if not wait and with_rich_status_updates: + raise ValueError( + "Cannot use `with_rich_status_updates` when `wait=False`. " + "Rich status updates require waiting for the sync to complete." + ) + + # Snapshot the current committed state *before* triggering the sync. + # This gives us the baseline (denominator) for progress calculation, + # since the state API only exposes the latest advancing cursor. + pre_sync_state: dict[str, Any] | None = None + pre_sync_stream_stats: dict[str, dict[str, int]] | None = None + if with_rich_status_updates: + pre_sync_state = self.dump_raw_state() + # Fetch the previous completed sync's per-stream records/bytes + # stats so we can use them as a rough denominator for the + # mid-sync records-based progress signal. Best-effort — a + # missing previous sync just means the Records % column will + # render as `--` during the new sync. + pre_sync_stream_stats = self._fetch_latest_sync_stream_stats() + connection_response = api_util.run_connection( connection_id=self.connection_id, api_root=self.workspace.api_root, @@ -276,6 +309,8 @@ def run_sync( workspace=self.workspace, connection=self, job_id=connection_response.job_id, + _pre_sync_state=pre_sync_state, + _pre_sync_stream_stats=pre_sync_stream_stats, ) if wait: @@ -283,6 +318,8 @@ def run_sync( wait_timeout=wait_timeout, raise_failure=True, raise_timeout=True, + with_rich_status_updates=with_rich_status_updates, + progress_log_path=progress_log_path, ) return sync_result @@ -376,6 +413,77 @@ def get_sync_result( job_id=job_id, ) + def get_previous_sync_state( + self, + *, + current_job_id: int | None = None, + ) -> dict[str, Any] | None: + """Get the state from the most recent completed sync job. + + Fetches the previous completed (succeeded) sync job from job history + and extracts the final committed state from its last attempt's output. + This is useful for determining the baseline cursor values before + the current sync started advancing state. + + When `current_job_id` is provided, that job is skipped so the + returned state always comes from a *previous* job. + + Returns the state dict (same shape as `dump_raw_state()`), or `None` + if no previous completed sync is found or if state data is not + available in the job output. + """ + previous_jobs = self.get_previous_sync_logs(limit=5) + + for job in previous_jobs: + # Skip the current job if specified + if current_job_id is not None and job.job_id == current_job_id: + continue + + if job.get_job_status() != JobStatusEnum.SUCCEEDED: + continue + + # Fetch full job data including attempt output + job_data = job._fetch_job_with_attempts() # noqa: SLF001 + attempts = job_data.get("attempts", []) + if not attempts: + continue + + last_attempt = attempts[-1] + output = last_attempt.get("attempt", {}).get("output", {}) + state = output.get("state") + if isinstance(state, dict): + return state + + return None + + def _fetch_latest_sync_stream_stats(self) -> dict[str, dict[str, int]] | None: + """Fetch per-stream records/bytes stats from the most recent completed sync. + + Used as a rough denominator for the mid-sync records-based + progress signal (Records % column). Walks recent job history to + find the latest `SUCCEEDED` sync and pulls its `streamStats` via + the Config API's `jobs/get_debug_info` endpoint. + + Returns a mapping from stream name to a dict of integer stats + (`records_emitted`, `bytes_emitted`), or `None` if no suitable + previous sync is found. + """ + previous_jobs = self.get_previous_sync_logs(limit=5) + for job in previous_jobs: + if job.get_job_status() != JobStatusEnum.SUCCEEDED: + continue + debug_info = api_util.get_job_debug_info( + job_id=job.job_id, + api_root=self.workspace.api_root, + client_id=self.workspace.client_id, + client_secret=self.workspace.client_secret, + bearer_token=self.workspace.bearer_token, + ) + stats = _extract_stream_stats(debug_info) + if stats: + return stats + return None + # Artifacts @deprecated("Use 'dump_raw_state()' instead.") diff --git a/airbyte/cloud/sync_results.py b/airbyte/cloud/sync_results.py index 8920b9615..ae2230a4a 100644 --- a/airbyte/cloud/sync_results.py +++ b/airbyte/cloud/sync_results.py @@ -100,36 +100,314 @@ from __future__ import annotations +import json import time +import warnings from collections.abc import Iterator, Mapping from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path from typing import TYPE_CHECKING, Any +from pydantic import ValidationError +from rich.console import Console +from rich.live import Live as RichLive +from rich.table import Table from typing_extensions import final from airbyte_cdk.utils.datetime_helpers import ab_datetime_parse from airbyte._util import api_util from airbyte.caches._utils._dest_to_cache import destination_to_cache -from airbyte.cloud.constants import FAILED_STATUSES, FINAL_STATUSES +from airbyte.cloud._connection_state import ( + ConnectionStateResponse, + _get_stream_list, +) +from airbyte.cloud._sync_progress import ( + _extract_cursor_field_from_catalog, + _find_cursor_value_in_state, + compute_stream_progress, +) +from airbyte.cloud.constants import FAILED_STATUSES, FINAL_STATUSES, JobStatusEnum from airbyte.datasets import CachedDataset -from airbyte.exceptions import AirbyteConnectionSyncError, AirbyteConnectionSyncTimeoutError +from airbyte.exceptions import ( + AirbyteConnectionSyncError, + AirbyteConnectionSyncTimeoutError, + AirbyteError, +) DEFAULT_SYNC_TIMEOUT_SECONDS = 30 * 60 # 30 minutes """The default timeout for waiting for a sync job to complete, in seconds.""" -if TYPE_CHECKING: - from datetime import datetime +MIN_RICH_UPDATE_INTERVAL_SECS = 15 +"""Minimum polling interval when Rich status updates are enabled.""" +DEFAULT_RICH_UPDATE_INTERVAL_SECS = 15 +"""Default polling interval when `with_rich_status_updates=True`.""" + +if TYPE_CHECKING: import sqlalchemy - from airbyte._util.api_imports import ConnectionResponse, JobResponse, JobStatusEnum + from airbyte._util.api_imports import ConnectionResponse, JobResponse from airbyte.caches.base import CacheBase from airbyte.cloud.connections import CloudConnection from airbyte.cloud.workspaces import CloudWorkspace +def _resolve_rich_interval(*, with_rich_status_updates: bool | int) -> float: + """Normalize `with_rich_status_updates` to a polling interval in seconds. + + `True` maps to `DEFAULT_RICH_UPDATE_INTERVAL_SECS`. An `int` is + clamped to `MIN_RICH_UPDATE_INTERVAL_SECS` with a warning when the + caller-provided value is too low. + """ + if with_rich_status_updates is True: + return float(DEFAULT_RICH_UPDATE_INTERVAL_SECS) + + interval = int(with_rich_status_updates) + if interval < MIN_RICH_UPDATE_INTERVAL_SECS: + warnings.warn( + f"Rich status update interval {interval}s is below the minimum " + f"of {MIN_RICH_UPDATE_INTERVAL_SECS}s. Using {MIN_RICH_UPDATE_INTERVAL_SECS}s.", + UserWarning, + stacklevel=3, + ) + return float(MIN_RICH_UPDATE_INTERVAL_SECS) + + return float(interval) + + +def _format_bytes(num_bytes: int) -> str: + """Format a byte count as a human-readable string (e.g. ``1.2 MB``).""" + if num_bytes < 1_000: # noqa: PLR2004 # Byte thresholds are self-documenting + return f"{num_bytes} B" + if num_bytes < 1_000_000: # noqa: PLR2004 + return f"{num_bytes / 1_000:,.1f} KB" + if num_bytes < 1_000_000_000: # noqa: PLR2004 + return f"{num_bytes / 1_000_000:,.1f} MB" + return f"{num_bytes / 1_000_000_000:,.2f} GB" + + +def _extract_stream_stats(debug_info: dict[str, Any]) -> dict[str, dict[str, int]]: + """Extract per-stream `recordsEmitted` / `bytesEmitted` from a debug-info payload. + + The Config API's `jobs/get_debug_info` endpoint returns an `attempts` + list; the latest attempt's `streamStats` updates in real-time during + a running sync and is our primary mid-sync proof-of-life signal. + + Returns a mapping from stream name to a dict of integer stats + (`records_emitted`, `bytes_emitted`). Missing or malformed payloads + yield an empty dict. + """ + attempts = debug_info.get("attempts") or [] + if not attempts: + return {} + + latest_attempt = attempts[-1] + attempt_inner = latest_attempt.get("attempt") or latest_attempt + stream_stats_list = attempt_inner.get("streamStats") or [] + + result: dict[str, dict[str, int]] = {} + for entry in stream_stats_list: + stream_name = entry.get("streamName") or entry.get("stream_name") + if not stream_name: + continue + stats = entry.get("stats") or {} + result[stream_name] = { + "records_emitted": int(stats.get("recordsEmitted") or 0), + "bytes_emitted": int(stats.get("bytesEmitted") or 0), + } + return result + + +def _append_progress_log_entry( + log_path: Path, + elapsed_secs: float, + job_status: str, + records_synced: int, + bytes_synced: int, + stream_progress: list[dict[str, Any]], + current_stream_stats: dict[str, dict[str, int]], + previous_stream_stats: dict[str, dict[str, int]], +) -> None: + """Append a single JSONL entry describing per-stream sync progress. + + Factored into a module-level helper so the write happens in a scope + that does not reference any credential attributes (`client_secret`, + `bearer_token`). The values passed in are all non-sensitive progress + metrics: stream names, cursor positions, record counts, and timestamps. + """ + log_streams: list[dict[str, Any]] = [] + progress_names: dict[str, dict[str, Any]] = { + str(e.get("stream_name", "")): e for e in stream_progress + } + all_names: list[str] = list(progress_names.keys()) + all_names.extend(name for name in current_stream_stats if name not in progress_names) + + for name in all_names: + entry = progress_names.get(name, {"stream_name": name}) + stats = current_stream_stats.get(name, {}) + prev_stats = previous_stream_stats.get(name, {}) + prev_recs = prev_stats.get("records_emitted", 0) + recs = stats.get("records_emitted", 0) + records_progress = min(recs / prev_recs, 1.0) if prev_recs > 0 else None + log_streams.append( + { + "stream_name": entry.get("stream_name", name), + "progress_pct": entry.get("progress_pct"), + "cursor_value": entry.get("cursor_value"), + "previous_cursor_value": entry.get("previous_cursor_value"), + "reason": entry.get("reason"), + "records_emitted": recs, + "bytes_emitted": stats.get("bytes_emitted", 0), + "records_progress": records_progress, + "previous_records_emitted": prev_recs, + } + ) + + log_entry = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "elapsed_secs": round(elapsed_secs, 1), + "job_status": job_status, + "records_synced": records_synced, + "bytes_synced": bytes_synced, + "streams": log_streams, + } + with log_path.open("a") as f: + f.write(json.dumps(log_entry) + "\n") + + +def _build_rich_table( # noqa: PLR0913, PLR0914, PLR0915, PLR0917 + stream_progress: list[dict[str, Any]], + job_status: str, + elapsed_secs: float, + sync_start_time: datetime | None = None, + total_selected_streams: int | None = None, + records_synced: int = 0, + bytes_synced: int = 0, + stream_stats: dict[str, dict[str, int]] | None = None, + previous_stream_stats: dict[str, dict[str, int]] | None = None, +) -> Table: + """Build a Rich `Table` showing per-stream sync progress. + + `stream_stats` is the current mid-sync per-stream `recordsEmitted` / + `bytesEmitted` map (from the Config API's `jobs/get_debug_info` + endpoint). `previous_stream_stats` is the same map from the previous + completed sync, used as a rough denominator for the records-based + progress signal. + """ + elapsed_str = _format_elapsed(elapsed_secs) + stream_stats = stream_stats or {} + previous_stream_stats = previous_stream_stats or {} + + streams_with_pct = sum(1 for s in stream_progress if s.get("progress_pct") is not None) + # Use the catalog stream count as the denominator when available; + # fall back to the number of streams currently reporting state. + total_streams = total_selected_streams or len(stream_progress) + + # Build throughput string (records/sec, bytes/sec) like Source.read() + throughput_parts: list[str] = [] + if records_synced: + throughput_parts.append(f"{records_synced:,} records") + if bytes_synced: + throughput_parts.append(_format_bytes(bytes_synced)) + if records_synced and elapsed_secs > 0: + rps = records_synced / elapsed_secs + throughput_parts.append(f"{rps:,.1f} records/s") + + title = ( + f"Sync Progress | Status: {job_status} | " + f"Elapsed: {elapsed_str} | " + f"Streams: {streams_with_pct}/{total_streams} reporting progress" + ) + if throughput_parts: + title += f" | {', '.join(throughput_parts)}" + + # Build a caption with start / end / elapsed timestamps + caption_parts: list[str] = [] + if sync_start_time is not None: + caption_parts.append(f"Start: {sync_start_time:%Y-%m-%d %H:%M:%S} UTC") + end_time = datetime.now(timezone.utc) + caption_parts.extend( + [ + f"Current: {end_time:%Y-%m-%d %H:%M:%S} UTC", + f"Elapsed: {elapsed_str}", + ] + ) + caption = " | ".join(caption_parts) + + table = Table(title=title, caption=caption, show_lines=False, expand=True) + table.add_column("Stream", style="cyan", no_wrap=True) + table.add_column("Records", justify="right", style="magenta") + table.add_column("Bytes", justify="right", style="magenta") + table.add_column("Records %", justify="right", style="yellow") + table.add_column("Cursor %", justify="right", style="green") + table.add_column("Stream Status", style="bold") + table.add_column("Reason", style="dim") + + # Union of stream names from cursor-based progress and records-based stats + # so proof-of-life shows up even when cursor progress can't compute a pct. + progress_by_name: dict[str, dict[str, Any]] = { + e.get("stream_name", "?"): e for e in stream_progress + } + all_names: list[str] = list(progress_by_name.keys()) + all_names.extend(name for name in stream_stats if name not in progress_by_name) + + for name in all_names: + entry = progress_by_name.get(name, {}) + pct = entry.get("progress_pct") + cursor_pct_str = f"{pct:.1%}" if pct is not None else "--" + reason = entry.get("reason") or "" + + stats = stream_stats.get(name, {}) + recs = stats.get("records_emitted", 0) + byts = stats.get("bytes_emitted", 0) + recs_str = f"{recs:,}" if recs else "0" + bytes_str = _format_bytes(byts) if byts else "0 B" + + prev_stats = previous_stream_stats.get(name, {}) + prev_recs = prev_stats.get("records_emitted", 0) + if prev_recs > 0: + records_pct = min(recs / prev_recs, 1.0) + records_pct_str = f"{records_pct:.1%}" + else: + records_pct_str = "--" + + if recs > 0: + stream_status = ( + "running" + if job_status.lower() not in {"succeeded", "failed"} + else ("complete" if job_status.lower() == "succeeded" else "failed") + ) + else: + stream_status = "pending" + + table.add_row( + name, + recs_str, + bytes_str, + records_pct_str, + cursor_pct_str, + stream_status, + reason, + ) + + return table + + +def _format_elapsed(seconds: float) -> str: + """Format elapsed seconds as `HH:MM:SS`.""" + total = int(seconds) + hours, remainder = divmod(total, 3600) + minutes, secs = divmod(remainder, 60) + if hours: + return f"{hours}h {minutes:02d}m {secs:02d}s" + if minutes: + return f"{minutes}m {secs:02d}s" + return f"{secs}s" + + @dataclass class SyncAttempt: """Represents a single attempt of a sync job. @@ -214,6 +492,41 @@ def get_full_log_text(self) -> str: return result +def _update_first_seen_cursors( + *, + first_seen_cursors: dict[tuple[str, str | None], str], + state_data: dict[str, Any], + catalog_data: dict[str, Any] | None, +) -> None: + """Record the first observed cursor value for each stream. + + On the first poll iteration where a stream appears in the state, + its cursor value is captured. This provides a fallback baseline + for first-ever syncs where no previous completed state exists. + """ + try: + state = ConnectionStateResponse(**state_data) + streams = _get_stream_list(state) + except (ValidationError, TypeError, KeyError): + return + + for stream in streams: + key = (stream.stream_descriptor.name, stream.stream_descriptor.namespace) + if key in first_seen_cursors: + continue # already recorded + + cursor_field: str | None = None + if catalog_data: + cursor_field = _extract_cursor_field_from_catalog( + catalog_data, + stream.stream_descriptor.name, + stream.stream_descriptor.namespace, + ) + cursor_val = _find_cursor_value_in_state(stream.stream_state, cursor_field) + if cursor_val is not None: + first_seen_cursors[key] = cursor_val + + @dataclass class SyncResult: """The result of a sync operation. @@ -231,6 +544,8 @@ class SyncResult: _connection_response: ConnectionResponse | None = None _cache: CacheBase | None = None _job_with_attempts_info: dict[str, Any] | None = None + _pre_sync_state: dict[str, Any] | None = None + _pre_sync_stream_stats: dict[str, dict[str, int]] | None = None @property def job_url(self) -> str: @@ -388,16 +703,79 @@ def wait_for_completion( wait_timeout: int = DEFAULT_SYNC_TIMEOUT_SECONDS, raise_timeout: bool = True, raise_failure: bool = False, + with_rich_status_updates: bool | int = False, + progress_log_path: str | Path | None = None, ) -> JobStatusEnum: - """Wait for a job to finish running.""" + """Wait for a job to finish running. + + When `with_rich_status_updates` is truthy, a Rich Live table is + rendered to stderr showing per-stream sync progress. Pass `True` + for 15-second polling, or an `int` for a custom interval in + seconds (minimum 15s -- values below 15 are clamped with a + warning). The rich polling interval replaces + `JOB_WAIT_INTERVAL_SECS` as the sole loop cadence. + + When `progress_log_path` is set (requires Rich updates enabled), + each polling iteration appends a JSONL line to the given file + with timestamped per-stream progress data for auditing. + """ + poll_interval: float = api_util.JOB_WAIT_INTERVAL_SECS + rich_enabled = bool(with_rich_status_updates) + + if rich_enabled: + poll_interval = _resolve_rich_interval( + with_rich_status_updates=with_rich_status_updates, + ) + + log_path: Path | None = Path(progress_log_path) if progress_log_path else None + start_time = time.time() + + if not rich_enabled: + return self._poll_until_complete( + start_time=start_time, + poll_interval=poll_interval, + wait_timeout=wait_timeout, + raise_timeout=raise_timeout, + raise_failure=raise_failure, + ) + + # Rich status updates path + console = Console(stderr=True) + live = RichLive(console=console, auto_refresh=False) + try: + live.start() + return self._poll_until_complete_with_rich( + live=live, + start_time=start_time, + poll_interval=poll_interval, + wait_timeout=wait_timeout, + raise_timeout=raise_timeout, + raise_failure=raise_failure, + progress_log_path=log_path, + ) + finally: + live.stop() + + # ------------------------------------------------------------------ + # Internal polling helpers + # ------------------------------------------------------------------ + + def _poll_until_complete( + self, + *, + start_time: float, + poll_interval: float, + wait_timeout: int, + raise_timeout: bool, + raise_failure: bool, + ) -> JobStatusEnum: + """Plain polling loop without Rich output.""" while True: latest_status = self.get_job_status() if latest_status in FINAL_STATUSES: if raise_failure: - # No-op if the job succeeded or is still running: self.raise_failure_status() - return latest_status if time.time() - start_time > wait_timeout: @@ -409,10 +787,165 @@ def wait_for_completion( job_status=latest_status, timeout=wait_timeout, ) + return latest_status - return latest_status # This will be a non-final status + time.sleep(poll_interval) + + def _poll_until_complete_with_rich( # noqa: PLR0912, PLR0914 + self, + *, + live: RichLive, + start_time: float, + poll_interval: float, + wait_timeout: int, + raise_timeout: bool, + raise_failure: bool, + progress_log_path: Path | None = None, + ) -> JobStatusEnum: + """Polling loop with Rich Live table showing per-stream progress.""" + previous_state: dict[str, Any] | None = self._pre_sync_state + previous_stream_stats: dict[str, dict[str, int]] = self._pre_sync_stream_stats or {} + catalog_data: dict[str, Any] | None = None + catalog_fetched = False + + # Track first-observed cursors as a fallback baseline when no + # previous sync state is available. Keys are (stream_name, namespace) + # and values are the raw cursor string captured on first sighting. + first_seen_cursors: dict[tuple[str, str | None], str] = {} + + # Total selected streams from catalog (resolved after first fetch). + catalog_stream_count: int | None = None + + while True: + latest_status = self.get_job_status() + + # Lazy-fetch catalog on first iteration + if not catalog_fetched: + catalog_data = api_util.get_connection_catalog( + connection_id=self.connection.connection_id, + api_root=self.workspace.api_root, + client_id=self.workspace.client_id, + client_secret=self.workspace.client_secret, + bearer_token=self.workspace.bearer_token, + ) + catalog_fetched = True + + # Resolve total *selected* streams from the catalog. + if catalog_data: + cat_streams = catalog_data.get("streams", []) + if not cat_streams and "syncCatalog" in catalog_data: + cat_streams = catalog_data["syncCatalog"].get("streams", []) + if cat_streams: + selected = [ + s for s in cat_streams if s.get("config", {}).get("selected", False) + ] + catalog_stream_count = len(selected) if selected else len(cat_streams) + + # Fetch current state and compute progress + state_data = api_util.get_connection_state( + connection_id=self.connection.connection_id, + api_root=self.workspace.api_root, + client_id=self.workspace.client_id, + client_secret=self.workspace.client_secret, + bearer_token=self.workspace.bearer_token, + ) + + # Record first-observed cursors for streams we haven't seen yet + _update_first_seen_cursors( + first_seen_cursors=first_seen_cursors, + state_data=state_data, + catalog_data=catalog_data, + ) + + sync_start_time_dt: datetime + try: + sync_start_time_dt = self.start_time + except (ValueError, TypeError): + sync_start_time_dt = datetime.now(timezone.utc) + + stream_progress = compute_stream_progress( + state_data=state_data, + catalog_data=catalog_data, + sync_start_time=sync_start_time_dt, + previous_state_data=previous_state, + first_seen_cursors=first_seen_cursors, + ) + + # Override progress to 100% for successful syncs. The formula + # compares cursors against `now`, so a source whose data stops + # before "today" (e.g. GA4 data through 2025-12-27 when today + # is 2026-04-04) would otherwise show <100% even after the job + # completes successfully. + if latest_status == JobStatusEnum.SUCCEEDED: + for entry in stream_progress: + if entry.get("progress_pct") is not None: + entry["progress_pct"] = 1.0 + + # Fetch live per-stream records/bytes from the Config API's + # `jobs/get_debug_info` endpoint. These counters update + # mid-sync and serve as the primary proof-of-life signal when + # cursor-based progress can't compute a percentage. + current_stream_stats: dict[str, dict[str, int]] = {} + try: + debug_info = api_util.get_job_debug_info( + job_id=self.job_id, + api_root=self.workspace.api_root, + client_id=self.workspace.client_id, + client_secret=self.workspace.client_secret, + bearer_token=self.workspace.bearer_token, + ) + except AirbyteError: + # Progress signal is best-effort; swallow API errors and + # render `--` rather than failing the sync polling loop. + debug_info = {} + if debug_info: + current_stream_stats = _extract_stream_stats(debug_info) + + elapsed = time.time() - start_time + job_info = self._latest_job_info + table = _build_rich_table( + stream_progress=stream_progress, + job_status=str(latest_status), + elapsed_secs=elapsed, + sync_start_time=sync_start_time_dt, + total_selected_streams=catalog_stream_count, + records_synced=(job_info.rows_synced or 0) if job_info else 0, + bytes_synced=(job_info.bytes_synced or 0) if job_info else 0, + stream_stats=current_stream_stats, + previous_stream_stats=previous_stream_stats, + ) + live.update(table, refresh=True) + + # Write JSONL progress log entry when a log path is configured. + if progress_log_path is not None: + _append_progress_log_entry( + log_path=progress_log_path, + elapsed_secs=elapsed, + job_status=str(latest_status), + records_synced=(job_info.rows_synced or 0) if job_info else 0, + bytes_synced=(job_info.bytes_synced or 0) if job_info else 0, + stream_progress=stream_progress, + current_stream_stats=current_stream_stats, + previous_stream_stats=previous_stream_stats, + ) + + if latest_status in FINAL_STATUSES: + if raise_failure: + self.raise_failure_status() + return latest_status + + if time.time() - start_time > wait_timeout: + if raise_timeout: + raise AirbyteConnectionSyncTimeoutError( + workspace=self.workspace, + connection_id=self.connection.connection_id, + job_id=self.job_id, + job_status=latest_status, + timeout=wait_timeout, + ) + return latest_status - time.sleep(api_util.JOB_WAIT_INTERVAL_SECS) + time.sleep(poll_interval) def get_sql_cache(self) -> CacheBase: """Return a SQL Cache object for working with the data in a SQL-based destination's.""" diff --git a/airbyte/mcp/cloud.py b/airbyte/mcp/cloud.py index 594bee0db..c0cabd579 100644 --- a/airbyte/mcp/cloud.py +++ b/airbyte/mcp/cloud.py @@ -11,6 +11,7 @@ from airbyte import cloud, get_destination, get_source from airbyte._util import api_util +from airbyte.cloud._sync_progress import compute_stream_progress from airbyte.cloud.connectors import CustomCloudSourceDefinition from airbyte.cloud.constants import FAILED_STATUSES from airbyte.cloud.workspaces import CloudOrganization, CloudWorkspace @@ -647,6 +648,19 @@ def get_cloud_sync_status( default=False, ), ], + with_stream_progress: Annotated[ + bool, + Field( + description=( + "Whether to include per-stream sync progress estimates. " + "When enabled, fetches current connection state and catalog to " + "compute an estimated progress percentage for each stream with " + "a datetime-based cursor. This adds latency from additional API " + "calls. Progress is approximate and advances at checkpoint intervals." + ), + default=False, + ), + ], ) -> dict[str, Any]: """Get the status of a sync job from the Airbyte Cloud.""" workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) @@ -682,6 +696,19 @@ def get_cloud_sync_status( for attempt in attempts ] + if with_stream_progress: + state_data = connection.dump_raw_state() + catalog_data = connection.dump_raw_catalog() + previous_state_data = connection.get_previous_sync_state( + current_job_id=sync_result.job_id, + ) + result["stream_progress"] = compute_stream_progress( + state_data=state_data, + catalog_data=catalog_data, + sync_start_time=sync_result.start_time, + previous_state_data=previous_state_data, + ) + return result diff --git a/airbyte/progress.py b/airbyte/progress.py index ad984f8a3..6777b26c5 100644 --- a/airbyte/progress.py +++ b/airbyte/progress.py @@ -28,6 +28,7 @@ from enum import Enum, auto from typing import IO, TYPE_CHECKING, Any, Literal, cast +import ulid from rich.console import Console from rich.errors import LiveError from rich.live import Live as RichLive @@ -40,6 +41,10 @@ ) from airbyte import logs +from airbyte._local_sync_progress import ( + compute_stream_progress_pct, + extract_cursor_from_state_message, +) from airbyte._message_iterators import _new_stream_success_message from airbyte._util import meta from airbyte._util.telemetry import ( @@ -47,16 +52,19 @@ EventType, send_telemetry, ) -from airbyte.logs import get_global_file_logger +from airbyte.logs import AIRBYTE_LOGGING_ROOT, get_global_file_logger if TYPE_CHECKING: import logging from collections.abc import Generator, Iterable + from pathlib import Path from types import ModuleType from structlog import BoundLogger + from airbyte_protocol.models import AirbyteStateMessage + from airbyte._message_iterators import AirbyteMessageIterator from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination @@ -120,6 +128,21 @@ def _to_time_str(timestamp: float) -> str: return datetime_obj.strftime("%H:%M:%S") +def _cursor_date_str(cursor_value: str | None) -> str | None: + """Format a cursor value as a concise `yyyy-mm-dd` date string. + + Returns `None` if `cursor_value` is falsy or cannot be parsed as a + datetime. ISO-8601 timestamps are truncated to their date portion. + """ + if not cursor_value: + return None + try: + parsed = datetime.datetime.fromisoformat(cursor_value.replace("Z", "+00:00")) + except (ValueError, AttributeError): + return None + return parsed.strftime("%Y-%m-%d") + + def _get_elapsed_time_str(seconds: float) -> str: """Return duration as a string. @@ -214,11 +237,27 @@ def __init__( self.destination_stream_records_delivered: dict[str, int] = defaultdict(int) self.destination_stream_records_confirmed: dict[str, int] = defaultdict(int) + # State-message cursor tracking (local sync progress). + # `baseline` is the first cursor observed from the source; `latest` is + # updated every time a STATE message is observed from the source. + # `committed` mirrors the cursor on state messages acknowledged by the + # destination. See `airbyte._local_sync_progress` for formula details. + self.stream_baseline_cursors: dict[str, str] = {} + self.stream_latest_cursors: dict[str, str] = {} + self.stream_committed_cursors: dict[str, str] = {} + self.stream_cursor_fields: dict[str, str] = {} + self.stream_latest_cursor_times: dict[str, float] = {} + self.stream_committed_cursor_times: dict[str, float] = {} + # Progress bar properties self._last_update_time: float | None = None self._stderr_console: Console | None = None self._rich_view: RichLive | None = None + # JSONL audit log path -- set lazily on first write. + self._progress_audit_log_path: Path | None = None + self._progress_audit_log_resolved: bool = False + self.reset_progress_style(style) def _print_info_message( @@ -256,7 +295,13 @@ def tally_records_read( *, auto_close_streams: bool = False, ) -> Generator[AirbyteMessage, Any, None]: - """This method simply tallies the number of records processed and yields the messages.""" + """This method simply tallies the number of records processed and yields the messages. + + STATE messages emitted by the source are observed here to extract + per-stream cursor values. The original messages are passed through + unchanged so downstream consumers (caches, destinations) still receive + them. + """ # Update the display before we start. self._log_sync_start() self._start_rich_view() @@ -268,6 +313,9 @@ def tally_records_read( # Yield the message immediately. yield message + if message.state: + self._observe_state_message(message.state, is_committed=False) + if message.record: # If this is the first record, set the start time. if self.first_record_received_time is None: @@ -331,6 +379,9 @@ def tally_pending_writes( # For now at least, we don't need to pay the cost of parsing it. continue + if message.state: + self._observe_state_message(message.state, is_committed=False) + if message.record and message.record.stream: self.destination_stream_records_delivered[message.record.stream] += 1 @@ -358,6 +409,8 @@ def tally_confirmed_writes( self._start_rich_view() # Start Rich's live view if not already running. for message in messages: if message.state: + # Observe the destination-acknowledged cursor before anything else. + self._observe_state_message(message.state, is_committed=True) # This is a state message from the destination. Tally the records written. if message.state.stream and message.state.destinationStats: stream_name = message.state.stream.stream_descriptor.name @@ -377,6 +430,165 @@ def tally_bytes_read(self, bytes_read: int, stream_name: str) -> None: """ self.stream_bytes_read[stream_name] += bytes_read + # Local-sync state observation + + def _observe_state_message( + self, + state_message: AirbyteStateMessage, + *, + is_committed: bool, + ) -> None: + """Observe a STATE message flowing through the local pipeline. + + When `is_committed` is `False`, the message came directly from the + source and advances the `latest` cursor. When `True`, the message was + acknowledged by the destination and advances the `committed` cursor. + + Only per-stream (`STREAM`-type) state messages contribute a + cursor. `GLOBAL` and `LEGACY` state messages are still passed through + the pipeline but are skipped here because they have no single + per-stream cursor value to record. + """ + stream_name, cursor_field, cursor_value = extract_cursor_from_state_message(state_message) + if stream_name is None or cursor_value is None: + return + + now = time.time() + if cursor_field and stream_name not in self.stream_cursor_fields: + self.stream_cursor_fields[stream_name] = cursor_field + + if is_committed: + self.stream_committed_cursors[stream_name] = cursor_value + self.stream_committed_cursor_times[stream_name] = now + return + + # Source-side: record baseline the first time we see a cursor. + if stream_name not in self.stream_baseline_cursors: + self.stream_baseline_cursors[stream_name] = cursor_value + self.stream_latest_cursors[stream_name] = cursor_value + self.stream_latest_cursor_times[stream_name] = now + + @property + def stream_progress_pcts(self) -> dict[str, float]: + """Per-stream datetime-cursor progress estimates. + + Returns values in `[0.0, 1.0]`. A stream is reported as `1.0` once the + source has emitted a `STREAM_STATUS=COMPLETE` trace (or the sync has + otherwise been finalized via `log_success()`); otherwise the value is + derived from the datetime-cursor formula and only streams with a + datetime-parseable baseline + latest cursor are included. + """ + now = datetime.datetime.now(datetime.timezone.utc) + result: dict[str, float] = {} + for stream_name, latest in self.stream_latest_cursors.items(): + baseline = self.stream_baseline_cursors.get(stream_name) + pct = compute_stream_progress_pct( + baseline_cursor=baseline, + latest_cursor=latest, + now=now, + ) + if pct is not None: + result[stream_name] = pct + for stream_name in self.stream_read_end_times: + result[stream_name] = 1.0 + return result + + @property + def stream_write_progress_pcts(self) -> dict[str, float]: + """Per-stream datetime-cursor write-progress estimates. + + Uses the same formula as `stream_progress_pcts` but with the + destination-acknowledged (committed) cursor instead of the latest + source-emitted cursor. Only populated for streams with a + datetime-parseable baseline + committed cursor. + """ + now = datetime.datetime.now(datetime.timezone.utc) + result: dict[str, float] = {} + for stream_name, committed in self.stream_committed_cursors.items(): + baseline = self.stream_baseline_cursors.get(stream_name) + pct = compute_stream_progress_pct( + baseline_cursor=baseline, + latest_cursor=committed, + now=now, + ) + if pct is not None: + result[stream_name] = pct + return result + + def _build_progress_snapshot(self) -> dict[str, Any]: + """Build a point-in-time progress snapshot for JSONL audit logging. + + Includes per-stream record counts, cursor fields/values, and + computed progress percentages when available. + """ + now_dt = datetime.datetime.now(datetime.timezone.utc) + progress_pcts = self.stream_progress_pcts + write_progress_pcts = self.stream_write_progress_pcts + stream_names = ( + set(self.stream_read_counts) + | set(self.written_stream_names) + | set(self.stream_latest_cursors) + | set(self.stream_committed_cursors) + ) + streams: list[dict[str, Any]] = [ + { + "stream_name": stream_name, + "records_read": self.stream_read_counts.get(stream_name, 0), + "records_delivered": self.destination_stream_records_delivered.get(stream_name, 0), + "records_confirmed": self.destination_stream_records_confirmed.get(stream_name, 0), + "cursor_field": self.stream_cursor_fields.get(stream_name), + "baseline_cursor": self.stream_baseline_cursors.get(stream_name), + "latest_cursor": self.stream_latest_cursors.get(stream_name), + "committed_cursor": self.stream_committed_cursors.get(stream_name), + "progress_pct": progress_pcts.get(stream_name), + "write_progress_pct": write_progress_pcts.get(stream_name), + } + for stream_name in sorted(stream_names) + ] + return { + "timestamp": now_dt.isoformat(), + "elapsed_secs": round(self.elapsed_seconds, 3), + "total_records_read": self.total_records_read, + "total_records_written": self.total_records_written, + "total_records_delivered": self.total_destination_records_delivered, + "total_records_confirmed": self.total_destination_records_confirmed, + "streams": streams, + } + + def _get_progress_audit_log_path(self) -> Path | None: + """Resolve the JSONL audit log path for this tracker, once per instance.""" + if self._progress_audit_log_resolved: + return self._progress_audit_log_path + + self._progress_audit_log_resolved = True + if AIRBYTE_LOGGING_ROOT is None: + return None + + yyyy_mm_dd = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") + folder = AIRBYTE_LOGGING_ROOT / yyyy_mm_dd + try: + folder.mkdir(parents=True, exist_ok=True) + except OSError: + return None + + self._progress_audit_log_path = folder / f"progress-{ulid.ULID()}.jsonl" + return self._progress_audit_log_path + + def _append_progress_audit_log(self) -> None: + """Append a single JSONL snapshot to the audit log, if available.""" + path = self._get_progress_audit_log_path() + if path is None: + return + + snapshot = self._build_progress_snapshot() + try: + with path.open("a", encoding="utf-8") as fh: + fh.write(json.dumps(snapshot, default=str)) + fh.write("\n") + except OSError: + # Audit logging is best-effort; do not interfere with the sync. + return + # Logging methods @property @@ -608,7 +820,20 @@ def log_success( self.end_time = time.time() + now = time.time() + all_known_streams = ( + set(self.stream_read_counts) + | set(self.stream_latest_cursors) + | set(self.stream_committed_cursors) + | set(self.written_stream_names) + ) + for stream_name in all_known_streams: + self.stream_read_end_times.setdefault(stream_name, now) + self._update_display(force_refresh=True) + audit_path = self._get_progress_audit_log_path() + if audit_path is not None: + self._print_info_message(f"Progress audit log written to `{audit_path}`.") self._stop_rich_view() streams = list(self.stream_read_start_times.keys()) if not streams: @@ -862,9 +1087,10 @@ def _update_display(self, *, force_refresh: bool = False) -> None: elif self.style in {ProgressStyle.PLAIN, ProgressStyle.NONE}: pass + self._append_progress_audit_log() self._last_update_time = time.time() - def _get_status_message(self) -> str: + def _get_status_message(self) -> str: # noqa: PLR0912, PLR0914, PLR0915 """Compile and return a status message.""" # Format start time as a friendly string in local timezone: start_time_str = _to_time_str(self.read_start_time) @@ -893,23 +1119,53 @@ def join_streams_strings(streams_list: list[str]) -> str: f"({records_per_second:,.1f} records/s{mb_per_second_str}).\n\n" ) - if self.stream_read_counts: - status_message += ( - f"- Received records for {len(self.stream_read_counts)}" - + ( - f" out of {self.num_streams_expected} expected" - if self.num_streams_expected - else "" - ) - + " streams:\n - " - + join_streams_strings( - [ - f"{self.stream_read_counts[stream_name]:,} {stream_name}" - for stream_name in self.stream_read_counts - ] - ) - + "\n\n" + # Consolidated per-stream status: records + read/write progress per line. + read_progress_pcts = self.stream_progress_pcts + write_progress_pcts = self.stream_write_progress_pcts + per_stream_names = sorted( + set(self.stream_read_counts) + | set(self.stream_latest_cursors) + | set(self.stream_read_end_times) + | set(self.stream_committed_cursors) + ) + if per_stream_names: + header = ( + f"- Streams ({len(self.stream_read_counts)}" + + (f" of {self.num_streams_expected} expected" if self.num_streams_expected else "") + + "):\n" ) + lines: list[str] = [] + for stream_name in per_stream_names: + done_marker = "✓ " if stream_name in self.stream_read_end_times else "" + records = self.stream_read_counts.get(stream_name, 0) + baseline = self.stream_baseline_cursors.get(stream_name) + latest = self.stream_latest_cursors.get(stream_name) + committed = self.stream_committed_cursors.get(stream_name) + read_pct = read_progress_pcts.get(stream_name) + write_pct = write_progress_pcts.get(stream_name) + + baseline_date = _cursor_date_str(baseline) + latest_date = _cursor_date_str(latest) + committed_date = _cursor_date_str(committed) + + if read_pct is not None and baseline_date and latest_date: + read_frag = ( + f"Read Progress: {read_pct * 100:.1f}% " + f"(dates: `{baseline_date}`-`{latest_date}`)" + ) + elif stream_name in self.stream_read_end_times: + read_frag = "Read Progress: 100.0% ✓" + else: + read_frag = "Progress: n/a" + + segments = [f"{records:,} records", read_frag] + if write_pct is not None and baseline_date and committed_date: + segments.append( + f"Write Progress: {write_pct * 100:.1f}% " + f"(dates: `{baseline_date}`-`{committed_date}`)" + ) + lines.append(f" - `{stream_name}`: {done_marker}" + " | ".join(segments)) + status_message += header + "\n".join(lines) + "\n\n" # Source cache writes if self.total_records_written > 0: diff --git a/examples/run_cloud_sync_with_rich_progress.py b/examples/run_cloud_sync_with_rich_progress.py new file mode 100644 index 000000000..61e14c557 --- /dev/null +++ b/examples/run_cloud_sync_with_rich_progress.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Run a cloud sync with Rich Live progress tracking. + +Demonstrates the `with_rich_status_updates` feature, which renders a +real-time Rich table to stderr showing per-stream sync progress while +waiting for a Cloud sync to complete. + +Prerequisites: + - An Airbyte Cloud API key exported as ``AIRBYTE_CLOUD_API_KEY``. + - The workspace and connection IDs below must be accessible with + that key. The defaults point to the ``@devin-ai-sandbox`` + workspace (Google Analytics 4 connection). + +Usage (from the PyAirbyte root directory): + uv run python examples/run_cloud_sync_with_rich_progress.py + +You can also pass a custom refresh interval (in seconds) instead of +``True`` to control how often the table refreshes: + + connection.run_sync(with_rich_status_updates=30) # refresh every 30s +""" + +from __future__ import annotations + +import airbyte as ab +from airbyte.cloud import CloudWorkspace + +# --------------------------------------------------------------------------- +# Configuration – hard-coded to the @devin-ai-sandbox workspace +# --------------------------------------------------------------------------- + +WORKSPACE_ID = "266ebdfe-0d7b-4540-9817-de7e4505ba61" +CONNECTION_ID = "d9c752fe-515a-4066-9234-096b101ea16e" # GA4 -> dev-null + + +def main() -> None: + """Trigger a Cloud sync and display a Rich Live progress table.""" + workspace = CloudWorkspace( + workspace_id=WORKSPACE_ID, + client_id=ab.get_secret("AIRBYTE_CLOUD_CLIENT_ID"), + client_secret=ab.get_secret("AIRBYTE_CLOUD_CLIENT_SECRET"), + # api_key=ab.get_secret("AIRBYTE_CLOUD_API_KEY"), + ) + + connection = workspace.get_connection(connection_id=CONNECTION_ID) + print(f"Connection ID: {connection.connection_id}") + print(f"Connection URL: {connection.connection_url}") + print(f"Streams: {connection.stream_names}") + print() + print("Starting sync with Rich status updates (15 s refresh) ...") + print("(The Rich table renders to stderr; watch your terminal.)") + print() + + sync_result = connection.run_sync( + wait_timeout=60 * 60, # 1 hour + with_rich_status_updates=True, # 15s default refresh + progress_log_path="sync_progress.jsonl", # audit log + ) + + print() + print(f"Sync complete! Status: {sync_result.get_job_status()}") + print(f" Records synced : {sync_result.records_synced:,}") + print(f" Bytes synced : {sync_result.bytes_synced:,}") + print(f" Job URL : {sync_result.job_url}") + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/test_local_sync_progress.py b/tests/unit_tests/test_local_sync_progress.py new file mode 100644 index 000000000..481ef9d01 --- /dev/null +++ b/tests/unit_tests/test_local_sync_progress.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Tests for local-sync state-message observation in `ProgressTracker`.""" + +from __future__ import annotations + +import json + +import pytest +from airbyte._local_sync_progress import ( + _try_parse_datetime_cursor, + compute_stream_progress_pct, + extract_cursor_from_state_message, +) +from airbyte.progress import ProgressStyle, ProgressTracker +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + AirbyteStreamStatus, + AirbyteStreamStatusTraceMessage, + AirbyteTraceMessage, + StreamDescriptor, + TraceType, + Type, +) + + +def _stream_complete_msg(stream: str) -> AirbyteMessage: + return AirbyteMessage( + type=Type.TRACE, + trace=AirbyteTraceMessage( + type=TraceType.STREAM_STATUS, + emitted_at=0.0, + stream_status=AirbyteStreamStatusTraceMessage( + stream_descriptor=StreamDescriptor(name=stream), + status=AirbyteStreamStatus.COMPLETE, + ), + ), + ) + + +def _state_msg( + stream: str, cursor_value: str, cursor_field: str = "updatedAt" +) -> AirbyteMessage: + return AirbyteMessage( + type=Type.STATE, + state=AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=stream), + stream_state={cursor_field: cursor_value}, + ), + ), + ) + + +def test_try_parse_datetime_cursor_rejects_numeric_strings() -> None: + assert _try_parse_datetime_cursor("12345") is None + assert _try_parse_datetime_cursor("2024-01-01T00:00:00Z") is not None + + +def test_extract_cursor_from_state_message_prefers_known_fields() -> None: + msg = _state_msg("contacts", "2024-06-15T10:30:00Z", cursor_field="updatedAt") + name, field, value = extract_cursor_from_state_message(msg.state) + assert name == "contacts" + assert field == "updatedAt" + assert value == "2024-06-15T10:30:00Z" + + +def test_extract_cursor_from_nested_partition_state() -> None: + """`source-github`-style per-partition state: `{repo: {updated_at: ...}}`.""" + msg = AirbyteMessage( + type=Type.STATE, + state=AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="issues"), + stream_state={ + "airbytehq/airbyte": {"updated_at": "2026-04-23T22:20:47Z"}, + }, + ), + ), + ) + name, field, value = extract_cursor_from_state_message(msg.state) + assert name == "issues" + assert field == "updated_at" + assert value == "2026-04-23T22:20:47Z" + + +def test_extract_cursor_from_deeply_nested_state() -> None: + msg = AirbyteMessage( + type=Type.STATE, + state=AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="commits"), + stream_state={ + "airbytehq/airbyte": { + "main": {"created_at": "2026-04-23T22:20:47Z"}, + }, + }, + ), + ), + ) + name, field, value = extract_cursor_from_state_message(msg.state) + assert name == "commits" + assert field == "created_at" + assert value == "2026-04-23T22:20:47Z" + + +def test_extract_cursor_falls_back_to_first_datetime_value() -> None: + msg = AirbyteMessage( + type=Type.STATE, + state=AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name="events"), + stream_state={"unknown_field": "2024-01-01T00:00:00Z"}, + ), + ), + ) + name, field, value = extract_cursor_from_state_message(msg.state) + assert name == "events" + assert field == "unknown_field" + assert value == "2024-01-01T00:00:00Z" + + +def test_compute_stream_progress_pct_basic() -> None: + pct = compute_stream_progress_pct( + baseline_cursor="2024-01-01T00:00:00Z", + latest_cursor="2024-07-01T00:00:00Z", + now=None, + ) + # We're past 2024, so from Jan 1 -> Jul 1 -> today should be a valid [0, 1] + assert pct is None or 0.0 <= pct <= 1.0 + + +def test_compute_stream_progress_pct_returns_none_for_non_datetime() -> None: + assert ( + compute_stream_progress_pct( + baseline_cursor="not-a-date", + latest_cursor="also-not-a-date", + ) + is None + ) + + +def test_tally_records_read_observes_source_state() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + messages = [ + _state_msg("contacts", "2024-01-01T00:00:00Z"), + _state_msg("contacts", "2024-06-15T10:30:00Z"), + _state_msg("companies", "2024-03-15T00:00:00Z"), + ] + list(tracker.tally_records_read(iter(messages))) + + assert tracker.stream_baseline_cursors == { + "contacts": "2024-01-01T00:00:00Z", + "companies": "2024-03-15T00:00:00Z", + } + assert tracker.stream_latest_cursors == { + "contacts": "2024-06-15T10:30:00Z", + "companies": "2024-03-15T00:00:00Z", + } + assert tracker.stream_cursor_fields == { + "contacts": "updatedAt", + "companies": "updatedAt", + } + + +def test_tally_confirmed_writes_observes_committed_state() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + messages = [_state_msg("contacts", "2024-06-15T10:30:00Z")] + list(tracker.tally_confirmed_writes(iter(messages))) + + assert tracker.stream_committed_cursors == {"contacts": "2024-06-15T10:30:00Z"} + # Source-side fields should NOT be populated from committed state alone. + assert tracker.stream_baseline_cursors == {} + assert tracker.stream_latest_cursors == {} + + +def test_tally_pending_writes_observes_source_state() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + messages = [_state_msg("contacts", "2024-02-01T00:00:00Z")] + list(tracker.tally_pending_writes(iter(messages))) + + assert tracker.stream_latest_cursors == {"contacts": "2024-02-01T00:00:00Z"} + + +def test_build_progress_snapshot_shape() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + list( + tracker.tally_records_read( + iter([ + _state_msg("contacts", "2024-01-01T00:00:00Z"), + _state_msg("contacts", "2024-06-15T10:30:00Z"), + ]) + ) + ) + snapshot = tracker._build_progress_snapshot() # noqa: SLF001 + assert "timestamp" in snapshot + assert "streams" in snapshot + contacts = next(s for s in snapshot["streams"] if s["stream_name"] == "contacts") + assert contacts["baseline_cursor"] == "2024-01-01T00:00:00Z" + assert contacts["latest_cursor"] == "2024-06-15T10:30:00Z" + assert contacts["cursor_field"] == "updatedAt" + + +def test_progress_audit_log_written(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("AIRBYTE_LOGGING_ROOT", str(tmp_path)) + # Override the module-level AIRBYTE_LOGGING_ROOT (resolved at import time). + from airbyte import progress as progress_module + + monkeypatch.setattr(progress_module, "AIRBYTE_LOGGING_ROOT", tmp_path) + + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + list( + tracker.tally_records_read( + iter([ + _state_msg("contacts", "2024-01-01T00:00:00Z"), + _state_msg("contacts", "2024-06-15T10:30:00Z"), + ]) + ) + ) + tracker._update_display(force_refresh=True) # noqa: SLF001 + + audit_path = tracker._get_progress_audit_log_path() # noqa: SLF001 + assert audit_path is not None + assert audit_path.exists() + + lines = audit_path.read_text().strip().splitlines() + assert lines, "expected at least one JSONL line" + last = json.loads(lines[-1]) + assert "streams" in last + contacts = next(s for s in last["streams"] if s["stream_name"] == "contacts") + assert contacts["latest_cursor"] == "2024-06-15T10:30:00Z" + + +def test_stream_complete_trace_forces_progress_to_100pct() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + list( + tracker.tally_records_read( + iter([ + _state_msg("contacts", "2024-01-01T00:00:00Z"), + _state_msg("contacts", "2024-06-15T10:30:00Z"), + _stream_complete_msg("contacts"), + ]) + ) + ) + assert "contacts" in tracker.stream_read_end_times + assert tracker.stream_progress_pcts["contacts"] == 1.0 + + +def test_log_success_forces_all_streams_to_100pct() -> None: + tracker = ProgressTracker( + ProgressStyle.NONE, + source=None, + cache=None, + destination=None, + ) + list( + tracker.tally_records_read( + iter([ + _state_msg("contacts", "2024-01-01T00:00:00Z"), + _state_msg("companies", "2024-03-15T00:00:00Z"), + ]) + ) + ) + tracker.log_success() + assert tracker.stream_progress_pcts == { + "contacts": 1.0, + "companies": 1.0, + } + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])