Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions vidur-alibabacloud/README-vidur.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,69 @@ or to get information on all parameters,
python -m vidur.main -h
```

### SGLang scheduler (RadixAttention + chunked prefill)

SimAI supports simulating [SGLang](https://github.com/sgl-project/sglang)'s runtime scheduler,
which combines two key optimisations:

| Feature | CLI parameter | Default |
|---------|--------------|---------|
| **Chunked prefill** – long prompts are split into fixed-size chunks, interleaving prefill and decode to reduce head-of-line blocking | `--sglang_scheduler_config_chunk_size` | `512` |
| **RadixAttention prefix caching** – KV-cache blocks for shared prefixes (e.g. a system prompt) are reused across requests, reducing both memory allocation and the number of prefill chunks | `--sglang_scheduler_config_enable_prefix_caching` | `True` |
| **Prefix cache hit rate** – fraction of prefill tokens satisfied by the prefix cache (set based on your workload; e.g. `0.7`–`0.9` for workloads with long shared system prompts) | `--sglang_scheduler_config_prefix_cache_hit_rate` | `0.0` |
| **Max tokens per batch** | `--sglang_scheduler_config_max_tokens_in_batch` | `4096` |

Example command (Llama-3-8B, simulating a workload where 70% of prefill tokens hit the prefix cache):

```sh
python -m vidur.main \
--replica_config_device a100 \
--replica_config_model_name meta-llama/Meta-Llama-3-8B \
--cluster_config_num_replicas 1 \
--replica_config_tensor_parallel_size 1 \
--replica_config_num_pipeline_stages 1 \
--request_generator_config_type synthetic \
--synthetic_request_generator_config_num_requests 512 \
--length_generator_config_type trace \
--trace_request_length_generator_config_max_tokens 16384 \
--trace_request_length_generator_config_trace_file ./data/processed_traces/splitwise_conv.csv \
--interval_generator_config_type poisson \
--poisson_request_interval_generator_config_qps 6.45 \
--replica_scheduler_config_type sglang \
--sglang_scheduler_config_chunk_size 512 \
--sglang_scheduler_config_enable_prefix_caching \
--sglang_scheduler_config_prefix_cache_hit_rate 0.7 \
--sglang_scheduler_config_max_tokens_in_batch 4096 \
--random_forrest_execution_time_predictor_config_prediction_max_prefill_chunk_size 16384 \
--random_forrest_execution_time_predictor_config_prediction_max_batch_size 512 \
--random_forrest_execution_time_predictor_config_prediction_max_tokens_per_request 16384
```

**How the simulation models SGLang behaviour**

* *Chunked prefill* – identical to the Sarathi-Serve scheduler already in SimAI; each
scheduling iteration processes at most `chunk_size` new prefill tokens.

* *Prefix-cache memory savings* – when `enable_prefix_caching=True` the scheduler allocates
only `ceil((1 − hit_rate) × num_prefill_tokens / block_size)` fresh KV blocks for each
new request. The remaining blocks are treated as shared cache entries that require no new
allocation. This allows more requests to fit in GPU memory concurrently, correctly
modelling SGLang's `RadixAttention` memory savings.

* *Reduced prefill iterations* – the cached portion of a request's prompt is "fast-forwarded"
in the first scheduling iteration (the `num_processed_tokens` counter advances past the
cached portion at no extra execution cost), so the total number of chunked-prefill rounds
is reduced proportionally.

**Choosing `prefix_cache_hit_rate`**

| Workload characteristic | Suggested value |
|------------------------|----------------|
| No shared prefix (pure decode, random prompts) | `0.0` |
| Short system prompt (≤ 5 % of prompt length) | `0.05`–`0.15` |
| Medium system prompt / few-shot examples | `0.3`–`0.5` |
| Long shared system prompt (≥ 70 % of prompt length) | `0.7`–`0.95` |

## Simulator Output

* The metrics will be logged to wandb directly and a copy will be stored in the `simulator_output/<TIMESTAMP>` directory. __A description of all the logged metrics can be found [here](docs/metrics.md).__
Expand Down
6 changes: 5 additions & 1 deletion vidur-alibabacloud/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ python -m vidur.main \
| `--cluster_config_num_replicas` | 1 | Total number of replicas (i.e., data parallelism degree) |
| `--replica_config_pd_node_ratio` | 0.5 | Ratio of P-nodes to (P-nodes + D-nodes) <font style="color:rgb(0, 0, 0);background-color:rgb(245, 242, 240);">Fraction of replicas allocated as prefill (P) nodes. The remaining replicas are used as decode (D) nodes. </font> <font style="color:rgb(0, 0, 0);background-color:rgb(245, 242, 240);">For example, 0.5 means half of the replicas are prefill nodes and half are decode nodes (P:D = 1:1).</font> |
| `--global_scheduler_config_type` | round_robin | Global scheduler type (`split_wise`, `round_robin`, etc.) |
| `--replica_scheduler_config_type` | sarathi | Per-replica scheduler type |
| `--replica_scheduler_config_type` | sarathi | Per-replica scheduler type (`sarathi`, `vllm`, `orca`, `lightllm`, `sglang`, `split_wise`, `faster_transformer`) |
| `--sglang_scheduler_config_chunk_size` | 512 | Chunked-prefill chunk size for SGLang (tokens per chunk); only effective when `--replica_scheduler_config_type sglang` |
| `--sglang_scheduler_config_enable_prefix_caching` | True | Enable RadixAttention-based prefix caching for SGLang; only effective when `--replica_scheduler_config_type sglang` |
| `--sglang_scheduler_config_prefix_cache_hit_rate` | 0.0 | Fraction of prefill tokens that hit the prefix cache (0.0–1.0); see SGLang scheduler section for guidance; only effective when `--replica_scheduler_config_type sglang` |
| `--sglang_scheduler_config_max_tokens_in_batch` | 4096 | Maximum total tokens per batch for SGLang; only effective when `--replica_scheduler_config_type sglang` |
| `--replica_config_model_name` | meta-llama/Llama-2-7b-hf | Model name (DeepSeek-671B, Qwen3-Moe-235B, Qwen3-Next-80B , etc.)<br/>⚠️ **Note**: Vidur GPU Memory management module is still under adaptation for DeepSeek-671B, Qwen3-Moe-235B, Qwen3-Next-80B |
| `--replica_config_tensor_parallel_size` | 1 | Tensor parallelism size (TP) |
| `--replica_config_num_pipeline_stages` | 1 | Number of pipeline stages (PP) |
Expand Down
32 changes: 32 additions & 0 deletions vidur-alibabacloud/vidur/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ def get_type():
return ReplicaSchedulerType.SPLIT_WISE


@dataclass
class SglangSchedulerConfig(BaseReplicaSchedulerConfig):
chunk_size: int = field(
default=512,
metadata={"help": "Chunked-prefill chunk size for SGLang (tokens per chunk)."},
)
enable_prefix_caching: bool = field(
default=True,
metadata={"help": "Enable RadixAttention-based prefix caching (SGLang feature)."},
)
prefix_cache_hit_rate: float = field(
default=0.0,
metadata={
"help": (
"Fraction of prefill tokens that hit the prefix cache (0.0–1.0). "
"Set based on your workload's prefix-sharing characteristics "
"(e.g. 0.7–0.9 for workloads with long shared system prompts). "
"Cached tokens reduce KV-block allocation, modelling SGLang's "
"RadixAttention memory savings."
)
},
)
max_tokens_in_batch: int = field(
default=4096,
metadata={"help": "Maximum total tokens per batch for SGLang."},
)

@staticmethod
def get_type():
return ReplicaSchedulerType.SGLANG


@dataclass
class MetricsConfig:
"""Metric configuration."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from vidur.scheduler.replica_scheduler.splitwise_replica_scheduler import (
SplitwiseReplicaScheduler,
)
from vidur.scheduler.replica_scheduler.sglang_replica_scheduler import (
SglangReplicaScheduler,
)

class ReplicaSchedulerRegistry(BaseRegistry):
pass
Expand All @@ -36,4 +39,7 @@ class ReplicaSchedulerRegistry(BaseRegistry):
)
ReplicaSchedulerRegistry.register(
ReplicaSchedulerType.SPLIT_WISE, SplitwiseReplicaScheduler
)
ReplicaSchedulerRegistry.register(
ReplicaSchedulerType.SGLANG, SglangReplicaScheduler
)
Loading