From 84632f252d0197def277cca18389c85247154874 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Mon, 23 Feb 2026 12:18:16 +0200 Subject: [PATCH 01/30] Add adaptive metadata upload batch sizing --- dagshub/common/config.py | 13 +++ dagshub/data_engine/model/datasource.py | 102 ++++++++++++++++++++++-- tests/data_engine/test_datasource.py | 45 +++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index e82b0063..60084863 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -60,6 +60,19 @@ def set_host(new_host: str): DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000)) +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" +dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 150)) + +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL" +dataengine_metadata_upload_batch_size_initial = int( + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min) +) + +DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME" +dataengine_metadata_upload_target_batch_time = float( + os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0) +) + DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 255bb76d..491edcda 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -755,16 +755,104 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry]) progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - upload_batch_size = dagshub.common.config.dataengine_metadata_upload_batch_size + max_batch_size = max(1, dagshub.common.config.dataengine_metadata_upload_batch_size) + min_batch_size = max( + 1, + min(dagshub.common.config.dataengine_metadata_upload_batch_size_min, max_batch_size), + ) + current_batch_size = max( + min_batch_size, + min(dagshub.common.config.dataengine_metadata_upload_batch_size_initial, max_batch_size), + ) + target_batch_time = max(dagshub.common.config.dataengine_metadata_upload_target_batch_time, 0.01) + + def _next_batch_after_success(batch_size: int, bad_batch_size: Optional[int]) -> int: + # Keep expanding quickly until we find an upper bound, then binary-search between good and bad. + if bad_batch_size is not None and batch_size < bad_batch_size: + next_batch_size = batch_size + max(1, (bad_batch_size - batch_size) // 2) + else: + next_batch_size = batch_size * 2 + + next_batch_size = min(max_batch_size, next_batch_size) + if next_batch_size <= batch_size and batch_size < max_batch_size: + next_batch_size = batch_size + 1 + return max(min_batch_size, next_batch_size) + + def _next_batch_after_bad( + batch_size: int, + good_batch_size: Optional[int], + bad_batch_size: Optional[int], + ) -> int: + upper_bound = bad_batch_size if bad_batch_size is not None else batch_size + + if good_batch_size is not None and good_batch_size < upper_bound: + next_batch_size = good_batch_size + max(1, (upper_bound - good_batch_size) // 2) + else: + next_batch_size = upper_bound // 2 + + next_batch_size = max(min_batch_size, min(max_batch_size, next_batch_size)) + if next_batch_size >= batch_size: + next_batch_size = max(min_batch_size, batch_size - 1) + return next_batch_size + total_entries = len(metadata_entries) - total_task = progress.add_task(f"Uploading metadata (batch size {upload_batch_size})...", total=total_entries) + total_task = progress.add_task( + f"Uploading metadata (adaptive batch {current_batch_size}-{max_batch_size})...", + total=total_entries, + ) + + last_good_batch_size: Optional[int] = None + last_bad_batch_size: Optional[int] = None with progress: - for start in range(0, total_entries, upload_batch_size): - entries = metadata_entries[start : start + upload_batch_size] - logger.debug(f"Uploading {len(entries)} metadata entries...") - self.source.client.update_metadata(self, entries) - progress.update(total_task, advance=upload_batch_size) + start = 0 + while start < total_entries: + entries_left = total_entries - start + batch_size = min(current_batch_size, entries_left) + entries = metadata_entries[start : start + batch_size] + + progress.update( + total_task, + description=f"Uploading metadata (batch size {batch_size})...", + ) + logger.debug(f"Uploading {batch_size} metadata entries...") + + start_time = time.monotonic() + try: + self.source.client.update_metadata(self, entries) + except Exception as exc: + if batch_size <= min_batch_size: + logger.error( + f"Metadata upload failed at minimum batch size ({min_batch_size}); aborting.", + exc_info=True, + ) + raise + + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_bad(batch_size, last_good_batch_size, last_bad_batch_size) + logger.warning( + f"Metadata upload failed for batch size {batch_size} " + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." + ) + continue + + elapsed = time.monotonic() - start_time + start += batch_size + progress.update(total_task, advance=batch_size) + + if elapsed <= target_batch_time: + last_good_batch_size = ( + batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_success(batch_size, last_bad_batch_size) + else: + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_bad(batch_size, last_good_batch_size, last_bad_batch_size) + progress.update(total_task, completed=total_entries, refresh=True) # Update the status from dagshub, so we get back the new metadata columns diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index bd1f1912..49359f59 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -8,6 +8,7 @@ import pandas as pd import pytest +import dagshub.common.config from dagshub.common.util import wrap_bytes from dagshub.data_engine.annotation import MetadataAnnotations from dagshub.data_engine.client.models import MetadataFieldSchema @@ -19,6 +20,10 @@ from tests.data_engine.util import add_string_fields, add_document_fields, add_annotation_fields +def _uploaded_batch_sizes(ds: Datasource): + return [len(call.args[1]) for call in ds.source.client.update_metadata.call_args_list] + + @pytest.fixture def metadata_df(): data_dict = { @@ -142,6 +147,46 @@ def test_uploading_to_document_turns_into_blob(ds): client_mock.update_metadata.assert_called_with(ds, expected_data_upload) +def test_upload_metadata_starts_small_and_grows(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [2, 4, 8] + + +def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 8 and not has_failed["value"]: + has_failed["value"] = True + raise RuntimeError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 4, 6] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From ded500751edf69b9f66d9dc0ca53a5426c66f211 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 3 Mar 2026 16:12:46 +0200 Subject: [PATCH 02/30] Handle non-retryable metadata upload failures --- dagshub/data_engine/model/datasource.py | 23 ++++++++++++++++ tests/data_engine/test_datasource.py | 35 ++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 491edcda..d9b20475 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -16,7 +16,10 @@ import rich.progress from dataclasses_json import DataClassJsonMixin, LetterCase, config +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from pathvalidate import sanitize_filepath +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import Timeout as RequestsTimeout import dagshub.common.config from dagshub.common import rich_console @@ -42,6 +45,7 @@ from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource_state import DatasourceState from dagshub.data_engine.model.errors import ( + DataEngineGqlError, DatasetFieldComparisonError, DatasetNotFoundError, FieldNotFoundError, @@ -795,6 +799,21 @@ def _next_batch_after_bad( next_batch_size = max(min_batch_size, batch_size - 1) return next_batch_size + def _is_retryable_upload_error(exc: Exception) -> bool: + if isinstance(exc, DataEngineGqlError): + return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) + return isinstance( + exc, + ( + TransportServerError, + TransportConnectionFailed, + TimeoutError, + ConnectionError, + RequestsConnectionError, + RequestsTimeout, + ), + ) + total_entries = len(metadata_entries) total_task = progress.add_task( f"Uploading metadata (adaptive batch {current_batch_size}-{max_batch_size})...", @@ -821,6 +840,10 @@ def _next_batch_after_bad( try: self.source.client.update_metadata(self, entries) except Exception as exc: + if not _is_retryable_upload_error(exc): + logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) + raise + if batch_size <= min_batch_size: logger.error( f"Metadata upload failed at minimum batch size ({min_batch_size}); aborting.", diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 49359f59..a25741fb 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -177,7 +177,7 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): def _flaky_upload(_ds, upload_entries): if len(upload_entries) == 8 and not has_failed["value"]: has_failed["value"] = True - raise RuntimeError("simulated timeout") + raise TimeoutError("simulated timeout") ds.source.client.update_metadata.side_effect = _flaky_upload @@ -187,6 +187,39 @@ def _flaky_upload(_ds, upload_entries): assert _uploaded_batch_sizes(ds) == [8, 4, 6] +def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1.0) + mocker.patch("dagshub.data_engine.model.datasource.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) + + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [8, 4] + + +def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + ds.source.client.update_metadata.side_effect = ValueError("simulated validation error") + + with pytest.raises(ValueError, match="simulated validation error"): + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [8] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From 567c8306f060d4d7cb1b56a3ce856fef0ac25b47 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 3 Mar 2026 16:20:13 +0200 Subject: [PATCH 03/30] Allow retry shrink on partial metadata batches --- dagshub/data_engine/model/datasource.py | 11 ++++++++--- tests/data_engine/test_datasource.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index d9b20475..da44c944 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -787,6 +787,11 @@ def _next_batch_after_bad( good_batch_size: Optional[int], bad_batch_size: Optional[int], ) -> int: + # If we're already below the configured minimum (for example, last partial chunk), + # keep shrinking until we reach 1. + if batch_size <= min_batch_size: + return max(1, batch_size - 1) + upper_bound = bad_batch_size if bad_batch_size is not None else batch_size if good_batch_size is not None and good_batch_size < upper_bound: @@ -796,7 +801,7 @@ def _next_batch_after_bad( next_batch_size = max(min_batch_size, min(max_batch_size, next_batch_size)) if next_batch_size >= batch_size: - next_batch_size = max(min_batch_size, batch_size - 1) + next_batch_size = max(1, batch_size - 1) return next_batch_size def _is_retryable_upload_error(exc: Exception) -> bool: @@ -844,9 +849,9 @@ def _is_retryable_upload_error(exc: Exception) -> bool: logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) raise - if batch_size <= min_batch_size: + if batch_size <= 1: logger.error( - f"Metadata upload failed at minimum batch size ({min_batch_size}); aborting.", + "Metadata upload failed at minimum possible batch size (1); aborting.", exc_info=True, ) raise diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index a25741fb..f059df6c 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -220,6 +220,31 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): assert _uploaded_batch_sizes(ds) == [8] +def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 2 and not has_failed["value"]: + has_failed["value"] = True + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 2, 1, 1] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From 4a767ab55ea32d3e09a1fb746ca079bfd2fc47d4 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 3 Mar 2026 16:40:04 +0200 Subject: [PATCH 04/30] Add retry backoff for metadata upload failures --- dagshub/data_engine/model/datasource.py | 6 ++++++ tests/data_engine/test_datasource.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index da44c944..9bf7b223 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -827,6 +827,7 @@ def _is_retryable_upload_error(exc: Exception) -> bool: last_good_batch_size: Optional[int] = None last_bad_batch_size: Optional[int] = None + consecutive_retryable_failures = 0 with progress: start = 0 @@ -856,6 +857,10 @@ def _is_retryable_upload_error(exc: Exception) -> bool: ) raise + consecutive_retryable_failures += 1 + retry_delay_sec = min(5.0, 0.25 * (2 ** min(consecutive_retryable_failures - 1, 4))) + time.sleep(retry_delay_sec) + last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) @@ -867,6 +872,7 @@ def _is_retryable_upload_error(exc: Exception) -> bool: continue elapsed = time.monotonic() - start_time + consecutive_retryable_failures = 0 start += batch_size progress.update(total_task, advance=batch_size) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index f059df6c..63cc4962 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -245,6 +245,31 @@ def _flaky_upload(_ds, upload_entries): assert _uploaded_batch_sizes(ds) == [8, 2, 1, 1] +def test_upload_metadata_backoff_resets_after_success(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + + call_idx = {"value": 0} + + def _flaky_upload(_ds, _upload_entries): + call_idx["value"] += 1 + if call_idx["value"] in {1, 3}: + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25, 0.25] + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From 0f5e9d74dfc4dbfbb2355e5bd0aac25ab01070b2 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 3 Mar 2026 17:50:47 +0200 Subject: [PATCH 05/30] Honor min batch size on retry failures --- dagshub/data_engine/model/datasource.py | 8 ++++---- tests/data_engine/test_datasource.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 9bf7b223..af639b25 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -789,7 +789,7 @@ def _next_batch_after_bad( ) -> int: # If we're already below the configured minimum (for example, last partial chunk), # keep shrinking until we reach 1. - if batch_size <= min_batch_size: + if batch_size < min_batch_size: return max(1, batch_size - 1) upper_bound = bad_batch_size if bad_batch_size is not None else batch_size @@ -801,7 +801,7 @@ def _next_batch_after_bad( next_batch_size = max(min_batch_size, min(max_batch_size, next_batch_size)) if next_batch_size >= batch_size: - next_batch_size = max(1, batch_size - 1) + next_batch_size = max(min_batch_size, batch_size - 1) return next_batch_size def _is_retryable_upload_error(exc: Exception) -> bool: @@ -850,9 +850,9 @@ def _is_retryable_upload_error(exc: Exception) -> bool: logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) raise - if batch_size <= 1: + if batch_size <= 1 or batch_size == min_batch_size: logger.error( - "Metadata upload failed at minimum possible batch size (1); aborting.", + f"Metadata upload failed at minimum batch size ({batch_size}); aborting.", exc_info=True, ) raise diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 63cc4962..446928d3 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -270,6 +270,25 @@ def _flaky_upload(_ds, _upload_entries): assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25, 0.25] +def test_upload_metadata_aborts_on_failure_at_min_batch(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(6) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + ds.source.client.update_metadata.side_effect = TimeoutError("simulated timeout") + + with pytest.raises(TimeoutError, match="simulated timeout"): + ds._upload_metadata(entries) + + assert _uploaded_batch_sizes(ds) == [2] + sleep_mock.assert_not_called() + + def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From 82e4503f4d01236818a76470cd9f07da8ca3bc97 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 4 Mar 2026 00:34:31 +0200 Subject: [PATCH 06/30] Mock retry backoff sleep in upload tests --- tests/data_engine/test_datasource.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 446928d3..27b567e2 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -171,6 +171,7 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) has_failed = {"value": False} @@ -229,6 +230,7 @@ def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) has_failed = {"value": False} From 996b7fc44d690afde8aede5d5baaeaa43081c9aa Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 4 Mar 2026 00:57:07 +0200 Subject: [PATCH 07/30] Avoid reusing known-bad upload batch size --- dagshub/data_engine/model/datasource.py | 4 ++++ tests/data_engine/test_datasource.py | 26 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index af639b25..edf31754 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -778,8 +778,12 @@ def _next_batch_after_success(batch_size: int, bad_batch_size: Optional[int]) -> next_batch_size = batch_size * 2 next_batch_size = min(max_batch_size, next_batch_size) + if bad_batch_size is not None and bad_batch_size > min_batch_size and next_batch_size >= bad_batch_size: + next_batch_size = bad_batch_size - 1 if next_batch_size <= batch_size and batch_size < max_batch_size: next_batch_size = batch_size + 1 + if bad_batch_size is not None and bad_batch_size > min_batch_size and next_batch_size >= bad_batch_size: + next_batch_size = bad_batch_size - 1 return max(min_batch_size, next_batch_size) def _next_batch_after_bad( diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 27b567e2..1733c162 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -188,6 +188,32 @@ def _flaky_upload(_ds, upload_entries): assert _uploaded_batch_sizes(ds) == [8, 4, 6] +def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): + entries = [ + DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(32) + ] + + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + + has_failed = {"value": False} + + def _flaky_upload(_ds, upload_entries): + if len(upload_entries) == 8 and not has_failed["value"]: + has_failed["value"] = True + raise TimeoutError("simulated timeout") + + ds.source.client.update_metadata.side_effect = _flaky_upload + + ds._upload_metadata(entries) + + assert has_failed["value"] + assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 7, 1] + + def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): entries = [ DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) From 2338c949fd6b68c38417a3e1d6bac0a2ba8111db Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 4 Mar 2026 10:51:40 +0200 Subject: [PATCH 08/30] Extract metadata upload retry backoff constants --- dagshub/data_engine/model/datasource.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index edf31754..553b068f 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -90,6 +90,10 @@ MLFLOW_DATASOURCE_TAG_NAME = "dagshub.datasets.datasource_id" MLFLOW_DATASET_TAG_NAME = "dagshub.datasets.dataset_id" +METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS = 0.25 +METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS = 5.0 +METADATA_UPLOAD_RETRY_BACKOFF_EXPONENT_CAP = 4 + @dataclass class DatapointDeleteMetadataEntry(DataClassJsonMixin): @@ -862,7 +866,12 @@ def _is_retryable_upload_error(exc: Exception) -> bool: raise consecutive_retryable_failures += 1 - retry_delay_sec = min(5.0, 0.25 * (2 ** min(consecutive_retryable_failures - 1, 4))) + # Bounded exponential backoff: 0.25s, 0.5s, 1s, 2s, 4s, then capped at 5s. + retry_delay_sec = min( + METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS, + METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS + * (2 ** min(consecutive_retryable_failures - 1, METADATA_UPLOAD_RETRY_BACKOFF_EXPONENT_CAP)), + ) time.sleep(retry_delay_sec) last_bad_batch_size = ( From b884e9ea3867711daeac64ad5c28a8e29e9b6a01 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 4 Mar 2026 11:02:51 +0200 Subject: [PATCH 09/30] Align metadata upload backoff cap with schedule --- dagshub/data_engine/model/datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 553b068f..37491e51 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -91,7 +91,7 @@ MLFLOW_DATASET_TAG_NAME = "dagshub.datasets.dataset_id" METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS = 0.25 -METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS = 5.0 +METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS = 4.0 METADATA_UPLOAD_RETRY_BACKOFF_EXPONENT_CAP = 4 @@ -866,7 +866,7 @@ def _is_retryable_upload_error(exc: Exception) -> bool: raise consecutive_retryable_failures += 1 - # Bounded exponential backoff: 0.25s, 0.5s, 1s, 2s, 4s, then capped at 5s. + # Bounded exponential backoff: 0.25s, 0.5s, 1s, 2s, then capped at 4s. retry_delay_sec = min( METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS, METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS From 4c1132c69820588bc6e6eef73212e7884676bd9f Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Thu, 5 Mar 2026 16:20:41 +0200 Subject: [PATCH 10/30] Refactor adaptive metadata upload sizing and retries --- dagshub/common/config.py | 20 ++- dagshub/data_engine/model/datasource.py | 114 +++++------------ .../model/metadata/upload_batching.py | 116 ++++++++++++++++++ tests/data_engine/test_datasource.py | 26 ++-- 4 files changed, 176 insertions(+), 100 deletions(-) create mode 100644 dagshub/data_engine/model/metadata/upload_batching.py diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 60084863..9a2e75f6 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -58,10 +58,16 @@ def set_host(new_host: str): recommended_annotate_limit = int(os.environ.get(RECOMMENDED_ANNOTATE_LIMIT_KEY, 1e5)) DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" -dataengine_metadata_upload_batch_size = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000)) +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" +dataengine_metadata_upload_batch_size = int( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000), + ) +) DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" -dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 150)) +dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1)) DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_INITIAL" dataengine_metadata_upload_batch_size_initial = int( @@ -69,9 +75,15 @@ def set_host(new_host: str): ) DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME" -dataengine_metadata_upload_target_batch_time = float( - os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0) +DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS" +dataengine_metadata_upload_target_batch_time_seconds = float( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, + os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0), + ) ) +# Backwards compatibility for code that imports the old module attribute name. +dataengine_metadata_upload_target_batch_time = dataengine_metadata_upload_target_batch_time_seconds DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 37491e51..447f58ab 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -16,10 +16,7 @@ import rich.progress from dataclasses_json import DataClassJsonMixin, LetterCase, config -from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from pathvalidate import sanitize_filepath -from requests.exceptions import ConnectionError as RequestsConnectionError -from requests.exceptions import Timeout as RequestsTimeout import dagshub.common.config from dagshub.common import rich_console @@ -45,7 +42,6 @@ from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource_state import DatasourceState from dagshub.data_engine.model.errors import ( - DataEngineGqlError, DatasetFieldComparisonError, DatasetNotFoundError, FieldNotFoundError, @@ -57,6 +53,13 @@ run_preupload_transforms, validate_uploading_metadata, ) +from dagshub.data_engine.model.metadata.upload_batching import ( + AdaptiveUploadBatchConfig, + get_retry_delay_seconds, + is_retryable_metadata_upload_error, + next_batch_after_retryable_failure, + next_batch_after_success, +) from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder @@ -90,10 +93,6 @@ MLFLOW_DATASOURCE_TAG_NAME = "dagshub.datasets.datasource_id" MLFLOW_DATASET_TAG_NAME = "dagshub.datasets.dataset_id" -METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS = 0.25 -METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS = 4.0 -METADATA_UPLOAD_RETRY_BACKOFF_EXPONENT_CAP = 4 - @dataclass class DatapointDeleteMetadataEntry(DataClassJsonMixin): @@ -763,73 +762,17 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry]) progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - max_batch_size = max(1, dagshub.common.config.dataengine_metadata_upload_batch_size) - min_batch_size = max( - 1, - min(dagshub.common.config.dataengine_metadata_upload_batch_size_min, max_batch_size), - ) - current_batch_size = max( - min_batch_size, - min(dagshub.common.config.dataengine_metadata_upload_batch_size_initial, max_batch_size), + batch_config = AdaptiveUploadBatchConfig.from_values( + max_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size, + min_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_min, + initial_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_initial, + target_batch_time_seconds=dagshub.common.config.dataengine_metadata_upload_target_batch_time_seconds, ) - target_batch_time = max(dagshub.common.config.dataengine_metadata_upload_target_batch_time, 0.01) - - def _next_batch_after_success(batch_size: int, bad_batch_size: Optional[int]) -> int: - # Keep expanding quickly until we find an upper bound, then binary-search between good and bad. - if bad_batch_size is not None and batch_size < bad_batch_size: - next_batch_size = batch_size + max(1, (bad_batch_size - batch_size) // 2) - else: - next_batch_size = batch_size * 2 - - next_batch_size = min(max_batch_size, next_batch_size) - if bad_batch_size is not None and bad_batch_size > min_batch_size and next_batch_size >= bad_batch_size: - next_batch_size = bad_batch_size - 1 - if next_batch_size <= batch_size and batch_size < max_batch_size: - next_batch_size = batch_size + 1 - if bad_batch_size is not None and bad_batch_size > min_batch_size and next_batch_size >= bad_batch_size: - next_batch_size = bad_batch_size - 1 - return max(min_batch_size, next_batch_size) - - def _next_batch_after_bad( - batch_size: int, - good_batch_size: Optional[int], - bad_batch_size: Optional[int], - ) -> int: - # If we're already below the configured minimum (for example, last partial chunk), - # keep shrinking until we reach 1. - if batch_size < min_batch_size: - return max(1, batch_size - 1) - - upper_bound = bad_batch_size if bad_batch_size is not None else batch_size - - if good_batch_size is not None and good_batch_size < upper_bound: - next_batch_size = good_batch_size + max(1, (upper_bound - good_batch_size) // 2) - else: - next_batch_size = upper_bound // 2 - - next_batch_size = max(min_batch_size, min(max_batch_size, next_batch_size)) - if next_batch_size >= batch_size: - next_batch_size = max(min_batch_size, batch_size - 1) - return next_batch_size - - def _is_retryable_upload_error(exc: Exception) -> bool: - if isinstance(exc, DataEngineGqlError): - return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) - return isinstance( - exc, - ( - TransportServerError, - TransportConnectionFailed, - TimeoutError, - ConnectionError, - RequestsConnectionError, - RequestsTimeout, - ), - ) + current_batch_size = batch_config.initial_batch_size total_entries = len(metadata_entries) total_task = progress.add_task( - f"Uploading metadata (adaptive batch {current_batch_size}-{max_batch_size})...", + f"Uploading metadata (adaptive batch {batch_config.min_batch_size}-{batch_config.max_batch_size})...", total=total_entries, ) @@ -854,11 +797,11 @@ def _is_retryable_upload_error(exc: Exception) -> bool: try: self.source.client.update_metadata(self, entries) except Exception as exc: - if not _is_retryable_upload_error(exc): + if not is_retryable_metadata_upload_error(exc): logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) raise - if batch_size <= 1 or batch_size == min_batch_size: + if batch_size <= 1: logger.error( f"Metadata upload failed at minimum batch size ({batch_size}); aborting.", exc_info=True, @@ -866,18 +809,18 @@ def _is_retryable_upload_error(exc: Exception) -> bool: raise consecutive_retryable_failures += 1 - # Bounded exponential backoff: 0.25s, 0.5s, 1s, 2s, then capped at 4s. - retry_delay_sec = min( - METADATA_UPLOAD_RETRY_BACKOFF_MAX_SECONDS, - METADATA_UPLOAD_RETRY_BACKOFF_BASE_SECONDS - * (2 ** min(consecutive_retryable_failures - 1, METADATA_UPLOAD_RETRY_BACKOFF_EXPONENT_CAP)), - ) + retry_delay_sec = get_retry_delay_seconds(consecutive_retryable_failures) time.sleep(retry_delay_sec) last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) - current_batch_size = _next_batch_after_bad(batch_size, last_good_batch_size, last_bad_batch_size) + current_batch_size = next_batch_after_retryable_failure( + batch_size, + batch_config, + last_good_batch_size, + last_bad_batch_size, + ) logger.warning( f"Metadata upload failed for batch size {batch_size} " f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." @@ -889,16 +832,21 @@ def _is_retryable_upload_error(exc: Exception) -> bool: start += batch_size progress.update(total_task, advance=batch_size) - if elapsed <= target_batch_time: + if elapsed <= batch_config.target_batch_time_seconds: last_good_batch_size = ( batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) ) - current_batch_size = _next_batch_after_success(batch_size, last_bad_batch_size) + current_batch_size = next_batch_after_success(batch_size, batch_config, last_bad_batch_size) else: last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) - current_batch_size = _next_batch_after_bad(batch_size, last_good_batch_size, last_bad_batch_size) + current_batch_size = next_batch_after_retryable_failure( + batch_size, + batch_config, + last_good_batch_size, + last_bad_batch_size, + ) progress.update(total_task, completed=total_entries, refresh=True) diff --git a/dagshub/data_engine/model/metadata/upload_batching.py b/dagshub/data_engine/model/metadata/upload_batching.py new file mode 100644 index 00000000..6394ff03 --- /dev/null +++ b/dagshub/data_engine/model/metadata/upload_batching.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Optional + +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import Timeout as RequestsTimeout +from tenacity import wait_exponential + +from dagshub.data_engine.model.errors import DataEngineGqlError + +MIN_TARGET_BATCH_TIME_SECONDS = 0.01 +BATCH_GROWTH_FACTOR = 10 +RETRY_BACKOFF_BASE_SECONDS = 0.25 +RETRY_BACKOFF_MAX_SECONDS = 4.0 + +_retry_delay_strategy = wait_exponential( + multiplier=RETRY_BACKOFF_BASE_SECONDS, + min=RETRY_BACKOFF_BASE_SECONDS, + max=RETRY_BACKOFF_MAX_SECONDS, +) + + +@dataclass(frozen=True) +class AdaptiveUploadBatchConfig: + max_batch_size: int + min_batch_size: int + initial_batch_size: int + target_batch_time_seconds: float + + @classmethod + def from_values( + cls, + max_batch_size: int, + min_batch_size: int, + initial_batch_size: int, + target_batch_time_seconds: float, + ) -> "AdaptiveUploadBatchConfig": + normalized_max_batch_size = max(1, max_batch_size) + normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) + normalized_initial_batch_size = max( + normalized_min_batch_size, + min(initial_batch_size, normalized_max_batch_size), + ) + normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) + return cls( + max_batch_size=normalized_max_batch_size, + min_batch_size=normalized_min_batch_size, + initial_batch_size=normalized_initial_batch_size, + target_batch_time_seconds=normalized_target_batch_time_seconds, + ) + + +def _midpoint(lower_bound: int, upper_bound: int) -> int: + return lower_bound + max(1, (upper_bound - lower_bound) // 2) + + +def next_batch_after_success( + batch_size: int, + config: AdaptiveUploadBatchConfig, + bad_batch_size: Optional[int], +) -> int: + if bad_batch_size is not None and batch_size < bad_batch_size: + next_batch_size = _midpoint(batch_size, bad_batch_size) + next_batch_size = min(next_batch_size, bad_batch_size - 1) + else: + next_batch_size = batch_size * BATCH_GROWTH_FACTOR + + next_batch_size = min(config.max_batch_size, next_batch_size) + if next_batch_size <= batch_size and batch_size < config.max_batch_size: + next_batch_size = min(config.max_batch_size, batch_size + 1) + if bad_batch_size is not None: + next_batch_size = min(next_batch_size, bad_batch_size - 1) + + return max(config.min_batch_size, next_batch_size) + + +def next_batch_after_retryable_failure( + batch_size: int, + config: AdaptiveUploadBatchConfig, + good_batch_size: Optional[int], + bad_batch_size: Optional[int], +) -> int: + if batch_size <= 1: + return 1 + + upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size + if good_batch_size is not None and good_batch_size < upper_bound: + next_batch_size = _midpoint(good_batch_size, upper_bound) + else: + next_batch_size = batch_size // 2 + + next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) + return max(1, next_batch_size) + + +def is_retryable_metadata_upload_error(exc: Exception) -> bool: + if isinstance(exc, DataEngineGqlError): + return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) + + return isinstance( + exc, + ( + TransportServerError, + TransportConnectionFailed, + TimeoutError, + ConnectionError, + RequestsConnectionError, + RequestsTimeout, + ), + ) + + +def get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: + retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) + return float(_retry_delay_strategy(retry_state)) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 1733c162..7e09801b 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -155,11 +155,11 @@ def test_upload_metadata_starts_small_and_grows(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) ds._upload_metadata(entries) - assert _uploaded_batch_sizes(ds) == [2, 4, 8] + assert _uploaded_batch_sizes(ds) == [2, 12] def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): @@ -170,7 +170,7 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) has_failed = {"value": False} @@ -196,7 +196,7 @@ def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) has_failed = {"value": False} @@ -211,7 +211,7 @@ def _flaky_upload(_ds, upload_entries): ds._upload_metadata(entries) assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 7, 1] + assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 1] def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): @@ -222,7 +222,7 @@ def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1.0) mocker.patch("dagshub.data_engine.model.datasource.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) ds._upload_metadata(entries) @@ -238,7 +238,7 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) ds.source.client.update_metadata.side_effect = ValueError("simulated validation error") with pytest.raises(ValueError, match="simulated validation error"): @@ -255,7 +255,7 @@ def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) has_failed = {"value": False} @@ -281,7 +281,7 @@ def test_upload_metadata_backoff_resets_after_success(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") call_idx = {"value": 0} @@ -298,7 +298,7 @@ def _flaky_upload(_ds, _upload_entries): assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25, 0.25] -def test_upload_metadata_aborts_on_failure_at_min_batch(ds, mocker): +def test_upload_metadata_retries_below_configured_min_before_aborting(ds, mocker): entries = [ DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(6) ] @@ -306,15 +306,15 @@ def test_upload_metadata_aborts_on_failure_at_min_batch(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time", 1000.0) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") ds.source.client.update_metadata.side_effect = TimeoutError("simulated timeout") with pytest.raises(TimeoutError, match="simulated timeout"): ds._upload_metadata(entries) - assert _uploaded_batch_sizes(ds) == [2] - sleep_mock.assert_not_called() + assert _uploaded_batch_sizes(ds) == [2, 1] + assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25] def test_pandas_timestamp(ds): From 6ae0a5828a5b5f753f0c1984c6f93661ac38df8c Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Thu, 5 Mar 2026 16:34:43 +0200 Subject: [PATCH 11/30] Fix adaptive upload expected batch sequence test --- tests/data_engine/test_datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 7e09801b..5fe2ea00 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -211,7 +211,7 @@ def _flaky_upload(_ds, upload_entries): ds._upload_metadata(entries) assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 1] + assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 7, 1] def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): From 1b6356bab95955f02a726b8e026cfa69175b9141 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 15:21:06 +0200 Subject: [PATCH 12/30] Extract AdaptiveBatcher into dagshub.common for reuse Move generic adaptive batching logic (config, size adaptation, retry, progress) out of data_engine.model.metadata into dagshub.common.adaptive_batching. Config defaults come from common.config so callers don't need boilerplate. Domain-specific is_retryable_metadata_upload_error stays in metadata.util. --- dagshub/common/adaptive_batching.py | 211 ++++++++++++++++++ dagshub/common/config.py | 10 +- dagshub/data_engine/model/datasource.py | 102 +-------- .../model/metadata/upload_batching.py | 116 ---------- dagshub/data_engine/model/metadata/util.py | 21 ++ 5 files changed, 239 insertions(+), 221 deletions(-) create mode 100644 dagshub/common/adaptive_batching.py delete mode 100644 dagshub/data_engine/model/metadata/upload_batching.py diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py new file mode 100644 index 00000000..91ff4f90 --- /dev/null +++ b/dagshub/common/adaptive_batching.py @@ -0,0 +1,211 @@ +import logging +import time +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Callable, List, Optional, TypeVar + +import rich.progress +from tenacity import wait_exponential + +from dagshub.common.rich_util import get_rich_progress + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +MIN_TARGET_BATCH_TIME_SECONDS = 0.01 +BATCH_GROWTH_FACTOR = 10 +RETRY_BACKOFF_BASE_SECONDS = 0.25 +RETRY_BACKOFF_MAX_SECONDS = 4.0 + +_retry_delay_strategy = wait_exponential( + multiplier=RETRY_BACKOFF_BASE_SECONDS, + min=RETRY_BACKOFF_BASE_SECONDS, + max=RETRY_BACKOFF_MAX_SECONDS, +) + + +@dataclass(frozen=True) +class AdaptiveBatchConfig: + max_batch_size: int + min_batch_size: int + initial_batch_size: int + target_batch_time_seconds: float + + @classmethod + def from_values( + cls, + max_batch_size: Optional[int] = None, + min_batch_size: Optional[int] = None, + initial_batch_size: Optional[int] = None, + target_batch_time_seconds: Optional[float] = None, + ) -> "AdaptiveBatchConfig": + import dagshub.common.config as dgs_config + + if max_batch_size is None: + max_batch_size = dgs_config.dataengine_metadata_upload_batch_size + if min_batch_size is None: + min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min + if initial_batch_size is None: + initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial + if target_batch_time_seconds is None: + target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds + + normalized_max_batch_size = max(1, max_batch_size) + normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) + normalized_initial_batch_size = max( + normalized_min_batch_size, + min(initial_batch_size, normalized_max_batch_size), + ) + normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) + return cls( + max_batch_size=normalized_max_batch_size, + min_batch_size=normalized_min_batch_size, + initial_batch_size=normalized_initial_batch_size, + target_batch_time_seconds=normalized_target_batch_time_seconds, + ) + + +def _midpoint(lower_bound: int, upper_bound: int) -> int: + return lower_bound + max(1, (upper_bound - lower_bound) // 2) + + +def _next_batch_after_success( + batch_size: int, + config: AdaptiveBatchConfig, + bad_batch_size: Optional[int], +) -> int: + if bad_batch_size is not None and batch_size < bad_batch_size: + next_batch_size = _midpoint(batch_size, bad_batch_size) + next_batch_size = min(next_batch_size, bad_batch_size - 1) + else: + next_batch_size = batch_size * BATCH_GROWTH_FACTOR + + next_batch_size = min(config.max_batch_size, next_batch_size) + if next_batch_size <= batch_size and batch_size < config.max_batch_size: + next_batch_size = min(config.max_batch_size, batch_size + 1) + if bad_batch_size is not None: + next_batch_size = min(next_batch_size, bad_batch_size - 1) + + return max(config.min_batch_size, next_batch_size) + + +def _next_batch_after_retryable_failure( + batch_size: int, + config: AdaptiveBatchConfig, + good_batch_size: Optional[int], + bad_batch_size: Optional[int], +) -> int: + if batch_size <= 1: + return 1 + + upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size + if good_batch_size is not None and good_batch_size < upper_bound: + next_batch_size = _midpoint(good_batch_size, upper_bound) + else: + next_batch_size = batch_size // 2 + + next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) + return max(1, next_batch_size) + + +def _get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: + retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) + return float(_retry_delay_strategy(retry_state)) + + +class AdaptiveBatcher: + """Sends items in adaptively-sized batches, growing on success and shrinking on failure.""" + + def __init__( + self, + is_retryable: Callable[[Exception], bool], + config: Optional[AdaptiveBatchConfig] = None, + progress_label: str = "Uploading", + ): + self._config = config if config is not None else AdaptiveBatchConfig.from_values() + self._is_retryable = is_retryable + self._progress_label = progress_label + + def run(self, items: List[T], operation: Callable[[List[T]], None]) -> None: + total = len(items) + if total == 0: + return + + config = self._config + current_batch_size = config.initial_batch_size + + progress = get_rich_progress(rich.progress.MofNCompleteColumn()) + total_task = progress.add_task( + f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...", + total=total, + ) + + last_good_batch_size: Optional[int] = None + last_bad_batch_size: Optional[int] = None + consecutive_retryable_failures = 0 + + with progress: + start = 0 + while start < total: + batch_size = min(current_batch_size, total - start) + batch = items[start : start + batch_size] + + progress.update( + total_task, + description=f"{self._progress_label} (batch size {batch_size})...", + ) + logger.debug(f"{self._progress_label}: {batch_size} entries...") + + start_time = time.monotonic() + try: + operation(batch) + except Exception as exc: + if not self._is_retryable(exc): + logger.error( + f"{self._progress_label} failed with a non-retryable error; aborting.", + exc_info=True, + ) + raise + + if batch_size <= 1: + logger.error( + f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.", + exc_info=True, + ) + raise + + consecutive_retryable_failures += 1 + time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures)) + + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_retryable_failure( + batch_size, config, last_good_batch_size, last_bad_batch_size + ) + logger.warning( + f"{self._progress_label} failed for batch size {batch_size} " + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." + ) + continue + + elapsed = time.monotonic() - start_time + consecutive_retryable_failures = 0 + start += batch_size + progress.update(total_task, advance=batch_size) + + if elapsed <= config.target_batch_time_seconds: + last_good_batch_size = ( + batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size) + else: + last_bad_batch_size = ( + batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + ) + current_batch_size = _next_batch_after_retryable_failure( + batch_size, config, last_good_batch_size, last_bad_batch_size + ) + + progress.update(total_task, completed=total, refresh=True) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 9a2e75f6..826d1f21 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -74,16 +74,8 @@ def set_host(new_host: str): os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_INITIAL_KEY, dataengine_metadata_upload_batch_size_min) ) -DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME" DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS" -dataengine_metadata_upload_target_batch_time_seconds = float( - os.environ.get( - DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, - os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_KEY, 5.0), - ) -) -# Backwards compatibility for code that imports the old module attribute name. -dataengine_metadata_upload_target_batch_time = dataengine_metadata_upload_target_batch_time_seconds +dataengine_metadata_upload_target_batch_time_seconds = float(os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0)) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index 447f58ab..94f78522 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Union -import rich.progress from dataclasses_json import DataClassJsonMixin, LetterCase, config from pathvalidate import sanitize_filepath @@ -53,13 +52,8 @@ run_preupload_transforms, validate_uploading_metadata, ) -from dagshub.data_engine.model.metadata.upload_batching import ( - AdaptiveUploadBatchConfig, - get_retry_delay_seconds, - is_retryable_metadata_upload_error, - next_batch_after_retryable_failure, - next_batch_after_success, -) +from dagshub.common.adaptive_batching import AdaptiveBatcher +from dagshub.data_engine.model.metadata.util import is_retryable_metadata_upload_error from dagshub.data_engine.model.metadata.dtypes import DatapointMetadataUpdateEntry from dagshub.data_engine.model.metadata.transforms import DatasourceFieldInfo, _add_metadata from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder @@ -760,95 +754,11 @@ def _upload_metadata(self, metadata_entries: List[DatapointMetadataUpdateEntry]) validate_uploading_metadata(precalculated_info) run_preupload_transforms(self, metadata_entries, precalculated_info) - progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - - batch_config = AdaptiveUploadBatchConfig.from_values( - max_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size, - min_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_min, - initial_batch_size=dagshub.common.config.dataengine_metadata_upload_batch_size_initial, - target_batch_time_seconds=dagshub.common.config.dataengine_metadata_upload_target_batch_time_seconds, - ) - current_batch_size = batch_config.initial_batch_size - - total_entries = len(metadata_entries) - total_task = progress.add_task( - f"Uploading metadata (adaptive batch {batch_config.min_batch_size}-{batch_config.max_batch_size})...", - total=total_entries, + batcher = AdaptiveBatcher( + is_retryable=is_retryable_metadata_upload_error, + progress_label="Uploading metadata", ) - - last_good_batch_size: Optional[int] = None - last_bad_batch_size: Optional[int] = None - consecutive_retryable_failures = 0 - - with progress: - start = 0 - while start < total_entries: - entries_left = total_entries - start - batch_size = min(current_batch_size, entries_left) - entries = metadata_entries[start : start + batch_size] - - progress.update( - total_task, - description=f"Uploading metadata (batch size {batch_size})...", - ) - logger.debug(f"Uploading {batch_size} metadata entries...") - - start_time = time.monotonic() - try: - self.source.client.update_metadata(self, entries) - except Exception as exc: - if not is_retryable_metadata_upload_error(exc): - logger.error("Metadata upload failed with a non-retryable error; aborting.", exc_info=True) - raise - - if batch_size <= 1: - logger.error( - f"Metadata upload failed at minimum batch size ({batch_size}); aborting.", - exc_info=True, - ) - raise - - consecutive_retryable_failures += 1 - retry_delay_sec = get_retry_delay_seconds(consecutive_retryable_failures) - time.sleep(retry_delay_sec) - - last_bad_batch_size = ( - batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) - ) - current_batch_size = next_batch_after_retryable_failure( - batch_size, - batch_config, - last_good_batch_size, - last_bad_batch_size, - ) - logger.warning( - f"Metadata upload failed for batch size {batch_size} " - f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." - ) - continue - - elapsed = time.monotonic() - start_time - consecutive_retryable_failures = 0 - start += batch_size - progress.update(total_task, advance=batch_size) - - if elapsed <= batch_config.target_batch_time_seconds: - last_good_batch_size = ( - batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) - ) - current_batch_size = next_batch_after_success(batch_size, batch_config, last_bad_batch_size) - else: - last_bad_batch_size = ( - batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) - ) - current_batch_size = next_batch_after_retryable_failure( - batch_size, - batch_config, - last_good_batch_size, - last_bad_batch_size, - ) - - progress.update(total_task, completed=total_entries, refresh=True) + batcher.run(metadata_entries, lambda batch: self.source.client.update_metadata(self, batch)) # Update the status from dagshub, so we get back the new metadata columns self.source.get_from_dagshub() diff --git a/dagshub/data_engine/model/metadata/upload_batching.py b/dagshub/data_engine/model/metadata/upload_batching.py deleted file mode 100644 index 6394ff03..00000000 --- a/dagshub/data_engine/model/metadata/upload_batching.py +++ /dev/null @@ -1,116 +0,0 @@ -from dataclasses import dataclass -from types import SimpleNamespace -from typing import Optional - -from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from requests.exceptions import ConnectionError as RequestsConnectionError -from requests.exceptions import Timeout as RequestsTimeout -from tenacity import wait_exponential - -from dagshub.data_engine.model.errors import DataEngineGqlError - -MIN_TARGET_BATCH_TIME_SECONDS = 0.01 -BATCH_GROWTH_FACTOR = 10 -RETRY_BACKOFF_BASE_SECONDS = 0.25 -RETRY_BACKOFF_MAX_SECONDS = 4.0 - -_retry_delay_strategy = wait_exponential( - multiplier=RETRY_BACKOFF_BASE_SECONDS, - min=RETRY_BACKOFF_BASE_SECONDS, - max=RETRY_BACKOFF_MAX_SECONDS, -) - - -@dataclass(frozen=True) -class AdaptiveUploadBatchConfig: - max_batch_size: int - min_batch_size: int - initial_batch_size: int - target_batch_time_seconds: float - - @classmethod - def from_values( - cls, - max_batch_size: int, - min_batch_size: int, - initial_batch_size: int, - target_batch_time_seconds: float, - ) -> "AdaptiveUploadBatchConfig": - normalized_max_batch_size = max(1, max_batch_size) - normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) - normalized_initial_batch_size = max( - normalized_min_batch_size, - min(initial_batch_size, normalized_max_batch_size), - ) - normalized_target_batch_time_seconds = max(target_batch_time_seconds, MIN_TARGET_BATCH_TIME_SECONDS) - return cls( - max_batch_size=normalized_max_batch_size, - min_batch_size=normalized_min_batch_size, - initial_batch_size=normalized_initial_batch_size, - target_batch_time_seconds=normalized_target_batch_time_seconds, - ) - - -def _midpoint(lower_bound: int, upper_bound: int) -> int: - return lower_bound + max(1, (upper_bound - lower_bound) // 2) - - -def next_batch_after_success( - batch_size: int, - config: AdaptiveUploadBatchConfig, - bad_batch_size: Optional[int], -) -> int: - if bad_batch_size is not None and batch_size < bad_batch_size: - next_batch_size = _midpoint(batch_size, bad_batch_size) - next_batch_size = min(next_batch_size, bad_batch_size - 1) - else: - next_batch_size = batch_size * BATCH_GROWTH_FACTOR - - next_batch_size = min(config.max_batch_size, next_batch_size) - if next_batch_size <= batch_size and batch_size < config.max_batch_size: - next_batch_size = min(config.max_batch_size, batch_size + 1) - if bad_batch_size is not None: - next_batch_size = min(next_batch_size, bad_batch_size - 1) - - return max(config.min_batch_size, next_batch_size) - - -def next_batch_after_retryable_failure( - batch_size: int, - config: AdaptiveUploadBatchConfig, - good_batch_size: Optional[int], - bad_batch_size: Optional[int], -) -> int: - if batch_size <= 1: - return 1 - - upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size - if good_batch_size is not None and good_batch_size < upper_bound: - next_batch_size = _midpoint(good_batch_size, upper_bound) - else: - next_batch_size = batch_size // 2 - - next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) - return max(1, next_batch_size) - - -def is_retryable_metadata_upload_error(exc: Exception) -> bool: - if isinstance(exc, DataEngineGqlError): - return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) - - return isinstance( - exc, - ( - TransportServerError, - TransportConnectionFailed, - TimeoutError, - ConnectionError, - RequestsConnectionError, - RequestsTimeout, - ), - ) - - -def get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: - retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) - return float(_retry_delay_strategy(retry_state)) diff --git a/dagshub/data_engine/model/metadata/util.py b/dagshub/data_engine/model/metadata/util.py index ff660324..d366aea3 100644 --- a/dagshub/data_engine/model/metadata/util.py +++ b/dagshub/data_engine/model/metadata/util.py @@ -1,6 +1,10 @@ import datetime +from gql.transport.exceptions import TransportServerError, TransportConnectionFailed +from requests import ConnectionError as RequestsConnectionError, Timeout as RequestsTimeout from typing import Optional +from dagshub.data_engine.model.errors import DataEngineGqlError + def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: """ @@ -19,3 +23,20 @@ def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: offset_minutes = int((offset.total_seconds() % 3600) // 60) offset_str = f"{offset_hours:+03d}:{offset_minutes:02d}" return offset_str + + +def is_retryable_metadata_upload_error(exc: Exception) -> bool: + if isinstance(exc, DataEngineGqlError): + return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) + + return isinstance( + exc, + ( + TransportServerError, + TransportConnectionFailed, + TimeoutError, + ConnectionError, + RequestsConnectionError, + RequestsTimeout, + ), + ) From 3a944aee9418d2b6887503094a0e445290218899 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 15:24:11 +0200 Subject: [PATCH 13/30] Support unbounded iterables in AdaptiveBatcher Accept Iterable[T] instead of List[T], consuming via itertools.islice. When total is unknown (no __len__), progress shows a counter instead of a bar. Failed batches are re-prepended to the iterator for retry. --- dagshub/common/adaptive_batching.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 91ff4f90..b9e357f2 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -2,7 +2,8 @@ import time from dataclasses import dataclass from types import SimpleNamespace -from typing import Callable, List, Optional, TypeVar +import itertools +from typing import Callable, Iterable, List, Optional, Sized, TypeVar import rich.progress from tenacity import wait_exponential @@ -127,29 +128,32 @@ def __init__( self._is_retryable = is_retryable self._progress_label = progress_label - def run(self, items: List[T], operation: Callable[[List[T]], None]) -> None: - total = len(items) + def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: + total: Optional[int] = len(items) if isinstance(items, Sized) else None if total == 0: return config = self._config current_batch_size = config.initial_batch_size + it = iter(items) progress = get_rich_progress(rich.progress.MofNCompleteColumn()) total_task = progress.add_task( f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...", - total=total, + total=total if total is not None else float("inf"), ) last_good_batch_size: Optional[int] = None last_bad_batch_size: Optional[int] = None consecutive_retryable_failures = 0 + processed = 0 with progress: - start = 0 - while start < total: - batch_size = min(current_batch_size, total - start) - batch = items[start : start + batch_size] + while True: + batch = list(itertools.islice(it, current_batch_size)) + if not batch: + break + batch_size = len(batch) progress.update( total_task, @@ -188,11 +192,13 @@ def run(self, items: List[T], operation: Callable[[List[T]], None]) -> None: f"{self._progress_label} failed for batch size {batch_size} " f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." ) + # Re-prepend the failed batch to the iterator for retry + it = itertools.chain(batch, it) continue elapsed = time.monotonic() - start_time consecutive_retryable_failures = 0 - start += batch_size + processed += batch_size progress.update(total_task, advance=batch_size) if elapsed <= config.target_batch_time_seconds: @@ -208,4 +214,4 @@ def run(self, items: List[T], operation: Callable[[List[T]], None]) -> None: batch_size, config, last_good_batch_size, last_bad_batch_size ) - progress.update(total_task, completed=total, refresh=True) + progress.update(total_task, completed=processed, total=processed, refresh=True) From 6cea20f49b52f2e0f19a2a71a9472006daeaa500 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 15:29:05 +0200 Subject: [PATCH 14/30] Format config.py line length (black) --- dagshub/common/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 826d1f21..8bd62d34 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -75,7 +75,9 @@ def set_host(new_host: str): ) DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY = "DAGSHUB_DE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS" -dataengine_metadata_upload_target_batch_time_seconds = float(os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0)) +dataengine_metadata_upload_target_batch_time_seconds = float( + os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0) +) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ From a52a0aa2e3cb28471fcce59b7fc5ae0bf56b91fd Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 15:59:37 +0200 Subject: [PATCH 15/30] Fix review issues in AdaptiveBatcher - Pass None (not float('inf')) as progress total for unbounded iterables to avoid OverflowError in MofNCompleteColumn - Respect min_batch_size on failure path: floor at config.min_batch_size instead of hardcoded 1, and abort when batch_size <= min_batch_size - Replace nested itertools.chain retry with a pending list to avoid unbounded recursion depth under repeated failures --- dagshub/common/adaptive_batching.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index b9e357f2..48aa990f 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -97,8 +97,8 @@ def _next_batch_after_retryable_failure( good_batch_size: Optional[int], bad_batch_size: Optional[int], ) -> int: - if batch_size <= 1: - return 1 + if batch_size <= config.min_batch_size: + return config.min_batch_size upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size if good_batch_size is not None and good_batch_size < upper_bound: @@ -107,7 +107,7 @@ def _next_batch_after_retryable_failure( next_batch_size = batch_size // 2 next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) - return max(1, next_batch_size) + return max(config.min_batch_size, next_batch_size) def _get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: @@ -136,11 +136,12 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: config = self._config current_batch_size = config.initial_batch_size it = iter(items) + pending: List[T] = [] progress = get_rich_progress(rich.progress.MofNCompleteColumn()) total_task = progress.add_task( f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...", - total=total if total is not None else float("inf"), + total=total, ) last_good_batch_size: Optional[int] = None @@ -150,7 +151,11 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: with progress: while True: - batch = list(itertools.islice(it, current_batch_size)) + # Draw from pending (failed-batch leftovers) first, then the source iterator + batch = pending[:current_batch_size] + pending = pending[current_batch_size:] + if len(batch) < current_batch_size: + batch.extend(itertools.islice(it, current_batch_size - len(batch))) if not batch: break batch_size = len(batch) @@ -172,7 +177,7 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: ) raise - if batch_size <= 1: + if batch_size <= config.min_batch_size: logger.error( f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.", exc_info=True, @@ -192,8 +197,8 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: f"{self._progress_label} failed for batch size {batch_size} " f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." ) - # Re-prepend the failed batch to the iterator for retry - it = itertools.chain(batch, it) + # Re-queue the failed batch items for retry with smaller batch size + pending = batch + pending continue elapsed = time.monotonic() - start_time From f6dfec9e598ccf62ae57ca4020e75799aacdd93b Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:22:43 +0200 Subject: [PATCH 16/30] Make batch growth factor and retry backoff configurable, add tests Add batch_growth_factor, retry_backoff_base_seconds, retry_backoff_max_seconds to AdaptiveBatchConfig with env var defaults in common.config. Add 34 unit tests covering config normalization, batch sizing functions, retry delay, and full AdaptiveBatcher.run integration (lists, generators, retries). --- dagshub/common/adaptive_batching.py | 43 ++-- dagshub/common/config.py | 9 + tests/common/test_adaptive_batching.py | 330 +++++++++++++++++++++++++ 3 files changed, 367 insertions(+), 15 deletions(-) create mode 100644 tests/common/test_adaptive_batching.py diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 48aa990f..8910f47a 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -1,8 +1,8 @@ import logging import time from dataclasses import dataclass -from types import SimpleNamespace import itertools +from types import SimpleNamespace from typing import Callable, Iterable, List, Optional, Sized, TypeVar import rich.progress @@ -15,15 +15,6 @@ T = TypeVar("T") MIN_TARGET_BATCH_TIME_SECONDS = 0.01 -BATCH_GROWTH_FACTOR = 10 -RETRY_BACKOFF_BASE_SECONDS = 0.25 -RETRY_BACKOFF_MAX_SECONDS = 4.0 - -_retry_delay_strategy = wait_exponential( - multiplier=RETRY_BACKOFF_BASE_SECONDS, - min=RETRY_BACKOFF_BASE_SECONDS, - max=RETRY_BACKOFF_MAX_SECONDS, -) @dataclass(frozen=True) @@ -32,6 +23,9 @@ class AdaptiveBatchConfig: min_batch_size: int initial_batch_size: int target_batch_time_seconds: float + batch_growth_factor: int + retry_backoff_base_seconds: float + retry_backoff_max_seconds: float @classmethod def from_values( @@ -40,6 +34,9 @@ def from_values( min_batch_size: Optional[int] = None, initial_batch_size: Optional[int] = None, target_batch_time_seconds: Optional[float] = None, + batch_growth_factor: Optional[int] = None, + retry_backoff_base_seconds: Optional[float] = None, + retry_backoff_max_seconds: Optional[float] = None, ) -> "AdaptiveBatchConfig": import dagshub.common.config as dgs_config @@ -51,6 +48,12 @@ def from_values( initial_batch_size = dgs_config.dataengine_metadata_upload_batch_size_initial if target_batch_time_seconds is None: target_batch_time_seconds = dgs_config.dataengine_metadata_upload_target_batch_time_seconds + if batch_growth_factor is None: + batch_growth_factor = dgs_config.adaptive_batch_growth_factor + if retry_backoff_base_seconds is None: + retry_backoff_base_seconds = dgs_config.adaptive_batch_retry_backoff_base_seconds + if retry_backoff_max_seconds is None: + retry_backoff_max_seconds = dgs_config.adaptive_batch_retry_backoff_max_seconds normalized_max_batch_size = max(1, max_batch_size) normalized_min_batch_size = max(1, min(min_batch_size, normalized_max_batch_size)) @@ -64,6 +67,9 @@ def from_values( min_batch_size=normalized_min_batch_size, initial_batch_size=normalized_initial_batch_size, target_batch_time_seconds=normalized_target_batch_time_seconds, + batch_growth_factor=max(2, batch_growth_factor), + retry_backoff_base_seconds=max(0.0, retry_backoff_base_seconds), + retry_backoff_max_seconds=max(0.0, retry_backoff_max_seconds), ) @@ -80,10 +86,10 @@ def _next_batch_after_success( next_batch_size = _midpoint(batch_size, bad_batch_size) next_batch_size = min(next_batch_size, bad_batch_size - 1) else: - next_batch_size = batch_size * BATCH_GROWTH_FACTOR + next_batch_size = batch_size * config.batch_growth_factor next_batch_size = min(config.max_batch_size, next_batch_size) - if next_batch_size <= batch_size and batch_size < config.max_batch_size: + if next_batch_size <= batch_size < config.max_batch_size: next_batch_size = min(config.max_batch_size, batch_size + 1) if bad_batch_size is not None: next_batch_size = min(next_batch_size, bad_batch_size - 1) @@ -110,9 +116,16 @@ def _next_batch_after_retryable_failure( return max(config.min_batch_size, next_batch_size) -def _get_retry_delay_seconds(consecutive_retryable_failures: int) -> float: +def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: + # SimpleNamespace duck-types the .attempt_number attribute that tenacity's + # wait strategies read, avoiding the heavier RetryCallState constructor. + strategy = wait_exponential( + multiplier=config.retry_backoff_base_seconds, + min=config.retry_backoff_base_seconds, + max=config.retry_backoff_max_seconds, + ) retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) - return float(_retry_delay_strategy(retry_state)) + return float(strategy(retry_state)) # type: ignore[arg-type] class AdaptiveBatcher: @@ -185,7 +198,7 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: raise consecutive_retryable_failures += 1 - time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures)) + time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config)) last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 8bd62d34..18b9e4e9 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -79,6 +79,15 @@ def set_host(new_host: str): os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0) ) +ADAPTIVE_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_ADAPTIVE_BATCH_GROWTH_FACTOR" +adaptive_batch_growth_factor = int(os.environ.get(ADAPTIVE_BATCH_GROWTH_FACTOR_KEY, 10)) + +ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE" +adaptive_batch_retry_backoff_base_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY, 0.25)) + +ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX" +adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY, 4.0)) + DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py new file mode 100644 index 00000000..529d38b1 --- /dev/null +++ b/tests/common/test_adaptive_batching.py @@ -0,0 +1,330 @@ +from unittest.mock import patch + +import pytest + +from dagshub.common.adaptive_batching import ( + AdaptiveBatchConfig, + AdaptiveBatcher, + _get_retry_delay_seconds, + _midpoint, + _next_batch_after_retryable_failure, + _next_batch_after_success, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _cfg( + max_batch_size=1000, + min_batch_size=1, + initial_batch_size=10, + target_batch_time_seconds=5.0, + batch_growth_factor=10, + retry_backoff_base_seconds=0.25, + retry_backoff_max_seconds=4.0, +): + """Shortcut to build a config without going through from_values (avoids config import).""" + return AdaptiveBatchConfig( + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + initial_batch_size=initial_batch_size, + target_batch_time_seconds=target_batch_time_seconds, + batch_growth_factor=batch_growth_factor, + retry_backoff_base_seconds=retry_backoff_base_seconds, + retry_backoff_max_seconds=retry_backoff_max_seconds, + ) + + +# --------------------------------------------------------------------------- +# AdaptiveBatchConfig.from_values +# --------------------------------------------------------------------------- + + +class TestAdaptiveBatchConfigFromValues: + def test_normalizes_max_batch_size_to_at_least_1(self): + cfg = AdaptiveBatchConfig.from_values(max_batch_size=0, min_batch_size=1) + assert cfg.max_batch_size == 1 + + def test_clamps_min_to_max(self): + cfg = AdaptiveBatchConfig.from_values(max_batch_size=5, min_batch_size=100) + assert cfg.min_batch_size == 5 + + def test_clamps_initial_between_min_and_max(self): + cfg = AdaptiveBatchConfig.from_values( + max_batch_size=100, min_batch_size=10, initial_batch_size=5 + ) + assert cfg.initial_batch_size == 10 + + cfg2 = AdaptiveBatchConfig.from_values( + max_batch_size=100, min_batch_size=10, initial_batch_size=200 + ) + assert cfg2.initial_batch_size == 100 + + def test_batch_growth_factor_minimum_is_2(self): + cfg = AdaptiveBatchConfig.from_values(batch_growth_factor=1) + assert cfg.batch_growth_factor == 2 + + def test_backoff_seconds_non_negative(self): + cfg = AdaptiveBatchConfig.from_values( + retry_backoff_base_seconds=-1.0, retry_backoff_max_seconds=-5.0 + ) + assert cfg.retry_backoff_base_seconds == 0.0 + assert cfg.retry_backoff_max_seconds == 0.0 + + +# --------------------------------------------------------------------------- +# _midpoint +# --------------------------------------------------------------------------- + + +class TestMidpoint: + def test_basic(self): + assert _midpoint(10, 20) == 15 + + def test_adjacent_values_advance_by_at_least_1(self): + assert _midpoint(10, 11) == 11 + + def test_equal_values_advance_by_1(self): + assert _midpoint(5, 5) == 6 + + +# --------------------------------------------------------------------------- +# _next_batch_after_success +# --------------------------------------------------------------------------- + + +class TestNextBatchAfterSuccess: + def test_grows_by_growth_factor_when_no_bad_size(self): + cfg = _cfg(batch_growth_factor=10, max_batch_size=10000) + assert _next_batch_after_success(10, cfg, bad_batch_size=None) == 100 + + def test_capped_at_max_batch_size(self): + cfg = _cfg(batch_growth_factor=10, max_batch_size=50) + assert _next_batch_after_success(10, cfg, bad_batch_size=None) == 50 + + def test_binary_search_toward_bad_size(self): + cfg = _cfg(max_batch_size=10000) + result = _next_batch_after_success(10, cfg, bad_batch_size=20) + assert 10 < result < 20 + + def test_never_reaches_bad_size(self): + cfg = _cfg(max_batch_size=10000) + result = _next_batch_after_success(18, cfg, bad_batch_size=20) + assert result <= 19 # bad_batch_size - 1 + + def test_respects_min_batch_size(self): + cfg = _cfg(min_batch_size=5, max_batch_size=10) + result = _next_batch_after_success(1, cfg, bad_batch_size=None) + assert result >= 5 + + def test_makes_progress_when_growth_factor_would_not_increase(self): + cfg = _cfg(batch_growth_factor=2, max_batch_size=100) + # batch_size=99, growth gives 198 capped to 100, which is > 99 so it works. + # But let's test the edge: batch_size at max-1 should reach max. + result = _next_batch_after_success(99, cfg, bad_batch_size=None) + assert result == 100 + + +# --------------------------------------------------------------------------- +# _next_batch_after_retryable_failure +# --------------------------------------------------------------------------- + + +class TestNextBatchAfterRetryableFailure: + def test_halves_when_no_bounds(self): + cfg = _cfg(min_batch_size=1) + assert _next_batch_after_retryable_failure(100, cfg, None, None) == 50 + + def test_returns_min_when_at_min(self): + cfg = _cfg(min_batch_size=5) + assert _next_batch_after_retryable_failure(5, cfg, None, None) == 5 + + def test_returns_min_when_below_min(self): + cfg = _cfg(min_batch_size=10) + assert _next_batch_after_retryable_failure(3, cfg, None, None) == 10 + + def test_binary_search_between_good_and_bad(self): + cfg = _cfg(min_batch_size=1) + result = _next_batch_after_retryable_failure(100, cfg, good_batch_size=40, bad_batch_size=100) + assert 40 < result < 100 + + def test_never_returns_below_min_batch_size(self): + cfg = _cfg(min_batch_size=10) + result = _next_batch_after_retryable_failure(20, cfg, None, None) + assert result >= 10 + + def test_strictly_decreases_from_current(self): + cfg = _cfg(min_batch_size=1) + for batch_size in [2, 5, 10, 50, 100, 1000]: + result = _next_batch_after_retryable_failure(batch_size, cfg, None, None) + assert result < batch_size + + +# --------------------------------------------------------------------------- +# _get_retry_delay_seconds +# --------------------------------------------------------------------------- + + +class TestGetRetryDelaySeconds: + def test_returns_base_for_first_failure(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + delay = _get_retry_delay_seconds(1, cfg) + assert delay == pytest.approx(0.25, abs=0.01) + + def test_increases_with_more_failures(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + d1 = _get_retry_delay_seconds(1, cfg) + d2 = _get_retry_delay_seconds(2, cfg) + d3 = _get_retry_delay_seconds(3, cfg) + assert d1 < d2 < d3 + + def test_capped_at_max(self): + cfg = _cfg(retry_backoff_base_seconds=0.25, retry_backoff_max_seconds=4.0) + delay = _get_retry_delay_seconds(100, cfg) + assert delay <= 4.0 + + def test_zero_failures_treated_as_one(self): + cfg = _cfg(retry_backoff_base_seconds=0.5, retry_backoff_max_seconds=10.0) + assert _get_retry_delay_seconds(0, cfg) == _get_retry_delay_seconds(1, cfg) + + +# --------------------------------------------------------------------------- +# AdaptiveBatcher.run — integration tests +# --------------------------------------------------------------------------- + + +class TestAdaptiveBatcherRun: + @staticmethod + def _make_batcher(**config_overrides): + cfg = _cfg(**{**dict( + initial_batch_size=3, + max_batch_size=100, + min_batch_size=1, + target_batch_time_seconds=999, # fast enough to always grow + retry_backoff_base_seconds=0.0, + retry_backoff_max_seconds=0.0, + ), **config_overrides}) + return AdaptiveBatcher( + is_retryable=lambda exc: isinstance(exc, ValueError), + config=cfg, + ) + + def test_processes_all_items(self): + batcher = self._make_batcher() + received = [] + batcher.run(list(range(10)), lambda batch: received.extend(batch)) + assert received == list(range(10)) + + def test_empty_list(self): + batcher = self._make_batcher() + called = [] + batcher.run([], lambda batch: called.append(batch)) + assert called == [] + + def test_generator_input(self): + batcher = self._make_batcher() + received = [] + + def gen(): + for i in range(7): + yield i + + batcher.run(gen(), lambda batch: received.extend(batch)) + assert received == list(range(7)) + + def test_retries_on_retryable_error(self): + batcher = self._make_batcher(initial_batch_size=5, min_batch_size=1) + call_count = 0 + received = [] + + def op(batch): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("transient") + received.extend(batch) + + batcher.run(list(range(5)), op) + assert sorted(received) == list(range(5)) + assert call_count > 1 + + def test_aborts_on_non_retryable_error(self): + batcher = self._make_batcher() + + with pytest.raises(TypeError): + batcher.run(list(range(5)), lambda batch: (_ for _ in ()).throw(TypeError("fatal"))) + + def test_aborts_at_min_batch_size(self): + batcher = self._make_batcher(initial_batch_size=1, min_batch_size=1) + + with pytest.raises(ValueError, match="always fails"): + batcher.run([1], lambda batch: (_ for _ in ()).throw(ValueError("always fails"))) + + def test_no_items_lost_on_retry(self): + """All items from a failed batch must be retried.""" + batcher = self._make_batcher(initial_batch_size=4, min_batch_size=1) + fail_once = True + all_received = [] + + def op(batch): + nonlocal fail_once + if fail_once and len(batch) == 4: + fail_once = False + raise ValueError("fail big batch once") + all_received.extend(batch) + + items = list(range(8)) + batcher.run(items, op) + assert sorted(all_received) == items + + def test_generator_retry_no_items_lost(self): + """Items from a failed batch are retried even with generator input.""" + batcher = self._make_batcher(initial_batch_size=3, min_batch_size=1) + fail_once = True + all_received = [] + + def op(batch): + nonlocal fail_once + if fail_once: + fail_once = False + raise ValueError("transient") + all_received.extend(batch) + + def gen(): + for i in range(6): + yield i + + batcher.run(gen(), op) + assert sorted(all_received) == list(range(6)) + + def test_batch_size_shrinks_on_failure(self): + batcher = self._make_batcher(initial_batch_size=10, min_batch_size=1) + batch_sizes = [] + + def op(batch): + batch_sizes.append(len(batch)) + if batch_sizes[-1] == 10: + raise ValueError("too big") + + batcher.run(list(range(20)), op) + # First call is size 10 (fails), next should be smaller + assert batch_sizes[0] == 10 + assert batch_sizes[1] < 10 + + @patch("dagshub.common.adaptive_batching.time") + def test_batch_size_grows_on_fast_success(self, mock_time): + # Make monotonic() return increasing values, but elapsed always < target + mock_time.monotonic.side_effect = [0.0, 0.001] * 50 + mock_time.sleep = lambda _: None + + batcher = self._make_batcher( + initial_batch_size=2, + max_batch_size=100, + target_batch_time_seconds=5.0, + ) + batch_sizes = [] + batcher.run(list(range(100)), lambda batch: batch_sizes.append(len(batch))) + # Should grow from initial=2 + assert max(batch_sizes) > 2 From 349c25c6fee17e4c6e9ebe23da27026bc18c28d0 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:29:02 +0200 Subject: [PATCH 17/30] Fix batch size stall at bad_batch_size - 1, clear stale bounds _next_batch_after_success could return the same batch_size forever when batch_size == bad_batch_size - 1. Fix by guaranteeing +1 progress in the stall guard. Also clear last_bad_batch_size when a fast success occurs at or above it, since the bound is stale. Add regression tests for convergence and the stall scenario, plus from_values() defaults. --- dagshub/common/adaptive_batching.py | 6 ++- tests/common/test_adaptive_batching.py | 64 ++++++++++++++++++-------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 8910f47a..a93bf8c1 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -90,9 +90,8 @@ def _next_batch_after_success( next_batch_size = min(config.max_batch_size, next_batch_size) if next_batch_size <= batch_size < config.max_batch_size: + # Guarantee forward progress: step up by 1, even past a stale bad_batch_size next_batch_size = min(config.max_batch_size, batch_size + 1) - if bad_batch_size is not None: - next_batch_size = min(next_batch_size, bad_batch_size - 1) return max(config.min_batch_size, next_batch_size) @@ -223,6 +222,9 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: last_good_batch_size = ( batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) ) + # Clear stale bad bound if we succeeded fast at or above it + if last_bad_batch_size is not None and batch_size >= last_bad_batch_size: + last_bad_batch_size = None current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size) else: last_bad_batch_size = ( diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index 529d38b1..7f23448e 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -11,11 +11,11 @@ _next_batch_after_success, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _cfg( max_batch_size=1000, min_batch_size=1, @@ -52,14 +52,10 @@ def test_clamps_min_to_max(self): assert cfg.min_batch_size == 5 def test_clamps_initial_between_min_and_max(self): - cfg = AdaptiveBatchConfig.from_values( - max_batch_size=100, min_batch_size=10, initial_batch_size=5 - ) + cfg = AdaptiveBatchConfig.from_values(max_batch_size=100, min_batch_size=10, initial_batch_size=5) assert cfg.initial_batch_size == 10 - cfg2 = AdaptiveBatchConfig.from_values( - max_batch_size=100, min_batch_size=10, initial_batch_size=200 - ) + cfg2 = AdaptiveBatchConfig.from_values(max_batch_size=100, min_batch_size=10, initial_batch_size=200) assert cfg2.initial_batch_size == 100 def test_batch_growth_factor_minimum_is_2(self): @@ -67,12 +63,22 @@ def test_batch_growth_factor_minimum_is_2(self): assert cfg.batch_growth_factor == 2 def test_backoff_seconds_non_negative(self): - cfg = AdaptiveBatchConfig.from_values( - retry_backoff_base_seconds=-1.0, retry_backoff_max_seconds=-5.0 - ) + cfg = AdaptiveBatchConfig.from_values(retry_backoff_base_seconds=-1.0, retry_backoff_max_seconds=-5.0) assert cfg.retry_backoff_base_seconds == 0.0 assert cfg.retry_backoff_max_seconds == 0.0 + def test_defaults_from_config(self): + cfg = AdaptiveBatchConfig.from_values() + assert cfg.max_batch_size >= 1 + assert cfg.min_batch_size >= 1 + assert cfg.min_batch_size <= cfg.max_batch_size + assert cfg.initial_batch_size >= cfg.min_batch_size + assert cfg.initial_batch_size <= cfg.max_batch_size + assert cfg.target_batch_time_seconds > 0 + assert cfg.batch_growth_factor >= 2 + assert cfg.retry_backoff_base_seconds >= 0 + assert cfg.retry_backoff_max_seconds >= 0 + # --------------------------------------------------------------------------- # _midpoint @@ -119,6 +125,23 @@ def test_respects_min_batch_size(self): result = _next_batch_after_success(1, cfg, bad_batch_size=None) assert result >= 5 + def test_no_stall_at_bad_batch_size_minus_one(self): + """Regression: batch_size == bad_batch_size - 1 must still make progress.""" + cfg = _cfg(max_batch_size=1000, batch_growth_factor=2) + result = _next_batch_after_success(9, cfg, bad_batch_size=10) + assert result > 9 + + def test_convergence_reaches_max(self): + """Iterating _next_batch_after_success must eventually reach max_batch_size.""" + cfg = _cfg(max_batch_size=100, batch_growth_factor=2) + batch_size = 1 + bad = 50 # initial bad bound + for _ in range(200): + batch_size = _next_batch_after_success(batch_size, cfg, bad_batch_size=bad) + if batch_size >= cfg.max_batch_size: + break + assert batch_size == cfg.max_batch_size + def test_makes_progress_when_growth_factor_would_not_increase(self): cfg = _cfg(batch_growth_factor=2, max_batch_size=100) # batch_size=99, growth gives 198 capped to 100, which is > 99 so it works. @@ -198,14 +221,19 @@ def test_zero_failures_treated_as_one(self): class TestAdaptiveBatcherRun: @staticmethod def _make_batcher(**config_overrides): - cfg = _cfg(**{**dict( - initial_batch_size=3, - max_batch_size=100, - min_batch_size=1, - target_batch_time_seconds=999, # fast enough to always grow - retry_backoff_base_seconds=0.0, - retry_backoff_max_seconds=0.0, - ), **config_overrides}) + cfg = _cfg( + **{ + **dict( + initial_batch_size=3, + max_batch_size=100, + min_batch_size=1, + target_batch_time_seconds=999, # fast enough to always grow + retry_backoff_base_seconds=0.0, + retry_backoff_max_seconds=0.0, + ), + **config_overrides, + } + ) return AdaptiveBatcher( is_retryable=lambda exc: isinstance(exc, ValueError), config=cfg, From 3e0ae687f1c2d2415e777ef077804929d770ee3a Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:34:47 +0200 Subject: [PATCH 18/30] Simplify batch sizing functions with clear strategy comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the convoluted _midpoint + multi-conditional logic in _next_batch_after_success and _next_batch_after_retryable_failure with straightforward (a+b)//2 midpoints, a _clamp helper, and clear docstrings explaining each function's strategy. Remove _midpoint's defensive max(1,...) hack — callers now handle boundaries directly. --- dagshub/common/adaptive_batching.py | 49 +++++++++++++++++--------- tests/common/test_adaptive_batching.py | 21 ++++++----- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index a93bf8c1..c81a3eea 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -73,8 +73,8 @@ def from_values( ) -def _midpoint(lower_bound: int, upper_bound: int) -> int: - return lower_bound + max(1, (upper_bound - lower_bound) // 2) +def _clamp(value: int, lo: int, hi: int) -> int: + return max(lo, min(hi, value)) def _next_batch_after_success( @@ -82,18 +82,24 @@ def _next_batch_after_success( config: AdaptiveBatchConfig, bad_batch_size: Optional[int], ) -> int: + """Pick the next batch size after a successful (fast) batch. + + Strategy: + - If we have a known-bad upper bound, binary-search toward it. + - Otherwise, multiply by the growth factor. + - Always guarantee at least +1 progress (so we never stall). + """ if bad_batch_size is not None and batch_size < bad_batch_size: - next_batch_size = _midpoint(batch_size, bad_batch_size) - next_batch_size = min(next_batch_size, bad_batch_size - 1) + # Binary search: try the midpoint between current and bad + candidate = (batch_size + bad_batch_size) // 2 else: - next_batch_size = batch_size * config.batch_growth_factor + # No upper bound (or we've already passed it): grow aggressively + candidate = batch_size * config.batch_growth_factor - next_batch_size = min(config.max_batch_size, next_batch_size) - if next_batch_size <= batch_size < config.max_batch_size: - # Guarantee forward progress: step up by 1, even past a stale bad_batch_size - next_batch_size = min(config.max_batch_size, batch_size + 1) + # Must advance by at least 1 to avoid stalling + candidate = max(candidate, batch_size + 1) - return max(config.min_batch_size, next_batch_size) + return _clamp(candidate, config.min_batch_size, config.max_batch_size) def _next_batch_after_retryable_failure( @@ -102,17 +108,28 @@ def _next_batch_after_retryable_failure( good_batch_size: Optional[int], bad_batch_size: Optional[int], ) -> int: + """Pick the next batch size after a failed or slow batch. + + Strategy: + - If we have a known-good lower bound, binary-search between it and the + failing size. + - Otherwise, halve. + - Must be strictly less than the current size (so we converge downward). + """ if batch_size <= config.min_batch_size: return config.min_batch_size - upper_bound = min(batch_size, bad_batch_size) if bad_batch_size is not None else batch_size - if good_batch_size is not None and good_batch_size < upper_bound: - next_batch_size = _midpoint(good_batch_size, upper_bound) + ceiling = batch_size - 1 # must shrink + if bad_batch_size is not None: + ceiling = min(ceiling, bad_batch_size - 1) + + if good_batch_size is not None and good_batch_size < ceiling: + # Binary search: try the midpoint between good and failing + candidate = (good_batch_size + ceiling) // 2 else: - next_batch_size = batch_size // 2 + candidate = batch_size // 2 - next_batch_size = min(next_batch_size, upper_bound - 1, batch_size - 1, config.max_batch_size) - return max(config.min_batch_size, next_batch_size) + return _clamp(candidate, config.min_batch_size, ceiling) def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index 7f23448e..6606abab 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -5,8 +5,8 @@ from dagshub.common.adaptive_batching import ( AdaptiveBatchConfig, AdaptiveBatcher, + _clamp, _get_retry_delay_seconds, - _midpoint, _next_batch_after_retryable_failure, _next_batch_after_success, ) @@ -81,19 +81,22 @@ def test_defaults_from_config(self): # --------------------------------------------------------------------------- -# _midpoint +# _clamp # --------------------------------------------------------------------------- -class TestMidpoint: - def test_basic(self): - assert _midpoint(10, 20) == 15 +class TestClamp: + def test_within_range(self): + assert _clamp(5, 1, 10) == 5 - def test_adjacent_values_advance_by_at_least_1(self): - assert _midpoint(10, 11) == 11 + def test_below_minimum(self): + assert _clamp(0, 3, 10) == 3 - def test_equal_values_advance_by_1(self): - assert _midpoint(5, 5) == 6 + def test_above_maximum(self): + assert _clamp(20, 1, 10) == 10 + + def test_equal_bounds(self): + assert _clamp(5, 7, 7) == 7 # --------------------------------------------------------------------------- From 3c1c8172c9dc851463a88a7d622640b50ecd6dec Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:43:16 +0200 Subject: [PATCH 19/30] Clear stale good/bad bounds when they become incoherent When last_bad_batch_size drops below last_good_batch_size (from slow batches or errors), clear the good bound since it is no longer useful for binary search. Add integration test for the slow-batch path. --- dagshub/common/adaptive_batching.py | 4 ++++ tests/common/test_adaptive_batching.py | 27 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index c81a3eea..5be109c5 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -219,6 +219,8 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) + if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: + last_good_batch_size = None current_batch_size = _next_batch_after_retryable_failure( batch_size, config, last_good_batch_size, last_bad_batch_size ) @@ -247,6 +249,8 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) + if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: + last_good_batch_size = None current_batch_size = _next_batch_after_retryable_failure( batch_size, config, last_good_batch_size, last_bad_batch_size ) diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index 6606abab..e5b8ba4b 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -359,3 +359,30 @@ def test_batch_size_grows_on_fast_success(self, mock_time): batcher.run(list(range(100)), lambda batch: batch_sizes.append(len(batch))) # Should grow from initial=2 assert max(batch_sizes) > 2 + + @patch("dagshub.common.adaptive_batching.time") + def test_batch_size_shrinks_on_slow_success(self, mock_time): + # Make elapsed always > target (slow batches) + mock_time.monotonic.side_effect = [0.0, 100.0] * 100 + mock_time.sleep = lambda _: None + + batcher = self._make_batcher( + initial_batch_size=20, + min_batch_size=1, + max_batch_size=100, + target_batch_time_seconds=1.0, + ) + received = [] + batch_sizes = [] + + def op(batch): + batch_sizes.append(len(batch)) + received.extend(batch) + + items = list(range(50)) + batcher.run(items, op) + # All items processed despite slow batches + assert sorted(received) == items + # Batch size should shrink from 20 + assert batch_sizes[0] == 20 + assert min(batch_sizes) < 20 From 232ebbfb5348419ebe60dd53e36bfd1ece7aeb53 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:53:02 +0200 Subject: [PATCH 20/30] Improve failure fallback convergence, add bad-bound clearing test Use (min + ceiling) // 2 instead of batch_size // 2 in the failure fallback so we binary-search within the valid range rather than jumping to the ceiling when bad_batch_size is much smaller than batch_size. Add integration test verifying batch size grows past a cleared bad bound. --- dagshub/common/adaptive_batching.py | 3 ++- tests/common/test_adaptive_batching.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 5be109c5..3a90053e 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -127,7 +127,8 @@ def _next_batch_after_retryable_failure( # Binary search: try the midpoint between good and failing candidate = (good_batch_size + ceiling) // 2 else: - candidate = batch_size // 2 + # No good lower bound — probe midpoint of the valid range + candidate = (config.min_batch_size + ceiling) // 2 return _clamp(candidate, config.min_batch_size, ceiling) diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index e5b8ba4b..dcc7e660 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -386,3 +386,26 @@ def op(batch): # Batch size should shrink from 20 assert batch_sizes[0] == 20 assert min(batch_sizes) < 20 + + def test_grows_past_cleared_bad_bound(self): + """After a transient failure, fast successes must clear the bad bound and grow past it.""" + batcher = self._make_batcher( + initial_batch_size=10, + min_batch_size=1, + max_batch_size=1000, + ) + fail_once = True + batch_sizes = [] + + def op(batch): + nonlocal fail_once + batch_sizes.append(len(batch)) + if fail_once: + fail_once = False + raise ValueError("transient") + + items = list(range(200)) + batcher.run(items, op) + # First call at size 10 fails, shrinks, then should recover and grow past 10 + assert batch_sizes[0] == 10 + assert max(batch_sizes) > 10 From 2c44d834354d665d9ff3b7b180ead0cbccdb23c5 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 16:57:19 +0200 Subject: [PATCH 21/30] Fix retryable error check for wrapped exceptions in DataEngineGqlError When a DataEngineGqlError wraps a TimeoutError, ConnectionError, or requests exception, the old code only checked for TransportServerError and TransportConnectionFailed, causing retryable errors to be misclassified as fatal. Recurse into the wrapped exception instead. --- dagshub/data_engine/model/metadata/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dagshub/data_engine/model/metadata/util.py b/dagshub/data_engine/model/metadata/util.py index d366aea3..418eb1bf 100644 --- a/dagshub/data_engine/model/metadata/util.py +++ b/dagshub/data_engine/model/metadata/util.py @@ -26,8 +26,8 @@ def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: def is_retryable_metadata_upload_error(exc: Exception) -> bool: - if isinstance(exc, DataEngineGqlError): - return isinstance(exc.original_exception, (TransportServerError, TransportConnectionFailed)) + if isinstance(exc, DataEngineGqlError) and isinstance(exc.original_exception, Exception): + return is_retryable_metadata_upload_error(exc.original_exception) return isinstance( exc, From 0453a330a19ec2aaacf86d51864256404042ea8e Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 17:12:12 +0200 Subject: [PATCH 22/30] Align metadata upload tests with adaptive batching behavior --- tests/data_engine/test_datasource.py | 31 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 5fe2ea00..cc14ba63 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -171,7 +171,7 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) has_failed = {"value": False} @@ -197,7 +197,7 @@ def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) + mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) has_failed = {"value": False} @@ -211,7 +211,7 @@ def _flaky_upload(_ds, upload_entries): ds._upload_metadata(entries) assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 7, 7, 1] + assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 8, 7] def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): @@ -223,7 +223,7 @@ def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1.0) - mocker.patch("dagshub.data_engine.model.datasource.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) + mocker.patch("dagshub.common.adaptive_batching.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) ds._upload_metadata(entries) @@ -248,6 +248,8 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): + """A tail batch smaller than min_batch_size that fails is aborted immediately, + since shrinking further is impossible.""" entries = [ DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) ] @@ -256,21 +258,18 @@ def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - mocker.patch("dagshub.data_engine.model.datasource.time.sleep", return_value=None) - - has_failed = {"value": False} + mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) def _flaky_upload(_ds, upload_entries): - if len(upload_entries) == 2 and not has_failed["value"]: - has_failed["value"] = True + if len(upload_entries) == 2: raise TimeoutError("simulated timeout") ds.source.client.update_metadata.side_effect = _flaky_upload - ds._upload_metadata(entries) + with pytest.raises(TimeoutError, match="simulated timeout"): + ds._upload_metadata(entries) - assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 2, 1, 1] + assert _uploaded_batch_sizes(ds) == [8, 2] def test_upload_metadata_backoff_resets_after_success(ds, mocker): @@ -282,7 +281,7 @@ def test_upload_metadata_backoff_resets_after_success(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") call_idx = {"value": 0} @@ -307,14 +306,14 @@ def test_upload_metadata_retries_below_configured_min_before_aborting(ds, mocker mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - sleep_mock = mocker.patch("dagshub.data_engine.model.datasource.time.sleep") + sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") ds.source.client.update_metadata.side_effect = TimeoutError("simulated timeout") with pytest.raises(TimeoutError, match="simulated timeout"): ds._upload_metadata(entries) - assert _uploaded_batch_sizes(ds) == [2, 1] - assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25] + assert _uploaded_batch_sizes(ds) == [2] + sleep_mock.assert_not_called() def test_pandas_timestamp(ds): From 00294fe2744a7bde2d0595d09717dd7d473827f5 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Tue, 10 Mar 2026 17:32:53 +0200 Subject: [PATCH 23/30] Handle tail-batch retry edge case and docstring mismatch --- dagshub/common/adaptive_batching.py | 17 ++++++++++++----- tests/data_engine/test_datasource.py | 8 ++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 3a90053e..f4cbf0c1 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -113,7 +113,8 @@ def _next_batch_after_retryable_failure( Strategy: - If we have a known-good lower bound, binary-search between it and the failing size. - - Otherwise, halve. + - Otherwise, probe the midpoint between config.min_batch_size and the + largest allowed size below the failing batch. - Must be strictly less than the current size (so we converge downward). """ if batch_size <= config.min_batch_size: @@ -207,7 +208,8 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: ) raise - if batch_size <= config.min_batch_size: + exhausted_shrink = batch_size <= config.min_batch_size and batch_size == current_batch_size + if exhausted_shrink: logger.error( f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.", exc_info=True, @@ -222,9 +224,14 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: ) if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: last_good_batch_size = None - current_batch_size = _next_batch_after_retryable_failure( - batch_size, config, last_good_batch_size, last_bad_batch_size - ) + if batch_size < config.min_batch_size: + # Tail batches below configured min cannot be split further. + # Retry that exact size once before treating it as exhausted. + current_batch_size = batch_size + else: + current_batch_size = _next_batch_after_retryable_failure( + batch_size, config, last_good_batch_size, last_bad_batch_size + ) logger.warning( f"{self._progress_label} failed for batch size {batch_size} " f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index cc14ba63..13c869e4 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -248,8 +248,7 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): - """A tail batch smaller than min_batch_size that fails is aborted immediately, - since shrinking further is impossible.""" + """A short tail batch below min_batch_size gets one retry before aborting.""" entries = [ DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) ] @@ -258,7 +257,7 @@ def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) + sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") def _flaky_upload(_ds, upload_entries): if len(upload_entries) == 2: @@ -269,7 +268,8 @@ def _flaky_upload(_ds, upload_entries): with pytest.raises(TimeoutError, match="simulated timeout"): ds._upload_metadata(entries) - assert _uploaded_batch_sizes(ds) == [8, 2] + assert _uploaded_batch_sizes(ds) == [8, 2, 2] + assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25] def test_upload_metadata_backoff_resets_after_success(ds, mocker): From 5cd02a73efe0cf031ce936a4426243a53ed1b23c Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 11 Mar 2026 22:04:10 +0200 Subject: [PATCH 24/30] Raise adaptive upload max default and clarify max indirection --- dagshub/common/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 18b9e4e9..4ff6cb2a 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -59,12 +59,16 @@ def set_host(new_host: str): DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" -dataengine_metadata_upload_batch_size = int( +DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX = 50000 +dataengine_metadata_upload_batch_size_max = int( os.environ.get( DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, - os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, 15000), + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX), ) ) +# Keep the runtime config name consumed by existing code/tests, while making the +# source setting semantics explicit as a max bound. +dataengine_metadata_upload_batch_size = dataengine_metadata_upload_batch_size_max DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1)) From dbf7a680520e81525f7b036b189dc53ba0ae72a2 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 11 Mar 2026 22:06:45 +0200 Subject: [PATCH 25/30] Use explicit max config name for adaptive upload sizing --- dagshub/common/adaptive_batching.py | 2 +- dagshub/common/config.py | 5 ++--- tests/data_engine/test_datasource.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index f4cbf0c1..39f2d38c 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -41,7 +41,7 @@ def from_values( import dagshub.common.config as dgs_config if max_batch_size is None: - max_batch_size = dgs_config.dataengine_metadata_upload_batch_size + max_batch_size = dgs_config.dataengine_metadata_upload_batch_size_max if min_batch_size is None: min_batch_size = dgs_config.dataengine_metadata_upload_batch_size_min if initial_batch_size is None: diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 4ff6cb2a..11d053dc 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -60,15 +60,14 @@ def set_host(new_host: str): DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX = 50000 +# Read from the explicit MAX key first; if unset, honor the legacy key so existing +# env-based overrides keep working while config names are clarified. dataengine_metadata_upload_batch_size_max = int( os.environ.get( DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX), ) ) -# Keep the runtime config name consumed by existing code/tests, while making the -# source setting semantics explicit as a max bound. -dataengine_metadata_upload_batch_size = dataengine_metadata_upload_batch_size_max DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MIN" dataengine_metadata_upload_batch_size_min = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MIN_KEY, 1)) diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 13c869e4..c210aa81 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -152,7 +152,7 @@ def test_upload_metadata_starts_small_and_grows(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 16) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -167,7 +167,7 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -193,7 +193,7 @@ def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(32) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 16) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 16) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -219,7 +219,7 @@ def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1.0) @@ -235,7 +235,7 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -253,7 +253,7 @@ def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -277,7 +277,7 @@ def test_upload_metadata_backoff_resets_after_success(ds, mocker): DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) @@ -302,7 +302,7 @@ def test_upload_metadata_retries_below_configured_min_before_aborting(ds, mocker DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(6) ] - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size", 8) + mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) From 53ca65a6295f6f74446dbc8a4bf570a7d9557c15 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Thu, 12 Mar 2026 00:23:22 +0200 Subject: [PATCH 26/30] Handle missing TransportConnectionFailed in supported gql versions --- dagshub/data_engine/model/metadata/util.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dagshub/data_engine/model/metadata/util.py b/dagshub/data_engine/model/metadata/util.py index 418eb1bf..e0ae9180 100644 --- a/dagshub/data_engine/model/metadata/util.py +++ b/dagshub/data_engine/model/metadata/util.py @@ -1,10 +1,15 @@ import datetime -from gql.transport.exceptions import TransportServerError, TransportConnectionFailed +from gql.transport import exceptions as gql_transport_exceptions from requests import ConnectionError as RequestsConnectionError, Timeout as RequestsTimeout from typing import Optional from dagshub.data_engine.model.errors import DataEngineGqlError +TransportServerError = gql_transport_exceptions.TransportServerError +# Some supported gql versions (e.g. 3.4.x) do not expose this symbol. +# Fallback to an empty tuple so isinstance(...) still works without broad try/except imports. +TransportConnectionFailed = getattr(gql_transport_exceptions, "TransportConnectionFailed", tuple()) + def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: """ From 946134f88fd69a0c58b0123e1baa9261af423376 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 18 Mar 2026 23:25:03 +0200 Subject: [PATCH 27/30] Review fixes --- dagshub/common/adaptive_batching.py | 55 ++++---- dagshub/common/config.py | 44 +++++-- dagshub/data_engine/client/data_client.py | 30 ++--- dagshub/data_engine/model/metadata/util.py | 17 +-- tests/common/test_adaptive_batching.py | 75 ++++------- tests/data_engine/test_datasource.py | 142 +++------------------ 6 files changed, 115 insertions(+), 248 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 39f2d38c..5ec4ad72 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -1,13 +1,12 @@ +import itertools import logging import time from dataclasses import dataclass -import itertools -from types import SimpleNamespace from typing import Callable, Iterable, List, Optional, Sized, TypeVar import rich.progress -from tenacity import wait_exponential +import dagshub.common.config as dgs_config from dagshub.common.rich_util import get_rich_progress logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ MIN_TARGET_BATCH_TIME_SECONDS = 0.01 -@dataclass(frozen=True) +@dataclass class AdaptiveBatchConfig: max_batch_size: int min_batch_size: int @@ -38,8 +37,6 @@ def from_values( retry_backoff_base_seconds: Optional[float] = None, retry_backoff_max_seconds: Optional[float] = None, ) -> "AdaptiveBatchConfig": - import dagshub.common.config as dgs_config - if max_batch_size is None: max_batch_size = dgs_config.dataengine_metadata_upload_batch_size_max if min_batch_size is None: @@ -82,21 +79,23 @@ def _next_batch_after_success( config: AdaptiveBatchConfig, bad_batch_size: Optional[int], ) -> int: - """Pick the next batch size after a successful (fast) batch. + """Pick the next batch size after a fast successful batch. Strategy: - - If we have a known-bad upper bound, binary-search toward it. + - If we have a previous slow/failing size, binary-search toward it as a soft upper hint. - Otherwise, multiply by the growth factor. - - Always guarantee at least +1 progress (so we never stall). + - If the midpoint rounds back to the current size, advance by 1 so the search + keeps moving. That may revisit the previous failing size, because these hints + are soft signals rather than permanent bans. """ if bad_batch_size is not None and batch_size < bad_batch_size: - # Binary search: try the midpoint between current and bad + # Binary search: try the midpoint between current and the soft upper hint. candidate = (batch_size + bad_batch_size) // 2 else: - # No upper bound (or we've already passed it): grow aggressively + # No upper hint (or we've already reached it): grow aggressively. candidate = batch_size * config.batch_growth_factor - # Must advance by at least 1 to avoid stalling + # Always make forward progress in the search. candidate = max(candidate, batch_size + 1) return _clamp(candidate, config.min_batch_size, config.max_batch_size) @@ -135,15 +134,12 @@ def _next_batch_after_retryable_failure( def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: AdaptiveBatchConfig) -> float: - # SimpleNamespace duck-types the .attempt_number attribute that tenacity's - # wait strategies read, avoiding the heavier RetryCallState constructor. - strategy = wait_exponential( - multiplier=config.retry_backoff_base_seconds, - min=config.retry_backoff_base_seconds, - max=config.retry_backoff_max_seconds, - ) - retry_state = SimpleNamespace(attempt_number=max(1, consecutive_retryable_failures)) - return float(strategy(retry_state)) # type: ignore[arg-type] + if config.retry_backoff_base_seconds <= 0.0 or config.retry_backoff_max_seconds <= 0.0: + return 0.0 + + attempt_number = max(1, consecutive_retryable_failures) + delay = config.retry_backoff_base_seconds * (2 ** (attempt_number - 1)) + return min(delay, config.retry_backoff_max_seconds) class AdaptiveBatcher: @@ -166,14 +162,12 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: config = self._config current_batch_size = config.initial_batch_size + # Consume the source iterable incrementally across retries and successes. it = iter(items) pending: List[T] = [] progress = get_rich_progress(rich.progress.MofNCompleteColumn()) - total_task = progress.add_task( - f"{self._progress_label} (adaptive batch {config.min_batch_size}-{config.max_batch_size})...", - total=total, - ) + total_task = progress.add_task(f"{self._progress_label}...", total=total) last_good_batch_size: Optional[int] = None last_bad_batch_size: Optional[int] = None @@ -191,10 +185,6 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: break batch_size = len(batch) - progress.update( - total_task, - description=f"{self._progress_label} (batch size {batch_size})...", - ) logger.debug(f"{self._progress_label}: {batch_size} entries...") start_time = time.monotonic() @@ -240,6 +230,7 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: pending = batch + pending continue + # On success. elapsed = time.monotonic() - start_time consecutive_retryable_failures = 0 processed += batch_size @@ -249,11 +240,15 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: last_good_batch_size = ( batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) ) - # Clear stale bad bound if we succeeded fast at or above it + # Clear the soft upper hint if we succeeded fast at or above it. if last_bad_batch_size is not None and batch_size >= last_bad_batch_size: last_bad_batch_size = None current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size) else: + logger.debug( + f"{self._progress_label} batch size {batch_size} took {elapsed:.2f}s " + f"(target {config.target_batch_time_seconds:.2f}s); shrinking." + ) last_bad_batch_size = ( batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) ) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 11d053dc..22b09df3 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -1,11 +1,12 @@ import logging - -import appdirs import os from urllib.parse import urlparse -from dagshub import __version__ + +import appdirs from httpx._client import USER_AGENT +from dagshub import __version__ + logger = logging.getLogger(__name__) HOST_KEY = "DAGSHUB_CLIENT_HOST" @@ -59,13 +60,12 @@ def set_host(new_host: str): DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE" DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE_MAX" -DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX = 50000 -# Read from the explicit MAX key first; if unset, honor the legacy key so existing -# env-based overrides keep working while config names are clarified. +DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX = 50000 +# Fall back to the old `DAGSHUB_DE_METADATA_UPLOAD_BATCH_SIZE` env var for backwards compatibility. dataengine_metadata_upload_batch_size_max = int( os.environ.get( DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX_KEY, - os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_DEFAULT_MAX), + os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_KEY, DEFAULT_DATAENGINE_METADATA_UPLOAD_BATCH_SIZE_MAX), ) ) @@ -82,14 +82,32 @@ def set_host(new_host: str): os.environ.get(DATAENGINE_METADATA_UPLOAD_TARGET_BATCH_TIME_SECONDS_KEY, 5.0) ) -ADAPTIVE_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_ADAPTIVE_BATCH_GROWTH_FACTOR" -adaptive_batch_growth_factor = int(os.environ.get(ADAPTIVE_BATCH_GROWTH_FACTOR_KEY, 10)) +DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR" +LEGACY_ADAPTIVE_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_ADAPTIVE_BATCH_GROWTH_FACTOR" +adaptive_batch_growth_factor = int( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY, + os.environ.get(LEGACY_ADAPTIVE_BATCH_GROWTH_FACTOR_KEY, 10), + ) +) -ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE" -adaptive_batch_retry_backoff_base_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY, 0.25)) +DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_BASE" +LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE" +adaptive_batch_retry_backoff_base_seconds = float( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY, + os.environ.get(LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY, 0.25), + ) +) -ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX" -adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY, 4.0)) +DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_MAX" +LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX" +adaptive_batch_retry_backoff_max_seconds = float( + os.environ.get( + DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, + os.environ.get(LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY, 4.0), + ) +) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/dagshub/data_engine/client/data_client.py b/dagshub/data_engine/client/data_client.py index d5c61e2f..fa8bacb1 100644 --- a/dagshub/data_engine/client/data_client.py +++ b/dagshub/data_engine/client/data_client.py @@ -1,45 +1,43 @@ import datetime import logging -from typing import Any, Optional, List, Dict, Union, TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import dacite import gql import rich.progress -from gql.transport.exceptions import TransportQueryError, TransportServerError +from gql.transport.exceptions import TransportError from gql.transport.requests import RequestsHTTPTransport import dagshub.auth import dagshub.common.config from dagshub.common import config -from dagshub.common.tracing import build_traceparent from dagshub.common.analytics import send_analytics_event from dagshub.common.rich_util import get_rich_progress +from dagshub.common.tracing import build_traceparent from dagshub.data_engine.client.gql_introspections import GqlIntrospections, TypesIntrospection +from dagshub.data_engine.client.gql_mutations import GqlMutations +from dagshub.data_engine.client.gql_queries import GqlQueries from dagshub.data_engine.client.models import ( - DatasourceResult, + DatapointHistoryResult, DatasetResult, + DatasourceResult, MetadataFieldSchema, - DatapointHistoryResult, + ScanOption, ) -from dagshub.data_engine.client.models import ScanOption -from dagshub.data_engine.client.gql_mutations import GqlMutations -from dagshub.data_engine.client.gql_queries import GqlQueries from dagshub.data_engine.client.query_builder import GqlQuery from dagshub.data_engine.model.errors import DataEngineGqlError from dagshub.data_engine.model.query_result import QueryResult - - from dagshub.data_engine.model.schema_util import dacite_config if TYPE_CHECKING: from dagshub.data_engine.datasources import DatasourceState + from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.datasource import ( - Datasource, - DatapointMetadataUpdateEntry, - DatapointDeleteMetadataEntry, DatapointDeleteEntry, + DatapointDeleteMetadataEntry, + DatapointMetadataUpdateEntry, + Datasource, ) - from dagshub.data_engine.model.datapoint import Datapoint logger = logging.getLogger(__name__) @@ -191,8 +189,8 @@ def _exec( traceparent = build_traceparent() headers["traceparent"] = traceparent try: - resp = self.client.execute(q, variable_values=params, extra_args={'headers': headers}) - except (TransportQueryError, TransportServerError) as e: + resp = self.client.execute(q, variable_values=params, extra_args={"headers": headers}) + except TransportError as e: support_id = self.client.transport.response_headers.get("X-DagsHub-Support-Id") if support_id is None: support_id = traceparent diff --git a/dagshub/data_engine/model/metadata/util.py b/dagshub/data_engine/model/metadata/util.py index e0ae9180..c24bb23b 100644 --- a/dagshub/data_engine/model/metadata/util.py +++ b/dagshub/data_engine/model/metadata/util.py @@ -1,14 +1,11 @@ import datetime -from gql.transport import exceptions as gql_transport_exceptions -from requests import ConnectionError as RequestsConnectionError, Timeout as RequestsTimeout from typing import Optional -from dagshub.data_engine.model.errors import DataEngineGqlError +from gql.transport.exceptions import TransportError, TransportQueryError +from requests import ConnectionError as RequestsConnectionError +from requests import Timeout as RequestsTimeout -TransportServerError = gql_transport_exceptions.TransportServerError -# Some supported gql versions (e.g. 3.4.x) do not expose this symbol. -# Fallback to an empty tuple so isinstance(...) still works without broad try/except imports. -TransportConnectionFailed = getattr(gql_transport_exceptions, "TransportConnectionFailed", tuple()) +from dagshub.data_engine.model.errors import DataEngineGqlError def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: @@ -31,14 +28,12 @@ def _get_datetime_utc_offset(t: datetime.datetime) -> Optional[str]: def is_retryable_metadata_upload_error(exc: Exception) -> bool: - if isinstance(exc, DataEngineGqlError) and isinstance(exc.original_exception, Exception): + if isinstance(exc, DataEngineGqlError): return is_retryable_metadata_upload_error(exc.original_exception) - return isinstance( + return (isinstance(exc, TransportError) and not isinstance(exc, TransportQueryError)) or isinstance( exc, ( - TransportServerError, - TransportConnectionFailed, TimeoutError, ConnectionError, RequestsConnectionError, diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index dcc7e660..11d45a9e 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -11,6 +11,11 @@ _next_batch_after_success, ) + +class RetryableTestError(Exception): + pass + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -67,23 +72,6 @@ def test_backoff_seconds_non_negative(self): assert cfg.retry_backoff_base_seconds == 0.0 assert cfg.retry_backoff_max_seconds == 0.0 - def test_defaults_from_config(self): - cfg = AdaptiveBatchConfig.from_values() - assert cfg.max_batch_size >= 1 - assert cfg.min_batch_size >= 1 - assert cfg.min_batch_size <= cfg.max_batch_size - assert cfg.initial_batch_size >= cfg.min_batch_size - assert cfg.initial_batch_size <= cfg.max_batch_size - assert cfg.target_batch_time_seconds > 0 - assert cfg.batch_growth_factor >= 2 - assert cfg.retry_backoff_base_seconds >= 0 - assert cfg.retry_backoff_max_seconds >= 0 - - -# --------------------------------------------------------------------------- -# _clamp -# --------------------------------------------------------------------------- - class TestClamp: def test_within_range(self): @@ -118,7 +106,7 @@ def test_binary_search_toward_bad_size(self): result = _next_batch_after_success(10, cfg, bad_batch_size=20) assert 10 < result < 20 - def test_never_reaches_bad_size(self): + def test_stays_below_soft_upper_hint_when_midpoint_advances(self): cfg = _cfg(max_batch_size=10000) result = _next_batch_after_success(18, cfg, bad_batch_size=20) assert result <= 19 # bad_batch_size - 1 @@ -128,22 +116,10 @@ def test_respects_min_batch_size(self): result = _next_batch_after_success(1, cfg, bad_batch_size=None) assert result >= 5 - def test_no_stall_at_bad_batch_size_minus_one(self): - """Regression: batch_size == bad_batch_size - 1 must still make progress.""" + def test_can_revisit_previous_failure_size_when_midpoint_rounds_down(self): cfg = _cfg(max_batch_size=1000, batch_growth_factor=2) result = _next_batch_after_success(9, cfg, bad_batch_size=10) - assert result > 9 - - def test_convergence_reaches_max(self): - """Iterating _next_batch_after_success must eventually reach max_batch_size.""" - cfg = _cfg(max_batch_size=100, batch_growth_factor=2) - batch_size = 1 - bad = 50 # initial bad bound - for _ in range(200): - batch_size = _next_batch_after_success(batch_size, cfg, bad_batch_size=bad) - if batch_size >= cfg.max_batch_size: - break - assert batch_size == cfg.max_batch_size + assert result == 10 def test_makes_progress_when_growth_factor_would_not_increase(self): cfg = _cfg(batch_growth_factor=2, max_batch_size=100) @@ -167,10 +143,6 @@ def test_returns_min_when_at_min(self): cfg = _cfg(min_batch_size=5) assert _next_batch_after_retryable_failure(5, cfg, None, None) == 5 - def test_returns_min_when_below_min(self): - cfg = _cfg(min_batch_size=10) - assert _next_batch_after_retryable_failure(3, cfg, None, None) == 10 - def test_binary_search_between_good_and_bad(self): cfg = _cfg(min_batch_size=1) result = _next_batch_after_retryable_failure(100, cfg, good_batch_size=40, bad_batch_size=100) @@ -238,7 +210,7 @@ def _make_batcher(**config_overrides): } ) return AdaptiveBatcher( - is_retryable=lambda exc: isinstance(exc, ValueError), + is_retryable=lambda exc: isinstance(exc, RetryableTestError), config=cfg, ) @@ -274,11 +246,11 @@ def op(batch): nonlocal call_count call_count += 1 if call_count == 1: - raise ValueError("transient") + raise RetryableTestError("transient") received.extend(batch) batcher.run(list(range(5)), op) - assert sorted(received) == list(range(5)) + assert received == list(range(5)) assert call_count > 1 def test_aborts_on_non_retryable_error(self): @@ -290,8 +262,8 @@ def test_aborts_on_non_retryable_error(self): def test_aborts_at_min_batch_size(self): batcher = self._make_batcher(initial_batch_size=1, min_batch_size=1) - with pytest.raises(ValueError, match="always fails"): - batcher.run([1], lambda batch: (_ for _ in ()).throw(ValueError("always fails"))) + with pytest.raises(RetryableTestError, match="always fails"): + batcher.run([1], lambda batch: (_ for _ in ()).throw(RetryableTestError("always fails"))) def test_no_items_lost_on_retry(self): """All items from a failed batch must be retried.""" @@ -303,12 +275,12 @@ def op(batch): nonlocal fail_once if fail_once and len(batch) == 4: fail_once = False - raise ValueError("fail big batch once") + raise RetryableTestError("fail big batch once") all_received.extend(batch) items = list(range(8)) batcher.run(items, op) - assert sorted(all_received) == items + assert all_received == items def test_generator_retry_no_items_lost(self): """Items from a failed batch are retried even with generator input.""" @@ -320,7 +292,7 @@ def op(batch): nonlocal fail_once if fail_once: fail_once = False - raise ValueError("transient") + raise RetryableTestError("transient") all_received.extend(batch) def gen(): @@ -328,7 +300,7 @@ def gen(): yield i batcher.run(gen(), op) - assert sorted(all_received) == list(range(6)) + assert all_received == list(range(6)) def test_batch_size_shrinks_on_failure(self): batcher = self._make_batcher(initial_batch_size=10, min_batch_size=1) @@ -337,7 +309,7 @@ def test_batch_size_shrinks_on_failure(self): def op(batch): batch_sizes.append(len(batch)) if batch_sizes[-1] == 10: - raise ValueError("too big") + raise RetryableTestError("too big") batcher.run(list(range(20)), op) # First call is size 10 (fails), next should be smaller @@ -382,13 +354,12 @@ def op(batch): items = list(range(50)) batcher.run(items, op) # All items processed despite slow batches - assert sorted(received) == items + assert received == items # Batch size should shrink from 20 assert batch_sizes[0] == 20 assert min(batch_sizes) < 20 - def test_grows_past_cleared_bad_bound(self): - """After a transient failure, fast successes must clear the bad bound and grow past it.""" + def test_can_revisit_previous_failure_size_after_fast_recovery(self): batcher = self._make_batcher( initial_batch_size=10, min_batch_size=1, @@ -402,10 +373,10 @@ def op(batch): batch_sizes.append(len(batch)) if fail_once: fail_once = False - raise ValueError("transient") + raise RetryableTestError("transient") items = list(range(200)) batcher.run(items, op) - # First call at size 10 fails, shrinks, then should recover and grow past 10 + # First call at size 10 fails, shrinks, then can probe 10 again after recovery. assert batch_sizes[0] == 10 - assert max(batch_sizes) > 10 + assert 10 in batch_sizes[1:] diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index c210aa81..e6f6e0dc 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -14,10 +14,11 @@ from dagshub.data_engine.client.models import MetadataFieldSchema from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags from dagshub.data_engine.model.datapoint import Datapoint -from dagshub.data_engine.model.datasource import Datasource, DatapointMetadataUpdateEntry, MetadataContextManager +from dagshub.data_engine.model.datasource import DatapointMetadataUpdateEntry, Datasource, MetadataContextManager +from dagshub.data_engine.model.errors import DataEngineGqlError from dagshub.data_engine.model.metadata import MultipleDataTypesUploadedError, StringFieldValueTooLongError from dagshub.data_engine.model.query_result import QueryResult -from tests.data_engine.util import add_string_fields, add_document_fields, add_annotation_fields +from tests.data_engine.util import add_annotation_fields, add_document_fields, add_string_fields def _uploaded_batch_sizes(ds: Datasource): @@ -148,24 +149,24 @@ def test_uploading_to_document_turns_into_blob(ds): def test_upload_metadata_starts_small_and_grows(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14) - ] + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(14)] mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 16) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) + mocker.patch.object(dagshub.common.config, "adaptive_batch_growth_factor", 8) ds._upload_metadata(entries) - assert _uploaded_batch_sizes(ds) == [2, 12] + batch_sizes = _uploaded_batch_sizes(ds) + assert batch_sizes[0] == 2 + assert max(batch_sizes) > 2 + assert sum(batch_sizes) == len(entries) -def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) - ] +def test_upload_metadata_retries_wrapped_retryable_error_with_smaller_batch(ds, mocker): + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10)] mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) @@ -178,62 +179,20 @@ def test_upload_metadata_retries_with_smaller_batch_after_failure(ds, mocker): def _flaky_upload(_ds, upload_entries): if len(upload_entries) == 8 and not has_failed["value"]: has_failed["value"] = True - raise TimeoutError("simulated timeout") + raise DataEngineGqlError(TimeoutError("simulated timeout"), support_id="test-support-id") ds.source.client.update_metadata.side_effect = _flaky_upload ds._upload_metadata(entries) + batch_sizes = _uploaded_batch_sizes(ds) assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 4, 6] - - -def test_upload_metadata_does_not_retry_known_bad_batch_size(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(32) - ] - - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 16) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - mocker.patch("dagshub.common.adaptive_batching.time.sleep", return_value=None) - - has_failed = {"value": False} - - def _flaky_upload(_ds, upload_entries): - if len(upload_entries) == 8 and not has_failed["value"]: - has_failed["value"] = True - raise TimeoutError("simulated timeout") - - ds.source.client.update_metadata.side_effect = _flaky_upload - - ds._upload_metadata(entries) - - assert has_failed["value"] - assert _uploaded_batch_sizes(ds) == [8, 4, 6, 7, 8, 7] - - -def test_upload_metadata_slow_success_reduces_batch_size(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) - ] - - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1.0) - mocker.patch("dagshub.common.adaptive_batching.time.monotonic", side_effect=[0.0, 2.0, 3.0, 3.1]) - - ds._upload_metadata(entries) - - assert _uploaded_batch_sizes(ds) == [8, 4] + assert batch_sizes[0] == 8 + assert any(size < 8 for size in batch_sizes[1:]) def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) - ] + entries = [DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10)] mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) @@ -247,75 +206,6 @@ def test_upload_metadata_non_retryable_error_does_not_retry(ds, mocker): assert _uploaded_batch_sizes(ds) == [8] -def test_upload_metadata_retries_partial_batch_below_min(ds, mocker): - """A short tail batch below min_batch_size gets one retry before aborting.""" - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(10) - ] - - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 4) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") - - def _flaky_upload(_ds, upload_entries): - if len(upload_entries) == 2: - raise TimeoutError("simulated timeout") - - ds.source.client.update_metadata.side_effect = _flaky_upload - - with pytest.raises(TimeoutError, match="simulated timeout"): - ds._upload_metadata(entries) - - assert _uploaded_batch_sizes(ds) == [8, 2, 2] - assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25] - - -def test_upload_metadata_backoff_resets_after_success(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(12) - ] - - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") - - call_idx = {"value": 0} - - def _flaky_upload(_ds, _upload_entries): - call_idx["value"] += 1 - if call_idx["value"] in {1, 3}: - raise TimeoutError("simulated timeout") - - ds.source.client.update_metadata.side_effect = _flaky_upload - - ds._upload_metadata(entries) - - assert [c.args[0] for c in sleep_mock.call_args_list] == [0.25, 0.25] - - -def test_upload_metadata_retries_below_configured_min_before_aborting(ds, mocker): - entries = [ - DatapointMetadataUpdateEntry(f"dp-{i}", "field", str(i), MetadataFieldType.INTEGER) for i in range(6) - ] - - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_max", 8) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_min", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_batch_size_initial", 2) - mocker.patch.object(dagshub.common.config, "dataengine_metadata_upload_target_batch_time_seconds", 1000.0) - sleep_mock = mocker.patch("dagshub.common.adaptive_batching.time.sleep") - ds.source.client.update_metadata.side_effect = TimeoutError("simulated timeout") - - with pytest.raises(TimeoutError, match="simulated timeout"): - ds._upload_metadata(entries) - - assert _uploaded_batch_sizes(ds) == [2] - sleep_mock.assert_not_called() - - def test_pandas_timestamp(ds): data_dict = { "file": ["test1", "test2"], From 841972cbe13e1fb553a08e7a3d1bdacf9c3b048d Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Wed, 18 Mar 2026 23:37:08 +0200 Subject: [PATCH 28/30] removed LEGACY_ nonsense --- dagshub/common/config.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 22b09df3..1f3d4d4f 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -83,30 +83,16 @@ def set_host(new_host: str): ) DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_DE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR" -LEGACY_ADAPTIVE_BATCH_GROWTH_FACTOR_KEY = "DAGSHUB_ADAPTIVE_BATCH_GROWTH_FACTOR" -adaptive_batch_growth_factor = int( - os.environ.get( - DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY, - os.environ.get(LEGACY_ADAPTIVE_BATCH_GROWTH_FACTOR_KEY, 10), - ) -) +adaptive_batch_growth_factor = int(os.environ.get(DATAENGINE_METADATA_UPLOAD_BATCH_GROWTH_FACTOR_KEY, 10)) DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_BASE" -LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE" adaptive_batch_retry_backoff_base_seconds = float( - os.environ.get( - DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY, - os.environ.get(LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_BASE_KEY, 0.25), - ) + os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_BASE_KEY, 0.25) ) DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_MAX" -LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX" adaptive_batch_retry_backoff_max_seconds = float( - os.environ.get( - DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, - os.environ.get(LEGACY_ADAPTIVE_BATCH_RETRY_BACKOFF_MAX_KEY, 4.0), - ) + os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, 4.0) ) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" From ebde20c8b6c261aad498e9de91c7e392503df5eb Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Sun, 22 Mar 2026 16:26:43 +0200 Subject: [PATCH 29/30] Refine adaptive batching search behavior --- dagshub/common/adaptive_batching.py | 164 +++++++++++++++++-------- dagshub/common/config.py | 4 +- tests/common/test_adaptive_batching.py | 60 ++++++--- 3 files changed, 156 insertions(+), 72 deletions(-) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 5ec4ad72..1169b704 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -1,8 +1,9 @@ import itertools import logging +import math import time from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional, Sized, TypeVar +from typing import Callable, Iterable, List, Optional, Sized, Tuple, TypeVar import rich.progress @@ -14,6 +15,19 @@ T = TypeVar("T") MIN_TARGET_BATCH_TIME_SECONDS = 0.01 +SOFT_UPPER_LIMIT_MIN_STEP_FRACTION = 0.05 +SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES = 3 + +# Overall strategy: +# - Grow aggressively on fast successes until we hit a slow or failing batch. +# - A slow or failing batch becomes last_bad_batch_size, and a fast batch becomes +# last_fast_batch_size. +# - last_bad_batch_size acts as a soft_upper_limit while the gap to +# last_fast_batch_size is still meaningful, and we probe within that range. +# - Once that gap is below the search resolution, hold the current fast batch size +# instead of micro-searching. +# - Several consecutive fast batches near last_bad_batch_size trigger one more +# probe at soft_upper_limit, since the earlier failure may have been transient. @dataclass @@ -77,35 +91,29 @@ def _clamp(value: int, lo: int, hi: int) -> int: def _next_batch_after_success( batch_size: int, config: AdaptiveBatchConfig, - bad_batch_size: Optional[int], + soft_upper_limit: Optional[int], ) -> int: """Pick the next batch size after a fast successful batch. Strategy: - If we have a previous slow/failing size, binary-search toward it as a soft upper hint. - Otherwise, multiply by the growth factor. - - If the midpoint rounds back to the current size, advance by 1 so the search - keeps moving. That may revisit the previous failing size, because these hints - are soft signals rather than permanent bans. """ - if bad_batch_size is not None and batch_size < bad_batch_size: - # Binary search: try the midpoint between current and the soft upper hint. - candidate = (batch_size + bad_batch_size) // 2 + if soft_upper_limit is not None and batch_size < soft_upper_limit: + # Binary search: try the midpoint between current and the soft upper limit. + candidate = (batch_size + soft_upper_limit) // 2 else: # No upper hint (or we've already reached it): grow aggressively. candidate = batch_size * config.batch_growth_factor - # Always make forward progress in the search. - candidate = max(candidate, batch_size + 1) - return _clamp(candidate, config.min_batch_size, config.max_batch_size) def _next_batch_after_retryable_failure( batch_size: int, config: AdaptiveBatchConfig, - good_batch_size: Optional[int], - bad_batch_size: Optional[int], + last_fast_batch_size: Optional[int], + soft_upper_limit: Optional[int], ) -> int: """Pick the next batch size after a failed or slow batch. @@ -120,12 +128,12 @@ def _next_batch_after_retryable_failure( return config.min_batch_size ceiling = batch_size - 1 # must shrink - if bad_batch_size is not None: - ceiling = min(ceiling, bad_batch_size - 1) + if soft_upper_limit is not None: + ceiling = min(ceiling, soft_upper_limit - 1) - if good_batch_size is not None and good_batch_size < ceiling: + if last_fast_batch_size is not None and last_fast_batch_size < ceiling: # Binary search: try the midpoint between good and failing - candidate = (good_batch_size + ceiling) // 2 + candidate = (last_fast_batch_size + ceiling) // 2 else: # No good lower bound — probe midpoint of the valid range candidate = (config.min_batch_size + ceiling) // 2 @@ -142,6 +150,28 @@ def _get_retry_delay_seconds(consecutive_retryable_failures: int, config: Adapti return min(delay, config.retry_backoff_max_seconds) +def _min_step_size(soft_upper_limit: int) -> int: + return max(1, math.ceil(soft_upper_limit * SOFT_UPPER_LIMIT_MIN_STEP_FRACTION)) + + +def _is_next_step_above_limit(batch_size: int, soft_upper_limit: Optional[int]) -> bool: + if soft_upper_limit is None or batch_size >= soft_upper_limit: + return False + + return soft_upper_limit - batch_size <= _min_step_size(soft_upper_limit) + + +def _update_bounds_after_bad_batch( + batch_size: int, + last_fast_batch_size: Optional[int], + last_bad_batch_size: Optional[int], +) -> Tuple[Optional[int], int]: + updated_last_bad_batch_size = batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + if last_fast_batch_size is not None and last_fast_batch_size >= updated_last_bad_batch_size: + last_fast_batch_size = None + return last_fast_batch_size, updated_last_bad_batch_size + + class AdaptiveBatcher: """Sends items in adaptively-sized batches, growing on success and shrinking on failure.""" @@ -161,7 +191,7 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: return config = self._config - current_batch_size = config.initial_batch_size + desired_batch_size = config.initial_batch_size # Consume the source iterable incrementally across retries and successes. it = iter(items) pending: List[T] = [] @@ -169,23 +199,24 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: progress = get_rich_progress(rich.progress.MofNCompleteColumn()) total_task = progress.add_task(f"{self._progress_label}...", total=total) - last_good_batch_size: Optional[int] = None + last_fast_batch_size: Optional[int] = None last_bad_batch_size: Optional[int] = None consecutive_retryable_failures = 0 + consecutive_fast_successes_near_upper_limit = 0 processed = 0 with progress: while True: # Draw from pending (failed-batch leftovers) first, then the source iterator - batch = pending[:current_batch_size] - pending = pending[current_batch_size:] - if len(batch) < current_batch_size: - batch.extend(itertools.islice(it, current_batch_size - len(batch))) + batch = pending[:desired_batch_size] + pending = pending[desired_batch_size:] + if len(batch) < desired_batch_size: + batch.extend(itertools.islice(it, desired_batch_size - len(batch))) if not batch: break - batch_size = len(batch) + actual_batch_size = len(batch) - logger.debug(f"{self._progress_label}: {batch_size} entries...") + logger.debug(f"{self._progress_label}: {actual_batch_size} entries...") start_time = time.monotonic() try: @@ -198,33 +229,37 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: ) raise - exhausted_shrink = batch_size <= config.min_batch_size and batch_size == current_batch_size - if exhausted_shrink: + is_short_tail_batch = ( + actual_batch_size <= config.min_batch_size and actual_batch_size < desired_batch_size + ) + if not is_short_tail_batch and actual_batch_size <= config.min_batch_size: logger.error( - f"{self._progress_label} failed at minimum batch size ({batch_size}); aborting.", + f"{self._progress_label} failed at minimum batch size ({actual_batch_size}); aborting.", exc_info=True, ) raise + consecutive_fast_successes_near_upper_limit = 0 + + # Exponential backoff consecutive_retryable_failures += 1 time.sleep(_get_retry_delay_seconds(consecutive_retryable_failures, config)) - last_bad_batch_size = ( - batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( + actual_batch_size, last_fast_batch_size, last_bad_batch_size ) - if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: - last_good_batch_size = None - if batch_size < config.min_batch_size: - # Tail batches below configured min cannot be split further. + if is_short_tail_batch: + # A naturally short tail batch cannot be shrunk further in a useful way. # Retry that exact size once before treating it as exhausted. - current_batch_size = batch_size + desired_batch_size = actual_batch_size else: - current_batch_size = _next_batch_after_retryable_failure( - batch_size, config, last_good_batch_size, last_bad_batch_size + # Binary search downwards + desired_batch_size = _next_batch_after_retryable_failure( + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size ) logger.warning( - f"{self._progress_label} failed for batch size {batch_size} " - f"({exc.__class__.__name__}: {exc}). Retrying with batch size {current_batch_size}." + f"{self._progress_label} failed for batch size {actual_batch_size} " + f"({exc.__class__.__name__}: {exc}). Retrying with batch size {desired_batch_size}." ) # Re-queue the failed batch items for retry with smaller batch size pending = batch + pending @@ -233,29 +268,50 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: # On success. elapsed = time.monotonic() - start_time consecutive_retryable_failures = 0 - processed += batch_size - progress.update(total_task, advance=batch_size) + processed += actual_batch_size + progress.update(total_task, advance=actual_batch_size) if elapsed <= config.target_batch_time_seconds: - last_good_batch_size = ( - batch_size if last_good_batch_size is None else max(last_good_batch_size, batch_size) - ) - # Clear the soft upper hint if we succeeded fast at or above it. - if last_bad_batch_size is not None and batch_size >= last_bad_batch_size: + if last_fast_batch_size is None or actual_batch_size > last_fast_batch_size: + last_fast_batch_size = actual_batch_size + if last_bad_batch_size is not None and actual_batch_size >= last_bad_batch_size: + # A fast success at the upper limit means the last_bad_batch_size is stale. + # We can resume unconstrained growth. last_bad_batch_size = None - current_batch_size = _next_batch_after_success(batch_size, config, last_bad_batch_size) + consecutive_fast_successes_near_upper_limit = 0 + desired_batch_size = _next_batch_after_success( + actual_batch_size, config, last_bad_batch_size + ) + elif _is_next_step_above_limit(actual_batch_size, last_bad_batch_size): + # Once the gap is smaller than our useful search resolution, + # hold the current known-good size and only re-probe the hint + # after a few stable fast successes. + consecutive_fast_successes_near_upper_limit += 1 + if consecutive_fast_successes_near_upper_limit >= SOFT_UPPER_LIMIT_RETRY_AFTER_SUCCESSES: + # We've had enough stable fast successes to re-probe the last_bad_batch_size. + desired_batch_size = last_bad_batch_size + consecutive_fast_successes_near_upper_limit = 0 + else: + # Hold current size for one more iteration + desired_batch_size = actual_batch_size + else: + # Binary search or unconstrained growth upwards + consecutive_fast_successes_near_upper_limit = 0 + desired_batch_size = _next_batch_after_success( + actual_batch_size, config, last_bad_batch_size + ) else: + # Binary search downwards due to a slow batch + consecutive_fast_successes_near_upper_limit = 0 logger.debug( - f"{self._progress_label} batch size {batch_size} took {elapsed:.2f}s " + f"{self._progress_label} batch size {actual_batch_size} took {elapsed:.2f}s " f"(target {config.target_batch_time_seconds:.2f}s); shrinking." ) - last_bad_batch_size = ( - batch_size if last_bad_batch_size is None else min(last_bad_batch_size, batch_size) + last_fast_batch_size, last_bad_batch_size = _update_bounds_after_bad_batch( + actual_batch_size, last_fast_batch_size, last_bad_batch_size ) - if last_good_batch_size is not None and last_good_batch_size >= last_bad_batch_size: - last_good_batch_size = None - current_batch_size = _next_batch_after_retryable_failure( - batch_size, config, last_good_batch_size, last_bad_batch_size + desired_batch_size = _next_batch_after_retryable_failure( + actual_batch_size, config, last_fast_batch_size, last_bad_batch_size ) progress.update(total_task, completed=processed, total=processed, refresh=True) diff --git a/dagshub/common/config.py b/dagshub/common/config.py index 1f3d4d4f..2a9d9bd4 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -91,9 +91,7 @@ def set_host(new_host: str): ) DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY = "DAGSHUB_DE_METADATA_UPLOAD_RETRY_BACKOFF_MAX" -adaptive_batch_retry_backoff_max_seconds = float( - os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, 4.0) -) +adaptive_batch_retry_backoff_max_seconds = float(os.environ.get(DATAENGINE_METADATA_UPLOAD_RETRY_BACKOFF_MAX_KEY, 60.0)) DISABLE_ANALYTICS_KEY = "DAGSHUB_DISABLE_ANALYTICS" disable_analytics = "DAGSHUB_DISABLE_ANALYTICS" in os.environ diff --git a/tests/common/test_adaptive_batching.py b/tests/common/test_adaptive_batching.py index 11d45a9e..d5da7171 100644 --- a/tests/common/test_adaptive_batching.py +++ b/tests/common/test_adaptive_batching.py @@ -7,6 +7,8 @@ AdaptiveBatcher, _clamp, _get_retry_delay_seconds, + _is_next_step_above_limit, + _min_step_size, _next_batch_after_retryable_failure, _next_batch_after_success, ) @@ -95,37 +97,37 @@ def test_equal_bounds(self): class TestNextBatchAfterSuccess: def test_grows_by_growth_factor_when_no_bad_size(self): cfg = _cfg(batch_growth_factor=10, max_batch_size=10000) - assert _next_batch_after_success(10, cfg, bad_batch_size=None) == 100 + assert _next_batch_after_success(10, cfg, soft_upper_limit=None) == 100 def test_capped_at_max_batch_size(self): cfg = _cfg(batch_growth_factor=10, max_batch_size=50) - assert _next_batch_after_success(10, cfg, bad_batch_size=None) == 50 + assert _next_batch_after_success(10, cfg, soft_upper_limit=None) == 50 def test_binary_search_toward_bad_size(self): cfg = _cfg(max_batch_size=10000) - result = _next_batch_after_success(10, cfg, bad_batch_size=20) + result = _next_batch_after_success(10, cfg, soft_upper_limit=20) assert 10 < result < 20 - def test_stays_below_soft_upper_hint_when_midpoint_advances(self): + def test_stays_below_soft_upper_limit_when_midpoint_advances(self): cfg = _cfg(max_batch_size=10000) - result = _next_batch_after_success(18, cfg, bad_batch_size=20) + result = _next_batch_after_success(18, cfg, soft_upper_limit=20) assert result <= 19 # bad_batch_size - 1 def test_respects_min_batch_size(self): cfg = _cfg(min_batch_size=5, max_batch_size=10) - result = _next_batch_after_success(1, cfg, bad_batch_size=None) + result = _next_batch_after_success(1, cfg, soft_upper_limit=None) assert result >= 5 - def test_can_revisit_previous_failure_size_when_midpoint_rounds_down(self): + def test_holds_near_soft_upper_limit_before_reprobing(self): cfg = _cfg(max_batch_size=1000, batch_growth_factor=2) - result = _next_batch_after_success(9, cfg, bad_batch_size=10) - assert result == 10 + result = _next_batch_after_success(9, cfg, soft_upper_limit=10) + assert result == 9 def test_makes_progress_when_growth_factor_would_not_increase(self): cfg = _cfg(batch_growth_factor=2, max_batch_size=100) # batch_size=99, growth gives 198 capped to 100, which is > 99 so it works. # But let's test the edge: batch_size at max-1 should reach max. - result = _next_batch_after_success(99, cfg, bad_batch_size=None) + result = _next_batch_after_success(99, cfg, soft_upper_limit=None) assert result == 100 @@ -145,7 +147,7 @@ def test_returns_min_when_at_min(self): def test_binary_search_between_good_and_bad(self): cfg = _cfg(min_batch_size=1) - result = _next_batch_after_retryable_failure(100, cfg, good_batch_size=40, bad_batch_size=100) + result = _next_batch_after_retryable_failure(100, cfg, last_fast_batch_size=40, soft_upper_limit=100) assert 40 < result < 100 def test_never_returns_below_min_batch_size(self): @@ -188,6 +190,18 @@ def test_zero_failures_treated_as_one(self): assert _get_retry_delay_seconds(0, cfg) == _get_retry_delay_seconds(1, cfg) +class TestSoftUpperHintResolution: + def test_min_step_size_scales_with_hint(self): + assert _min_step_size(10) == 1 + assert _min_step_size(1000) == 50 + + def test_holds_when_gap_is_within_resolution(self): + assert _is_next_step_above_limit(9, 10) + assert not _is_next_step_above_limit(8, 10) + assert not _is_next_step_above_limit(10, 10) + assert not _is_next_step_above_limit(9, None) + + # --------------------------------------------------------------------------- # AdaptiveBatcher.run — integration tests # --------------------------------------------------------------------------- @@ -265,6 +279,22 @@ def test_aborts_at_min_batch_size(self): with pytest.raises(RetryableTestError, match="always fails"): batcher.run([1], lambda batch: (_ for _ in ()).throw(RetryableTestError("always fails"))) + def test_short_tail_batch_retries_exact_size_once_before_aborting(self): + batcher = self._make_batcher(initial_batch_size=5, min_batch_size=4, max_batch_size=5) + attempts = 0 + + def op(batch): + nonlocal attempts + attempts += 1 + if attempts >= 3: + raise TypeError("retried too many times") + raise RetryableTestError("always fails") + + with pytest.raises(RetryableTestError, match="always fails"): + batcher.run([1, 2, 3], op) + + assert attempts == 2 + def test_no_items_lost_on_retry(self): """All items from a failed batch must be retried.""" batcher = self._make_batcher(initial_batch_size=4, min_batch_size=1) @@ -359,7 +389,7 @@ def op(batch): assert batch_sizes[0] == 20 assert min(batch_sizes) < 20 - def test_can_revisit_previous_failure_size_after_fast_recovery(self): + def test_reprobes_soft_upper_limit_after_stable_fast_successes(self): batcher = self._make_batcher( initial_batch_size=10, min_batch_size=1, @@ -377,6 +407,6 @@ def op(batch): items = list(range(200)) batcher.run(items, op) - # First call at size 10 fails, shrinks, then can probe 10 again after recovery. - assert batch_sizes[0] == 10 - assert 10 in batch_sizes[1:] + assert batch_sizes[:4] == [10, 5, 7, 8] + assert batch_sizes[4:7] == [9, 9, 9] + assert 10 in batch_sizes[7:] From e07343c5af8c07376503b8293de892fcf3c86ea0 Mon Sep 17 00:00:00 2001 From: Guy Smoilovsky Date: Mon, 23 Mar 2026 12:59:42 +0200 Subject: [PATCH 30/30] Show adaptive batch size progress again --- dagshub/common/adaptive_batching.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dagshub/common/adaptive_batching.py b/dagshub/common/adaptive_batching.py index 1169b704..5702d442 100644 --- a/dagshub/common/adaptive_batching.py +++ b/dagshub/common/adaptive_batching.py @@ -216,6 +216,7 @@ def run(self, items: Iterable[T], operation: Callable[[List[T]], None]) -> None: break actual_batch_size = len(batch) + progress.update(total_task, description=f"{self._progress_label} (batch size: {actual_batch_size})...") logger.debug(f"{self._progress_label}: {actual_batch_size} entries...") start_time = time.monotonic()