Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion miles/rollout/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class EnginePreemptedError(Exception):
class RLixRouterMetadataError(Exception):
"""Raised when an RLix-mode generate response is missing router-injected metadata.

The MILES router injects ``meta_info["miles_engine_index", "miles_admission_disabled"]``
The MILES router injects ``meta_info["miles_admission_disabled"]``
into every ``/generate`` JSON response in RLix mode. Absence is treated as a fatal
misconfiguration rather than allowing turn-level redispatch to silently degrade.
"""
48 changes: 13 additions & 35 deletions miles/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def __init__(self, args, verbose=False):
# - enabled_workers: URLs admitted for routing. Source of truth for
# dispatch, NOT metadata. shrink/disable removes; expand/enable
# adds. _use_url selects from `enabled_workers - dead_workers`.
# - worker_engine_index_map: URL → engine_index, populated at
# add_worker and consumed by F3 metadata injection (iter 8).
# URL -> Active Request Count (load state)
self.worker_request_counts: dict[str, int] = {}
# URL -> Consecutive Failures
Expand All @@ -78,7 +76,6 @@ def __init__(self, args, verbose=False):
# set on add_worker (preserves legacy behavior); RLix-mode flow uses
# disable/enable to flip admission without removing from the registry.
self.enabled_workers: set[str] = set()
self.worker_engine_index_map: dict[str, int] = {}
self.max_weight_version = None

