Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ dependencies = [
"opentelemetry-exporter-otlp>=1.28.0",
"h11>=0.16.0", # fix critical vulnerability GHSA-vqfr-h8mv-ghfj
"httpcore>=1.0.8", # required for h11>=0.16.0
"pyzmq>=27.0.0",
"msgspec>=0.19.0",
]

[project.scripts]
Expand Down
51 changes: 48 additions & 3 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm_router.routers.main_router import main_router
from vllm_router.routers.metrics_router import metrics_router
from vllm_router.routers.routing_logic import (
DisaggregatedPrefillRouter,
cleanup_routing_logic,
get_routing_logic,
initialize_routing_logic,
Expand All @@ -48,6 +49,7 @@
from vllm_router.services.request_service.rewriter import (
get_request_rewriter,
)
from vllm_router.services.request_service.zmq_proxy import NixlConfig, ZmqProxy
from vllm_router.stats.engine_stats import (
get_engine_stats_scraper,
initialize_engine_stats_scraper,
Expand Down Expand Up @@ -106,6 +108,7 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.aiohttp_client_wrapper.start()

if hasattr(app.state, "batch_processor"):
await app.state.batch_processor.initialize()

Expand All @@ -125,7 +128,29 @@ async def lifespan(app: FastAPI):
logger.info("Validating external provider models against live provider APIs")
await app.state.external_provider_registry.validate_models()

yield
use_nixl = (
isinstance(app.state.router, DisaggregatedPrefillRouter)
and hasattr(app.state, "nixl_config")
and app.state.nixl_config is not None
)
if use_nixl:
logger.info(
"Starting ZMQ task because the routing logic is"
" RoutingLogic.DISAGGREGATED_PREFILL and nixl_proxy_host is configured"
)
nixl_config = app.state.nixl_config
app.state.zmq_proxy = ZmqProxy(
finished_req_ttl=nixl_config.finished_req_ttl,
cleanup_interval=nixl_config.cleanup_interval,
)
await app.state.zmq_proxy.start(nixl_config.proxy_host, nixl_config.proxy_port)

yield

await app.state.zmq_proxy.stop()
else:
yield

await app.state.aiohttp_client_wrapper.stop()

# Close the threaded-components
Expand Down Expand Up @@ -230,8 +255,16 @@ def initialize_all(app: FastAPI, args):
namespace=args.k8s_namespace,
port=args.k8s_port,
label_selector=args.k8s_label_selector,
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
prefill_model_labels=(
parse_comma_separated_args(args.prefill_model_labels)
if args.prefill_model_labels
else None
),
decode_model_labels=(
parse_comma_separated_args(args.decode_model_labels)
if args.decode_model_labels
else None
),
watcher_timeout_seconds=args.k8s_watcher_timeout_seconds,
health_check_timeout_seconds=args.backend_health_check_timeout_seconds,
)
Expand Down Expand Up @@ -363,6 +396,18 @@ def initialize_all(app: FastAPI, args):
app.state.router = get_routing_logic()
app.state.request_rewriter = get_request_rewriter()

# Build NixlConfig if disaggregated prefill with NIXL proxy is configured.
if hasattr(args, "nixl_proxy_host") and args.nixl_proxy_host is not None:
app.state.nixl_config = NixlConfig(
proxy_host=args.nixl_proxy_host,
proxy_port=args.nixl_proxy_port,
peer_host=args.nixl_peer_host,
peer_init_port=args.nixl_peer_init_port,
peer_alloc_port=args.nixl_peer_alloc_port,
finished_req_ttl=args.nixl_finished_req_ttl,
cleanup_interval=args.nixl_cleanup_interval,
)


app = FastAPI(lifespan=lifespan)
app.include_router(main_router)
Expand Down
47 changes: 47 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,53 @@ def parse_args():
default=None,
help="Path to a YAML file defining external LLM provider configurations (startup-time only).",
)
parser.add_argument(
"--nixl-peer-host",
type=str,
help="The hostname or IP address of the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
)
parser.add_argument(
"--nixl-peer-init-port",
type=int,
default=7300,
help="The initialization port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
)
parser.add_argument(
"--nixl-peer-alloc-port",
type=int,
default=7400,
help="The allocation port for the NIXL peer service. Only use for DisaggregatedPrefillRouter.",
)
parser.add_argument(
"--nixl-proxy-host",
type=str,
help="The hostname or IP address for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.",
)
parser.add_argument(
"--nixl-proxy-port",
type=int,
default=7500,
help="The port for the NIXL proxy server. Only use for DisaggregatedPrefillRouter.",
)
parser.add_argument(
"--nixl-finished-req-ttl",
type=float,
default=120.0,
help=(
"Seconds to retain a KV-ready entry in the ZMQ proxy before "
"evicting it. Must be at least as long as the worst-case decode "
"latency for a single request. Defaults to 120 s."
),
)
parser.add_argument(
"--nixl-cleanup-interval",
type=float,
default=60.0,
help=(
"How often (seconds) the ZMQ proxy background task scans for "
"stale KV-ready entries. Defaults to 60 s."
),
)

args = parser.parse_args()
args = load_initial_config_from_config_file_if_required(parser, args)
Expand Down
140 changes: 110 additions & 30 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]:

