From 3407747483a8674a16b31c5e77fb0a4e17606ea8 Mon Sep 17 00:00:00 2001 From: howard989 Date: Mon, 25 May 2026 00:12:48 -0700 Subject: [PATCH 1/2] fix(rlix): gate MILES wake on residual SGLang memory --- rlix/pipeline/miles_coordinator.py | 31 ++++++- rlix/pipeline/miles_pipeline.py | 87 +++++++------------ rlix/utils/env.py | 18 ++++ tests/test_env_utils.py | 56 ++++++++++++ tests/test_miles_residual_threshold_wiring.py | 68 +++++++++++++++ 5 files changed, 199 insertions(+), 61 deletions(-) create mode 100644 tests/test_env_utils.py create mode 100644 tests/test_miles_residual_threshold_wiring.py diff --git a/rlix/pipeline/miles_coordinator.py b/rlix/pipeline/miles_coordinator.py index 92b7108..e9f1c69 100644 --- a/rlix/pipeline/miles_coordinator.py +++ b/rlix/pipeline/miles_coordinator.py @@ -18,6 +18,7 @@ import asyncio import logging import math +import os import threading import time from copy import deepcopy @@ -37,7 +38,7 @@ get_pipeline_namespace, ) from rlix.protocol.validation import validate_pipeline_id -from rlix.utils.env import pipeline_identity_env_vars +from rlix.utils.env import parse_env_positive_float, pipeline_identity_env_vars from rlix.utils.ray import get_actor_or_raise logger = logging.getLogger(__name__) @@ -60,9 +61,13 @@ def _build_pipeline_env_vars(*, pipeline_id: str, ray_namespace: str) -> Dict[st runtime_env. Reads ``RLIX_CONTROL_PLANE`` from the environment so actors inside an existing pipeline preserve the inherited value. """ - return pipeline_identity_env_vars( + env_vars = pipeline_identity_env_vars( pipeline_id=str(pipeline_id), ray_namespace=str(ray_namespace) ) + for key in ("MILES_MAX_RESIDUAL_GPU_MEM_GB",): + if (value := os.environ.get(key)) is not None: + env_vars[key] = value + return env_vars class MilesCoordinator(Coordinator): @@ -430,8 +435,26 @@ def _shrink_workers(self, engine_indices: Set[int]) -> None: rollout_manager = self._model_update_resources.get("rollout_manager") if rollout_manager is None: raise RuntimeError("resource registration missing for shrink") - # RPC outside the lock. - ray.get(rollout_manager.shrink_engines.remote(sorted(engine_indices))) + # RPC outside the lock. Use SGLang's server-side residual allocation + # check (weight + kvcache + graph) after release_memory_occupation; it + # is narrower than raw nvidia-smi used memory and avoids counting CUDA / + # Ray / process runtime overhead as model residue. + residual_threshold_gb = parse_env_positive_float( + "MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0 + ) + shrunk = ray.get( + rollout_manager.shrink_engines.remote( + sorted(engine_indices), + post_sleep_vram_threshold_gb=residual_threshold_gb, + ) + ) + logger.info( + "[MilesCoordinator] shrink_engines residual allocation check passed " + "pipeline_id=%s engine_indices=%s threshold=%.1f GB", + self._pipeline_id, + sorted(shrunk), + residual_threshold_gb, + ) # Commit under lock. with self._resize_sync_lock: self._active_engine_indices -= engine_indices diff --git a/rlix/pipeline/miles_pipeline.py b/rlix/pipeline/miles_pipeline.py index c23faef..c806706 100644 --- a/rlix/pipeline/miles_pipeline.py +++ b/rlix/pipeline/miles_pipeline.py @@ -503,16 +503,14 @@ def _init_phase_b_infer(self) -> None: def _wait_for_overlap_engines_offloaded(self, allocated_train_gpus, *, timeout_s: float = 60.0) -> None: """After scheduler grants actor_train, poll the rollout manager - until the engines on overlap GPUs have transitioned to ``offloaded`` - AND the OS-reported GPU memory is actually free. SGLang's HTTP - ``/release_memory_occupation`` 200 OK + state="offloaded" do not - by themselves guarantee the CUDA driver has returned the memory - to the OS pool — the wake_up in the next-process train actor - would then OOM. Verify actual GPU mem free by parsing - ``nvidia-smi --query-gpu=memory.free`` on the same node, since - miles' single-node smoke topology has driver+actors+engines all - on the head node and ``CUDA_VISIBLE_DEVICES`` is the per-actor - slice of the shared physical pool. + until the engines on overlap GPUs have transitioned to ``offloaded``. + + The hard residual-allocation safety check runs during + ``RolloutManager.shrink_engines`` via SGLang ``/server_info`` + (weight + kvcache + graph). This method only waits for the state + transition and logs raw OS-level ``nvidia-smi memory.used`` as a + diagnostic, because process-level GPU usage includes CUDA / Ray / + runtime overhead beyond SGLang's offloadable allocations. """ rollout_manager = getattr(self, "_rollout_manager", None) if rollout_manager is None: @@ -573,52 +571,27 @@ def _wait_for_overlap_engines_offloaded(self, allocated_train_gpus, *, timeout_s timeout_s, target_indices, uniq, ) - # Phase 2: probe nvidia-smi for OS-level free memory on the - # overlap GPU IDs. The train actor will need ~3.7 GB for the - # 0.5B model + a few GB for activations; aim for ≥20 GB free - # before we let _before_training proceed to wake_up. - target_free_gb = 20.0 - deadline2 = time.time() + float(timeout_s) - last_min_free_gb: Optional[float] = None - nvidia_smi_unavail_count = 0 - while time.time() < deadline2: - min_free_gb = self._probe_min_free_gpu_mem_gb(target_gpu_ids) - if min_free_gb is None: - # F5 (m11-review.review-report.md §2): nvidia-smi unavailable - # or unparseable. Was logged at DEBUG only — promoted to INFO - # so operators see the fallback without flipping log levels. - # If this fires repeatedly across sessions, it's a hardware - # / image regression worth investigating (driver missing, - # nvidia-smi path changed, etc.). - nvidia_smi_unavail_count += 1 - logger.info( - "_wait_for_overlap_engines_offloaded: nvidia-smi probe " - "unavailable (count=%d); falling back to 3s grace sleep", - nvidia_smi_unavail_count, - ) - time.sleep(3.0) - return - last_min_free_gb = min_free_gb - if min_free_gb >= target_free_gb: - logger.info( - "_wait_for_overlap_engines_offloaded: OS-level GPU mem free " - "min=%.2f GB across overlap GPUs %s (target=%.1f GB)", - min_free_gb, target_gpu_ids, target_free_gb, - ) - return - time.sleep(0.5) - logger.warning( - "_wait_for_overlap_engines_offloaded: free-mem timeout after %.1fs; " - "min_free_gb=%.2f below %.1f GB target on GPUs %s — wake_up may OOM", - timeout_s, - last_min_free_gb if last_min_free_gb is not None else float("nan"), - target_free_gb, + # Phase 2: log raw nvidia-smi used memory as diagnostics only. + # The hard safety check now runs inside RolloutManager.shrink_engines + # via SGLang /server_info (weight + kvcache + graph), which is a + # narrower residual-allocation signal than process-level GPU usage. + max_used_gb = self._probe_max_used_gpu_mem_gb(target_gpu_ids) + if max_used_gb is None: + logger.info( + "_wait_for_overlap_engines_offloaded: nvidia-smi probe unavailable; " + "server-side SGLang residual check already ran during shrink" + ) + return + logger.info( + "_wait_for_overlap_engines_offloaded: OS-level GPU mem used max=%.2f GB " + "across overlap GPUs %s (diagnostic; SGLang residual assert is the gate)", + max_used_gb, target_gpu_ids, ) @staticmethod - def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]: - """Return the minimum free GPU memory (GB) across ``gpu_ids`` as + def _probe_max_used_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]: + """Return the maximum used GPU memory (GB) across ``gpu_ids`` as reported by ``nvidia-smi``. Returns ``None`` if nvidia-smi is not available or output cannot be parsed. """ @@ -634,7 +607,7 @@ def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]: [ "nvidia-smi", f"--id={','.join(str(g) for g in gpu_ids)}", - "--query-gpu=memory.free", + "--query-gpu=memory.used", "--format=csv,noheader,nounits", ], stderr=subprocess.STDOUT, @@ -643,18 +616,18 @@ def _probe_min_free_gpu_mem_gb(gpu_ids: list[int]) -> Optional[float]: except (subprocess.SubprocessError, OSError) as exc: logger.debug("nvidia-smi probe failed: %r", exc) return None - free_mibs: list[float] = [] + used_mibs: list[float] = [] for line in out.strip().splitlines(): line = line.strip() if not line: continue try: - free_mibs.append(float(line)) + used_mibs.append(float(line)) except ValueError: continue - if not free_mibs: + if not used_mibs: return None - return min(free_mibs) / 1024.0 + return max(used_mibs) / 1024.0 def _before_training(self, step: int) -> None: if not self._initialized: diff --git a/rlix/utils/env.py b/rlix/utils/env.py index ddacd6d..29ca678 100644 --- a/rlix/utils/env.py +++ b/rlix/utils/env.py @@ -52,3 +52,21 @@ def parse_env_timeout_s(env_key: str, default_s: Optional[float] = None) -> Opti except ValueError as exc: raise RuntimeError(f"{env_key} must be a number, got: {raw!r}") from exc return None if value <= 0 else value + + +def parse_env_positive_float(env_key: str, default: float) -> float: + """Read a positive float from an env var; fail-fast on invalid values. + + Returns *default* when the env var is unset. Raises RuntimeError if the + value cannot be parsed as a number, or if the parsed value is <= 0. + """ + raw = os.environ.get(env_key) + if raw is None: + return float(default) + try: + value = float(raw) + except ValueError as exc: + raise RuntimeError(f"{env_key} must be a number, got: {raw!r}") from exc + if value <= 0.0: + raise RuntimeError(f"{env_key} must be > 0, got: {value!r}") + return value diff --git a/tests/test_env_utils.py b/tests/test_env_utils.py new file mode 100644 index 0000000..12a5bfd --- /dev/null +++ b/tests/test_env_utils.py @@ -0,0 +1,56 @@ +import importlib +import sys +import types +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +RLIX_ROOT = REPO_ROOT / "rlix" + + +def _load_env_module(monkeypatch): + for module_name in list(sys.modules): + if module_name == "rlix" or module_name.startswith("rlix."): + monkeypatch.delitem(sys.modules, module_name, raising=False) + + package_roots = { + "rlix": RLIX_ROOT, + "rlix.utils": RLIX_ROOT / "utils", + } + for module_name, module_path in package_roots.items(): + package_module = types.ModuleType(module_name) + package_module.__path__ = [str(module_path)] # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, module_name, package_module) + + return importlib.import_module("rlix.utils.env") + + +def test_parse_env_positive_float_uses_default_when_unset(monkeypatch): + env = _load_env_module(monkeypatch) + monkeypatch.delenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", raising=False) + + assert env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) == 2.0 + + +def test_parse_env_positive_float_reads_override(monkeypatch): + env = _load_env_module(monkeypatch) + monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "40.5") + + assert env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) == 40.5 + + +def test_parse_env_positive_float_rejects_non_positive(monkeypatch): + env = _load_env_module(monkeypatch) + monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "0") + + with pytest.raises(RuntimeError, match="must be > 0"): + env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) + + +def test_parse_env_positive_float_rejects_non_numeric(monkeypatch): + env = _load_env_module(monkeypatch) + monkeypatch.setenv("MILES_MAX_RESIDUAL_GPU_MEM_GB", "not-a-number") + + with pytest.raises(RuntimeError, match="must be a number"): + env.parse_env_positive_float("MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0) diff --git a/tests/test_miles_residual_threshold_wiring.py b/tests/test_miles_residual_threshold_wiring.py new file mode 100644 index 0000000..e743d97 --- /dev/null +++ b/tests/test_miles_residual_threshold_wiring.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import ast +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def _is_name(node: ast.AST, name: str) -> bool: + return isinstance(node, ast.Name) and node.id == name + + +def _is_attr(node: ast.AST, attr: str) -> bool: + return isinstance(node, ast.Attribute) and node.attr == attr + + +def test_miles_shrink_uses_server_side_residual_threshold() -> None: + source = (REPO_ROOT / "rlix" / "pipeline" / "miles_coordinator.py").read_text( + encoding="utf-8" + ) + tree = ast.parse(source) + + shrink_fn = next( + node + for node in ast.walk(tree) + if isinstance(node, ast.FunctionDef) and node.name == "_shrink_workers" + ) + + assert any( + isinstance(node, ast.Call) + and _is_name(node.func, "parse_env_positive_float") + and len(node.args) >= 2 + and isinstance(node.args[0], ast.Constant) + and node.args[0].value == "MILES_MAX_RESIDUAL_GPU_MEM_GB" + and isinstance(node.args[1], ast.Constant) + and node.args[1].value == 2.0 + for node in ast.walk(shrink_fn) + ), "_shrink_workers must parse the residual threshold env var with 2GB default" + + assert any( + isinstance(node, ast.Call) + and _is_attr(node.func, "remote") + and any( + kw.arg == "post_sleep_vram_threshold_gb" + and _is_name(kw.value, "residual_threshold_gb") + for kw in node.keywords + ) + for node in ast.walk(shrink_fn) + ), "shrink_engines must receive post_sleep_vram_threshold_gb" + + +def test_miles_coordinator_forwards_residual_threshold_env_var() -> None: + source = (REPO_ROOT / "rlix" / "pipeline" / "miles_coordinator.py").read_text( + encoding="utf-8" + ) + tree = ast.parse(source) + + build_env_fn = next( + node + for node in ast.walk(tree) + if isinstance(node, ast.FunctionDef) and node.name == "_build_pipeline_env_vars" + ) + + assert any( + isinstance(node, ast.Constant) + and node.value == "MILES_MAX_RESIDUAL_GPU_MEM_GB" + for node in ast.walk(build_env_fn) + ), "_build_pipeline_env_vars must forward MILES_MAX_RESIDUAL_GPU_MEM_GB" From ac9312d537028bb5d2d37c6351f0ac8b48856ff6 Mon Sep 17 00:00:00 2001 From: howard989 Date: Mon, 25 May 2026 01:10:44 -0700 Subject: [PATCH 2/2] fix(rlix): use 3GB per-process residual threshold default --- rlix/pipeline/miles_coordinator.py | 20 ++++++++++++------- tests/test_miles_residual_threshold_wiring.py | 4 ++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/rlix/pipeline/miles_coordinator.py b/rlix/pipeline/miles_coordinator.py index e9f1c69..f6db4c7 100644 --- a/rlix/pipeline/miles_coordinator.py +++ b/rlix/pipeline/miles_coordinator.py @@ -435,12 +435,17 @@ def _shrink_workers(self, engine_indices: Set[int]) -> None: rollout_manager = self._model_update_resources.get("rollout_manager") if rollout_manager is None: raise RuntimeError("resource registration missing for shrink") - # RPC outside the lock. Use SGLang's server-side residual allocation - # check (weight + kvcache + graph) after release_memory_occupation; it - # is narrower than raw nvidia-smi used memory and avoids counting CUDA / - # Ray / process runtime overhead as model residue. + # RPC outside the lock. + # Per-engine PROCESS-resident GPU memory threshold (GiB) passed to + # MILES shrink_engines -> assert_post_sleep_process_vram_below_threshold. + # Default 3.0: an offloaded SGLang scheduler process measures ~1.8 GiB + # resident (mostly non-offloadable CUDA context), so 3.0 leaves margin + # over that baseline while still catching large residuals such as an + # unoffloaded KV pool. (A 0.5B weight-only offload miss adds only ~1 GiB + # and may not trip it; the gate targets large KV/full-offload failures.) + # This is NOT whole-GPU used and NOT /server_info accounting. residual_threshold_gb = parse_env_positive_float( - "MILES_MAX_RESIDUAL_GPU_MEM_GB", 2.0 + "MILES_MAX_RESIDUAL_GPU_MEM_GB", 3.0 ) shrunk = ray.get( rollout_manager.shrink_engines.remote( @@ -449,8 +454,9 @@ def _shrink_workers(self, engine_indices: Set[int]) -> None: ) ) logger.info( - "[MilesCoordinator] shrink_engines residual allocation check passed " - "pipeline_id=%s engine_indices=%s threshold=%.1f GB", + "[MilesCoordinator] shrink_engines complete pipeline_id=%s " + "engine_indices=%s per_process_residual_threshold=%.1f GB " + "(per-engine resident gate ran inside shrink_engines; fail-open if unmeasurable)", self._pipeline_id, sorted(shrunk), residual_threshold_gb, diff --git a/tests/test_miles_residual_threshold_wiring.py b/tests/test_miles_residual_threshold_wiring.py index e743d97..2c10d7d 100644 --- a/tests/test_miles_residual_threshold_wiring.py +++ b/tests/test_miles_residual_threshold_wiring.py @@ -33,9 +33,9 @@ def test_miles_shrink_uses_server_side_residual_threshold() -> None: and isinstance(node.args[0], ast.Constant) and node.args[0].value == "MILES_MAX_RESIDUAL_GPU_MEM_GB" and isinstance(node.args[1], ast.Constant) - and node.args[1].value == 2.0 + and node.args[1].value == 3.0 for node in ast.walk(shrink_fn) - ), "_shrink_workers must parse the residual threshold env var with 2GB default" + ), "_shrink_workers must parse the residual threshold env var with 3GB default" assert any( isinstance(node, ast.Call)