max_connections = getattr(args, "miles_router_max_connections", None)
Expand Down Expand Up @@ -235,8 +232,7 @@ async def do_proxy(
"""Core proxy logic. Returns dict with request_body, response_body, status_code, headers.

F32 router metadata injection: when ``path == "generate"``,
rewrite the JSON response body to inject
``meta_info["miles_engine_index", "miles_admission_disabled"]``
rewrite the JSON response body to inject RLix admission metadata
based on the dispatched worker URL. The injection is path-guarded
(only ``/generate``) to avoid breaking
``/model_info`` / ``/v1/loads`` / ``/health`` etc. Header
Expand Down Expand Up @@ -310,10 +306,6 @@ def _inject_generate_metadata(
# Some SGLang versions return meta_info=None on errors; promote.
meta_info = {}
payload["meta_info"] = meta_info
engine_index = self.worker_engine_index_map.get(worker_url)
if engine_index is not None:
meta_info["miles_engine_index"] = engine_index
meta_info["miles_worker_url"] = worker_url
# Only report scheduler-preempt when admission has actually been
# declared. In the pre-admission legacy fallback mode the dispatch
# source of truth is worker_request_counts - dead_workers (not
Expand Down Expand Up @@ -348,17 +340,16 @@ async def add_worker(self, request: Request):
Supports providing the URL via query string or JSON body.
Examples:
- POST /add_worker?url=http://127.0.0.1:10090
- POST /add_worker?url=http://127.0.0.1:10090&engine_index=0
- POST /add_worker with body {"url": "...", "engine_index": 0}
- POST /add_worker with body {"url": "..."}
"""
worker_url, engine_index = self._extract_worker_params(
worker_url = self._extract_worker_params(
request.query_params, await self._safe_json_body(request)
)
if not worker_url:
return JSONResponse(
status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"}
)
self._add_worker_internal(worker_url, engine_index)
self._add_worker_internal(worker_url)
await self._notify_workers_changed()
return {"status": "success", "worker_urls": self.worker_request_counts}

Expand All @@ -370,7 +361,7 @@ async def disable_worker(self, request: Request):
consistent with later _finish_url calls); only enabled_workers
loses the URL.
"""
worker_url, _ = self._extract_worker_params(
worker_url = self._extract_worker_params(
request.query_params, await self._safe_json_body(request)
)
if not worker_url:
Expand All @@ -385,7 +376,7 @@ async def enable_worker(self, request: Request):
Reset failure_count to 0 (per F68 — invariant: re-admit must not
carry over a stale failure count from a prior disable cycle).
"""
worker_url, _ = self._extract_worker_params(
worker_url = self._extract_worker_params(
request.query_params, await self._safe_json_body(request)
)
if not worker_url:
Expand All @@ -398,11 +389,10 @@ async def remove_worker(self, request: Request):
"""Permanently drop a worker from every registry.

Distinct from disable_worker: removes from worker_request_counts /
worker_failure_counts / enabled_workers / dead_workers /
worker_engine_index_map. Use only on actor death; routing-time
worker_failure_counts / enabled_workers / dead_workers. Use only on actor death; routing-time
shrink uses disable_worker to preserve in-flight balance.
"""
worker_url, _ = self._extract_worker_params(
worker_url = self._extract_worker_params(
request.query_params, await self._safe_json_body(request)
)
if not worker_url:
Expand Down Expand Up @@ -435,7 +425,6 @@ async def admission_state(self, request: Request):
"dead_workers": sorted(self.dead_workers),
"worker_request_counts": dict(self.worker_request_counts),
"worker_failure_counts": dict(self.worker_failure_counts),
"worker_engine_index_map": dict(self.worker_engine_index_map),
}

# ------------------------------------------------------------------
Expand All @@ -445,20 +434,12 @@ async def admission_state(self, request: Request):
# ------------------------------------------------------------------

@staticmethod
def _extract_worker_params(
query_params, body_payload: dict | None
) -> tuple[str | None, int | None]:
"""Pull `worker_url` (and optional `engine_index`) from query or body."""
def _extract_worker_params(query_params, body_payload: dict | None) -> str | None:
"""Pull `worker_url` from query or body."""
worker_url = query_params.get("url") or query_params.get("worker_url")
engine_index_raw = query_params.get("engine_index")
if not worker_url and body_payload:
worker_url = body_payload.get("url") or body_payload.get("worker_url")
if engine_index_raw is None and body_payload:
engine_index_raw = body_payload.get("engine_index")
engine_index: int | None = None
if engine_index_raw is not None:
engine_index = int(engine_index_raw)
return worker_url, engine_index
return worker_url

@staticmethod
async def _safe_json_body(request: Request) -> dict | None:
Expand All @@ -470,7 +451,7 @@ async def _safe_json_body(request: Request) -> dict | None:
except (ValueError, TypeError):
return None

def _add_worker_internal(self, url: str, engine_index: int | None) -> None:
def _add_worker_internal(self, url: str) -> None:
"""Register a worker, default-admit it, clear stale dead-state.

F68 invariants: ``setdefault`` so re-add doesn't zero an in-flight
Expand All @@ -481,19 +462,16 @@ def _add_worker_internal(self, url: str, engine_index: int | None) -> None:
self.worker_failure_counts.setdefault(url, 0)
self.dead_workers.discard(url)
self.enabled_workers.add(url)
if engine_index is not None:
self.worker_engine_index_map[url] = engine_index
self._admission_declared = True
if self.verbose:
print(f"[miles-router] Added worker: {url} (engine_index={engine_index})")
print(f"[miles-router] Added worker: {url}")

def _remove_worker_internal(self, url: str) -> None:
"""Drop the URL from every per-worker registry."""
self.worker_request_counts.pop(url, None)
self.worker_failure_counts.pop(url, None)
self.dead_workers.discard(url)
self.enabled_workers.discard(url)
self.worker_engine_index_map.pop(url, None)
self._admission_declared = True
if self.verbose:
print(f"[miles-router] Removed worker: {url}")
Expand Down
7 changes: 3 additions & 4 deletions tests/test_partial_sleep_wake.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def _build_router(self):

def test_add_worker_admits_by_default(self):
router = self._build_router()
router._add_worker_internal("http://w1:8000", engine_index=0)
router._add_worker_internal("http://w1:8000")
self.assertIn("http://w1:8000", router.enabled_workers)
self.assertEqual(router.worker_request_counts["http://w1:8000"], 0)
self.assertTrue(router._admission_declared)

def test_disable_worker_keeps_request_counts(self):
router = self._build_router()
router._add_worker_internal("http://w1:8000", engine_index=0)
router._add_worker_internal("http://w1:8000")
router._disable_worker_internal("http://w1:8000")
self.assertNotIn("http://w1:8000", router.enabled_workers)
# Preserved so in-flight balance accounting stays consistent.
Expand All @@ -90,11 +90,10 @@ def test_disable_worker_keeps_request_counts(self):

def test_remove_worker_drops_all_state(self):
router = self._build_router()
router._add_worker_internal("http://w1:8000", engine_index=0)
router._add_worker_internal("http://w1:8000")
router._remove_worker_internal("http://w1:8000")
self.assertNotIn("http://w1:8000", router.worker_request_counts)
self.assertNotIn("http://w1:8000", router.enabled_workers)
self.assertNotIn("http://w1:8000", router.worker_engine_index_map)


class TestSchedulerPreemptClassification(unittest.TestCase):
Expand Down