Skip to content

[Feat][Router] Add load-balanced KV-aware routing strategy#884

Open
notTyche wants to merge 2 commits intovllm-project:mainfrom
notTyche:feat/load-balanced-kvaware-router
Open

[Feat][Router] Add load-balanced KV-aware routing strategy#884
notTyche wants to merge 2 commits intovllm-project:mainfrom
notTyche:feat/load-balanced-kvaware-router

Conversation

@notTyche
Copy link
Copy Markdown

@notTyche notTyche commented Mar 12, 2026

Add load_balanced_kvaware routing strategy

The existing kvaware router always routes based on KV cache locality (via LMCache), which can cause uneven load distribution when some replicas accumulate a large backlog while others sit idle. This PR introduces a new load_balanced_kvaware routing strategy that adds a load-balance check as a first tier before the KV-aware lookup, ensuring that heavily overloaded replicas are relieved before cache locality is considered.

How it works

LoadBalancedKvawareRouter inherits from KvawareRouter and operates in three tiers:

  1. Tier 1 — Load balance check: Compute queue_length = num_running_requests + num_queuing_requests for each replica. If max(queue_lengths) - min(queue_lengths) >= imbalanced_threshold, route immediately to the least-loaded replica.
  2. Tier 2 — KV-aware routing: If load is balanced, delegate to the parent KvawareRouter.route_request() which performs the standard LMCache prefix lookup.
  3. Tier 3 — Fallback: If no KV cache hit is found, fall back to session-based consistent hashing (if a session ID header is present) or QPS-based routing.

The default value for --imbalanced-threshold is infinity, which means the router behaves identically to kvaware unless the threshold is explicitly set.

New CLI argument

Argument Type Default Description
--imbalanced-threshold float inf Queue length spread above which load balancing takes priority over KV cache locality. Lower values prioritize load balancing.

Example usage

# Always KV-aware (same as kvaware)
--routing-logic load_balanced_kvaware

# Rebalance when any replica is 10+ requests ahead of the lightest
--routing-logic load_balanced_kvaware --imbalanced-threshold 10

Helm values

routerSpec:
  routingLogic: "load_balanced_kvaware"
  lmcacheControllerPort: 9000
  sessionKey: "x-user-id"
  extraArgs:
    - "--imbalanced-threshold"
    - "10"

Files changed

File Change
src/vllm_router/routers/routing_logic.py New LoadBalancedKvawareRouter class inheriting from KvawareRouter. Adds a load-balance tier and delegates KV-aware logic to the parent implementation. Registered in initialize_routing_logic, get_routing_logic, and cleanup_routing_logic.
src/vllm_router/parsers/parser.py Added load_balanced_kvaware to --routing-logic choices and introduced the --imbalanced-threshold argument.
src/vllm_router/app.py Passes imbalanced_threshold during router initialization.
src/vllm_router/services/request_service/request.py Added LoadBalancedKvawareRouter to the async routing and failover retry isinstance checks.
src/tests/test_load_balanced_kvaware_router.py Added 21 unit tests covering _get_queue_length, _is_load_balanced, _route_to_least_loaded, and end-to-end route_request behavior for load balancing and KV-aware fallback paths.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a sophisticated load_balanced_kvaware routing strategy designed to optimize request distribution in the router. Unlike the purely KV-aware approach, this new strategy intelligently assesses the load across available replicas. If a significant imbalance in queue lengths is detected, it prioritizes routing requests to the least-loaded replica to prevent bottlenecks and ensure smoother operation. When loads are balanced, it gracefully falls back to the existing KV-aware logic, maintaining cache locality benefits. This enhancement provides a more robust and adaptive routing mechanism, improving overall system efficiency and responsiveness.

Highlights

  • New Routing Strategy: Introduced load_balanced_kvaware routing, which combines load balancing with KV-aware routing to optimize request distribution.
  • Load Balance Prioritization: The new strategy first checks for load imbalance across replicas and routes to the least-loaded one if a configurable threshold is exceeded, preventing bottlenecks.
  • Configurable Threshold: Added a new CLI argument, --imbalanced-threshold, to control when load balancing takes precedence over KV cache locality, with a default of infinity.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/tests/test_load_balanced_kvaware_router.py
    • Added a new test file containing comprehensive unit tests for the LoadBalancedKvawareRouter, covering its core functionalities like queue length calculation, load balance detection, least-loaded routing, and end-to-end request routing scenarios including fallbacks.
  • src/vllm_router/app.py
    • Modified the initialize_all function to accept and pass the imbalanced_threshold argument to the routing logic initialization.
  • src/vllm_router/parsers/parser.py
    • Added load_balanced_kvaware as a valid choice for the --routing-logic argument.
    • Introduced the --imbalanced-threshold argument to configure the load balancing behavior of the new router.
  • src/vllm_router/routers/routing_logic.py
    • Implemented the LoadBalancedKvawareRouter class, which includes logic for queue length calculation, load balance checks, and routing to the least-loaded endpoint.
    • Registered the new router in the RoutingLogic enum, initialize_routing_logic, get_routing_logic, and cleanup_routing_logic functions.
  • src/vllm_router/services/request_service/request.py
    • Updated isinstance checks in route_general_request to include LoadBalancedKvawareRouter for both initial routing and failover retry paths.
Activity
  • No human activity (comments, reviews, etc.) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new load_balanced_kvaware routing strategy, which is a great addition for improving load distribution. The implementation correctly adds a load-balancing tier before falling back to the existing KV-aware logic. The new --imbalanced-threshold argument is well-documented and provides good control over the routing behavior. The unit tests are comprehensive and cover many edge cases.

My main feedback is on the significant code duplication between the new LoadBalancedKvawareRouter and the existing KvawareRouter. Refactoring this to use inheritance would greatly improve the code's maintainability and reduce redundancy. I've also included a suggestion to simplify one of the helper methods for better readability.

Comment thread src/vllm_router/routers/routing_logic.py Outdated
Comment thread src/vllm_router/routers/routing_logic.py Outdated
@notTyche notTyche force-pushed the feat/load-balanced-kvaware-router branch from f9334a7 to 8bdb482 Compare March 12, 2026 18:39
Signed-off-by: Matteo Perfidio <perfidiomatteo7@gmail.com>
@notTyche notTyche force-pushed the feat/load-balanced-kvaware-router branch from 8bdb482 to 0e74e54 Compare March 12, 2026 18:42
@notTyche
Copy link
Copy Markdown
Author

Documentation for this new routing strategy can be added in a follow-up PR. Happy to create a tutorial similar to 17-kv-aware-routing.md

Copy link
Copy Markdown
Collaborator

@ruizhang0101 ruizhang0101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some benchmarking compare to the kv-aware routing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants