diff --git a/pyproject.toml b/pyproject.toml index 5bf6725b3..fab8e5ef1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "opentelemetry-exporter-otlp>=1.28.0", "h11>=0.16.0", # fix critical vulnerability GHSA-vqfr-h8mv-ghfj "httpcore>=1.0.8", # required for h11>=0.16.0 + "pyzmq>=27.0.0", + "msgspec>=0.19.0", ] [project.scripts] diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 4bb8823e2..e6981409e 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -33,6 +33,7 @@ from vllm_router.routers.main_router import main_router from vllm_router.routers.metrics_router import metrics_router from vllm_router.routers.routing_logic import ( + DisaggregatedPrefillRouter, cleanup_routing_logic, get_routing_logic, initialize_routing_logic, @@ -48,6 +49,7 @@ from vllm_router.services.request_service.rewriter import ( get_request_rewriter, ) +from vllm_router.services.request_service.zmq_proxy import NixlConfig, ZmqProxy from vllm_router.stats.engine_stats import ( get_engine_stats_scraper, initialize_engine_stats_scraper, @@ -106,6 +108,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): app.state.aiohttp_client_wrapper.start() + if hasattr(app.state, "batch_processor"): await app.state.batch_processor.initialize() @@ -125,7 +128,29 @@ async def lifespan(app: FastAPI): logger.info("Validating external provider models against live provider APIs") await app.state.external_provider_registry.validate_models() - yield + use_nixl = ( + isinstance(app.state.router, DisaggregatedPrefillRouter) + and hasattr(app.state, "nixl_config") + and app.state.nixl_config is not None + ) + if use_nixl: + logger.info( + "Starting ZMQ task because the routing logic is" + " RoutingLogic.DISAGGREGATED_PREFILL and nixl_proxy_host is configured" + ) + nixl_config = app.state.nixl_config + app.state.zmq_proxy = ZmqProxy( + finished_req_ttl=nixl_config.finished_req_ttl, + cleanup_interval=nixl_config.cleanup_interval, + ) + await app.state.zmq_proxy.start(nixl_config.proxy_host, nixl_config.proxy_port) + + yield + + await app.state.zmq_proxy.stop() + else: + yield + await app.state.aiohttp_client_wrapper.stop() # Close the threaded-components @@ -230,8 +255,16 @@ def initialize_all(app: FastAPI, args): namespace=args.k8s_namespace, port=args.k8s_port, label_selector=args.k8s_label_selector, - prefill_model_labels=args.prefill_model_labels, - decode_model_labels=args.decode_model_labels, + prefill_model_labels=( + parse_comma_separated_args(args.prefill_model_labels) + if args.prefill_model_labels + else None + ), + decode_model_labels=( + parse_comma_separated_args(args.decode_model_labels) + if args.decode_model_labels + else None + ), watcher_timeout_seconds=args.k8s_watcher_timeout_seconds, health_check_timeout_seconds=args.backend_health_check_timeout_seconds, ) @@ -363,6 +396,18 @@ def initialize_all(app: FastAPI, args): app.state.router = get_routing_logic() app.state.request_rewriter = get_request_rewriter() + # Build NixlConfig if disaggregated prefill with NIXL proxy is configured. + if hasattr(args, "nixl_proxy_host") and args.nixl_proxy_host is not None: + app.state.nixl_config = NixlConfig( + proxy_host=args.nixl_proxy_host, + proxy_port=args.nixl_proxy_port, + peer_host=args.nixl_peer_host, + peer_init_port=args.nixl_peer_init_port, + peer_alloc_port=args.nixl_peer_alloc_port, + finished_req_ttl=args.nixl_finished_req_ttl, + cleanup_interval=args.nixl_cleanup_interval, + ) + app = FastAPI(lifespan=lifespan) app.include_router(main_router) diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 4a7c222f7..03e78dc9b 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -474,6 +474,53 @@ def parse_args(): default=None, help="Path to a YAML file defining external LLM provider configurations (startup-time only).", ) + parser.add_argument( + "--nixl-peer-host", + type=str, + help="The hostname or IP address of the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-peer-init-port", + type=int, + default=7300, + help="The initialization port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-peer-alloc-port", + type=int, + default=7400, + help="The allocation port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-proxy-host", + type=str, + help="The hostname or IP address for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-proxy-port", + type=int, + default=7500, + help="The port for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.", + ) + parser.add_argument( + "--nixl-finished-req-ttl", + type=float, + default=120.0, + help=( + "Seconds to retain a KV-ready entry in the ZMQ proxy before " + "evicting it. Must be at least as long as the worst-case decode " + "latency for a single request. Defaults to 120 s." + ), + ) + parser.add_argument( + "--nixl-cleanup-interval", + type=float, + default=60.0, + help=( + "How often (seconds) the ZMQ proxy background task scans for " + "stale KV-ready entries. Defaults to 60 s." + ), + ) args = parser.parse_args() args = load_initial_config_from_config_file_if_required(parser, args) diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 67c72c75b..4ab2268ad 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -347,7 +347,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]: async def initialize_client_sessions(self) -> None: """ - Initialize aiohttp ClientSession objects for prefill and decode endpoints. + Initialize aiohttp client sessions for prefill and decode endpoints. This must be called from an async context during app startup. """ if ( @@ -756,18 +756,22 @@ def _add_engine( # Store model information in the endpoint info self.available_engines[engine_name].model_info = model_info - if self.event_loop_ready.is_set() and self.event_loop is not None: - try: + # Initialize client sessions only if event_loop is available + try: + if hasattr(self.app.state, "event_loop") and self.app.state.event_loop: fut = asyncio.run_coroutine_threadsafe( - self.initialize_client_sessions(), - self.event_loop, + self.initialize_client_sessions(), self.app.state.event_loop ) fut.result() - except Exception as e: - logger.error(f"Error initializing client sessions: {e}") - else: - logger.debug( - "Event loop not ready; deferring client session initialization" + logger.info("Client sessions initialized successfully in _add_engine") + else: + # Event loop not ready yet, client sessions will be initialized in lifespan + logger.debug( + "Event loop not ready in _add_engine, client sessions will be initialized later" + ) + except Exception as e: + logger.error( + f"Error initializing client sessions in _add_engine: {e}", exc_info=True ) # Track all models we've ever seen @@ -850,35 +854,63 @@ def close(self): async def initialize_client_sessions(self) -> None: """ - Initialize aiohttp ClientSession objects for prefill and decode endpoints. + Initialize aiohttp client sessions for prefill and decode endpoints. This must be called from an async context during app startup. """ + logger.debug( + f"initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}" + ) if ( self.prefill_model_labels is not None and self.decode_model_labels is not None ): endpoint_infos = self.get_endpoint_info() + logger.debug(f"Got {len(endpoint_infos)} endpoints") for endpoint_info in endpoint_infos: + logger.debug( + f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}" + ) if endpoint_info.model_label in self.prefill_model_labels: if ( hasattr(self.app.state, "prefill_client") and self.app.state.prefill_client is not None ): - await self.app.state.prefill_client.close() - self.app.state.prefill_client = aiohttp.ClientSession( - base_url=endpoint_info.url, - timeout=aiohttp.ClientTimeout(total=None), - ) + # Session already initialised; skip to avoid disrupting + # in-flight requests. xPyD (multiple prefill nodes) is + # not supported in this PR — only the first discovered + # prefill endpoint is used. + logger.debug( + f"prefill_client already set, skipping {endpoint_info.url}" + ) + else: + self.app.state.prefill_client = aiohttp.ClientSession( + base_url=endpoint_info.url, + timeout=aiohttp.ClientTimeout(total=None), + ) + logger.info( + f"Created prefill_client for {endpoint_info.url} with timeout=None" + ) + elif endpoint_info.model_label in self.decode_model_labels: if ( hasattr(self.app.state, "decode_client") and self.app.state.decode_client is not None ): - await self.app.state.decode_client.close() - self.app.state.decode_client = aiohttp.ClientSession( - base_url=endpoint_info.url, - timeout=aiohttp.ClientTimeout(total=None), - ) + logger.debug( + f"decode_client already set, skipping {endpoint_info.url}" + ) + else: + self.app.state.decode_client = aiohttp.ClientSession( + base_url=endpoint_info.url, + timeout=aiohttp.ClientTimeout(total=None), + ) + logger.info( + f"Created decode_client for {endpoint_info.url} with timeout=None" + ) + else: + logger.warning( + "prefill_model_labels or decode_model_labels is None, skipping client session initialization" + ) def has_ever_seen_model(self, model_name: str) -> bool: """Check if we've ever seen this model, even if currently scaled to zero.""" @@ -1212,6 +1244,21 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str # Store model information in the endpoint info self.available_engines[engine_name].model_info = model_info + try: + # Only initialize client sessions if event_loop is available + if hasattr(self.app.state, "event_loop") and self.app.state.event_loop: + fut = asyncio.run_coroutine_threadsafe( + self.initialize_client_sessions(), self.app.state.event_loop + ) + fut.result() + else: + # Event loop not ready yet, client sessions will be initialized in lifespan + logger.debug( + "Event loop not ready, client sessions will be initialized later" + ) + except Exception as e: + logger.error(f"Error initializing client sessions: {e}") + def _delete_engine(self, engine_name: str): logger.info(f"Serving engine {engine_name} is deleted") with self.available_engines_lock: @@ -1287,25 +1334,58 @@ def close(self): async def initialize_client_sessions(self) -> None: """ - Initialize aiohttp ClientSession objects for prefill and decode endpoints. + Initialize aiohttp client sessions for prefill and decode endpoints. This must be called from an async context during app startup. """ + logger.debug( + f"K8sServiceNameServiceDiscovery.initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}" + ) if ( self.prefill_model_labels is not None and self.decode_model_labels is not None ): endpoint_infos = self.get_endpoint_info() + logger.debug(f"Got {len(endpoint_infos)} endpoints") for endpoint_info in endpoint_infos: + logger.debug( + f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}" + ) if endpoint_info.model_label in self.prefill_model_labels: - self.app.state.prefill_client = aiohttp.ClientSession( - base_url=endpoint_info.url, - timeout=aiohttp.ClientTimeout(total=None), - ) + if ( + hasattr(self.app.state, "prefill_client") + and self.app.state.prefill_client is not None + ): + logger.debug( + f"prefill_client already set, skipping {endpoint_info.url}" + ) + else: + self.app.state.prefill_client = aiohttp.ClientSession( + base_url=endpoint_info.url, + timeout=aiohttp.ClientTimeout(total=None), + ) + logger.info( + f"Created prefill_client for {endpoint_info.url} with timeout=None" + ) elif endpoint_info.model_label in self.decode_model_labels: - self.app.state.decode_client = aiohttp.ClientSession( - base_url=endpoint_info.url, - timeout=aiohttp.ClientTimeout(total=None), - ) + if ( + hasattr(self.app.state, "decode_client") + and self.app.state.decode_client is not None + ): + logger.debug( + f"decode_client already set, skipping {endpoint_info.url}" + ) + else: + self.app.state.decode_client = aiohttp.ClientSession( + base_url=endpoint_info.url, + timeout=aiohttp.ClientTimeout(total=None), + ) + logger.info( + f"Created decode_client for {endpoint_info.url} with timeout=None" + ) + else: + logger.warning( + "K8sServiceNameServiceDiscovery: prefill_model_labels or decode_model_labels is None, skipping client session initialization" + ) def _create_service_discovery( diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 961bf8c15..943746ab4 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -14,6 +14,7 @@ import json import os +import re import time import uuid from typing import Optional @@ -369,9 +370,19 @@ async def route_general_request( StreamingResponse: A response object that streams data from the backend server to the client. """ if isinstance(request.app.state.router, DisaggregatedPrefillRouter): - response = await route_disaggregated_prefill_request( - request, endpoint, background_tasks + use_nixl = ( + hasattr(request.app.state, "zmq_proxy") + and hasattr(request.app.state, "nixl_config") + and request.app.state.nixl_config.peer_host is not None ) + if use_nixl: + response = await route_disaggregated_prefill_nixl_request( + request, endpoint, background_tasks + ) + else: + response = await route_disaggregated_prefill_request( + request, endpoint, background_tasks + ) return response # Handle orchestrated disaggregated inference (NxDI pattern) @@ -663,6 +674,22 @@ async def send_request_to_prefiller( return await response.json() +async def send_request_to_tokenizer( + client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a tokenizer service using aiohttp. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + async with client.post(endpoint, json=req_data, headers=headers) as response: + response.raise_for_status() + return await response.json() + + async def send_request_to_decode( client: aiohttp.ClientSession, endpoint: str, req_data: dict, request_id: str ): @@ -883,8 +910,8 @@ async def route_disaggregated_prefill_request( endpoint: str, background_tasks: BackgroundTasks, ): + """Route disaggregated prefill request using LMCache shared storage mode.""" in_router_time = time.time() - # Same as vllm, Get request_id from X-Request-Id header if available request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) request_json = await request.json() @@ -942,6 +969,285 @@ async def generate_stream(): request.app.state.decode_client, endpoint, request_json, request_id ): yield chunk + except aiohttp.ClientResponseError as e: + logger.error(f"HTTP error in decoder: {e}", exc_info=True) + try: + error_text = e.message + except Exception: + error_text = f"HTTP {e.status}" + error_response = { + "error": { + "message": f"Decoder error: {error_text}", + "type": "decoder_error", + "code": e.status, + } + } + yield json.dumps(error_response).encode("utf-8") + except Exception as e: + logger.error(f"Unexpected error in decoder: {e}", exc_info=True) + error_response = { + "error": { + "message": f"Decoder error: {str(e)}", + "type": "decoder_error", + "code": 500, + } + } + yield json.dumps(error_response).encode("utf-8") + + curr_time = time.time() + logger.info( + f"Routing request {request_id} with session id None to {request.app.state.decode_client._base_url} at {curr_time}, process time = {curr_time - et:.4f}" + ) + + return StreamingResponse( + generate_stream(), + media_type="application/json", + headers={"X-Request-Id": request_id}, + ) + + +async def _prepare_nixl_prefill_request(request, request_json, request_id): + """Handle tokenization, build disagg_spec, inject kv_transfer_params. + + Mutates request_json in place: replaces prompt with tokens, sets + max_tokens=1, injects kv_transfer_params, and sets stream=False. + Returns the tokenize output (not used by caller today, but available). + """ + # Tokenize the prompt + if "messages" in request_json: + tokenize_payload = {"messages": request_json["messages"]} + else: + tokenize_payload = {"prompt": request_json["prompt"]} + tokenize_output = await send_request_to_tokenizer( + request.app.state.prefill_client, + "/tokenize", + tokenize_payload, + request_id, + ) + # Update request with tokenized prompt + request_json.pop("messages", None) + request_json["prompt"] = tokenize_output["tokens"] + request_json["max_tokens"] = 1 + + # Create disagg_spec for KV transfer + decode_base_url = ( + request.app.state.decode_client._base_url + if hasattr(request.app.state.decode_client, "_base_url") + else str(request.app.state.decode_client.base_url) + ) + ip_match = re.search(r"://([^:]+)", str(decode_base_url)) + receiver_host = ( + ip_match.group(1) if ip_match else request.app.state.nixl_config.peer_host + ) + + disagg_spec = { + "req_id": request_id, + "receiver_host": receiver_host, + "receiver_init_port": [request.app.state.nixl_config.peer_init_port], + "receiver_alloc_port": [request.app.state.nixl_config.peer_alloc_port], + } + + request_json["kv_transfer_params"] = { + "ret_first_tok": True, + "disagg_spec": disagg_spec, + } + request_json["stream"] = False + + return tokenize_output + + +def _convert_completion_chunk_to_chat(chunk_data): + """Convert a single /v1/completions chunk dict to chat.completion.chunk format. + + Returns the converted dict. + """ + return { + "id": chunk_data["id"], + "object": "chat.completion.chunk", + "created": chunk_data["created"], + "model": chunk_data["model"], + "choices": [ + { + "index": 0, + "delta": {"content": chunk_data["choices"][0]["text"]}, + "logprobs": chunk_data["choices"][0].get("logprobs"), + "finish_reason": chunk_data["choices"][0].get("finish_reason"), + } + ], + } + + +def _clean_completion_chunk(chunk_data): + """Strip extra fields (prompt_token_ids, token_ids) from a completion chunk. + + Returns a cleaned dict containing only the standard completion fields. + """ + return { + "id": chunk_data["id"], + "object": "text_completion", + "created": chunk_data["created"], + "model": chunk_data["model"], + "choices": [ + { + "index": 0, + "text": chunk_data["choices"][0]["text"], + "logprobs": chunk_data["choices"][0].get("logprobs"), + "finish_reason": chunk_data["choices"][0].get("finish_reason"), + "stop_reason": chunk_data["choices"][0].get("stop_reason"), + } + ], + "usage": chunk_data.get("usage"), + } + + +async def route_disaggregated_prefill_nixl_request( + request: Request, + endpoint: str, + background_tasks: BackgroundTasks, +): + in_router_time = time.time() + # Same as vllm, Get request_id from X-Request-Id header if available + request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) + request_json = await request.json() + + orig_max_tokens = request_json.get("max_tokens", 0) + stream_options = request_json.pop("stream_options", None) + is_chat_completion = "messages" in request_json + + try: + await _prepare_nixl_prefill_request(request, request_json, request_id) + + # Send to prefiller + prefill_output = await send_request_to_prefiller( + request.app.state.prefill_client, + "/v1/completions", + request_json, + request_id, + ) + et = time.time() + logger.info( + f"Routing request {request_id} with session id None to " + f"{request.app.state.prefill_client._base_url} at {et}, " + f"process time = {et - in_router_time:.4f}" + ) + # Prepare decode request + request_json["max_tokens"] = orig_max_tokens - 1 + request_json["prompt"].append(prefill_output["kv_transfer_params"]["first_tok"]) + request_json.pop("kv_transfer_params") + request_json["stream"] = True + if stream_options is not None: + request_json["stream_options"] = stream_options + + except Exception as e: + logger.error( + f"Error in prefiller stage: {type(e).__name__}: {e}", exc_info=True + ) + return JSONResponse( + status_code=500, + content={ + "error": { + "message": f"Prefiller error: {str(e)}", + "type": "prefiller_error", + "code": 500, + } + }, + headers={"X-Request-Id": request_id}, + ) + + async def generate_stream(): + try: + # Check if original request was chat completions + + if is_chat_completion: + # For chat completions, yield initial chunk with role + initial_chunk = { + "id": prefill_output["id"], + "object": "chat.completion.chunk", + "created": prefill_output["created"], + "model": prefill_output["model"], + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "logprobs": None, + "finish_reason": None, + } + ], + } + yield ( + "data: " + json.dumps(initial_chunk, separators=(",", ":")) + "\n\n" + ).encode() + + # Then yield head chunk with content + head_chunk = { + "id": prefill_output["id"], + "object": "chat.completion.chunk", + "created": prefill_output["created"], + "model": prefill_output["model"], + "choices": [ + { + "index": 0, + "delta": {"content": prefill_output["choices"][0]["text"]}, + "logprobs": None, + "finish_reason": None, + } + ], + } + else: + # For completions, use original format (clean, without extra fields) + head_chunk = { + "id": prefill_output["id"], + "object": "text_completion", + "created": prefill_output["created"], + "model": prefill_output["model"], + "choices": [ + { + "index": 0, + "text": prefill_output["choices"][0]["text"], + "logprobs": None, + "finish_reason": None, + "stop_reason": None, + } + ], + "usage": None, + } + + yield ( + "data: " + json.dumps(head_chunk, separators=(",", ":")) + "\n\n" + ).encode() + + await request.app.state.zmq_proxy.wait_kv_ready(request_id) + + # Stream the rest from decode service + async for chunk in send_request_to_decode( + request.app.state.decode_client, + "/v1/completions", + request_json, + request_id, + ): + chunk_str = chunk.decode("utf-8") + if chunk_str.startswith("data: ") and not chunk_str.startswith( + "data: [DONE]" + ): + try: + json_str = chunk_str[6:].strip() # Remove 'data: ' prefix + if json_str: + completion_data = json.loads(json_str) + if is_chat_completion: + converted = _convert_completion_chunk_to_chat( + completion_data + ) + else: + converted = _clean_completion_chunk(completion_data) + yield ( + "data: " + + json.dumps(converted, separators=(",", ":")) + + "\n\n" + ).encode() + except (json.JSONDecodeError, KeyError): + yield chunk + else: + yield chunk except aiohttp.ClientResponseError as e: logger.error(f"HTTP error in decoder: {e}", exc_info=True) try: diff --git a/src/vllm_router/services/request_service/zmq_proxy.py b/src/vllm_router/services/request_service/zmq_proxy.py new file mode 100644 index 000000000..f22b3e965 --- /dev/null +++ b/src/vllm_router/services/request_service/zmq_proxy.py @@ -0,0 +1,193 @@ +# Copyright 2024-2025 The vLLM Production Stack Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ZMQ proxy server for PD disaggregated prefill KV transfer notifications.""" + +import asyncio +import time +from dataclasses import dataclass + +import msgspec +import zmq +import zmq.asyncio + +from vllm_router.log import init_logger + +try: + from lmcache.v1.storage_backend.connector.nixl_connector_v3 import ( + NixlMsg, + ) +except ImportError: + try: + from lmcache.v1.storage_backend.pd_backend import ProxyNotif as NixlMsg + except ImportError: + + class NixlMsg(msgspec.Struct): + req_id: str + + +logger = init_logger(__name__) + + +@dataclass +class NixlConfig: + """NIXL-specific configuration for disaggregated prefill routing.""" + + proxy_host: str + proxy_port: int + peer_host: str + peer_init_port: int + peer_alloc_port: int + finished_req_ttl: float = 120.0 + cleanup_interval: float = 60.0 + + +class ZmqProxy: + """Manages a ZMQ PULL server for KV transfer completion notifications.""" + + def __init__( + self, + finished_req_ttl: float = 120.0, + cleanup_interval: float = 60.0, + ): + """ + Args: + finished_req_ttl: Seconds to keep a KV-ready entry before evicting + it. Should be at least as long as the longest expected decode + latency so that a slow decoder can still find its entry. + Defaults to 120 s (2× a typical 60 s worst-case decode). + cleanup_interval: How often the background cleanup task runs. + Defaults to 60 s; tune down if memory is a concern. + """ + self._pending: dict[str, asyncio.Event] = {} + self._finished_ts: dict[str, float] = {} + self._finished_req_ttl = finished_req_ttl + self._cleanup_interval = cleanup_interval + self._run_proxy: bool = True + self._zmq_ctx = zmq.asyncio.Context() + self._task: asyncio.Task | None = None + self._cleanup_task: asyncio.Task | None = None + + async def _pull_server(self, proxy_host: str, proxy_port: int): + """ZMQ PULL server that receives KV transfer completion notifications.""" + try: + socket = self._zmq_ctx.socket(zmq.PULL) + proxy_url = f"{proxy_host}:{proxy_port}" + socket.bind(f"tcp://{proxy_url}") + logger.info(f"ZMQ proxy server started on {proxy_url}") + except Exception as e: + logger.error(f"Failed to bind ZMQ socket to {proxy_url}: {e}") + socket.close() + return + + while self._run_proxy: + try: + msg_bytes = await socket.recv() + # Decode without strict type checking — LMCache may send + # ProxyNotif while router expects NixlMsg. Both have req_id. + try: + msg = msgspec.msgpack.decode(msg_bytes, type=NixlMsg) + except Exception: + # Fallback: decode as generic dict + msg_dict = msgspec.msgpack.decode(msg_bytes) + if isinstance(msg_dict, dict) and "req_id" in msg_dict: + msg = type("Msg", (), {"req_id": msg_dict["req_id"]})() + else: + logger.warning(f"ZMQ: unknown message format: {msg_dict}") + continue + req_id = msg.req_id + self._finished_ts[req_id] = time.time() + # Wake up any coroutine waiting on this request. + event = self._pending.get(req_id) + if event is not None: + event.set() + logger.debug(f"Prefill of req {req_id} done.") + except zmq.Again: + await asyncio.sleep(0.01) + except Exception as e: + logger.error(f"ZMQ Error in message processing: {e}") + # Don't break — continue processing messages + await asyncio.sleep(0.1) + + socket.close() + logger.info("ZMQ PULL server stopped.") + + async def _cleanup_loop(self): + """Periodically evict stale entries from _finished_ts and _pending.""" + while self._run_proxy: + await asyncio.sleep(self._cleanup_interval) + now = time.time() + stale = [ + req_id + for req_id, ts in self._finished_ts.items() + if now - ts > self._finished_req_ttl + ] + for req_id in stale: + del self._finished_ts[req_id] + self._pending.pop(req_id, None) + if stale: + logger.debug(f"ZMQ cleanup: evicted {len(stale)} stale req entries.") + + async def start(self, proxy_host: str = "0.0.0.0", proxy_port: int = 7500): + """Start the ZMQ pull server task.""" + if self._task is None: + self._task = asyncio.create_task(self._pull_server(proxy_host, proxy_port)) + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("ZMQ task started") + await asyncio.sleep(0.1) + + async def stop(self): + """Stop the ZMQ pull server task.""" + if self._task is not None: + self._run_proxy = False + self._task.cancel() + if self._cleanup_task is not None: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + logger.info("ZMQ task stopped") + + async def wait_kv_ready(self, req_id: str, timeout: float = 10.0): + """Wait for ZMQ notification that KV transfer is done, with timeout. + + Suspends the coroutine until the prefill node signals completion via + an asyncio.Event, avoiding a busy-wait loop. If timeout expires, + proceed anyway — decode will fallback to recompute via + kv_load_failure_policy='recompute'. + """ + # If the signal already arrived before we start waiting, skip the wait. + if req_id not in self._finished_ts: + event = self._pending.setdefault(req_id, asyncio.Event()) + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + logger.warning( + f"Timeout ({timeout}s) waiting for KV ready signal for req" + f" {req_id}. Proceeding to decode (will recompute if KV" + " not available)." + ) + return + finally: + self._pending.pop(req_id, None) + + logger.debug(f"Prefill node signaled kv ready for req {req_id}") + self._finished_ts.pop(req_id, None)