async def initialize_client_sessions(self) -> None:
"""
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
Initialize aiohttp client sessions for prefill and decode endpoints.
This must be called from an async context during app startup.
"""
if (
Expand Down Expand Up @@ -756,18 +756,22 @@ def _add_engine(
# Store model information in the endpoint info
self.available_engines[engine_name].model_info = model_info

if self.event_loop_ready.is_set() and self.event_loop is not None:
try:
# Initialize client sessions only if event_loop is available
try:
if hasattr(self.app.state, "event_loop") and self.app.state.event_loop:
fut = asyncio.run_coroutine_threadsafe(
self.initialize_client_sessions(),
self.event_loop,
self.initialize_client_sessions(), self.app.state.event_loop
)
fut.result()
except Exception as e:
logger.error(f"Error initializing client sessions: {e}")
else:
logger.debug(
"Event loop not ready; deferring client session initialization"
logger.info("Client sessions initialized successfully in _add_engine")
else:
# Event loop not ready yet, client sessions will be initialized in lifespan
logger.debug(
"Event loop not ready in _add_engine, client sessions will be initialized later"
)
except Exception as e:
logger.error(
f"Error initializing client sessions in _add_engine: {e}", exc_info=True
)

# Track all models we've ever seen
Expand Down Expand Up @@ -850,35 +854,63 @@ def close(self):

async def initialize_client_sessions(self) -> None:
"""
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
Initialize aiohttp client sessions for prefill and decode endpoints.
This must be called from an async context during app startup.
"""
logger.debug(
f"initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}"
)
if (
self.prefill_model_labels is not None
and self.decode_model_labels is not None
):
endpoint_infos = self.get_endpoint_info()
logger.debug(f"Got {len(endpoint_infos)} endpoints")
for endpoint_info in endpoint_infos:
logger.debug(
f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}"
)
if endpoint_info.model_label in self.prefill_model_labels:
if (
hasattr(self.app.state, "prefill_client")
and self.app.state.prefill_client is not None
):
await self.app.state.prefill_client.close()
self.app.state.prefill_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
# Session already initialised; skip to avoid disrupting
# in-flight requests. xPyD (multiple prefill nodes) is
# not supported in this PR — only the first discovered
# prefill endpoint is used.
logger.debug(
f"prefill_client already set, skipping {endpoint_info.url}"
)
else:
self.app.state.prefill_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
logger.info(
f"Created prefill_client for {endpoint_info.url} with timeout=None"
)

elif endpoint_info.model_label in self.decode_model_labels:
if (
hasattr(self.app.state, "decode_client")
and self.app.state.decode_client is not None
):
await self.app.state.decode_client.close()
self.app.state.decode_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
logger.debug(
f"decode_client already set, skipping {endpoint_info.url}"
)
else:
self.app.state.decode_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
logger.info(
f"Created decode_client for {endpoint_info.url} with timeout=None"
)
else:
logger.warning(
"prefill_model_labels or decode_model_labels is None, skipping client session initialization"
)

def has_ever_seen_model(self, model_name: str) -> bool:
"""Check if we've ever seen this model, even if currently scaled to zero."""
Expand Down Expand Up @@ -1212,6 +1244,21 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str
# Store model information in the endpoint info
self.available_engines[engine_name].model_info = model_info

try:
# Only initialize client sessions if event_loop is available
if hasattr(self.app.state, "event_loop") and self.app.state.event_loop:
fut = asyncio.run_coroutine_threadsafe(
self.initialize_client_sessions(), self.app.state.event_loop
)
fut.result()
else:
# Event loop not ready yet, client sessions will be initialized in lifespan
logger.debug(
"Event loop not ready, client sessions will be initialized later"
)
except Exception as e:
logger.error(f"Error initializing client sessions: {e}")

def _delete_engine(self, engine_name: str):
logger.info(f"Serving engine {engine_name} is deleted")
with self.available_engines_lock:
Expand Down Expand Up @@ -1287,25 +1334,58 @@ def close(self):

async def initialize_client_sessions(self) -> None:
"""
Initialize aiohttp ClientSession objects for prefill and decode endpoints.
Initialize aiohttp client sessions for prefill and decode endpoints.
This must be called from an async context during app startup.
"""
logger.debug(
f"K8sServiceNameServiceDiscovery.initialize_client_sessions called. prefill_model_labels={self.prefill_model_labels}, decode_model_labels={self.decode_model_labels}"
)
if (
self.prefill_model_labels is not None
and self.decode_model_labels is not None
):
endpoint_infos = self.get_endpoint_info()
logger.debug(f"Got {len(endpoint_infos)} endpoints")
for endpoint_info in endpoint_infos:
logger.debug(
f"Checking endpoint: url={endpoint_info.url}, model_label={endpoint_info.model_label}"
)
if endpoint_info.model_label in self.prefill_model_labels:
self.app.state.prefill_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
if (
hasattr(self.app.state, "prefill_client")
and self.app.state.prefill_client is not None
):
logger.debug(
f"prefill_client already set, skipping {endpoint_info.url}"
)
else:
self.app.state.prefill_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
logger.info(
f"Created prefill_client for {endpoint_info.url} with timeout=None"
)
elif endpoint_info.model_label in self.decode_model_labels:
self.app.state.decode_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
if (
hasattr(self.app.state, "decode_client")
and self.app.state.decode_client is not None
):
logger.debug(
f"decode_client already set, skipping {endpoint_info.url}"
)
else:
self.app.state.decode_client = aiohttp.ClientSession(
base_url=endpoint_info.url,
timeout=aiohttp.ClientTimeout(total=None),
)
logger.info(
f"Created decode_client for {endpoint_info.url} with timeout=None"
)
else:
logger.warning(
"K8sServiceNameServiceDiscovery: prefill_model_labels or decode_model_labels is None, skipping client session initialization"
)


def _create_service_discovery(
Expand Down
Loading