From 6fc66200894dd92937cc2c1b0cb62864a1499013 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 04:52:39 +0000 Subject: [PATCH 1/2] Initial plan From 288aa998cefc417cbdef6cf437b1e99826a1c4b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 05:09:04 +0000 Subject: [PATCH 2/2] Add SGLang scheduler integration (RadixAttention + chunked prefill) Co-authored-by: tianhao909 <48342395+tianhao909@users.noreply.github.com> --- vidur-alibabacloud/README-vidur.md | 63 ++++ vidur-alibabacloud/README.md | 6 +- vidur-alibabacloud/vidur/config/config.py | 32 ++ .../replica_scheduler_registry.py | 6 + .../sglang_replica_scheduler.py | 338 ++++++++++++++++++ .../vidur/types/replica_scheduler_type.py | 1 + 6 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 vidur-alibabacloud/vidur/scheduler/replica_scheduler/sglang_replica_scheduler.py diff --git a/vidur-alibabacloud/README-vidur.md b/vidur-alibabacloud/README-vidur.md index 6f1cd2d0..9e11f65f 100644 --- a/vidur-alibabacloud/README-vidur.md +++ b/vidur-alibabacloud/README-vidur.md @@ -125,6 +125,69 @@ or to get information on all parameters, python -m vidur.main -h ``` +### SGLang scheduler (RadixAttention + chunked prefill) + +SimAI supports simulating [SGLang](https://github.com/sgl-project/sglang)'s runtime scheduler, +which combines two key optimisations: + +| Feature | CLI parameter | Default | +|---------|--------------|---------| +| **Chunked prefill** – long prompts are split into fixed-size chunks, interleaving prefill and decode to reduce head-of-line blocking | `--sglang_scheduler_config_chunk_size` | `512` | +| **RadixAttention prefix caching** – KV-cache blocks for shared prefixes (e.g. a system prompt) are reused across requests, reducing both memory allocation and the number of prefill chunks | `--sglang_scheduler_config_enable_prefix_caching` | `True` | +| **Prefix cache hit rate** – fraction of prefill tokens satisfied by the prefix cache (set based on your workload; e.g. `0.7`–`0.9` for workloads with long shared system prompts) | `--sglang_scheduler_config_prefix_cache_hit_rate` | `0.0` | +| **Max tokens per batch** | `--sglang_scheduler_config_max_tokens_in_batch` | `4096` | + +Example command (Llama-3-8B, simulating a workload where 70% of prefill tokens hit the prefix cache): + +```sh +python -m vidur.main \ + --replica_config_device a100 \ + --replica_config_model_name meta-llama/Meta-Llama-3-8B \ + --cluster_config_num_replicas 1 \ + --replica_config_tensor_parallel_size 1 \ + --replica_config_num_pipeline_stages 1 \ + --request_generator_config_type synthetic \ + --synthetic_request_generator_config_num_requests 512 \ + --length_generator_config_type trace \ + --trace_request_length_generator_config_max_tokens 16384 \ + --trace_request_length_generator_config_trace_file ./data/processed_traces/splitwise_conv.csv \ + --interval_generator_config_type poisson \ + --poisson_request_interval_generator_config_qps 6.45 \ + --replica_scheduler_config_type sglang \ + --sglang_scheduler_config_chunk_size 512 \ + --sglang_scheduler_config_enable_prefix_caching \ + --sglang_scheduler_config_prefix_cache_hit_rate 0.7 \ + --sglang_scheduler_config_max_tokens_in_batch 4096 \ + --random_forrest_execution_time_predictor_config_prediction_max_prefill_chunk_size 16384 \ + --random_forrest_execution_time_predictor_config_prediction_max_batch_size 512 \ + --random_forrest_execution_time_predictor_config_prediction_max_tokens_per_request 16384 +``` + +**How the simulation models SGLang behaviour** + +* *Chunked prefill* – identical to the Sarathi-Serve scheduler already in SimAI; each + scheduling iteration processes at most `chunk_size` new prefill tokens. + +* *Prefix-cache memory savings* – when `enable_prefix_caching=True` the scheduler allocates + only `ceil((1 − hit_rate) × num_prefill_tokens / block_size)` fresh KV blocks for each + new request. The remaining blocks are treated as shared cache entries that require no new + allocation. This allows more requests to fit in GPU memory concurrently, correctly + modelling SGLang's `RadixAttention` memory savings. + +* *Reduced prefill iterations* – the cached portion of a request's prompt is "fast-forwarded" + in the first scheduling iteration (the `num_processed_tokens` counter advances past the + cached portion at no extra execution cost), so the total number of chunked-prefill rounds + is reduced proportionally. + +**Choosing `prefix_cache_hit_rate`** + +| Workload characteristic | Suggested value | +|------------------------|----------------| +| No shared prefix (pure decode, random prompts) | `0.0` | +| Short system prompt (≤ 5 % of prompt length) | `0.05`–`0.15` | +| Medium system prompt / few-shot examples | `0.3`–`0.5` | +| Long shared system prompt (≥ 70 % of prompt length) | `0.7`–`0.95` | + ## Simulator Output * The metrics will be logged to wandb directly and a copy will be stored in the `simulator_output/` directory. __A description of all the logged metrics can be found [here](docs/metrics.md).__ diff --git a/vidur-alibabacloud/README.md b/vidur-alibabacloud/README.md index 244a1e47..10f44d59 100644 --- a/vidur-alibabacloud/README.md +++ b/vidur-alibabacloud/README.md @@ -254,7 +254,11 @@ python -m vidur.main \ | `--cluster_config_num_replicas` | 1 | Total number of replicas (i.e., data parallelism degree) | | `--replica_config_pd_node_ratio` | 0.5 | Ratio of P-nodes to (P-nodes + D-nodes) Fraction of replicas allocated as prefill (P) nodes. The remaining replicas are used as decode (D) nodes. For example, 0.5 means half of the replicas are prefill nodes and half are decode nodes (P:D = 1:1). | | `--global_scheduler_config_type` | round_robin | Global scheduler type (`split_wise`, `round_robin`, etc.) | -| `--replica_scheduler_config_type` | sarathi | Per-replica scheduler type | +| `--replica_scheduler_config_type` | sarathi | Per-replica scheduler type (`sarathi`, `vllm`, `orca`, `lightllm`, `sglang`, `split_wise`, `faster_transformer`) | +| `--sglang_scheduler_config_chunk_size` | 512 | Chunked-prefill chunk size for SGLang (tokens per chunk); only effective when `--replica_scheduler_config_type sglang` | +| `--sglang_scheduler_config_enable_prefix_caching` | True | Enable RadixAttention-based prefix caching for SGLang; only effective when `--replica_scheduler_config_type sglang` | +| `--sglang_scheduler_config_prefix_cache_hit_rate` | 0.0 | Fraction of prefill tokens that hit the prefix cache (0.0–1.0); see SGLang scheduler section for guidance; only effective when `--replica_scheduler_config_type sglang` | +| `--sglang_scheduler_config_max_tokens_in_batch` | 4096 | Maximum total tokens per batch for SGLang; only effective when `--replica_scheduler_config_type sglang` | | `--replica_config_model_name` | meta-llama/Llama-2-7b-hf | Model name (DeepSeek-671B, Qwen3-Moe-235B, Qwen3-Next-80B , etc.)
⚠️ **Note**: Vidur GPU Memory management module is still under adaptation for DeepSeek-671B, Qwen3-Moe-235B, Qwen3-Next-80B | | `--replica_config_tensor_parallel_size` | 1 | Tensor parallelism size (TP) | | `--replica_config_num_pipeline_stages` | 1 | Number of pipeline stages (PP) | diff --git a/vidur-alibabacloud/vidur/config/config.py b/vidur-alibabacloud/vidur/config/config.py index 14eb13db..36c45a7c 100644 --- a/vidur-alibabacloud/vidur/config/config.py +++ b/vidur-alibabacloud/vidur/config/config.py @@ -342,6 +342,38 @@ def get_type(): return ReplicaSchedulerType.SPLIT_WISE +@dataclass +class SglangSchedulerConfig(BaseReplicaSchedulerConfig): + chunk_size: int = field( + default=512, + metadata={"help": "Chunked-prefill chunk size for SGLang (tokens per chunk)."}, + ) + enable_prefix_caching: bool = field( + default=True, + metadata={"help": "Enable RadixAttention-based prefix caching (SGLang feature)."}, + ) + prefix_cache_hit_rate: float = field( + default=0.0, + metadata={ + "help": ( + "Fraction of prefill tokens that hit the prefix cache (0.0–1.0). " + "Set based on your workload's prefix-sharing characteristics " + "(e.g. 0.7–0.9 for workloads with long shared system prompts). " + "Cached tokens reduce KV-block allocation, modelling SGLang's " + "RadixAttention memory savings." + ) + }, + ) + max_tokens_in_batch: int = field( + default=4096, + metadata={"help": "Maximum total tokens per batch for SGLang."}, + ) + + @staticmethod + def get_type(): + return ReplicaSchedulerType.SGLANG + + @dataclass class MetricsConfig: """Metric configuration.""" diff --git a/vidur-alibabacloud/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py b/vidur-alibabacloud/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py index bf2bf1ec..9737454d 100644 --- a/vidur-alibabacloud/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py +++ b/vidur-alibabacloud/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py @@ -20,6 +20,9 @@ from vidur.scheduler.replica_scheduler.splitwise_replica_scheduler import ( SplitwiseReplicaScheduler, ) +from vidur.scheduler.replica_scheduler.sglang_replica_scheduler import ( + SglangReplicaScheduler, +) class ReplicaSchedulerRegistry(BaseRegistry): pass @@ -36,4 +39,7 @@ class ReplicaSchedulerRegistry(BaseRegistry): ) ReplicaSchedulerRegistry.register( ReplicaSchedulerType.SPLIT_WISE, SplitwiseReplicaScheduler +) +ReplicaSchedulerRegistry.register( + ReplicaSchedulerType.SGLANG, SglangReplicaScheduler ) \ No newline at end of file diff --git a/vidur-alibabacloud/vidur/scheduler/replica_scheduler/sglang_replica_scheduler.py b/vidur-alibabacloud/vidur/scheduler/replica_scheduler/sglang_replica_scheduler.py new file mode 100644 index 00000000..2365f04b --- /dev/null +++ b/vidur-alibabacloud/vidur/scheduler/replica_scheduler/sglang_replica_scheduler.py @@ -0,0 +1,338 @@ +""" +SGLang replica scheduler for SimAI. + +Simulates two key SGLang features: + +1. **Chunked prefill** (similar to Sarathi-Serve) + Long prompts are split into fixed-size chunks so that prefill and decode + iterations can be interleaved, reducing head-of-line blocking. + +2. **RadixAttention prefix caching** + SGLang maintains a radix tree of KV-cache blocks indexed by token-prefix + hashes. When a new request shares a common prefix (e.g. a system prompt) + with previously-cached content those KV blocks are reused without + recomputation. + + Because this simulator works at the token-count level (not with actual token + values) the prefix-cache benefit is approximated via a configurable + ``prefix_cache_hit_rate``. A hit rate of *r* means that the first + ``floor(r * num_prefill_tokens)`` tokens of every new request are already + present in the cache, so: + + * **Memory** – only ``ceil((1-r) * num_prefill_tokens / block_size)`` new + KV blocks are allocated (the remainder are shared cache blocks). + * **Scheduler latency** – the first scheduling iteration for that request + advances ``num_processed_tokens`` by the full cached amount in one step + (no extra attention computation is needed), reducing the number of + prefill chunks required. + +Usage example (via CLI):: + + python -m vidur.main \\ + --replica_scheduler_config_type sglang \\ + --sglang_scheduler_config_chunk_size 512 \\ + --sglang_scheduler_config_enable_prefix_caching True \\ + --sglang_scheduler_config_prefix_cache_hit_rate 0.7 \\ + --sglang_scheduler_config_max_tokens_in_batch 4096 \\ + ... +""" + +from math import ceil +from typing import Dict, List, Set + +from vidur.entities.batch import Batch, Request +from vidur.scheduler.replica_scheduler.base_replica_scheduler import ( + BaseReplicaScheduler, +) + + +class SglangReplicaScheduler(BaseReplicaScheduler): + """SGLang-style replica scheduler with chunked prefill and prefix caching.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._num_running_batches: int = 0 + self._preempted_requests: List[Request] = [] + # Loose per-stage cap (memory is tracked explicitly via block allocation) + self._max_micro_batch_size: int = self._config.batch_size_cap // self._num_stages + self._watermark_blocks: int = int( + self._config.watermark_blocks_fraction * self._config.num_blocks + ) + + # --- prefix-cache tracking --- + # Maps request_id → number of prefill tokens satisfied by the prefix cache. + # These tokens do not need new KV-block allocation. + self._prefix_cache_hit_tokens: Dict[int, int] = {} + # Set of request IDs whose cache-hit token "bump" has already been + # included in a batch (so we don't add it twice). + self._cache_hits_advanced: Set[int] = set() + + # ------------------------------------------------------------------ + # Prefix-cache helpers + # ------------------------------------------------------------------ + + def _get_cache_hit_tokens(self, request: Request) -> int: + """Return the number of prefill tokens satisfied by the prefix cache.""" + if not self._config.enable_prefix_caching: + return 0 + return int(request.num_prefill_tokens * self._config.prefix_cache_hit_rate) + + def _effective_new_prefill_tokens(self, request: Request) -> int: + """Return the number of *new* (non-cached) prefill tokens for a request.""" + hit = self._prefix_cache_hit_tokens.get( + request.id, self._get_cache_hit_tokens(request) + ) + return max(0, request.num_prefill_tokens - hit) + + # ------------------------------------------------------------------ + # Memory allocation (overrides base class) + # ------------------------------------------------------------------ + + def _compute_required_blocks(self, effective_prefill_tokens: int) -> int: + """Return the number of fresh KV blocks needed for *effective_prefill_tokens*. + + At least 1 block is always required so that the decode phase can track + token capacity via ``_allocation_map``. + """ + if effective_prefill_tokens > 0: + return ceil(effective_prefill_tokens / self._config.block_size) + return 1 + + def _can_allocate_request(self, request: Request) -> bool: + if request.id not in self._allocation_map: + # New request: only non-cached tokens need fresh KV blocks. + num_required_blocks = self._compute_required_blocks( + self._effective_new_prefill_tokens(request) + ) + return ( + self._config.num_blocks + - self._num_allocated_blocks + - num_required_blocks + >= self._watermark_blocks + ) + # Existing (decode-phase) request: needs room for at most 1 more block. + return self._config.num_blocks - self._num_allocated_blocks >= 1 + + def _allocate_request(self, request: Request) -> None: + if request.id not in self._allocation_map: + # Compute and store cache-hit tokens for this request. + hit = self._get_cache_hit_tokens(request) + self._prefix_cache_hit_tokens[request.id] = hit + effective = max(0, request.num_prefill_tokens - hit) + # Always allocate at least 1 block so the decode phase can track + # token capacity via _allocation_map. + self.allocate(request.id, self._compute_required_blocks(effective)) + return + + # Decode phase: determine how many new blocks (if any) are needed. + # + # The "virtual" token capacity includes both physically allocated blocks + # and cached tokens (which live in the shared prefix cache, not in new + # blocks allocated to this request). + hit = self._prefix_cache_hit_tokens.get(request.id, 0) + num_tokens_available = ( + self._allocation_map[request.id] * self._config.block_size + hit + ) + num_tokens_required = max(0, request.num_processed_tokens - num_tokens_available) + + assert ( + num_tokens_required == 0 or num_tokens_required == 1 + ), ( + f"Expected decode-phase allocation delta of 0 or 1, " + f"got {num_tokens_required} " + f"(processed={request.num_processed_tokens}, " + f"available={num_tokens_available})" + ) + + if num_tokens_required == 0: + return + + self.allocate(request.id, 1) + + # ------------------------------------------------------------------ + # Batch lifecycle callbacks + # ------------------------------------------------------------------ + + def on_batch_end(self, batch: Batch) -> None: + self._num_running_batches -= 1 + + for request in batch.requests: + if request.completed: + self.free(request.id) + self._prefix_cache_hit_tokens.pop(request.id, None) + self._cache_hits_advanced.discard(request.id) + else: + self._preempted_requests.append(request) + + # ------------------------------------------------------------------ + # Token-count helpers + # ------------------------------------------------------------------ + + def _get_request_next_num_tokens( + self, + request: Request, + batch_contains_prefill: bool, + num_batch_tokens: int, + ) -> int: + """Return how many tokens this request should process in the next chunk. + + For decode, this is always 1. + + For prefill the chunk budget is ``chunk_size - num_batch_tokens``. On + the *first* iteration of a prefix-cached request the cached-token bump + is prepended so that ``num_processed_tokens`` jumps past the cached + portion in one step; subsequent chunks proceed normally. + """ + assert not request.completed + + if request.is_prefill_complete: + return 1 + + # Determine whether this is the first time this request is being + # scheduled (cache-hit bump not yet applied). + cache_hit_bump = 0 + if ( + self._config.enable_prefix_caching + and request.id in self._prefix_cache_hit_tokens + and request.id not in self._cache_hits_advanced + ): + cache_hit_bump = self._prefix_cache_hit_tokens[request.id] + + # Remaining *new* tokens (excluding the not-yet-applied cache-hit bump). + remaining_new = ( + request.num_prefill_tokens + - request.num_processed_tokens + - cache_hit_bump + ) + + next_new_tokens = min( + max(0, remaining_new), + max(0, self._config.chunk_size - num_batch_tokens), + ) + + total_tokens = cache_hit_bump + next_new_tokens + if total_tokens == 0: + return 0 + + # Mark the cache-hit bump as consumed so it is not added again. + if cache_hit_bump > 0: + self._cache_hits_advanced.add(request.id) + + return total_tokens + + # ------------------------------------------------------------------ + # Core scheduling logic + # ------------------------------------------------------------------ + + def _restart_request(self, request: Request) -> None: + """Evict a request, freeing its blocks and resetting prefix-cache state.""" + request.restart() + self.free(request.id) + # A restarted request gets a new (possibly different) num_prefill_tokens, + # so discard stale prefix-cache bookkeeping. + self._prefix_cache_hit_tokens.pop(request.id, None) + self._cache_hits_advanced.discard(request.id) + + def _get_next_batch(self) -> Batch: # noqa: C901 + requests: List[Request] = [] + num_tokens: List[int] = [] + skipped_requests: List[Request] = [] + running_prefills: List[Request] = [] + contains_prefill = False + num_batch_tokens = 0 + + # ---------------------------------------------------------------- + # 1. Process preempted requests (may include partial prefills) + # ---------------------------------------------------------------- + while self._preempted_requests: + if len(requests) == self._max_micro_batch_size: + break + + request = self._preempted_requests.pop(0) + + if not request.is_prefill_complete: + # Still in prefill phase – handle separately below. + running_prefills.append(request) + continue + + # Decode-phase preempted request. + next_num_tokens = self._get_request_next_num_tokens( + request, contains_prefill, num_batch_tokens + ) + + if next_num_tokens == 0: + skipped_requests.append(request) + continue + + # Ensure there is enough memory; evict the youngest preempted + # request if necessary. + while not self._can_allocate_request(request): + if self._preempted_requests: + victim = self._preempted_requests.pop(-1) + self._restart_request(victim) + self._request_queue = [victim] + self._request_queue + else: + self._restart_request(request) + self._request_queue = [request] + self._request_queue + break + else: + self._allocate_request(request) + assert request.is_prefill_complete + num_batch_tokens += next_num_tokens + requests.append(request) + num_tokens.append(next_num_tokens) + + # ---------------------------------------------------------------- + # 2. Continue in-flight partial prefills + # ---------------------------------------------------------------- + for request in running_prefills: + assert not request.is_prefill_complete + + next_num_tokens = self._get_request_next_num_tokens( + request, contains_prefill, num_batch_tokens + ) + + if next_num_tokens == 0: + skipped_requests.append(request) + continue + + contains_prefill = True + num_batch_tokens += next_num_tokens + requests.append(request) + num_tokens.append(next_num_tokens) + + # Restore skipped requests at the front (preserve FIFO ordering). + self._preempted_requests = skipped_requests + self._preempted_requests + self._preempted_requests = sorted( + self._preempted_requests, key=lambda r: r.arrived_at + ) + + # ---------------------------------------------------------------- + # 3. Admit new requests from the queue + # ---------------------------------------------------------------- + while self._request_queue: + if len(self._allocation_map) == self._config.batch_size_cap: + break + if len(requests) == self._max_micro_batch_size: + break + if not self._can_allocate_request(self._request_queue[0]): + break + + next_num_tokens = self._get_request_next_num_tokens( + self._request_queue[0], contains_prefill, num_batch_tokens + ) + if next_num_tokens == 0: + break + + request = self._request_queue.pop(0) + self._allocate_request(request) + contains_prefill = True + num_batch_tokens += next_num_tokens + requests.append(request) + num_tokens.append(next_num_tokens) + + if not requests: + return None + + return Batch(self._replica_id, requests, num_tokens) diff --git a/vidur-alibabacloud/vidur/types/replica_scheduler_type.py b/vidur-alibabacloud/vidur/types/replica_scheduler_type.py index ca69c937..fb12012a 100644 --- a/vidur-alibabacloud/vidur/types/replica_scheduler_type.py +++ b/vidur-alibabacloud/vidur/types/replica_scheduler_type.py @@ -8,3 +8,4 @@ class ReplicaSchedulerType(BaseIntEnum): VLLM = 4 LIGHTLLM = 5 SPLIT_WISE= 6 + SGLANG = 7