Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/rlix/run_miles_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _build_pipeline(
pipeline_runtime_env_vars["PYTHONPATH"] = pythonpath
for _k in (
"MILES_TMS_HOOK_MODE",
"MILES_MAX_RESIDUAL_GPU_MEM_GB",
"MILES_SKIP_TMS_PAUSE",
"MILES_SKIP_NODE_PG_PIN",
"TMS_INIT_ENABLE_CPU_BACKUP",
Expand Down
1 change: 1 addition & 0 deletions examples/rlix/run_miles_rlix.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MilesPipelineConfig:
# the parent driver's env by default).
for _k in (
"MILES_TMS_HOOK_MODE",
"MILES_MAX_RESIDUAL_GPU_MEM_GB",
"MILES_SKIP_TMS_PAUSE",
"MILES_SKIP_NODE_PG_PIN",
"TMS_INIT_ENABLE_CPU_BACKUP",
Expand Down
65 changes: 65 additions & 0 deletions miles/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, convert_target_modules_to_hf, is_lora_enabled
from miles.ray.ray_actor import RayActor
from miles.utils.gpu_probe import query_process_tree_gpu_used_gb
from miles.utils.env_report import collect_and_print_node_env_report
from miles.utils.http_utils import get_host_info

Expand Down Expand Up @@ -756,6 +757,70 @@ def assert_post_sleep_vram_below_threshold(
)
return observed_max_gb

def _server_info_residual_gb(self, timeout_s: float = 5.0):
"""SGLang /server_info weight+kvcache+graph, max across DPs (GiB).

This is *accounting* (KV static-pool size). It does NOT drop after a
torch_memory_saver pause, so it is logged for diagnostics only and is
never used as a hard gate. Returns None if unavailable.
"""
try:
body = self.get_server_info()
except Exception:
return None
internal_states = body.get("internal_states") if isinstance(body, dict) else None
if not isinstance(internal_states, list) or not internal_states:
return None
observed_max_gb = 0.0
for state in internal_states:
mem = state.get("memory_usage") if isinstance(state, dict) else None
if not isinstance(mem, dict):
continue
total_gb = sum(float(mem.get(k, 0.0) or 0.0) for k in ("weight", "kvcache", "graph"))
observed_max_gb = max(observed_max_gb, total_gb)
return observed_max_gb

def log_post_sleep_residual_diagnostics(
self, threshold_gb: float | None = None, timeout_s: float = 5.0
):
"""Log attribution diagnostics after ``release_memory_occupation``.

The hard residual gate is whole-GPU ``memory.used`` in RLix. This
engine-side diagnostic still records this SGLang process tree's real
resident GPU memory and ``/server_info`` accounting so high whole-GPU
residual can be attributed to SGLang vs non-SGLang co-tenants.

Returns the measured process-resident GiB, or ``None`` when
unmeasurable (nvidia-smi missing / PID-namespace mismatch).
"""
if self.node_rank != 0:
return None
_log = logging.getLogger(__name__)
account_gb = self._server_info_residual_gb(timeout_s=timeout_s)
root = getattr(self, "process", None)
resident_gb = query_process_tree_gpu_used_gb(
getattr(root, "pid", None), timeout_s=timeout_s
)
_log.info(
"post-sleep residual diagnostic engine=%s:%s "
"process_resident=%s GiB "
"server_info_accounting(weight+kvcache+graph)=%s GiB "
"whole_gpu_threshold=%s GiB",
self.server_host,
self.server_port,
("%.3f" % resident_gb) if resident_gb is not None else "n/a",
("%.3f" % account_gb) if account_gb is not None else "n/a",
("%.3f" % float(threshold_gb)) if threshold_gb is not None else "n/a",
)
if resident_gb is None:
_log.warning(
"post-sleep process-resident diagnostic unavailable on engine "
"%s:%s (nvidia-smi missing or PID-namespace mismatch).",
self.server_host,
self.server_port,
)
return resident_gb

def resume_memory_occupation(self, tags: list[str] = None):
"""
Available tags for multi-stage resume: weights, kv_cache
Expand Down
21 changes: 18 additions & 3 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,16 +1031,31 @@ def shrink_engines(
)
# Step 4: release memory.
ray.get([h.release_memory_occupation.remote(tags=None) for h in handles])
# Step 5: optional post-sleep VRAM assert.
# Step 5: attribution diagnostics. The hard residual gate is
# whole-GPU memory.used in RLix; this logs each SGLang engine's
# process-resident memory and /server_info accounting so a high
# whole-GPU residual can be attributed to SGLang vs non-SGLang
# co-tenants (Megatron/Miles/vLLM/orphan processes).
if post_sleep_vram_threshold_gb is not None:
ray.get(
observed_resident_gbs = ray.get(
[
h.assert_post_sleep_vram_below_threshold.remote(
h.log_post_sleep_residual_diagnostics.remote(
threshold_gb=post_sleep_vram_threshold_gb
)
for h in handles
]
)
measured = [v for v in observed_resident_gbs if v is not None]
logger.info(
"shrink_engines: post-sleep SGLang residual diagnostics "
"process_resident_max=%s GiB per_engine=%s "
"whole_gpu_threshold=%.3f GiB engine_indices=%s "
"(whole-GPU hard gate runs in RLix)",
("%.3f" % max(measured)) if measured else "n/a",
[None if v is None else round(float(v), 3) for v in observed_resident_gbs],
float(post_sleep_vram_threshold_gb),
indices,
)
except Exception:
# Reset the abort cache on failure so retry re-aborts new
# in-flights that arrived during the failed cycle.
Expand Down
165 changes: 165 additions & 0 deletions miles/utils/gpu_probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""GPU residual probing helpers — pure, dependency-free, unit-testable.

Used by :class:`SGLangEngine` to measure a SGLang server's REAL per-process
resident GPU memory after offload, via ``nvidia-smi`` compute-apps over the
engine's process tree. Kept free of sglang/torch imports so the parsing and
fail-open logic can be unit-tested without a GPU.

Semantics: ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` is the **max per-GPU** resident
residual for an engine — for each GPU, sum the engine's process-tree usage on
that GPU, then take the max across GPUs. This supports TP>1 without summing
across cards (which would over-count and falsely trip the gate).
"""
from __future__ import annotations

import logging
import shutil
import subprocess

logger = logging.getLogger(__name__)


def build_process_tree(root_pid, proc_root: str = "/proc") -> set:
"""Return ``root_pid`` plus all descendant PIDs by reading
``<proc_root>/<pid>/stat`` ppid links. Pure /proc walk, no psutil.

``self.process.pid`` is the multiprocessing spawn parent; the real
GPU-resident process is the ``sglang::scheduler`` child, so the whole
tree must be walked.
"""
import os

try:
entries = [int(p) for p in os.listdir(proc_root) if p.isdigit()]
except OSError:
return {root_pid}
children: dict = {}
for pid in entries:
try:
with open(os.path.join(proc_root, str(pid), "stat"), "rb") as f:
data = f.read()
except OSError:
continue
# comm (2nd field) is paren-wrapped and may contain spaces/parens;
# ppid is the 2nd whitespace token after the final ')'.
try:
rparen = data.rindex(b")")
ppid = int(data[rparen + 2:].split()[1])
except (ValueError, IndexError):
continue
children.setdefault(ppid, []).append(pid)
tree = {root_pid}
stack = [root_pid]
while stack:
cur = stack.pop()
for ch in children.get(cur, ()):
if ch not in tree:
tree.add(ch)
stack.append(ch)
return tree


def parse_compute_apps_per_gpu_max_gb(nvidia_csv: str, tree_pids: set):
"""Parse ``nvidia-smi --query-compute-apps=gpu_bus_id,pid,used_memory``.

For PIDs in ``tree_pids``: sum ``used_memory`` (MiB) per GPU
(keyed by ``gpu_bus_id``), then take the MAX across GPUs and return GiB.
This is the ``MILES_MAX_RESIDUAL_GPU_MEM_GB`` semantics: the engine's
worst single-GPU resident residual (TP-safe — no cross-card summing).

Returns ``None`` (fail-open) if no tree PID appears in the listing.
"""
per_gpu: dict = {}
matched = False
for line in nvidia_csv.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) < 3:
continue
bus_id = parts[0]
try:
pid = int(parts[1])
used = float(parts[2])
except ValueError:
continue
if pid in tree_pids:
per_gpu[bus_id] = per_gpu.get(bus_id, 0.0) + used
matched = True
if not matched:
return None
return max(per_gpu.values()) / 1024.0


def parse_compute_apps_used_gb(nvidia_csv: str, tree_pids: set):
"""Fallback parser for the legacy 2-col ``pid,used_memory`` query (no
``gpu_bus_id``). Sums all matched rows -> GiB. Used only when the
GPU-aware query is unavailable; it cannot distinguish per-GPU, so it
over-estimates for a multi-GPU engine.

Returns ``None`` (fail-open) if no tree PID appears.
"""
total_mib = 0.0
matched = False
for line in nvidia_csv.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) < 2:
continue
try:
pid = int(parts[0])
used = float(parts[1])
except ValueError:
continue
if pid in tree_pids:
total_mib += used
matched = True
if not matched:
return None
return total_mib / 1024.0


def _run_nvidia_smi(args, timeout_s: float):
try:
return subprocess.check_output(
["nvidia-smi"] + args, stderr=subprocess.STDOUT, timeout=timeout_s
).decode("utf-8", errors="replace")
except (subprocess.SubprocessError, OSError):
return None


def query_process_tree_gpu_used_gb(root_pid, timeout_s: float = 5.0,
proc_root: str = "/proc"):
"""Max per-GPU resident GPU memory (GiB) of ``root_pid``'s process tree.

Prefers the GPU-aware query (``gpu_bus_id,pid,used_memory``): per-GPU
sum, max across GPUs. Falls back to the legacy 2-col query
(``pid,used_memory``, summed) with a warning if the GPU-aware query is
unsupported by this nvidia-smi.

Returns ``None`` (fail-open) when nvidia-smi is missing, the call fails,
or no tree PID appears (PID-namespace mismatch inside a container).
Callers MUST treat ``None`` as "cannot measure", never as 0.
"""
if root_pid is None or shutil.which("nvidia-smi") is None:
return None
tree = build_process_tree(root_pid, proc_root=proc_root)
if not tree:
return None
out = _run_nvidia_smi(
["--query-compute-apps=gpu_bus_id,pid,used_memory",
"--format=csv,noheader,nounits"],
timeout_s,
)
if out is not None:
return parse_compute_apps_per_gpu_max_gb(out, tree)
# GPU-aware query unsupported -> legacy 2-col fallback (summed).
logger.warning(
"nvidia-smi gpu_bus_id query unavailable; falling back to "
"pid,used_memory (summed; cannot distinguish per-GPU)"
)
out = _run_nvidia_smi(
["--query-compute-apps=pid,used_memory",
"--format=csv,noheader,nounits"],
timeout_s,
)
if out is None:
return None
return parse_compute_apps_used_gb(out, tree)
Loading