diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py new file mode 100644 index 00000000..5702d442 --- /dev/null +++ b/dagshub/common/adaptive_batching.py @@ -0,0 +1,318 @@ +import itertools +import logging +import math +import time +from dataclasses import dataclass +from typing import Callable, Iterable, List, Optional, Sized, Tuple, TypeVar + +import rich.progress + +import dagshub.common.config as dgs_config +from dagshub.common.rich_util import get_rich_progress + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +MIN_TARGET_BATCH_TIME_SECONDS = 0.01 +SOFT_UPPER_LIMIT_MIN_STEP_FRACTION = 0.05 +SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES = 3 + +# Overall strategy: +# - Grow aggressively on fast successes until we hit a slow or failing batch. +# - A slow or failing batch becomes last_bad_batch_size, and a fast batch becomes +# last_fast_batch_size. +# - last_bad_batch_size acts as a soft_upper_limit while the gap to +# last_fast_batch_size is still meaningful, and we probe within that range. +# - Once that gap is below the search resolution, hold the current fast batch size +# instead of micro-searching. +# - Several consecutive fast batches near last_bad_batch_size trigger one more +# probe at soft_upper_limit, since the earlier failure may have been transient. + + +@dataclass +class AdaptiveBatchConfig: + max_batch_size: int + min_batch_size: int + initial_batch_size: int + target_batch_time_seconds: float + batch_growth_factor: int + retry_backoff_base_seconds: float + retry_backoff_max_seconds: float + + @classmethod + def from_values( + cls, + max_batch_size: Optional[int] = None, + min_batch_size: Optional[int] = None, + initial_batch_size: Optional[int] = None, + target_batch_time_seconds: Optional[float] = None, + batch_growth_factor: Optional[int] = None, + retry_backoff_base_seconds: Optional[float] = None, + retry_backoff_max_seconds: Optional[float] = None, + ) -> "AdaptiveBatchConfig": + if max_batch_size is None: + max_batch_size = dgs_config.dataengine_metadata_upload_batch_size_max + if min_batch_size is None: + min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min + if initial_batch_size is None: + initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial + if target_batch_time_seconds is None: + target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds + if batch_growth_factor is None: + batch_growth_factor = dgs_config.adaptive_batch_growth_factor + if retry_backoff_base_seconds is None: + retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds + if retry_backoff_max_seconds is None: + retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds + + normalized_max_batch_size = max(1, max_batch_size) + normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) + normalized_initial_batch_size = max( + normalized_min_batch_size, + min(initial_batch_size, normalized_max_batch_size), + ) + normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) + return cls( + max_batch_size=normalized_max_batch_size, + min_batch_size=normalized_min_batch_size, + initial_batch_size=normalized_initial_batch_size, + target_batch_time_seconds=normalized_target_batch_time_seconds, + batch_growth_factor=max(2, batch_growth_factor), + retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds), + retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds), + ) + + +def _clamp(value: int, lo: int, hi: int) -> int: + return max(lo, min(hi, value)) + + +def _next_batch_after_success( + batch_size: int, + config: AdaptiveBatchConfig, + soft_upper_limit: Optional[int], +) -> int: + """Pick the next batch size after a fast successful batch. + + Strategy: + - If we have a previous slow/failing size, binary-search toward it as a soft upper hint. + - Otherwise, multiply by the growth factor. + """ + if soft_upper_limit is not None and batch_size < soft_upper_limit: + # Binary search: try the midpoint between current and the soft upper limit. + candidate = (batch_size + soft_upper_limit) // 2 + else: + # No upper hint (or we've already reached it): grow aggressively. + candidate = batch_size * config.batch_growth_factor + + return _clamp(candidate, config.min_batch_size, config.max_batch_size) + + +def _next_batch_after_retryable_failure( + batch_size: int, + config: AdaptiveBatchConfig, + last_fast_batch_size: Optional[int], + soft_upper_limit: Optional[int], +) -> int: + """Pick the next batch size after a failed or slow batch. + + Strategy: + - If we have a known-good lower bound, binary-search between it and the + failing size. + - Otherwise, probe the midpoint between config.min_batch_size and the + largest allowed size below the failing batch. + - Must be strictly less than the current size (so we converge downward). + """ + if batch_size <= config.min_batch_size: + return config.min_batch_size + + ceiling = batch_size - 1 # must shrink + if soft_upper_limit is not None: + ceiling = min(ceiling, soft_upper_limit - 1) + + if last_fast_batch_size is not None and last_fast_batch_size < ceiling: + # Binary search: try the midpoint between good and failing + candidate = (last_fast_batch_size + ceiling) // 2 + else: + # No good lower bound — probe midpoint of the valid range + candidate = (config.min_batch_size + ceiling) // 2 + + return _clamp(candidate, config.min_batch_size, ceiling) + + +def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: + if config.retry_backoff_base_seconds <= 0.0 or config.retry_backoff_max_seconds <= 0.0: + return 0.0 + + attempt_number = max(1, consecutive_retryable_failures) + delay = config.retry_backoff_base_seconds * (2 ** (attempt_number - 1)) + return min(delay, config.retry_backoff_max_seconds) + + +def _min_step_size(soft_upper_limit: int) -> int: + return max(1, math.ceil(soft_upper_limit * SOFT_UPPER_LIMIT_MIN_STEP_FRACTION)) + + +def _is_next_step_above_limit(batch_size: int, soft_upper_limit: Optional[int]) -> bool: + if soft_upper_limit is None or batch_size >= soft_upper_limit: + return False + + return soft_upper_limit - batch_size <= _min_step_size(soft_upper_limit) + + +def _update_bounds_after_bad_batch( + batch_size: int, + last_fast_batch_size: Optional[int], + last_bad_batch_size: Optional[int], +) -> Tuple[Optional[int], int]: + updated_last_bad_batch_size = batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + if last_fast_batch_size is not None and last_fast_batch_size >= updated_last_bad_batch_size: + last_fast_batch_size = None + return last_fast_batch_size, updated_last_bad_batch_size + + +class AdaptiveBatcher: + """Sends items in adaptively-sized batches, growing on success and shrinking on failure.""" + + def __init__( + self, + is_retryable: Callable[[Exception], bool], + config: Optional[AdaptiveBatchConfig] = None, + progress_label: str = "Uploading", + ): + self._config = config if config is not None else AdaptiveBatchConfig.from_values() + self._is_retryable = is_retryable + self._progress_label = progress_label + + def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: + total: Optional[int] = len(items) if isinstance(items, Sized) else None + if total == 0: + return + + config = self._config + desired_batch_size = config.initial_batch_size + # Consume the source iterable incrementally across retries and successes. + it = iter(items) + pending: List[T] = [] + + progress = get_rich_progress(rich.progress.MofNCompleteColumn()) + total_task = progress.add_task(f"{self._progress_label}...", total=total) + + last_fast_batch_size: Optional[int] = None + last_bad_batch_size: Optional[int] = None + consecutive_retryable_failures = 0 + consecutive_fast_successes_near_upper_limit = 0 + processed = 0 + + with progress: + while True: + # Draw from pending (failed-batch leftovers) first, then the source iterator + batch = pending[:desired_batch_size] + pending = pending[desired_batch_size:] + if len(batch) < desired_batch_size: + batch.extend(itertools.islice(it, desired_batch_size - len(batch))) + if not batch: + break + actual_batch_size = len(batch) + + progress.update(total_task, description=f"{self._progress_label} (batch size: {actual_batch_size})...") + logger.debug(f"{self._progress_label}: {actual_batch_size} entries...") + + start_time = time.monotonic() + try: + operation(batch) + except Exception as exc: + if not self._is_retryable(exc): + logger.error( + f"{self._progress_label} failed with a non-retryable error; aborting.", + exc_info=True, + ) + raise + + is_short_tail_batch = ( + actual_batch_size <= config.min_batch_size and actual_batch_size < desired_batch_size + ) + if not is_short_tail_batch and actual_batch_size <= config.min_batch_size: + logger.error( + f"{self._progress_label} failed at minimum batch size ({actual_batch_size}); aborting.", + exc_info=True, + ) + raise + + consecutive_fast_successes_near_upper_limit = 0 + + # Exponential backoff + consecutive_retryable_failures += 1 + time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config)) + + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( + actual_batch_size, last_fast_batch_size, last_bad_batch_size + ) + if is_short_tail_batch: + # A naturally short tail batch cannot be shrunk further in a useful way. + # Retry that exact size once before treating it as exhausted. + desired_batch_size = actual_batch_size + else: + # Binary search downwards + desired_batch_size = _next_batch_after_retryable_failure( + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size + ) + logger.warning( + f"{self._progress_label} failed for batch size {actual_batch_size} " + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {desired_batch_size}." + ) + # Re-queue the failed batch items for retry with smaller batch size + pending = batch + pending + continue + + # On success. + elapsed = time.monotonic() - start_time + consecutive_retryable_failures = 0 + processed += actual_batch_size + progress.update(total_task, advance=actual_batch_size) + + if elapsed <= config.target_batch_time_seconds: + if last_fast_batch_size is None or actual_batch_size > last_fast_batch_size: + last_fast_batch_size = actual_batch_size + if last_bad_batch_size is not None and actual_batch_size >= last_bad_batch_size: + # A fast success at the upper limit means the last_bad_batch_size is stale. + # We can resume unconstrained growth. + last_bad_batch_size = None + consecutive_fast_successes_near_upper_limit = 0 + desired_batch_size = _next_batch_after_success( + actual_batch_size, config, last_bad_batch_size + ) + elif _is_next_step_above_limit(actual_batch_size, last_bad_batch_size): + # Once the gap is smaller than our useful search resolution, + # hold the current known-good size and only re-probe the hint + # after a few stable fast successes. + consecutive_fast_successes_near_upper_limit += 1 + if consecutive_fast_successes_near_upper_limit >= SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES: + # We've had enough stable fast successes to re-probe the last_bad_batch_size. + desired_batch_size = last_bad_batch_size + consecutive_fast_successes_near_upper_limit = 0 + else: + # Hold current size for one more iteration + desired_batch_size = actual_batch_size + else: + # Binary search or unconstrained growth upwards + consecutive_fast_successes_near_upper_limit = 0 + desired_batch_size = _next_batch_after_success( + actual_batch_size, config, last_bad_batch_size + ) + else: + # Binary search downwards due to a slow batch + consecutive_fast_successes_near_upper_limit = 0 + logger.debug( + f"{self._progress_label} batch size {actual_batch_size} took {elapsed:.2f}s " + f"(target {config.target_batch_time_seconds:.2f}s); shrinking." + ) + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( + actual_batch_size, last_fast_batch_size, last_bad_batch_size + ) + desired_batch_size = _next_batch_after_retryable_failure( + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size + ) + + progress.update(total_task, completed=processed, total=processed, refresh=True) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index e82b0063..2a9d9bd4 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -1,11 +1,12 @@ import logging - -import appdirs import os from urllib.parse import urlparse -from dagshub import __version__ + +import appdirs from httpx._client import USER_AGENT +from dagshub import __version__ + logger = logging.getLogger(__name__) HOST_KEY = "DAGSHUB_CLIENT_HOST" @@ -58,7 +59,39 @@ def set_host(new_host: str): recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5)) DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" -dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000)) +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" +DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX = 50000 +# Fall back to the old `DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE` env var for backwards compatibility. +dataengine_metadata_upload_batch_size_max = int( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX), + ) +) + +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" +dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1)) + +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL" +dataengine_metadata_upload_batch_size_initial = int( + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min) +) + +DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS" +dataengine_metadata_upload_target_batch_time_seconds = float( + os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0) +) + +DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR" +adaptive_batch_growth_factor = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY, 10)) + +DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_BASE" +adaptive_batch_retry_backoff_base_seconds = float( + os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY, 0.25) +) + +DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_MAX" +adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, 60.0)) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/client/data_client.py b/dagshub/data_engine/client/data_client.py index d5c61e2f..fa8bacb1 100644 --- a/dagshub/data_engine/client/data_client.py +++ b/dagshub/data_engine/client/data_client.py @@ -1,45 +1,43 @@ import datetime import logging -from typing import Any, Optional, List, Dict, Union, TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import dacite import gql import rich.progress -from gql.transport.exceptions import TransportQueryError, TransportServerError +from gql.transport.exceptions import TransportError from gql.transport.requests import RequestsHTTPTransport import dagshub.auth import dagshub.common.config from dagshub.common import config -from dagshub.common.tracing import build_traceparent from dagshub.common.analytics import send_analytics_event from dagshub.common.rich_util import get_rich_progress +from dagshub.common.tracing import build_traceparent from dagshub.data_engine.client.gql_introspections import GqlIntrospections, TypesIntrospection +from dagshub.data_engine.client.gql_mutations import GqlMutations +from dagshub.data_engine.client.gql_queries import GqlQueries from dagshub.data_engine.client.models import ( - DatasourceResult, + DatapointHistoryResult, DatasetResult, + DatasourceResult, MetadataFieldSchema, - DatapointHistoryResult, + ScanOption, ) -from dagshub.data_engine.client.models import ScanOption -from dagshub.data_engine.client.gql_mutations import GqlMutations -from dagshub.data_engine.client.gql_queries import GqlQueries from dagshub.data_engine.client.query_builder import GqlQuery from dagshub.data_engine.model.errors import DataEngineGqlError from dagshub.data_engine.model.query_result import QueryResult - - from dagshub.data_engine.model.schema_util import dacite_config if TYPE_CHECKING: from dagshub.data_engine.datasources import DatasourceState + from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource import ( - Datasource, - DatapointMetadataUpdateEntry, - DatapointDeleteMetadataEntry, DatapointDeleteEntry, + DatapointDeleteMetadataEntry, + DatapointMetadataUpdateEntry, + Datasource, ) - from dagshub.data_engine.model.datapoint import Datapoint logger = logging.getLogger(__name__) @@ -191,8 +189,8 @@ def _exec( traceparent = build_traceparent() headers["traceparent"] = traceparent try: - resp = self.client.execute(q, variable_values=params, extra_args={'headers': headers}) - except (TransportQueryError, TransportServerError) as e: + resp = self.client.execute(q, variable_values=params, extra_args={"headers": headers}) + except TransportError as e: support_id = self.client.transport.response_headers.get("X-DagsHub-Support-Id") if support_id is None: support_id = traceparent diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 255bb76d..94f78522 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Union -import rich.progress from dataclasses_json import DataClassJsonMixin, LetterCase, config from pathvalidate import sanitize_filepath @@ -53,6 +52,8 @@ run_preupload_transforms, validate_uploading_metadata, ) +from dagshub.common.adaptive_batching import AdaptiveBatcher +from dagshub.data_engine.model.metadata.util import is_retryable_metadata_upload_error from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder @@ -753,19 +754,11 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry]) validate_uploading_metadata(precalculated_info) run_preupload_transforms(self, metadata_entries, precalculated_info) - progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - - upload_batch_size = dagshub.common.config.dataengine_metadata_upload_batch_size - total_entries = len(metadata_entries) - total_task = progress.add_task(f"Uploading metadata (batch size {upload_batch_size})...", total=total_entries) - - with progress: - for start in range(0, total_entries, upload_batch_size): - entries = metadata_entries[start : start + upload_batch_size] - logger.debug(f"Uploading {len(entries)} metadata entries...") - self.source.client.update_metadata(self, entries) - progress.update(total_task, advance=upload_batch_size) - progress.update(total_task, completed=total_entries, refresh=True) + batcher = AdaptiveBatcher( + is_retryable=is_retryable_metadata_upload_error, + progress_label="Uploading metadata", + ) + batcher.run(metadata_entries, lambda batch: self.source.client.update_metadata(self, batch)) # Update the status from dagshub, so we get back the new metadata columns self.source.get_from_dagshub() diff --git a/dagshub/data_engine/model/metadata/util.py b/dagshub/data_engine/model/metadata/util.py index ff660324..c24bb23b 100644 --- a/dagshub/data_engine/model/metadata/util.py +++ b/dagshub/data_engine/model/metadata/util.py @@ -1,6 +1,12 @@ import datetime from typing import Optional +from gql.transport.exceptions import TransportError, TransportQueryError +from requests import ConnectionError as RequestsConnectionError +from requests import Timeout as RequestsTimeout + +from dagshub.data_engine.model.errors import DataEngineGqlError + def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: """ @@ -19,3 +25,18 @@ def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: offset_minutes = int((offset.total_seconds() % 3600) // 60) offset_str = f"{offset_hours:+03d}:{offset_minutes:02d}" return offset_str + + +def is_retryable_metadata_upload_error(exc: Exception) -> bool: + if isinstance(exc, DataEngineGqlError): + return is_retryable_metadata_upload_error(exc.original_exception) + + return (isinstance(exc, TransportError) and not isinstance(exc, TransportQueryError)) or isinstance( + exc, + ( + TimeoutError, + ConnectionError, + RequestsConnectionError, + RequestsTimeout, + ), + ) diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py new file mode 100644 index 00000000..d5da7171 --- /dev/null +++ b/tests/common/test_adaptive_batching.py @@ -0,0 +1,412 @@ +from unittest.mock import patch + +import pytest + +from dagshub.common.adaptive_batching import ( + AdaptiveBatchConfig, + AdaptiveBatcher, + _clamp, + _get_retry_delay_seconds, + _is_next_step_above_limit, + _min_step_size, + _next_batch_after_retryable_failure, + _next_batch_after_success, +) + + +class RetryableTestError(Exception): + pass + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _cfg( + max_batch_size=1000, + min_batch_size=1, + initial_batch_size=10, + target_batch_time_seconds=5.0, + batch_growth_factor=10, + retry_backoff_base_seconds=0.25, + retry_backoff_max_seconds=4.0, +): + """Shortcut to build a config without going through from_values (avoids config import).""" + return AdaptiveBatchConfig( + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + initial_batch_size=initial_batch_size, + target_batch_time_seconds=target_batch_time_seconds, + batch_growth_factor=batch_growth_factor, + retry_backoff_base_seconds=retry_backoff_base_seconds, + retry_backoff_max_seconds=retry_backoff_max_seconds, + ) + + +# --------------------------------------------------------------------------- +# AdaptiveBatchConfig.from_values +# --------------------------------------------------------------------------- + + +class TestAdaptiveBatchConfigFromValues: + def test_normalizes_max_batch_size_to_at_least_1(self): + cfg = AdaptiveBatchConfig.from_values(max_batch_size=0, min_batch_size=1) + assert cfg.max_batch_size == 1 + + def test_clamps_min_to_max(self): + cfg = AdaptiveBatchConfig.from_values(max_batch_size=5, min_batch_size=100) + assert cfg.min_batch_size == 5 + + def test_clamps_initial_between_min_and_max(self): + cfg = AdaptiveBatchConfig.from_values(max_batch_size=100, min_batch_size=10, initial_batch_size=5) + assert cfg.initial_batch_size == 10 + + cfg2 = AdaptiveBatchConfig.from_values(max_batch_size=100, min_batch_size=10, initial_batch_size=200) + assert cfg2.initial_batch_size == 100 + + def test_batch_growth_factor_minimum_is_2(self): + cfg = AdaptiveBatchConfig.from_values(batch_growth_factor=1) + assert cfg.batch_growth_factor == 2 + + def test_backoff_seconds_non_negative(self): + cfg = AdaptiveBatchConfig.from_values(retry_backoff_base_seconds=-1.0, retry_backoff_max_seconds=-5.0) + assert cfg.retry_backoff_base_seconds == 0.0 + assert cfg.retry_backoff_max_seconds == 0.0 + + +class TestClamp: + def test_within_range(self): + assert _clamp(5, 1, 10) == 5 + + def test_below_minimum(self): + assert _clamp(0, 3, 10) == 3 + + def test_above_maximum(self): + assert _clamp(20, 1, 10) == 10 + + def test_equal_bounds(self): + assert _clamp(5, 7, 7) == 7 + + +# --------------------------------------------------------------------------- +# _next_batch_after_success +# --------------------------------------------------------------------------- + + +class TestNextBatchAfterSuccess: + def test_grows_by_growth_factor_when_no_bad_size(self): + cfg = _cfg(batch_growth_factor=10, max_batch_size=10000) + assert _next_batch_after_success(10, cfg, soft_upper_limit=None) == 100 + + def test_capped_at_max_batch_size(self): + cfg = _cfg(batch_growth_factor=10, max_batch_size=50) + assert _next_batch_after_success(10, cfg, soft_upper_limit=None) == 50 + + def test_binary_search_toward_bad_size(self): + cfg = _cfg(max_batch_size=10000) + result = _next_batch_after_success(10, cfg, soft_upper_limit=20) + assert 10 < result < 20 + + def test_stays_below_soft_upper_limit_when_midpoint_advances(self): + cfg = _cfg(max_batch_size=10000) + result = _next_batch_after_success(18, cfg, soft_upper_limit=20) + assert result <= 19 # bad_batch_size - 1 + + def test_respects_min_batch_size(self): + cfg = _cfg(min_batch_size=5, max_batch_size=10) + result = _next_batch_after_success(1, cfg, soft_upper_limit=None) + assert result >= 5 + + def test_holds_near_soft_upper_limit_before_reprobing(self): + cfg = _cfg(max_batch_size=1000, batch_growth_factor=2) + result = _next_batch_after_success(9, cfg, soft_upper_limit=10) + assert result == 9 + + def test_makes_progress_when_growth_factor_would_not_increase(self): + cfg = _cfg(batch_growth_factor=2, max_batch_size=100) + # batch_size=99, growth gives 198 capped to 100, which is > 99 so it works. + # But let's test the edge: batch_size at max-1 should reach max. + result = _next_batch_after_success(99, cfg, soft_upper_limit=None) + assert result == 100 + + +# --------------------------------------------------------------------------- +# _next_batch_after_retryable_failure +# --------------------------------------------------------------------------- + + +class TestNextBatchAfterRetryableFailure: + def test_halves_when_no_bounds(self): + cfg = _cfg(min_batch_size=1) + assert _next_batch_after_retryable_failure(100, cfg, None, None) == 50 + + def test_returns_min_when_at_min(self): + cfg = _cfg(min_batch_size=5) + assert _next_batch_after_retryable_failure(5, cfg, None, None) == 5 + + def test_binary_search_between_good_and_bad(self): + cfg = _cfg(min_batch_size=1) + result = _next_batch_after_retryable_failure(100, cfg, last_fast_batch_size=40, soft_upper_limit=100) + assert 40 < result < 100 + + def test_never_returns_below_min_batch_size(self): + cfg = _cfg(min_batch_size=10) + result = _next_batch_after_retryable_failure(20, cfg, None, None) + assert result >= 10 + + def test_strictly_decreases_from_current(self): + cfg = _cfg(min_batch_size=1) + for batch_size in [2, 5, 10, 50, 100, 1000]: + result = _next_batch_after_retryable_failure(batch_size, cfg, None, None) + assert result < batch_size + + +# --------------------------------------------------------------------------- +# _get_retry_delay_seconds +# --------------------------------------------------------------------------- + + +class TestGetRetryDelaySeconds: + def test_returns_base_for_first_failure(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + delay = _get_retry_delay_seconds(1, cfg) + assert delay == pytest.approx(0.25, abs=0.01) + + def test_increases_with_more_failures(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + d1 = _get_retry_delay_seconds(1, cfg) + d2 = _get_retry_delay_seconds(2, cfg) + d3 = _get_retry_delay_seconds(3, cfg) + assert d1 < d2 < d3 + + def test_capped_at_max(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + delay = _get_retry_delay_seconds(100, cfg) + assert delay <= 4.0 + + def test_zero_failures_treated_as_one(self): + cfg = _cfg(retry_backoff_base_seconds=0.5, retry_backoff_max_seconds=10.0) + assert _get_retry_delay_seconds(0, cfg) == _get_retry_delay_seconds(1, cfg) + + +class TestSoftUpperHintResolution: + def test_min_step_size_scales_with_hint(self): + assert _min_step_size(10) == 1 + assert _min_step_size(1000) == 50 + + def test_holds_when_gap_is_within_resolution(self): + assert _is_next_step_above_limit(9, 10) + assert not _is_next_step_above_limit(8, 10) + assert not _is_next_step_above_limit(10, 10) + assert not _is_next_step_above_limit(9, None) + + +# --------------------------------------------------------------------------- +# AdaptiveBatcher.run — integration tests +# --------------------------------------------------------------------------- + + +class TestAdaptiveBatcherRun: + @staticmethod + def _make_batcher(**config_overrides): + cfg = _cfg( + **{ + **dict( + initial_batch_size=3, + max_batch_size=100, + min_batch_size=1, + target_batch_time_seconds=999, # fast enough to always grow + retry_backoff_base_seconds=0.0, + retry_backoff_max_seconds=0.0, + ), + **config_overrides, + } + ) + return AdaptiveBatcher( + is_retryable=lambda exc: isinstance(exc, RetryableTestError), + config=cfg, + ) + + def test_processes_all_items(self): + batcher = self._make_batcher() + received = [] + batcher.run(list(range(10)), lambda batch: received.extend(batch)) + assert received == list(range(10)) + + def test_empty_list(self): + batcher = self._make_batcher() + called = [] + batcher.run([], lambda batch: called.append(batch)) + assert called == [] + + def test_generator_input(self): + batcher = self._make_batcher() + received = [] + + def gen(): + for i in range(7): + yield i + + batcher.run(gen(), lambda batch: received.extend(batch)) + assert received == list(range(7)) + + def test_retries_on_retryable_error(self): + batcher = self._make_batcher(initial_batch_size=5, min_batch_size=1) + call_count = 0 + received = [] + + def op(batch): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RetryableTestError("transient") + received.extend(batch) + + batcher.run(list(range(5)), op) + assert received == list(range(5)) + assert call_count > 1 + + def test_aborts_on_non_retryable_error(self): + batcher = self._make_batcher() + + with pytest.raises(TypeError): + batcher.run(list(range(5)), lambda batch: (_ for _ in ()).throw(TypeError("fatal"))) + + def test_aborts_at_min_batch_size(self): + batcher = self._make_batcher(initial_batch_size=1, min_batch_size=1) + + with pytest.raises(RetryableTestError, match="always fails"): + batcher.run([1], lambda batch: (_ for _ in ()).throw(RetryableTestError("always fails"))) + + def test_short_tail_batch_retries_exact_size_once_before_aborting(self): + batcher = self._make_batcher(initial_batch_size=5, min_batch_size=4, max_batch_size=5) + attempts = 0 + + def op(batch): + nonlocal attempts + attempts += 1 + if attempts >= 3: + raise TypeError("retried too many times") + raise RetryableTestError("always fails") + + with pytest.raises(RetryableTestError, match="always fails"): + batcher.run([1, 2, 3], op) + + assert attempts == 2 + + def test_no_items_lost_on_retry(self): + """All items from a failed batch must be retried.""" + batcher = self._make_batcher(initial_batch_size=4, min_batch_size=1) + fail_once = True + all_received = [] + + def op(batch): + nonlocal fail_once + if fail_once and len(batch) == 4: + fail_once = False + raise RetryableTestError("fail big batch once") + all_received.extend(batch) + + items = list(range(8)) + batcher.run(items, op) + assert all_received == items + + def test_generator_retry_no_items_lost(self): + """Items from a failed batch are retried even with generator input.""" + batcher = self._make_batcher(initial_batch_size=3, min_batch_size=1) + fail_once = True + all_received = [] + + def op(batch): + nonlocal fail_once + if fail_once: + fail_once = False + raise RetryableTestError("transient") + all_received.extend(batch) + + def gen(): + for i in range(6): + yield i + + batcher.run(gen(), op) + assert all_received == list(range(6)) + + def test_batch_size_shrinks_on_failure(self): + batcher = self._make_batcher(initial_batch_size=10, min_batch_size=1) + batch_sizes = [] + + def op(batch): + batch_sizes.append(len(batch)) + if batch_sizes[-1] == 10: + raise RetryableTestError("too big") + + batcher.run(list(range(20)), op) + # First call is size 10 (fails), next should be smaller + assert batch_sizes[0] == 10 + assert batch_sizes[1] < 10 + + @patch("dagshub.common.adaptive_batching.time") + def test_batch_size_grows_on_fast_success(self, mock_time): + # Make monotonic() return increasing values, but elapsed always < target + mock_time.monotonic.side_effect = [0.0, 0.001] * 50 + mock_time.sleep = lambda _: None + + batcher = self._make_batcher( + initial_batch_size=2, + max_batch_size=100, + target_batch_time_seconds=5.0, + ) + batch_sizes = [] + batcher.run(list(range(100)), lambda batch: batch_sizes.append(len(batch))) + # Should grow from initial=2 + assert max(batch_sizes) > 2 + + @patch("dagshub.common.adaptive_batching.time") + def test_batch_size_shrinks_on_slow_success(self, mock_time): + # Make elapsed always > target (slow batches) + mock_time.monotonic.side_effect = [0.0, 100.0] * 100 + mock_time.sleep = lambda _: None + + batcher = self._make_batcher( + initial_batch_size=20, + min_batch_size=1, + max_batch_size=100, + target_batch_time_seconds=1.0, + ) + received = [] + batch_sizes = [] + + def op(batch): + batch_sizes.append(len(batch)) + received.extend(batch) + + items = list(range(50)) + batcher.run(items, op) + # All items processed despite slow batches + assert received == items + # Batch size should shrink from 20 + assert batch_sizes[0] == 20 + assert min(batch_sizes) < 20 + + def test_reprobes_soft_upper_limit_after_stable_fast_successes(self): + batcher = self._make_batcher( + initial_batch_size=10, + min_batch_size=1, + max_batch_size=1000, + ) + fail_once = True + batch_sizes = [] + + def op(batch): + nonlocal fail_once + batch_sizes.append(len(batch)) + if fail_once: + fail_once = False + raise RetryableTestError("transient") + + items = list(range(200)) + batcher.run(items, op) + assert batch_sizes[:4] == [10, 5, 7, 8] + assert batch_sizes[4:7] == [9, 9, 9] + assert 10 in batch_sizes[7:] diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index bd1f1912..e6f6e0dc 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -8,15 +8,21 @@ import pandas as pd import pytest +import dagshub.common.config from dagshub.common.util import wrap_bytes from dagshub.data_engine.annotation import MetadataAnnotations from dagshub.data_engine.client.models import MetadataFieldSchema from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags from dagshub.data_engine.model.datapoint import Datapoint -from dagshub.data_engine.model.datasource import Datasource, DatapointMetadataUpdateEntry, MetadataContextManager +from dagshub.data_engine.model.datasource import DatapointMetadataUpdateEntry, Datasource, MetadataContextManager +from dagshub.data_engine.model.errors import DataEngineGqlError from dagshub.data_engine.model.metadata import MultipleDataTypesUploadedError, StringFieldValueTooLongError from dagshub.data_engine.model.query_result import QueryResult -from tests.data_engine.util import add_string_fields, add_document_fields, add_annotation_fields +from tests.data_engine.util import add_annotation_fields, add_document_fields, add_string_fields + + +def _uploaded_batch_sizes(ds: Datasource): + return [len(call.args[1]) for call in ds.source.client.update_metadata.call_args_list] @pytest.fixture @@ -142,6 +148,64 @@ def test_uploading_to_document_turns_into_blob(ds): client_mock.update_metadata.assert_called_with(ds, expected_data_upload) +def test_upload_metadata_starts_small_and_grows(ds, mocker): + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14)] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch.object(dagshub.common.config, "adaptive_batch_growth_factor", 8) + + ds._upload_metadata(entries) + + batch_sizes = _uploaded_batch_sizes(ds) + assert batch_sizes[0] == 2 + assert max(batch_sizes) > 2 + assert sum(batch_sizes) == len(entries) + + +def test_upload_metadata_retries_wrapped_retryable_error_with_smaller_batch(ds, mocker): + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10)] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 8 and not has_failed["value"]: + has_failed["value"] = True + raise DataEngineGqlError(TimeoutError("simulated timeout"), support_id="test-support-id") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + batch_sizes = _uploaded_batch_sizes(ds) + assert has_failed["value"] + assert batch_sizes[0] == 8 + assert any(size < 8 for size in batch_sizes[1:]) + + +def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10)] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + ds.source.client.update_metadata.side_effect = ValueError("simulated validation error") + + with pytest.raises(ValueError, match="simulated validation error"): + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [8] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"],