diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index ad70994cc21c..66591fb10957 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -367,6 +367,26 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion, nb::call_guard()) + .def("get_remaining_blocks_to_completion_batch", + [](tbk::BaseKVCacheManager& self, nb::list const& pyRequests, SizeType32 windowSize) + { + // Extract C++ request pointers while GIL is held + std::vector requests; + requests.reserve(nb::len(pyRequests)); + for (auto const& item : pyRequests) + { + requests.push_back(&nb::cast(item)); + } + // Release GIL for the C++ computation + nb::gil_scoped_release release; + std::vector result; + result.reserve(requests.size()); + for (auto const* req : requests) + { + result.push_back(self.getRemainingBlocksToCompletion(*req, windowSize)); + } + return result; + }) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index db8931c3e621..13419d07e4ff 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -474,6 +474,36 @@ def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False): print("-- Done building kv_cache_manager_v2.") +def build_pyexecutor_scheduler(project_dir, venv_python, use_mypyc=False): + print("-- Building pyexecutor scheduler...") + scheduler_dir = project_dir / "tensorrt_llm/_torch/pyexecutor/scheduler" + pyexecutor_dir = project_dir / "tensorrt_llm/_torch/pyexecutor" + + # Clean up any existing mypyc artifacts to prevent stale inclusion + if not use_mypyc: + for so_file in pyexecutor_dir.glob("*__mypyc*.so"): + print(f"Removing stale mypyc artifact: {so_file}") + so_file.unlink() + + for so_file in scheduler_dir.glob("*.so"): + print(f"Removing stale artifact: {so_file}") + so_file.unlink() + + if use_mypyc: + print("-- Building scheduler mypyc extensions...", end=" ") + setup_mypyc = scheduler_dir / "setup_mypyc.py" + build_run(f'"{venv_python}" "{setup_mypyc}" build_ext --inplace', + cwd=pyexecutor_dir) + + # Verify that the shared library was generated + if not list(pyexecutor_dir.glob("*__mypyc*.so")) and not list( + scheduler_dir.glob("*.so")): + raise RuntimeError( + "Failed to build scheduler: no shared library generated.") + print("Done") + print("-- Done building pyexecutor scheduler.") + + def main(*, build_type: str = "Release", generator: str = "", @@ -969,6 +999,7 @@ def get_binding_lib(subdirectory, name): binding_lib_file_name) build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=mypyc) + build_pyexecutor_scheduler(project_dir, venv_python, use_mypyc=mypyc) if not skip_building_wheel: if dist_dir is None: diff --git a/setup.py b/setup.py index b26af764f83f..94bff4080828 100644 --- a/setup.py +++ b/setup.py @@ -159,6 +159,9 @@ def has_ext_modules(self): 'runtime/kv_cache_manager_v2/rawref/*.py', 'runtime/kv_cache_manager_v2/rawref/*.pyi', 'runtime/*__mypyc*.so', + '_torch/pyexecutor/scheduler/*.so', + '_torch/pyexecutor/scheduler/*.pyi', + '_torch/pyexecutor/*__mypyc*.so', ] package_data += [ @@ -372,7 +375,10 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str], package_data = [ p for p in package_data if p not in [ 'runtime/kv_cache_manager_v2/*.so', - 'runtime/kv_cache_manager_v2/**/*.so', 'runtime/*__mypyc*.so' + 'runtime/kv_cache_manager_v2/**/*.so', + 'runtime/*__mypyc*.so', + '_torch/pyexecutor/scheduler/*.so', + '_torch/pyexecutor/*__mypyc*.so', ] ] # Ensure rawref is included diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b5aeeb2a1f98..a5c30320728a 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -44,8 +44,8 @@ from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - KVCacheV2Scheduler, SimpleScheduler, - SimpleUnifiedScheduler) + KVCacheV2Scheduler, ScheduleStepConfig, SimpleScheduler, + UnifiedScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -1268,6 +1268,21 @@ def create_py_executor_instance( if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: scheduler_capacity += 1 + adp_cfg = llm_args.attention_dp_config + adp_enable_balance = adp_cfg is not None and adp_cfg.enable_balance + schedule_step_config = ScheduleStepConfig( + enable_attention_dp=mapping.enable_attention_dp, + attention_dp_enable_balance=adp_enable_balance, + attention_dp_time_out_iters=adp_cfg.timeout_iters + if adp_enable_balance else 0, + attention_dp_batching_wait_iters=(adp_cfg.batching_wait_iters + if adp_enable_balance else 0), + batch_wait_timeout_iters=llm_args.batch_wait_timeout_iters, + batch_wait_max_tokens_ratio=llm_args.batch_wait_max_tokens_ratio, + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + ) + if isinstance(kv_cache_manager, KVCacheManagerV2): # V2: interleaved scheduler handles both capacity and budget draft_kv_cache_manager = resources.get( @@ -1285,31 +1300,45 @@ def create_py_executor_instance( if peft_cache_manager is not None else None, scheduler_capacity=scheduler_capacity, draft_kv_cache_manager=draft_kv_cache_manager, + schedule_step_config=schedule_step_config, + dist=dist, ) - elif (scheduler_config is not None - and scheduler_config.use_python_scheduler): - scheduler = SimpleUnifiedScheduler( - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - kv_cache_manager=kv_cache_manager.impl - if kv_cache_manager is not None else None, - peft_cache_manager=peft_cache_manager.impl - if peft_cache_manager is not None else None, - scheduler_policy=scheduler_config.capacity_scheduler_policy, - ctx_chunk_config=ctx_chunk_config, - two_step_lookahead=mapping.has_pp(), - scheduler_capacity=scheduler_capacity) else: - capacity_scheduler = BindCapacityScheduler( - scheduler_capacity, - kv_cache_manager.impl if kv_cache_manager is not None else None, - peft_cache_manager.impl if peft_cache_manager is not None else None, - scheduler_config.capacity_scheduler_policy, - two_step_lookahead=mapping.has_pp()) - - mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, - ctx_chunk_config) - scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) + use_python_scheduler = (scheduler_config.use_python_scheduler + if scheduler_config is not None else False) + if use_python_scheduler: + scheduler = UnifiedScheduler( + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + kv_cache_manager=kv_cache_manager.impl + if kv_cache_manager is not None else None, + peft_cache_manager=peft_cache_manager.impl + if peft_cache_manager is not None else None, + scheduler_policy=scheduler_config.capacity_scheduler_policy, + ctx_chunk_config=ctx_chunk_config, + two_step_lookahead=mapping.has_pp(), + scheduler_capacity=scheduler_capacity, + schedule_step_config=schedule_step_config, + dist=dist, + ) + else: + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl + if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) + + mb_scheduler = BindMicroBatchScheduler(max_batch_size, + max_num_tokens, + ctx_chunk_config) + scheduler = SimpleScheduler( + capacity_scheduler, + mb_scheduler, + schedule_step_config=schedule_step_config, + dist=dist, + ) config = model_engine.model.model_config.pretrained_config attention_type = AttentionTypeCpp.MLA if is_mla( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 45014d37febe..dad8e439780f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -336,16 +336,7 @@ def __init__( self.stream_interval = self.llm_args.stream_interval self.perf_manager = PerfMetricsManager( enabled=getattr(self.llm_args, 'return_perf_metrics', False)) - self.attention_dp_enable_balance = ( - self.llm_args.attention_dp_config is not None - and self.llm_args.attention_dp_config.enable_balance) - if self.attention_dp_enable_balance: - self.attention_dp_time_out_iters = self.llm_args.attention_dp_config.timeout_iters - self.attention_dp_batching_wait_iters = self.llm_args.attention_dp_config.batching_wait_iters self.batch_wait_timeout_ms = self.llm_args.batch_wait_timeout_ms - self.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters - self.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio - self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0 self.num_fetch_requests_cur_rank = 0 self.num_fetch_requests = 0 @@ -451,9 +442,6 @@ def __init__( self.is_shutdown = False self.max_batch_size = max_batch_size - self.adp_ctx_waiting_iters_count = 0 - self.adp_ctx_batching_wait_iters_count = 0 - self.batch_wait_iters_count = 0 def on_detected(): self._handle_errors( @@ -1173,8 +1161,11 @@ def _pp_schedule_and_propagate(self, microbatch_id: int): is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp if self.dist.rank == 0 or (self.dist.is_first_pp_rank and is_dp_broadcast): - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( - ) + step_result = self.scheduler.schedule_step(self.active_requests, + self.inflight_req_ids) + scheduled_batch = step_result.scheduled_requests + fitting_disagg_gen_init_requests = step_result.fitting_disagg_gen_init_requests + num_fitting_reqs = step_result.num_fitting_requests serializable_schedule = SerializableSchedulerOutput.from_scheduler_result( scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs) @@ -1291,7 +1282,7 @@ def _executor_loop_pp(self): if self.dist.rank != 0: # Retry until current rank can run first PP's schedule result. self._pp_retry_until_can_schedule(scheduled_batch) - # Run scheduler locally because scheduler may change llm requests' state. + # Replay scheduler-local request state changes only. self.scheduler.schedule_request(self.active_requests, self.inflight_req_ids) @@ -1780,8 +1771,11 @@ def _prepare_and_schedule_batch(self): # that speculation is about to happen. self._prepare_draft_requests() - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( - ) + step_result = self.scheduler.schedule_step(self.active_requests, + self.inflight_req_ids) + scheduled_batch = step_result.scheduled_requests + fitting_disagg_gen_init_requests = step_result.fitting_disagg_gen_init_requests + num_fitting_reqs = step_result.num_fitting_requests if self.drafter is not None and not self.use_spec_decode: for request in scheduled_batch.all_requests(): @@ -2662,108 +2656,6 @@ def _add_kv_cache_events(self): # to be transferred to main thread when user needs them. kv_cache_manager.flush_iteration_events() - def _balance_adp_requests(self, context_requests: list[LlmRequest], - generation_requests: list[LlmRequest]): - balanced_context_requests = context_requests - num_scheduled_context_requests = len(context_requests) - num_scheduled_generation_requests = len(generation_requests) - num_scheduled_tokens = sum( - [len(req.get_tokens(0)) - for req in context_requests]) + num_scheduled_generation_requests - # Note: We use tp_allgather instead of tp_cp_allgather because we want to - # balance the requests across DP ranks; not CP ranks within those DP ranks. - responses_list = self.dist.tp_allgather([ - num_scheduled_context_requests, num_scheduled_generation_requests, - num_scheduled_tokens - ]) - all_ranks_num_scheduled_context_requests = [ - response[0] for response in responses_list - ] - all_ranks_num_scheduled_generation_requests = [ - response[1] for response in responses_list - ] - all_ranks_have_free_ctx_slots = all([ - num_gen < self.max_batch_size - for num_gen in all_ranks_num_scheduled_generation_requests - ]) - all_ranks_have_ctx_requests = all([ - num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests - ]) - all_ranks_have_gen_requests = all([ - num_gen > 0 - for num_gen in all_ranks_num_scheduled_generation_requests - ]) - - if self.attention_dp_enable_balance: - # wait for all ranks have context requests - if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests: - self.adp_ctx_waiting_iters_count = 0 - # balance number of context requests across ranks - if all_ranks_have_gen_requests: - if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters: - self.adp_ctx_batching_wait_iters_count += 1 - balanced_context_requests = [] - else: - self.adp_ctx_batching_wait_iters_count = 0 - else: - self.adp_ctx_waiting_iters_count += 1 - balanced_context_requests = [] - timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters - if timeout_reached or not all_ranks_have_gen_requests: - self.adp_ctx_waiting_iters_count = 0 - balanced_context_requests = context_requests - return balanced_context_requests - - def _waiting_requests(self, context_requests: list[LlmRequest], - generation_requests: list[LlmRequest]): - """ - Return an empty list if scheduled requests fulfill the waiting conditions, otherwise return the original context requests. - Waiting conditions: - - The number of scheduled tokens (both context and generation) is smaller than `self.batch_wait_max_tokens_ratio * self.max_num_tokens` - - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`. - """ - - num_scheduled_ctx_tokens = sum( - len(ctx_req.get_tokens(0)) for ctx_req in context_requests) - num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens - for gen_req in generation_requests) - num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens - - should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens - if should_waiting: - self.batch_wait_iters_count += 1 - return [] - - self.batch_wait_iters_count = 0 - return context_requests - - @nvtx_range("_schedule") - def _schedule(self): - scheduler_output = self.scheduler.schedule_request( - self.active_requests, self.inflight_req_ids) - - scheduled_context_requests = scheduler_output.context_requests - if self.enable_attention_dp and self.attention_dp_enable_balance: - scheduled_context_requests = self._balance_adp_requests( - scheduler_output.context_requests, - scheduler_output.generation_requests) - - # If no generation requests, no need to wait, to avoid dead waiting - should_check_waiting = not self.enable_attention_dp and self.enable_batch_waiting and len( - scheduler_output.context_requests) > 0 and len( - scheduler_output.generation_requests) > 0 - if should_check_waiting: - scheduled_context_requests = self._waiting_requests( - scheduler_output.context_requests, - scheduler_output.generation_requests) - - scheduled_requests = ScheduledRequests() - scheduled_requests.reset_context_requests(scheduled_context_requests) - scheduled_requests.generation_requests = scheduler_output.generation_requests - scheduled_requests.paused_requests = scheduler_output.paused_requests - - return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests - @nvtx_range("_check_disagg_gen_transfer_status") def _check_disagg_gen_transfer_status(self): diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py index b5e1a94361aa..4f57406a96d0 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py @@ -20,24 +20,20 @@ - Waiting queues (FCFS) """ -# Re-export from scheduler.py +# Re-export from scheduler.py (interfaces and data structures) from .adp_router import ADPRouter, DefaultADPRouter, RankState from .scheduler import ( - BindCapacityScheduler, - BindMicroBatchScheduler, - CapacityScheduler, - MicroBatchScheduler, - PyCapacityScheduler, - PyMicroBatchScheduler, RequestList, RequestScheduler, ScheduledRequests, SchedulerOutput, + ScheduleStepConfig, + ScheduleStepResult, SerializableSchedulerOutput, - SimpleScheduler, - SimpleUnifiedScheduler, ) from .scheduler_v2 import KVCacheV2Scheduler +from .simple_scheduler import BindCapacityScheduler, BindMicroBatchScheduler, SimpleScheduler +from .unified_scheduler import PyCapacityScheduler, UnifiedScheduler # Re-export from waiting_queue.py from .waiting_queue import FCFSWaitingQueue, WaitingQueue, create_waiting_queue @@ -46,18 +42,17 @@ # Schedulers "BindCapacityScheduler", "BindMicroBatchScheduler", - "CapacityScheduler", "KVCacheV2Scheduler", - "MicroBatchScheduler", "PyCapacityScheduler", - "PyMicroBatchScheduler", "RequestList", "RequestScheduler", + "ScheduleStepConfig", + "ScheduleStepResult", "ScheduledRequests", "SchedulerOutput", "SerializableSchedulerOutput", "SimpleScheduler", - "SimpleUnifiedScheduler", + "UnifiedScheduler", # ADP "ADPRouter", "DefaultADPRouter", diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py index 344247ee18ca..fe2e30f6b481 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py @@ -20,10 +20,9 @@ if TYPE_CHECKING: from tensorrt_llm._torch.distributed.communicator import Distributed + from tensorrt_llm._torch.pyexecutor.executor_request_queue import RequestQueueItem from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest - from ..executor_request_queue import RequestQueueItem - HeapVal = namedtuple("HeapVal", ["num_tokens", "num_requests", "rank", "request_list"]) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/design.md b/tensorrt_llm/_torch/pyexecutor/scheduler/design.md new file mode 100644 index 000000000000..6e848eeff0bb --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/design.md @@ -0,0 +1,437 @@ +# UnifiedScheduler Refactor: Design Document + +## 1. Background + +TensorRT-LLM has two scheduler implementations: + +- **SimpleScheduler** (C++ bindings): The default scheduler. Uses C++ `BindCapacityScheduler`, + `BindMicroBatchScheduler` via nanobind. +- **UnifiedScheduler** (pure Python): A Python mirror of SimpleScheduler, introduced + for extensibility and experimentation. On main branch it follows the same two-pass + structure as SimpleScheduler but implemented in Python. + +The original two-pass Python implementation was slower due to Python interpreter overhead +and excessive Python→C++ boundary crossings. This refactor optimizes +`UnifiedScheduler` with a fused single-pass design, keeping scheduling intent and +major outputs aligned, with explicit intentional semantic differences documented in +Section 4. + +## 2. Optimizations + +### 2.1 Fused Single-Pass Scheduling + +**Old**: Two sequential passes — capacity first, then microbatch (token budget + chunking). + +``` +PyCapacityScheduler.schedule_request(active_requests) + → fitting_requests +PyMicroBatchScheduler.schedule(fitting_requests, inflight_ids) + → context_requests, generation_requests +``` + +**New**: `TokenBudgetTracker` is passed into the capacity policy loop. Each request is +checked for both KV-block capacity AND token budget in one iteration. Chunking and sorting +are still performed in `tracker.finalize()`, but the separate microbatch iteration over +`fitting_requests` is eliminated. + +**Impact**: Eliminates one full iteration over the fitting list. + +### 2.2 Batched Block Decrements + +**Old**: `decrement_reserved_blocks(req)` called per-request in first-pass loop → +O(N × W) C++ calls (N requests, W window sizes). + +**New**: Deferred to `batch_decrement_list(scheduled_requests)` after the loop → +O(W) batch C++ calls using `get_remaining_blocks_to_completion_batch()`. + +**Correctness**: `available_blocks` is not read during the first pass. `sync_to_dict()` +is called before the second pass starts. + +### 2.3 Preview/Commit Block Reservation + +**Old**: Second pass calls `enough_available_blocks(req)` then +`decrement_reserved_blocks(req)` → 2 × W C++ calls per request. + +**New**: `preview_reserve(req)` checks AND caches needed blocks (1 × W C++ calls). +`commit_preview()` applies the cached decrement in pure Python. + +### 2.4 Cached C++ Property Calls + +| Property | Old (per request) | New | +|----------|------------------|-----| +| `req.is_disagg_generation_init_state` | Called 2× (guard + elif) | Cached as `is_disagg` once | +| `req.state_value` | Called each pass | Cached as `sv` once | +| `req.is_generation_in_progress_state` | 1 C++ call | `sv == _gen_in_progress` (Python int compare) | + +### 2.5 Split Second-Pass Loops + +**Old**: Combined loop over `[disagg_requests, context_requests]` with per-request +`is_disagg_generation_init_state` checks and routing. + +**New**: Two typed loops — disagg loop skips `beneficial_to_skip` (never applies to +disagg) and routes directly to `fitting_disagg`; context loop skips disagg checks. + +### 2.6 Single-Window Fast Path + +`NoEvictScheduledBlocksManager` and `MaxUtilizationScheduledBlocksManager` detect +the common single-window case and use scalar arithmetic instead of dict iteration. + +## 3. mypyc Compilation + +### 3.1 Overview + +`unified_scheduler.py` can be compiled with [mypyc](https://mypyc.readthedocs.io/) to +produce a native C extension (`.so`), eliminating Python interpreter overhead (attribute +lookups, frame creation, type dispatch) from the scheduling hot path. + +mypyc compilation is optional and controlled by the `--mypyc` flag in `build_wheel.py`. +When not compiled, the module runs as normal Python. + +### 3.2 What Gets Compiled + +Only `unified_scheduler.py` is compiled — it contains all hot-path classes: +- `TokenBudgetTracker` +- `GuaranteedNoEvictPolicy`, `MaxUtilizationPolicy` +- `NoEvictScheduledBlocksManager`, `MaxUtilizationScheduledBlocksManager` +- `PyCapacityScheduler` +- `UnifiedScheduler` + +Other scheduler files (`scheduler.py`, `adp_router.py`, `waiting_queue.py`) are thin +wrappers or C++ bindings that don't benefit from compilation. + +### 3.3 Type Annotation Fixes for mypyc + +mypyc enforces type annotations at runtime (unlike CPython which ignores them). Several +annotations were widened for compatibility: + +| Change | Reason | +|--------|--------| +| `inflight_request_ids: set[int]` → `object = None` | Callers pass C++ `ReqIdsSet` (nanobind type), not Python `set` | +| `uniq_task_ids: set[int]` → `Optional[set[int]]` | Assigned `None` when PEFT is disabled | + +### 3.4 Build Integration + +```bash +# Standalone build (from pyexecutor/ directory): +python scheduler/setup_mypyc.py build_ext --inplace + +# Via build_wheel.py: +python scripts/build_wheel.py --mypyc +``` + +`build_wheel.py` calls `build_pyexecutor_scheduler()` which invokes `setup_mypyc.py`. +When `--mypyc` is not set, stale `.so` artifacts are cleaned up to prevent accidental +use. + +### 3.5 Profiling mypyc-Compiled Code + +mypyc-compiled functions lack `__code__`, so `line_profiler` cannot hook them. + +## 4. Behavior Changes vs Main Branch + +### 4.1 Intentional Semantic Changes + +#### 4.1.1 Fused first-pass break produces a lighter resource state + +When token budget is exhausted in the first pass, the fused path breaks the +loop immediately. Requests after the break point — generation, context, +and disagg alike — are never evaluated. This affects both +`MaxUtilizationPolicy` (token failure returns `None` → break) and +`GuaranteedNoEvictPolicy` (generation token failure → break classification +loop). + +Because the failing generation request is never admitted, it does not consume +KV blocks, request slots, or PEFT pages. The second pass (in +GuaranteedNoEvict) and downstream scheduling therefore see a **lighter +resource state** than the old two-pass path, where capacity admitted all +generation unconditionally and microbatch dropped the excess afterward. + +This produces two kinds of differences vs the old path: + +**a) `paused_requests` may have fewer entries (MaxUtilization only).** +The old path could pause requests to make room for later requests that +microbatch would then drop anyway — wasted work. The fused path avoids this. + +**b) Second-pass requests may see more available resources +(GuaranteedNoEvict).** Because the failing generation request is never +admitted, it does not consume KV blocks, PEFT pages, or token budget. The +second pass — which processes both context and disagg-init requests — sees +a lighter state than the old path. This can admit context or disagg that the +old path would have blocked. (Disagg/context after the break point is still +never reached — only those classified before the break benefit.) + +The two request types have different practical thresholds: +- **Extra context** requires speculative decoding or beam search, where each + generation request consumes multiple tokens (e.g., `beam_width + + draft_tokens`), creating enough token headroom for context. With standard + beam=1 and no speculation, generation requests are 1 token each, leaving + near-zero headroom. +- **Extra disagg-init** can happen with any configuration, because disagg + bypasses token accounting — it only needs KV blocks and PEFT pages. The + lighter KV/PEFT state from the unadmitted generation is sufficient. + +These differences all result in **equal or better token budget utilization** +than the old path. The old path's behavior was an artifact of the two-pass +ordering (capacity admits everything, microbatch iterates generation-first), +not a deliberate scheduling priority. + +**Example — MaxUtilization pause avoidance (token_budget=100):** + +``` +Old two-pass pipeline: + + Capacity (MaxUtil): iterates ALL requests, admits/pauses based on KV blocks only + → Request A: KV ok → admit + → Request B: KV ok → admit + → Request C: KV fail → pause older request, retry → admit + → Request D: KV ok → admit + Result: fitting_requests = [A, B, C, D], paused = [old_req] + + Microbatch: iterates fitting_requests with token budget + → A: 30 tokens → ok (30/100) + → B: 80 tokens → 30+80=110 > 100 → break + Result: scheduled = [A], B/C/D dropped silently + +Fused single-pass pipeline: + + Capacity + Token (MaxUtil): iterates with fused check + → Request A: KV ok, 30 tokens ok → admit + → Request B: KV ok, 30+80=110 > 100 → token fail → None → BREAK + → Request C: NEVER REACHED + → Request D: NEVER REACHED + Result: fitting = [A], paused = [] +``` + +Paused requests differ ([] vs [old_req]). Scheduled output is the same. + +**Example — GuaranteedNoEvict second pass benefits from lighter state:** + +Setup: speculative decoding with 7 draft tokens → each generation request +consumes `beam_width(1) + draft_tokens(7) = 8 tokens`. Token budget = 100. +active_requests = [Gen_1..Gen_12, Ctx_X(4 tokens, chunked), Disagg_Y, Gen_13, Disagg_Z]. + +``` +Old two-pass pipeline: + + Capacity first pass: no token budget — classifies ALL requests + → Gen_1..Gen_12: generation → scheduled, blocks decremented (12 requests) + → Ctx_X: context → pending_requests + → Disagg_Y: disagg → pending_dis_gen_init + → Gen_13: generation → scheduled, blocks decremented + → Disagg_Z: disagg → pending_dis_gen_init + + batch_decrement_list([Gen_1..Gen_13]) → all 13 consume KV blocks + + Capacity second pass: evaluates pending against remaining blocks + → Disagg_Y: blocks ok (after 13 gen consumed) → fitting_disagg + → Disagg_Z: blocks ok → fitting_disagg + → Ctx_X: blocks ok → added to scheduled + Result: fittingRequests = [Gen_1..Gen_13, Ctx_X] + fitting_disagg = [Disagg_Y, Disagg_Z] + + Microbatch: iterates fittingRequests (generation-first order) + → Gen_1..Gen_12: 12×8 = 96 tokens → ok (96/100) + → Gen_13: 96+8 = 104 > 100 → break + → Ctx_X: NEVER REACHED + Result: scheduled = [Gen_1..Gen_12] + fitting_disagg = [Disagg_Y, Disagg_Z] (unchanged) + +Fused single-pass pipeline: + + First pass: token budget checked inline + → Gen_1..Gen_12: 96 tokens → admitted, blocks decremented + → Ctx_X: context → pending_requests (classified, not token-checked) + → Disagg_Y: disagg → pending_dis_gen_init (classified) + → Gen_13: 96+8 = 104 > 100 → break + → Disagg_Z: NEVER REACHED (after break point) + + batch_decrement_list([Gen_1..Gen_12]) → only 12 gen consume KV blocks + (Gen_13's blocks NOT consumed) + + Second pass: processes pending against remaining budget + lighter blocks + → Disagg_Y: blocks ok (lighter state) → fitting_disagg + → Ctx_X: 96+4 = 100 ≤ 100, blocks ok → admitted + Result: scheduled = [Gen_1..Gen_12, Ctx_X] + fitting_disagg = [Disagg_Y] +``` + +Differences vs the old path: +- Ctx_X admitted (100/100 tokens) vs dropped (96/100). Gen_13 not scheduled + in either path. +- Disagg_Y evaluated against lighter block state (Gen_13's blocks not + consumed). May fit where the old path would have blocked it. +- Disagg_Z deferred (after break point). Retried next iteration. + +All differences are benign — better utilization for context/disagg that fit, +deferred requests retry next iteration. Extra context requires speculative +decoding or beam search (needs token headroom from multi-token generation). +Extra disagg can happen with any configuration — disagg bypasses token +accounting and only needs the lighter KV/PEFT state. + +#### 4.1.2 `num_fitting_requests` semantics + +Now counts requests admitted by the fused capacity + token-budget path +(`TokenBudgetTracker._num_fitting`), which is **more accurate** than the old +value. In `SimpleScheduler`, `num_fitting_requests` was +`len(fitting_requests)` from the capacity pass only — it included requests +that capacity admitted but microbatch would later drop for exceeding the +token budget. The new count reflects requests that passed both KV-block +capacity AND token-budget checks. + +Note: this count is still computed before late pruning, so it can overcount +in two edge cases: + +1. **Chunking**: `_num_fitting` is incremented when `try_add_context()` + accepts a request, but `finalize()` may later drop requests with + `context_chunk_size == 0` without decrementing. +2. **Post-scheduler filters**: `py_executor._schedule()` passes + `num_fitting_requests` through unchanged after ADP balance or batch + waiting may have shrunk the context batch. + +### 4.2 Bug Fixes vs Main + +#### 4.2.1 MaxUtilization PEFT page accumulation + +Fixes a pre-existing bug in main's Python `MaxUtilizationPolicy` where +`num_scheduled_peft_pages` was passed by value to `_try_scheduling_request()` +and never accumulated across requests. Every request saw +`num_scheduled_peft_pages = 0`, so cumulative PEFT page limits were not +enforced. The same bug exists on main's `scheduler.py`. + +Now returns the updated total from `_try_scheduling_request()`, matching the +C++ reference (`capacityScheduler.cpp` `trySchedulingRequestMaxUtilization`) +which passes by reference. `GuaranteedNoEvictPolicy` was already correct +(accumulates `claimed_peft_pages` locally). + +This can change `context_requests` and `generation_requests` vs main on +workloads that use LoRA with MaxUtilization scheduling, because the old path +would over-admit requests that exceed the cumulative PEFT page budget. + +### 4.3 Internal Refactoring (no external semantic change) + +| Change | Details | +|--------|---------| +| **Disagg request return path** | Capacity policy returns 3-tuple `(scheduled, fitting_disagg, paused)` instead of 2-tuple. `fitting_disagg` was already a separate output in `SchedulerOutput` — this is an internal plumbing change, not a new external behavior. | +| **Drop-in interface** | `UnifiedScheduler.schedule_request()` returns the same `SchedulerOutput` as `SimpleScheduler`. `py_executor.py` uses a single code path for both schedulers. | + +### 4.4 Preserved Behavior + +| Area | Why Equivalent | +|------|----------------| +| State range check | Same conditions: disagg bypasses range, others check `_until <= sv < _after` | +| Block reservation | Same check-then-decrement logic, batched/cached | +| PEFT checks (`GuaranteedNoEvictPolicy`) | Identical to main (accumulates `claimed_peft_pages` locally) | +| `beneficial_to_skip` | Disagg always skipped it (old code had `not req.is_disagg` guard) | +| Context chunking | Same `EQUAL_PROGRESS` / `FIRST_COME_FIRST_SERVED` policies | +| Request sorting | Same LoRA-based sort in `finalize()` | + +Note: `MaxUtilizationPolicy` PEFT behavior changed vs main — see Section 4.2.1. + +### 4.5 KV Allocation Semantics + +`prepare_resources()` runs on the final scheduled batch only — requests +filtered by token budget never allocate real KV blocks in either path. + +However, the fused path's lighter resource state (Section 4.1.1) means: +- The main scheduled batch may contain additional context requests that the + old path would have dropped (GuaranteedNoEvict). KV allocation for these + extra requests is correct — they passed KV block checks in the second pass. +- `fitting_disagg_gen_init_requests` may differ. Those requests are fed into + `_prepare_disagg_gen_init()` which prepares KV resources outside the main + `prepare_resources()` batch. + +## 5. Performance Results + +**Experiment setting**: Llama 8B, B200 single GPU, 411 scheduling iterations. +Measured with the host profiler. + +| Configuration | Total | Per-Iteration | vs Main | +|--------------|-------|---------------|---------| +| main branch | 7.16s | 17.4ms | baseline | +| Refactored (Python) | 4.33s | 10.5ms | **-39.6%** | +| Refactored (mypyc)* | 1.19s | 2.89ms | **-83.4%** | + +\* mypyc measurement covers `schedule_request` only (capacity + token budget scheduling). Fetch, validate, and drafter setup run in py_executor (same path for both schedulers). + +### Speedup Attribution (rough hypothesis, not precisely measured) + +| Source | Estimated Contribution | Mechanism | +|--------|----------------------|-----------| +| Eliminate separate microbatch pass | Major | One fewer O(N) iteration; chunking/sorting still runs in `finalize()` | +| Reduce C++ boundary crossings | Moderate | Caching, batching, preview/commit | +| Python micro-optimizations | Minor | Local variable caching, int counters, `__slots__` | + +## 6. Files Changed + +| File | Change | +|------|--------| +| `scheduler/unified_scheduler.py` | Refactored TokenBudgetTracker, capacity policies, NoEvictScheduledBlocksManager, UnifiedScheduler | +| `scheduler/scheduler.py` | Removed old Python scheduling classes (moved to unified_scheduler.py) | +| `pyexecutor/py_executor.py` | No scheduler-specific code paths — same `_prepare_and_schedule_batch()` for both | +| `pyexecutor/_util.py` | Instantiation gate: `UnifiedScheduler` when `SchedulerConfig(use_python_scheduler=True)` | +| `scheduler/setup_mypyc.py` | mypyc build script for compiling `unified_scheduler.py` to native C extension | +| `scheduler/mypy_mypyc.ini` | mypy configuration for mypyc compilation (error suppressions for external types) | +| `scripts/build_wheel.py` | Added `build_pyexecutor_scheduler()` for mypyc integration via `--mypyc` flag | +| `tools/profiler/host_profile_tools/host_profiler.py` | Added `TLLM_LINE_PROFILER_NO_DEFAULTS` env var to disable default profiler targets | + +## 7. Validation + +### Recommended correctness checks + +Compare scheduling outputs between `SimpleScheduler` (default) and +`UnifiedScheduler` on the same workload. + +**All policies — must match:** +- Request ordering (verify LoRA sort and chunk partitioning match) +- Key state transitions (requests entering/leaving scheduled batch) + +**All policies — expected to differ:** +- `num_fitting_requests`: now counts capacity + token-budget admissions, not + just capacity (Section 4.1.2) + +**GuaranteedNoEvict — must match:** +- `len(generation_requests)` +- `len(paused_requests)` (this policy does not pause) + +**GuaranteedNoEvict — expected to differ when token budget is the bottleneck:** +- `len(context_requests)`: may have more entries with speculative decoding or + beam search (needs token headroom from the lighter state) (Section 4.1.1b) +- `len(fitting_disagg_gen_init_requests)`: may have more entries with any + configuration (disagg bypasses token accounting, only needs lighter KV/PEFT + state) (Section 4.1.1b). Requests after the break point are deferred (fewer). + +**MaxUtilization — must match:** +- `len(context_requests)`, `len(generation_requests)` (single loop breaks, + no second pass to admit extra work) + +**MaxUtilization — expected to differ when token budget is the bottleneck:** +- `len(paused_requests)`: fewer — the fused path avoids wasted pause/backtrack + (Section 4.1.1a) +- `len(fitting_disagg_gen_init_requests)`: may have fewer entries — disagg + after the break point is deferred (Section 4.1.1b) + +**LoRA + MaxUtilization — expected to differ:** +- `len(context_requests)`, `len(generation_requests)`: may differ due to + PEFT page accumulation bug fix (Section 4.2.1) — the old path over-admitted + requests that exceed the cumulative PEFT page budget + +### Enable the refactored scheduler +```python +from tensorrt_llm.llmapi import LLM, SchedulerConfig + +llm = LLM(model, scheduler_config=SchedulerConfig(use_python_scheduler=True)) +``` + +### Profile with trtllm-serve (no default profiler targets) +```yaml +# config.yaml +scheduler_config: + use_python_scheduler: true +``` + +```bash +TLLM_LINE_PROFILER_PATH=./profile.txt \ +TLLM_LINE_PROFILER_NO_DEFAULTS=1 \ +TLLM_LINE_PROFILER_FUNCTIONS="tensorrt_llm._torch.pyexecutor.scheduler.unified_scheduler.UnifiedScheduler.schedule_request" \ +trtllm-serve --config config.yaml +``` diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/mypy_mypyc.ini b/tensorrt_llm/_torch/pyexecutor/scheduler/mypy_mypyc.ini new file mode 100644 index 000000000000..f0d62173b524 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/mypy_mypyc.ini @@ -0,0 +1,55 @@ +; SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +; SPDX-License-Identifier: Apache-2.0 +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +[mypy] +; Only check files explicitly listed - don't follow any imports +follow_imports = skip +follow_imports_for_stubs = False + +; Ignore all missing imports +ignore_missing_imports = True + +; Allow untyped code in dependencies +allow_untyped_calls = True +allow_untyped_defs = True +check_untyped_defs = False + +; Disable various warnings to reduce noise +warn_return_any = False +warn_unused_ignores = False +warn_unreachable = False +no_implicit_optional = False + +; Don't check .pyi files outside our target +exclude = (?x)( + ^(?!tensorrt_llm/_torch/pyexecutor/scheduler/) +) + +; Ignore errors in any imported modules +[mypy-tensorrt_llm.llmapi.*] +ignore_errors = True +follow_imports = skip + +[mypy-tensorrt_llm.bindings.*] +ignore_errors = True +follow_imports = skip + +[mypy-tensorrt_llm.logger.*] +ignore_errors = True +follow_imports = skip + +[mypy-strenum.*] +ignore_errors = True +follow_imports = skip diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py index 540a0788e2ff..d721fa9c1b72 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py @@ -1,20 +1,81 @@ -import dataclasses +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Scheduler interfaces, shared data structures, and helper utilities. + +This module defines the abstract base classes and data types used by all +scheduler implementations, along with small backend-agnostic helpers. + +Implementations: + simple_scheduler.py — SimpleScheduler (C++ binding wrappers) + unified_scheduler.py — UnifiedScheduler (pure-Python) + scheduler_v2.py — KVCacheV2Scheduler +""" + from abc import ABC, abstractmethod from collections import namedtuple -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Set +from dataclasses import dataclass, field +from typing import Optional -from strenum import StrEnum +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._utils import nvtx_range -from tensorrt_llm.bindings import internal as tb_internal -from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy -from tensorrt_llm.logger import logger +RequestList = list[LlmRequest] -# Assuming these imports exist in your environment -from ..llm_request import LlmRequest, LlmRequestState -RequestList = list[LlmRequest] +def sort_requests_by_lora( + context_requests: list[LlmRequest], + generation_requests: list[LlmRequest], + chunks_present: bool, +) -> None: + def sort_key(req: LlmRequest) -> tuple[int, int]: + lora_id = getattr(req, "lora_task_id", None) + if lora_id is None: + return (0, 0) + return (1, lora_id) + + if chunks_present: + not_last_chunk = [req for req in context_requests if not req.is_last_context_chunk] + last_chunk = [req for req in context_requests if req.is_last_context_chunk] + not_last_chunk.sort(key=sort_key) + last_chunk.sort(key=sort_key) + context_requests.clear() + context_requests.extend(not_last_chunk) + context_requests.extend(last_chunk) + else: + context_requests.sort(key=sort_key) + + generation_requests.sort(key=sort_key) + + +def compute_fcfs_context_chunk_size( + context_remaining_length: int, + capacity: Optional[int], + max_context_length: Optional[int], + unit_size: int, +) -> int: + actual_size = context_remaining_length + if capacity is not None and capacity < actual_size: + actual_size = capacity + if max_context_length is not None: + actual_size = min(max_context_length, actual_size) + if actual_size <= 0: + return 0 + if actual_size < context_remaining_length: + actual_size = (int(actual_size) // unit_size) * unit_size + return int(actual_size) + SchedulerOutput = namedtuple( "SchedulerOutput", @@ -100,7 +161,37 @@ def reset_context_requests(self, context_requests: RequestList | None = None) -> self.append_context_request(req) +@dataclass +class ScheduleStepConfig: + """Configuration for executor-facing scheduler step post-processing.""" + + enable_attention_dp: bool = False + attention_dp_enable_balance: bool = False + attention_dp_time_out_iters: int = 0 + attention_dp_batching_wait_iters: int = 0 + batch_wait_timeout_iters: int = 0 + batch_wait_max_tokens_ratio: float = 0.0 + max_batch_size: int = 0 + max_num_tokens: int = 0 + + +@dataclass +class ScheduleStepResult: + """Finalized scheduling result consumed by the executor.""" + + scheduled_requests: ScheduledRequests = field(default_factory=ScheduledRequests) + fitting_disagg_gen_init_requests: RequestList = field(default_factory=list) + num_fitting_requests: int = 0 + + class RequestScheduler(ABC): + def __init__(self, schedule_step_config: Optional[ScheduleStepConfig] = None, dist=None): + self._schedule_step_config = schedule_step_config or ScheduleStepConfig() + self._dist = dist + self._adp_ctx_waiting_iters_count = 0 + self._adp_ctx_batching_wait_iters_count = 0 + self._batch_wait_iters_count = 0 + @abstractmethod def schedule_request( self, active_requests: RequestList, inflight_request_ids: set[int] @@ -110,7 +201,6 @@ def schedule_request( :param inflight_request_ids: set of request ids that are inflight (of all micro batches) :return: SchedulerOutput """ - # to be aligned with RequestScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/requestScheduler.h raise NotImplementedError @abstractmethod @@ -122,6 +212,116 @@ def can_schedule(self, requests: RequestList) -> bool: """ raise NotImplementedError + @nvtx_range("_schedule") + def schedule_step( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> ScheduleStepResult: + scheduler_output = self.schedule_request(active_requests, inflight_request_ids) + + scheduled_context_requests = scheduler_output.context_requests + cfg = self._schedule_step_config + + if cfg.enable_attention_dp and cfg.attention_dp_enable_balance: + scheduled_context_requests = self._balance_adp_context_requests( + scheduler_output.context_requests, + scheduler_output.generation_requests, + ) + + enable_batch_waiting = ( + cfg.batch_wait_timeout_iters > 0 or cfg.batch_wait_max_tokens_ratio > 0 + ) + should_check_waiting = ( + not cfg.enable_attention_dp + and enable_batch_waiting + and len(scheduler_output.context_requests) > 0 + and len(scheduler_output.generation_requests) > 0 + ) + if should_check_waiting: + scheduled_context_requests = self._apply_batch_waiting( + scheduled_context_requests, scheduler_output.generation_requests + ) + + scheduled_requests = ScheduledRequests() + scheduled_requests.reset_context_requests(scheduled_context_requests) + scheduled_requests.generation_requests = scheduler_output.generation_requests + scheduled_requests.paused_requests = scheduler_output.paused_requests + + return ScheduleStepResult( + scheduled_requests=scheduled_requests, + fitting_disagg_gen_init_requests=scheduler_output.fitting_disagg_gen_init_requests, + num_fitting_requests=scheduler_output.num_fitting_requests, + ) + + def _balance_adp_context_requests( + self, context_requests: RequestList, generation_requests: RequestList + ) -> RequestList: + if self._dist is None: + raise RuntimeError( + "RequestScheduler.schedule_step requires dist for attention-dp balancing" + ) + + cfg = self._schedule_step_config + balanced_context_requests = context_requests + num_scheduled_context_requests = len(context_requests) + num_scheduled_generation_requests = len(generation_requests) + responses_list = self._dist.tp_allgather( + [ + num_scheduled_context_requests, + num_scheduled_generation_requests, + ] + ) + all_ranks_num_scheduled_context_requests = [response[0] for response in responses_list] + all_ranks_num_scheduled_generation_requests = [response[1] for response in responses_list] + all_ranks_have_free_ctx_slots = all( + num_gen < cfg.max_batch_size for num_gen in all_ranks_num_scheduled_generation_requests + ) + all_ranks_have_ctx_requests = all( + num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests + ) + all_ranks_have_gen_requests = all( + num_gen > 0 for num_gen in all_ranks_num_scheduled_generation_requests + ) + + if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests: + self._adp_ctx_waiting_iters_count = 0 + if all_ranks_have_gen_requests: + if self._adp_ctx_batching_wait_iters_count < cfg.attention_dp_batching_wait_iters: + self._adp_ctx_batching_wait_iters_count += 1 + balanced_context_requests = [] + else: + self._adp_ctx_batching_wait_iters_count = 0 + else: + self._adp_ctx_waiting_iters_count += 1 + balanced_context_requests = [] + timeout_reached = self._adp_ctx_waiting_iters_count >= cfg.attention_dp_time_out_iters + if timeout_reached or not all_ranks_have_gen_requests: + self._adp_ctx_waiting_iters_count = 0 + balanced_context_requests = context_requests + + return balanced_context_requests + + def _apply_batch_waiting( + self, context_requests: RequestList, generation_requests: RequestList + ) -> RequestList: + cfg = self._schedule_step_config + num_scheduled_ctx_tokens = self._get_num_scheduled_context_tokens(context_requests) + num_scheduled_gen_tokens = sum(1 + req.num_draft_tokens for req in generation_requests) + num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + should_waiting = ( + self._batch_wait_iters_count < cfg.batch_wait_timeout_iters + and num_scheduled_tokens < cfg.batch_wait_max_tokens_ratio * cfg.max_num_tokens + ) + if should_waiting: + self._batch_wait_iters_count += 1 + return [] + + self._batch_wait_iters_count = 0 + return context_requests + + def _get_num_scheduled_context_tokens(self, context_requests: RequestList) -> int: + return sum(len(req.get_tokens(0)) for req in context_requests) + @dataclass class SerializableSchedulerOutput: @@ -184,1197 +384,3 @@ def to_scheduler_result( id_to_request[req_id] for req_id in self.fitting_disagg_gen_init_requests ] return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests - - -class CapacityScheduler(ABC): - @abstractmethod - def schedule_request( - self, active_requests: RequestList - ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: - """ - :param active_requests: list of active requests, up to maximum number of sequences - :return: (scheduledRequests, pausedRequests) - """ - # to be aligned with CapacityScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/capacityScheduler.h - raise NotImplementedError - - -class BindCapacityScheduler(CapacityScheduler): - def __init__( - self, - max_num_requests: int, - kv_cache_manager, - peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, - two_step_lookahead: bool = False, - ): - super(BindCapacityScheduler, self).__init__() - self.kv_cache_manager = kv_cache_manager - self.peft_cache_manager = peft_cache_manager - - self.impl = tb_internal.algorithms.CapacityScheduler( - max_num_requests=max_num_requests, - capacity_scheduler_policy=scheduler_policy._to_pybind(), - has_kv_cache_manager=kv_cache_manager is not None, - two_step_lookahead=two_step_lookahead, - no_schedule_until_state=LlmRequestState.CONTEXT_INIT, - no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, - ) - - def schedule_request( - self, active_requests: RequestList - ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: - return self.impl(active_requests, self.kv_cache_manager, self.peft_cache_manager) - - -class MicroBatchScheduler(ABC): - @abstractmethod - def schedule( - self, active_requests: RequestList, inflight_request_ids: set[int] - ) -> tuple[list[LlmRequest], list[LlmRequest]]: - """ - :param active_requests: list of active requests, up to maximum number of sequences - :param inflight_request_ids: set of request ids that are inflight (of all micro batches) - :return: (contextRequests, generationRequests) - """ - # to be aligned with MicroBatchScheduler::scheduleRequests - # in cpp/tensorrt_llm/batch_manager/microBatchScheduler.h - raise NotImplementedError - - -class BindMicroBatchScheduler(MicroBatchScheduler): - def __init__( - self, - max_batch_size: int, - max_num_tokens: int = None, - ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, - ) -> None: - super(BindMicroBatchScheduler, self).__init__() - self.max_batch_size = max_batch_size - self.max_num_tokens = max_num_tokens - - ctx_chunk_config_cpp = None - if ctx_chunk_config is not None: - ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig( - ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1] - ) - - self.impl = tb_internal.algorithms.MicroBatchScheduler(ctx_chunk_config_cpp, max_num_tokens) - - def schedule( - self, active_requests: RequestList, inflight_request_ids: set[int] - ) -> tuple[list[LlmRequest], list[LlmRequest]]: - return self.impl( - active_requests, inflight_request_ids, self.max_batch_size, self.max_num_tokens - ) - - -class SimpleScheduler(RequestScheduler): - def __init__( - self, capacity_scheduler: CapacityScheduler, micro_batch_scheduler: MicroBatchScheduler - ): - super(SimpleScheduler, self).__init__() - self.capacity_scheduler = capacity_scheduler - self.micro_batch_scheduler = micro_batch_scheduler - - def schedule_request( - self, active_requests: RequestList, inflight_request_ids: set[int] - ) -> SchedulerOutput: - fitting_requests, fitting_disagg_gen_init_requests, paused_requests = ( - self.capacity_scheduler.schedule_request(active_requests) - ) - - context_requests, generation_requests = self.micro_batch_scheduler.schedule( - fitting_requests, inflight_request_ids - ) - # Convert from binding type RequestVector to list[LlmRequest], - # so Python fields on LlmRequest won't be stripped away - return SchedulerOutput( - list(context_requests), - list(generation_requests), - list(paused_requests), - list(fitting_disagg_gen_init_requests), - len(fitting_requests), - ) - - def can_schedule(self, requests: RequestList) -> bool: - fitting_requests, _, _ = self.capacity_scheduler.schedule_request(requests) - return len(fitting_requests) == len(requests) - - -class ChunkingPolicy(Enum): - EQUAL_PROGRESS = 1 - FIRST_COME_FIRST_SERVED = 2 - - -@dataclasses.dataclass -class ContextChunkingConfig: - chunking_policy: ChunkingPolicy - chunk_unit_size: int - - -class MicroBatchScheduler: - """Base class to match structure.""" - - -class PyMicroBatchScheduler(MicroBatchScheduler): - def __init__( - self, - max_batch_size: int, - max_num_tokens: Optional[int] = None, - ctx_chunk_config: Optional[ContextChunkingConfig] = None, - no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_TO_COMPLETE, - ): - super().__init__() - self.max_batch_size = max_batch_size - self.max_num_tokens = max_num_tokens - self.ctx_chunk_config = ctx_chunk_config - self.max_context_length = max_num_tokens - # Match C++ MicroBatchScheduler defaults (see algorithms.cpp line 68-70) - self.no_schedule_until_state = no_schedule_until_state - self.no_schedule_after_state = no_schedule_after_state - # Cache state values to avoid repeated .value access (optimization) - self._no_schedule_until_state_value = no_schedule_until_state.value - self._no_schedule_after_state_value = no_schedule_after_state.value - self._context_init_state_value = LlmRequestState.CONTEXT_INIT.value - self._encoder_init_state_value = LlmRequestState.ENCODER_INIT.value - - def _can_be_scheduled(self, req: LlmRequest) -> bool: - """ - Check if request is within the schedulable state range. - C++ reference: microBatchScheduler.cpp line 192-195 - Optimized: use state_value property to avoid enum object creation - """ - # Use state_value property (returns int directly, avoids enum object creation) - state_value = req.state_value - # Inline comparison: must have reached until_state but not after_state - return ( - state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value - ) - - def schedule( - self, active_requests: RequestList, inflight_request_ids: set[int] - ) -> tuple[RequestList, RequestList]: - context_requests: RequestList = [] - generation_requests: RequestList = [] - - # Current total tokens in the scheduled batch (Generation + Context) - batch_num_tokens = 0 - scheduled_req_size = 0 - scheduled_beam_width = 0 - - contexts_to_be_chunked: RequestList = [] - # Total tokens required by chunked requests (calculated tentatively) - num_chunked_tokens = 0 - all_context_requests_fit = True - - # Cache instance attributes as locals for faster access in loop - max_batch_size = self.max_batch_size - max_num_tokens = self.max_num_tokens - max_context_length = self.max_context_length - ctx_chunk_config = self.ctx_chunk_config - - # 1. Main Scheduling Loop - for req in active_requests: - req_state_value = req.state_value - # Skip requests already in flight (should be filtered by caller, but C++ checks) - if req.request_id in inflight_request_ids: - continue - - # Skip if request cannot be scheduled yet or should no longer be scheduled, - # manually inline the condition to reuse req.state_value - if not ( - req_state_value >= self._no_schedule_until_state_value - and req_state_value < self._no_schedule_after_state_value - ): - continue - - req_num_tokens = 0 - - # --- A. Encoder Request Handling --- - if req_state_value == self._encoder_init_state_value: - req_num_tokens = req.encoder_output_len - - assert max_context_length is None or req_num_tokens <= max_context_length, ( - f"The number of encoder tokens ({req_num_tokens}) exceeds the limit value ({max_context_length})" - ) - - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens - ): - break - - logger.debug(f"encoder request scheduled: ID {req.request_id}") - context_requests.append(req) - batch_num_tokens += req_num_tokens - - # --- B. Context Request Handling --- - elif req_state_value == self._context_init_state_value: - if not ctx_chunk_config: - # No Chunking: Schedule full context - # C++ uses getNumTokens(beam=0) which is tokens.size() - numPreDecodedTokens - base_tokens = req.get_num_tokens(0) - draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 - req_num_tokens = base_tokens + draft_tokens - - assert max_context_length is None or req_num_tokens <= max_context_length, ( - f"Context tokens ({req_num_tokens}) exceeds limit ({max_context_length})" - ) - - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens - ): - break - - logger.debug(f"context request scheduled: ID {req.request_id}") - context_requests.append(req) - batch_num_tokens += req_num_tokens - else: - # Chunking Enabled: Tentative schedule - req.context_chunk_size = req.context_remaining_length - - draft_tokens = ( - req.num_draft_tokens - if (req.is_last_context_chunk and req.has_draft_tokens) - else 0 - ) - req_num_tokens = req.context_chunk_size + draft_tokens - - if max_context_length is not None: - if max_context_length < req_num_tokens: - req_num_tokens = max_context_length - all_context_requests_fit = False - - logger.debug(f"contexts-to-be-chunked request scheduled: ID {req.request_id}") - contexts_to_be_chunked.append(req) - num_chunked_tokens += req_num_tokens - - # --- C. Generation Request Handling --- - else: - # C++ uses getBeamWidthByIter() which returns dynamic beam width - # during beam search (1->2->3->...->beamWidth) - beam_width = req.get_beam_width_by_iter(for_next_iteration=False) - req_num_tokens = beam_width + req.num_draft_tokens - - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens - ): - break - - # Beam Width Consistency Check - if scheduled_beam_width == 0: - scheduled_beam_width = beam_width - elif scheduled_beam_width != beam_width: - logger.debug( - f"generation request skipped: ID {req.request_id} since its " - f"beam width ({beam_width}) is different from scheduled ones " - f"({scheduled_beam_width})" - ) - continue - generation_requests.append(req) - batch_num_tokens += req_num_tokens - - # --- Batch Size Limit Check --- - scheduled_req_size += 1 - if scheduled_req_size >= max_batch_size: - break - - # 2. Verify Chunking Fits - if max_num_tokens is not None and num_chunked_tokens > (max_num_tokens - batch_num_tokens): - all_context_requests_fit = False - - # 3. Apply Chunking Strategy if needed - if not all_context_requests_fit and contexts_to_be_chunked: - assert ctx_chunk_config is not None, ( - "If chunking is not enabled, context scheduling should be completed." - ) - remaining_capacity = ( - (max_num_tokens - batch_num_tokens) if max_num_tokens is not None else None - ) - - self._set_ctx_requests_chunk_size(contexts_to_be_chunked, remaining_capacity) - - # 4. Finalize Chunked Requests - for req in contexts_to_be_chunked: - if req.context_chunk_size > 0: - context_requests.append(req) - batch_num_tokens += req.context_chunk_size - logger.debug( - f"context request scheduled: ID {req.request_id}, " - f"chunk size {req.context_chunk_size}" - ) - - # Sort requests for consistency with C++ - # C++ reference: utils::sortRequests in inflightBatchingUtils.cpp - self._sort_requests(context_requests, generation_requests, not all_context_requests_fit) - - # Summary logs - logger.debug( - f"batchSize (num ctx/enc requests + num gen requests): " - f"{len(context_requests) + len(generation_requests)}" - ) - logger.debug(f"batchNumTokens / maxNumTokens: {batch_num_tokens} / {max_num_tokens or 0}") - - return context_requests, generation_requests - - def _sort_requests( - self, context_requests: RequestList, generation_requests: RequestList, chunks_present: bool - ) -> None: - """ - Sort requests for consistency with C++. - C++ reference: utils::sortRequests in inflightBatchingUtils.cpp - - 1. If chunks are present, move context requests that reached the last - context chunk to the end of the vector. - 2. Sort all requests by lora task id for performance. - """ - - def get_lora_task_id(req: LlmRequest): - # C++ uses std::optional comparison where nullopt < any_value - # So requests without LoRA (nullopt) should come first - lora_id = getattr(req, "lora_task_id", None) - if lora_id is None: - return (0, 0) # (has_value=False, value=0) - comes first - return (1, lora_id) # (has_value=True, value) - sorted by value - - if chunks_present: - # Partition: non-last-chunk first, last-chunk at end - not_last_chunk = [r for r in context_requests if not r.is_last_context_chunk] - last_chunk = [r for r in context_requests if r.is_last_context_chunk] - # Sort each group by lora_task_id - not_last_chunk.sort(key=get_lora_task_id) - last_chunk.sort(key=get_lora_task_id) - # Rebuild the list in-place - context_requests.clear() - context_requests.extend(not_last_chunk) - context_requests.extend(last_chunk) - else: - context_requests.sort(key=get_lora_task_id) - - generation_requests.sort(key=get_lora_task_id) - - def _set_ctx_requests_chunk_size(self, requests: RequestList, capacity: Optional[int]): - # C++: Resets all chunk sizes to 0 at start - for req in requests: - req.context_chunk_size = 0 - - policy = self.ctx_chunk_config.chunking_policy - unit_size = self.ctx_chunk_config.chunk_unit_size - - if policy == ChunkingPolicy.EQUAL_PROGRESS: - self._chunk_equal_progress(requests, capacity, unit_size) - elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: - self._chunk_fcfs(requests, capacity, unit_size) - else: - raise ValueError(f"Invalid chunking policy: {policy}") - - self._fit_draft_tokens(requests, capacity, unit_size) - - def _chunk_equal_progress(self, requests: RequestList, capacity: Optional[int], unit_size: int): - num_ctx_tokens = 0 - num_tokens_single_loop = 1 - - # C++ Loop: while ((!capacity || numCtxTokens < capacity) && numTokensSingleLoop) - while (capacity is None or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: - num_tokens_single_loop = 0 - for req in requests: - past_size = req.context_chunk_size - - # C++ logic: suggested = past + unit - suggested_size = past_size + unit_size - - # Ensure we don't exceed what the request actually needs - remaining_total = req.context_remaining_length - suggested_size = min(suggested_size, remaining_total) - - req.context_chunk_size = suggested_size - - actual_size = req.context_chunk_size - actual_increment = actual_size - past_size - - # Check Constraints - # 1. Capacity - if capacity is not None and (num_ctx_tokens + actual_increment > capacity): - req.context_chunk_size = past_size # Revert - continue - - # 2. Max Context Length - if self.max_context_length is not None and actual_size > self.max_context_length: - req.context_chunk_size = past_size # Revert - continue - - num_ctx_tokens += actual_increment - num_tokens_single_loop += actual_increment - - def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], unit_size: int): - current_capacity = capacity if capacity is not None else float("inf") - - for req in requests: - suggested_size = req.context_remaining_length - actual_size = suggested_size - - if current_capacity < actual_size: - actual_size = current_capacity - - if self.max_context_length is not None: - actual_size = min(self.max_context_length, actual_size) - - # Round down to unit size if we had to truncate - if actual_size < suggested_size: - actual_size = (int(actual_size) // unit_size) * unit_size - - req.context_chunk_size = int(actual_size) - - # C++: ctxTokensCapacity = ctxTokensCapacity - actualChunkSize - if capacity is not None: - current_capacity -= req.context_chunk_size - - def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], unit_size: int): - # Calculate tokens already taken by the batch so far - num_ctx_tokens = sum(req.context_chunk_size for req in requests) - - for req in requests: - if req.is_last_context_chunk and req.has_draft_tokens: - remainder = req.context_chunk_size % unit_size - remaining_space = 0 if remainder == 0 else unit_size - remainder - - if self.max_context_length is not None: - remaining_context_len = self.max_context_length - req.context_chunk_size - remaining_space = min(remaining_space, remaining_context_len) - - if capacity is not None: - remaining_space = min(remaining_space, capacity - num_ctx_tokens) - num_ctx_tokens += remaining_space - - draft_discard = req.num_draft_tokens - remaining_space - if draft_discard > 0: - logger.debug(f"Discarding {draft_discard} draft tokens") - req.discard_draft_tokens(draft_discard) - - -class SchedulerPolicyBase(ABC): - """ - Abstract base class for capacity scheduler policies. - Each policy implements its own scheduling logic. - """ - - @abstractmethod - def schedule( - self, scheduler: "PyCapacityScheduler", active_requests: RequestList - ) -> tuple[RequestList, RequestList]: - """ - Schedule requests according to the policy. - - Args: - scheduler: The capacity scheduler instance (for accessing shared state) - active_requests: List of active requests to schedule - - Returns: - Tuple of (scheduled_requests, paused_requests) - """ - raise NotImplementedError - - -class MaxRequestsPolicy(SchedulerPolicyBase): - """ - MaxRequestsScheduler: Simple request count limiting without KV cache checks. - C++ reference: capacityScheduler.cpp:154-176 - """ - - def schedule( - self, scheduler: "PyCapacityScheduler", active_requests: RequestList - ) -> tuple[RequestList, RequestList]: - scheduled_requests: RequestList = [] - - for req in active_requests: - if not scheduler._can_be_scheduled(req): - continue - - if len(scheduled_requests) >= scheduler.max_num_requests: - break - - if ( - req.is_encoder_init_state - or req.is_context_init_state - or req.is_generation_in_progress_state - ): - scheduled_requests.append(req) - - return scheduled_requests, [] - - -class GuaranteedNoEvictPolicy(SchedulerPolicyBase): - """ - GuaranteedNoEvictScheduler: Reserve blocks for requests to complete without eviction. - C++ reference: capacityScheduler.cpp:194-331 - """ - - def __init__(self, static_batch: bool = False): - self.static_batch = static_batch - - def schedule( - self, scheduler: "PyCapacityScheduler", active_requests: RequestList - ) -> tuple[RequestList, RequestList]: - scheduled_requests: RequestList = [] - has_peft = scheduler.peft_cache_manager is not None - - skipping_is_relevant = scheduler._is_skipping_relevant() - - newly_contributed_context_blocks: Set = set() - newly_contributed_cross_context_blocks: Set = set() - if not self.static_batch and skipping_is_relevant: - newly_contributed_context_blocks, newly_contributed_cross_context_blocks = ( - scheduler._prefill_contributed_blocks(active_requests) - ) - - reserved_blocks = NoEvictScheduledBlocksManager(scheduler.kv_cache_manager) - reserved_cross_blocks: Optional[NoEvictScheduledBlocksManager] = None - if scheduler.cross_kv_cache_manager is not None: - reserved_cross_blocks = NoEvictScheduledBlocksManager(scheduler.cross_kv_cache_manager) - - # PEFT state - only used when has_peft - claimed_peft_pages = 0 - available_peft_pages = scheduler._get_max_peft_pages() if has_peft else 0 - uniq_task_ids: set[int] = set() if has_peft else None - - pending_requests: RequestList = [] - pending_dis_gen_init_requests: RequestList = [] - - # First pass: process in-progress generation and classify requests - for req in active_requests: - if not scheduler._can_be_scheduled_with_disagg_exception(req): - continue - - if len(scheduled_requests) >= scheduler.max_num_requests: - break - - if req.is_generation_in_progress_state: - scheduled_requests.append(req) - reserved_blocks.decrement_reserved_blocks(req) - if reserved_cross_blocks is not None: - reserved_cross_blocks.decrement_reserved_blocks(req) - - if has_peft: - lora_task_id, is_new_task, peft_pages = scheduler._get_peft_task_info( - req, uniq_task_ids - ) - if is_new_task: - claimed_peft_pages += peft_pages - uniq_task_ids.add(lora_task_id) - - elif req.is_disagg_generation_init_state: - pending_dis_gen_init_requests.append(req) - else: - pending_requests.append(req) - - # Second pass: process pending requests - if not self.static_batch or len(scheduled_requests) == 0: - if has_peft: - available_peft_pages -= claimed_peft_pages - - for requests in [pending_dis_gen_init_requests, pending_requests]: - for req in requests: - if ( - not self.static_batch - and skipping_is_relevant - and not req.is_disagg_generation_init_state - and scheduler._beneficial_to_skip( - req, - newly_contributed_context_blocks, - newly_contributed_cross_context_blocks, - ) - ): - continue - - if len(scheduled_requests) >= scheduler.max_num_requests: - break - - if req.is_context_init_state or req.is_disagg_generation_init_state: - enough_blocks = reserved_blocks.enough_available_blocks(req) - enough_cross_blocks = True - if reserved_cross_blocks is not None: - enough_cross_blocks = reserved_cross_blocks.enough_available_blocks(req) - - if not enough_blocks or not enough_cross_blocks: - break - - # PEFT check only when needed - if has_peft: - lora_task_id, is_new_task, needed_peft_pages = ( - scheduler._get_peft_task_info(req, uniq_task_ids) - ) - if needed_peft_pages > available_peft_pages: - continue - available_peft_pages -= needed_peft_pages - if is_new_task: - uniq_task_ids.add(lora_task_id) - - scheduled_requests.append(req) - reserved_blocks.decrement_reserved_blocks(req) - if reserved_cross_blocks is not None: - reserved_cross_blocks.decrement_reserved_blocks(req) - - return scheduled_requests, [] - - -class MaxUtilizationPolicy(SchedulerPolicyBase): - """ - MaxUtilizationScheduler: Maximize utilization, may pause started requests. - C++ reference: capacityScheduler.cpp:341-425 - """ - - def schedule( - self, scheduler: "PyCapacityScheduler", active_requests: RequestList - ) -> tuple[RequestList, RequestList]: - scheduler.kv_cache_manager.start_scheduling() - - skipping_is_relevant = scheduler._is_skipping_relevant() - - scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( - scheduler.kv_cache_manager, scheduler.two_step_lookahead - ) - - num_scheduled_peft_pages = 0 - seen_task_ids: set[int] = set() - - newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks(active_requests) - - def is_started_request(req: LlmRequest) -> bool: - if not scheduler._can_be_scheduled(req): - return False - return ( - req.is_context_init_state and not req.is_first_context_chunk - ) or req.is_generation_in_progress_state - - scheduled_requests: RequestList = [] - paused_requests: RequestList = [] - - requests_list = list(active_requests) - req_it_end = len(requests_list) - req_it = 0 - - while req_it < req_it_end: - req = requests_list[req_it] - logger.debug(f"MaxUtilizationScheduler: scheduling request ID {req.request_id}") - - if not scheduler._can_be_scheduled_with_disagg_exception(req): - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - "cannot / should not be scheduled" - ) - req_it += 1 - continue - - if skipping_is_relevant and scheduler._beneficial_to_skip( - req, newly_contributed_context_blocks, set() - ): - req_it += 1 - continue - - was_scheduled = self._try_scheduling_request( - scheduler, - req, - scheduled_requests, - scheduled_blocks_manager, - num_scheduled_peft_pages, - seen_task_ids, - ) - - if was_scheduled: - logger.debug(f"MaxUtilizationScheduler: request ID {req.request_id} -> start") - req_it += 1 - else: - last_started_idx = None - for i in range(req_it_end - 1, req_it - 1, -1): - if is_started_request(requests_list[i]): - last_started_idx = i - break - - if last_started_idx is not None: - paused_req = requests_list[last_started_idx] - scheduler.kv_cache_manager.scheduling_remove_sequence(paused_req.py_request_id) - paused_requests.append(paused_req) - logger.debug( - f"MaxUtilizationScheduler: request ID {paused_req.request_id} -> pause" - ) - req_it_end = last_started_idx - else: - break - - return scheduled_requests, paused_requests - - def _try_scheduling_request( - self, - scheduler: "PyCapacityScheduler", - req: LlmRequest, - scheduled_requests: RequestList, - scheduled_blocks_manager: "MaxUtilizationScheduledBlocksManager", - num_scheduled_peft_pages: int, - seen_task_ids: set[int], - ) -> bool: - if len(scheduled_requests) >= scheduler.max_num_requests: - return False - - blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable(req) - if blocks_if_scheduled is None: - return False - - # PEFT check only when needed - if scheduler.peft_cache_manager is not None: - lora_task_id, is_new_task, num_required_peft_pages = scheduler._get_peft_task_info( - req, seen_task_ids - ) - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required peft pages: {num_required_peft_pages}" - ) - max_peft_pages = scheduler._get_max_peft_pages() - if num_required_peft_pages + num_scheduled_peft_pages > max_peft_pages: - return False - logger.debug( - f"MaxUtilizationScheduler: scheduled peft pages: {num_required_peft_pages}" - ) - if is_new_task: - seen_task_ids.add(lora_task_id) - - scheduled_blocks_manager.update_scheduled_blocks(blocks_if_scheduled) - scheduled_requests.append(req) - return True - - -class NoEvictScheduledBlocksManager: - """ - Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. - Tracks available blocks per window size for GUARANTEED_NO_EVICT scheduling. - - Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:29-62 - """ - - def __init__(self, kv_cache_manager): - """ - Initialize with free blocks from KVCacheManager. - C++ equivalent: mAvailableBlocks = mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() - """ - self.kv_cache_manager = kv_cache_manager - stats = kv_cache_manager.get_kv_cache_stats() - self.available_blocks: dict[int, int] = dict(stats.num_free_blocks_per_window_size) - - def decrement_reserved_blocks(self, req: LlmRequest) -> None: - """ - Decrement available blocks by the blocks needed to complete this request. - C++ reference: scheduledBlocksManager.h:40-46 - """ - for window_size in self.available_blocks: - needed = self.kv_cache_manager.get_remaining_blocks_to_completion(req, window_size) - self.available_blocks[window_size] -= needed - - def enough_available_blocks(self, req: LlmRequest) -> bool: - """ - Check if there are enough available blocks for this request across all window sizes. - C++ reference: scheduledBlocksManager.h:48-57 - """ - return all( - self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= avail - for ws, avail in self.available_blocks.items() - ) - - -class MaxUtilizationScheduledBlocksManager: - """ - Python equivalent of C++ kv_cache_manager::MaxUtilizationScheduledBlocksManager. - Tracks scheduled blocks per window size for MAX_UTILIZATION scheduling. - - Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:64-117 - """ - - def __init__(self, kv_cache_manager, two_steps_look_ahead: bool): - """ - Initialize scheduled blocks count per window size. - C++ equivalent: iterate windowSizes and set mNumScheduledBlocks[windowSize] = 0 - """ - self.kv_cache_manager = kv_cache_manager - self.two_steps_look_ahead = two_steps_look_ahead - window_sizes = set(kv_cache_manager.max_attention_window_vec) - self.num_scheduled_blocks: dict[int, int] = {ws: 0 for ws in window_sizes} - - def prepare_blocks_if_schedulable(self, req: LlmRequest) -> Optional[dict[int, int]]: - """ - Check if request can be scheduled and return new block counts if so. - Returns None if request cannot fit. - C++ reference: scheduledBlocksManager.h:80-100 - """ - blocks_if_scheduled = {} - for window_size, num_scheduled in self.num_scheduled_blocks.items(): - required = self.kv_cache_manager.get_needed_blocks_one_step( - req, self.two_steps_look_ahead, window_size - ) - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required blocks {required} for {window_size} window size" - ) - scheduled_total = num_scheduled + required - has_free = self.kv_cache_manager.scheduling_has_free_blocks( - scheduled_total, window_size - ) - if not has_free: - return None - blocks_if_scheduled[window_size] = scheduled_total - return blocks_if_scheduled - - def update_scheduled_blocks(self, blocks: dict[int, int]) -> None: - """ - Update the scheduled blocks after successfully scheduling a request. - C++ reference: scheduledBlocksManager.h:102-110 - """ - assert len(blocks) == len(self.num_scheduled_blocks), ( - f"Block count mismatch: {len(blocks)} vs {len(self.num_scheduled_blocks)}" - ) - for window_size, blocks_if_scheduled in blocks.items(): - logger.debug( - f"MaxUtilizationScheduler: scheduled blocks {blocks_if_scheduled} " - f"for window size {window_size}" - ) - self.num_scheduled_blocks[window_size] = blocks_if_scheduled - - -class PyCapacityScheduler: - """ - Python implementation of the C++ CapacityScheduler. - Aligned 1:1 with C++ logic in cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp. - Supports Multiple Window Sizes (VSWA), block reuse optimization, and all policies. - - Policies: - - MaxRequestsScheduler: No KV cache manager, simple request count limit - - GuaranteedNoEvictScheduler: Reserve blocks for completion, no eviction - - StaticBatchScheduler: Only schedule when no requests are active - - MaxUtilizationScheduler: Maximize utilization, may pause requests - - Reference: cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h - """ - - def __init__( - self, - max_num_requests: int, - kv_cache_manager=None, - peft_cache_manager=None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, - cross_kv_cache_manager=None, - two_step_lookahead: bool = False, - no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_COMPLETE, - ): - """ - Initialize the capacity scheduler. - - Args: - max_num_requests: Maximum number of requests to schedule - kv_cache_manager: KV cache manager (None for MaxRequestsScheduler) - peft_cache_manager: PEFT/LoRA cache manager (optional) - scheduler_policy: Scheduling policy - cross_kv_cache_manager: Cross-attention KV cache manager for encoder-decoder - two_step_lookahead: Enable two-step lookahead for MAX_UTILIZATION - no_schedule_until_state: Don't schedule until this state is reached - no_schedule_after_state: Don't schedule after this state is reached - """ - self.max_num_requests = max_num_requests - self.kv_cache_manager = kv_cache_manager - self.peft_cache_manager = peft_cache_manager - self.cross_kv_cache_manager = cross_kv_cache_manager - self.scheduler_policy = scheduler_policy - self.two_step_lookahead = two_step_lookahead - self.no_schedule_until_state = no_schedule_until_state - self.no_schedule_after_state = no_schedule_after_state - # Cache state values to avoid repeated .value access (optimization) - self._no_schedule_until_state_value = no_schedule_until_state.value - self._no_schedule_after_state_value = no_schedule_after_state.value - - # Initialize the appropriate policy - self._policy = self._create_policy() - - def _create_policy(self) -> SchedulerPolicyBase: - """Create the appropriate policy based on configuration.""" - if self.kv_cache_manager is None: - return MaxRequestsPolicy() - elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: - return MaxUtilizationPolicy() - elif self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: - return GuaranteedNoEvictPolicy(static_batch=False) - elif self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: - return GuaranteedNoEvictPolicy(static_batch=True) - else: - raise ValueError(f"Unsupported scheduler policy: {self.scheduler_policy}") - - def _can_be_scheduled(self, req: LlmRequest) -> bool: - """ - Check if request is within the schedulable state range. - Returns True if request has reached no_schedule_until_state - but has not yet reached no_schedule_after_state. - Optimized: use state_value property to avoid enum object creation - """ - # Use state_value property (returns int directly, avoids enum object creation) - state_value = req.state_value - # Inline comparison: must have reached until_state but not after_state - return ( - state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value - ) - - def _is_skipping_relevant(self) -> bool: - """ - Check if block reuse skip optimization is relevant. - Disabled for VSWA (Variable Sliding Window Attention). - C++ reference: capacityScheduler.cpp:207-208, 348 - """ - if self.kv_cache_manager is None: - return False - if self.kv_cache_manager.is_variable_window: - return False - if ( - self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.is_variable_window - ): - return False - return True - - def _prefill_contributed_blocks(self, active_requests: RequestList) -> tuple[set, set]: - """ - Collect blocks contributed by chunked context requests already executing. - These blocks can be reused by later requests. - - C++ reference: capacityScheduler.cpp:34-68 (prefillWithChunkedContextsAlreadyExecuting) - """ - newly_contributed_context_blocks: Set = set() - newly_contributed_cross_context_blocks: Set = set() - - if self.kv_cache_manager is None: - return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - - enable_block_reuse = self.kv_cache_manager.enable_block_reuse - cross_enable_reuse = ( - self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.enable_block_reuse - ) - - for req in active_requests: - # Check: isContextInitState() && !isFirstContextChunk() - if req.is_context_init_state and not req.is_first_context_chunk: - # Chunked context request already executing - if enable_block_reuse: - unique_tokens = req.get_unique_tokens(0) - block_key = self.kv_cache_manager.find_new_context_block(unique_tokens, req) - if block_key is not None: - newly_contributed_context_blocks.add(block_key) - - if cross_enable_reuse: - encoder_unique_tokens = req.get_encoder_unique_tokens() - if encoder_unique_tokens is not None: - block_key = self.cross_kv_cache_manager.find_new_context_block( - encoder_unique_tokens, req - ) - if block_key is not None: - newly_contributed_cross_context_blocks.add(block_key) - - return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - - def _one_manager_beneficial_to_skip( - self, kv_cache_manager, unique_tokens, req: LlmRequest, newly_contributed_blocks: set - ) -> bool: - """ - Check if skipping is beneficial for one KV cache manager. - C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) - """ - new_context_block = kv_cache_manager.find_new_context_block(unique_tokens, req) - if new_context_block is not None: - if new_context_block in newly_contributed_blocks: - return True - newly_contributed_blocks.add(new_context_block) - return False - - def _beneficial_to_skip( - self, - req: LlmRequest, - newly_contributed_context_blocks: set, - newly_contributed_cross_context_blocks: set, - ) -> bool: - """ - Check if it's beneficial to skip this request. - A request should be skipped if it can reuse blocks contributed by - already scheduled context requests. - - C++ reference: capacityScheduler.cpp:97-123 (beneficialToSkip) - """ - if not (req.is_context_init_state and req.is_first_context_chunk): - return False - - if self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse: - unique_tokens = req.get_unique_tokens(0) - if self._one_manager_beneficial_to_skip( - self.kv_cache_manager, unique_tokens, req, newly_contributed_context_blocks - ): - return True - - if ( - self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.enable_block_reuse - ): - encoder_unique_tokens = req.get_encoder_unique_tokens() - if encoder_unique_tokens is not None: - if self._one_manager_beneficial_to_skip( - self.cross_kv_cache_manager, - encoder_unique_tokens, - req, - newly_contributed_cross_context_blocks, - ): - return True - - return False - - def _get_max_peft_pages(self) -> int: - """Get maximum PEFT cache pages.""" - if self.peft_cache_manager is None: - return 2**31 - 1 # INT_MAX equivalent - return self.peft_cache_manager.max_device_pages - - def _get_peft_pages_for_request(self, req: LlmRequest) -> int: - """Get PEFT pages needed for a request.""" - if self.peft_cache_manager is None: - return 0 - return self.peft_cache_manager.determine_num_pages(req) - - def _get_peft_task_info( - self, req: LlmRequest, seen_task_ids: set[int] - ) -> tuple[Optional[int], bool, int]: - """ - Get PEFT task information for a request. - Returns (lora_task_id, is_new_task, required_pages). - """ - lora_task_id = getattr(req, "lora_task_id", None) - is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids - required_pages = self._get_peft_pages_for_request(req) if is_new_task else 0 - return lora_task_id, is_new_task, required_pages - - def _can_be_scheduled_with_disagg_exception(self, req: LlmRequest) -> bool: - """ - Check if request can be scheduled, with exception for disagg generation init state. - Disagg generation init requests bypass the normal state gating. - """ - if req.is_disagg_generation_init_state: - return True - return self._can_be_scheduled(req) - - def schedule_request( - self, active_requests: RequestList - ) -> tuple[RequestList, RequestList, RequestList]: - """ - Schedule requests based on the configured policy. - - Args: - active_requests: List of active requests to consider - - Returns: - Tuple of (fitting_requests, fitting_disagg_gen_init_requests, paused_requests) - - C++ reference: capacityScheduler.cpp:488-539 (CapacityScheduler::operator()) - """ - scheduled, paused = self._policy.schedule(self, active_requests) - - fitting_requests, fitting_disagg_gen_init_requests = self._classify_output(scheduled) - - logger.debug( - f"[Summary] Capacity scheduler allows {len(fitting_requests)} requests, " - f"pauses {len(paused)} requests" - ) - - return fitting_requests, fitting_disagg_gen_init_requests, paused - - def _classify_output(self, scheduled_requests: RequestList) -> tuple[RequestList, RequestList]: - """ - Separate scheduled requests into normal requests and disagg gen init requests. - C++ reference: capacityScheduler.cpp:522-534 - """ - fitting_requests: RequestList = [] - fitting_disagg_gen_init_requests: RequestList = [] - for req in scheduled_requests: - if req.is_disagg_generation_init_state: - fitting_disagg_gen_init_requests.append(req) - else: - fitting_requests.append(req) - return fitting_requests, fitting_disagg_gen_init_requests - - -class SimpleUnifiedScheduler(RequestScheduler): - def __init__( - self, - max_batch_size: int, - max_num_tokens: int, - kv_cache_manager, - peft_cache_manager, - scheduler_policy: CapacitySchedulerPolicy, - ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, - cross_kv_cache_manager=None, - two_step_lookahead: bool = False, - scheduler_capacity: Optional[int] = None, - ): - # Use scheduler_capacity if provided, otherwise fall back to max_batch_size - # scheduler_capacity may differ from max_batch_size (e.g., adjusted for attention_dp + disagg) - capacity = scheduler_capacity if scheduler_capacity is not None else max_batch_size - - # 1. Initialize Python Capacity Scheduler - # Now fully aligned with C++ CapacityScheduler - self.capacity_scheduler = PyCapacityScheduler( - max_num_requests=capacity, - kv_cache_manager=kv_cache_manager, - peft_cache_manager=peft_cache_manager, - scheduler_policy=scheduler_policy, - cross_kv_cache_manager=cross_kv_cache_manager, - two_step_lookahead=two_step_lookahead, - ) - - # 2. Initialize Python MicroBatch Scheduler - py_chunk_config = None - if ctx_chunk_config: - # Fix: Use string comparison to identify the policy. - # This works regardless of whether the input is a Python Enum, C++ Binding Enum, or String. - input_policy = ctx_chunk_config[0] - - if "EQUAL_PROGRESS" in str(input_policy): - policy_enum = ChunkingPolicy.EQUAL_PROGRESS - else: - # Default to FCFS for FIRST_COME_FIRST_SERVED or others - policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED - - py_chunk_config = ContextChunkingConfig(policy_enum, ctx_chunk_config[1]) - - self.micro_batch_scheduler = PyMicroBatchScheduler( - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - ctx_chunk_config=py_chunk_config, - ) - - def schedule_request( - self, active_requests: RequestList, inflight_request_ids: set[int] - ) -> SchedulerOutput: - # Step 1: Capacity Check (Who fits in memory?) - fitting_requests, fitting_disagg_gen_init, paused_requests = ( - self.capacity_scheduler.schedule_request(active_requests) - ) - - # Step 2: MicroBatch Check (Who fits in token budget? + Chunking) - context_requests, generation_requests = self.micro_batch_scheduler.schedule( - fitting_requests, inflight_request_ids - ) - - return SchedulerOutput( - context_requests=context_requests, - generation_requests=generation_requests, - paused_requests=paused_requests, - fitting_disagg_gen_init_requests=fitting_disagg_gen_init, - num_fitting_requests=len(fitting_requests), - ) - - def can_schedule(self, requests: RequestList) -> bool: - # Dry run capacity check - fitting, _, _ = self.capacity_scheduler.schedule_request(requests) - return len(fitting) == len(requests) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py index 369661ee13eb..92347d125080 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py @@ -20,7 +20,14 @@ from tensorrt_llm.logger import logger from ..llm_request import LlmRequest, LlmRequestState, get_draft_token_length -from .scheduler import RequestList, RequestScheduler, SchedulerOutput +from .scheduler import ( + RequestList, + RequestScheduler, + SchedulerOutput, + ScheduleStepConfig, + compute_fcfs_context_chunk_size, + sort_requests_by_lora, +) class ScheduleAction(enum.Enum): @@ -122,7 +129,12 @@ def __init__( no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_TO_COMPLETE, draft_kv_cache_manager=None, # KVCacheManagerV2 for MTP draft layers + schedule_step_config: Optional[ScheduleStepConfig] = None, + dist=None, ): + super(KVCacheV2Scheduler, self).__init__( + schedule_step_config=schedule_step_config, dist=dist + ) self.max_num_tokens = max_num_tokens self.max_num_requests = ( scheduler_capacity if scheduler_capacity is not None else max_batch_size @@ -178,7 +190,7 @@ def schedule_request( ) # Sort by LoRA task ID - self._sort_requests(scheduled_ctx, scheduled_gen, has_chunking) + sort_requests_by_lora(scheduled_ctx, scheduled_gen, has_chunking) return SchedulerOutput( context_requests=scheduled_ctx, @@ -379,19 +391,13 @@ def _try_schedule_context_chunked( # Calculate chunk size from remaining budget # (context_remaining_length is now correct after block reuse) context_remaining = req.context_remaining_length - chunk_size = ( - min(remaining_budget, context_remaining) - if remaining_budget is not None - else context_remaining + chunk_size = compute_fcfs_context_chunk_size( + context_remaining, + remaining_budget, + self.max_context_length, + self.chunk_unit_size, ) - if self.max_context_length is not None: - chunk_size = min(chunk_size, self.max_context_length) - - # Round down to chunk_unit_size boundary (unless last chunk). - if chunk_size < context_remaining: - chunk_size = (chunk_size // self.chunk_unit_size) * self.chunk_unit_size - if chunk_size <= 0: # TODO: consider suspending first-chunk KVCache to release # GPU pages. Currently we skip without suspend to avoid @@ -535,29 +541,6 @@ def _try_evict_for_gen(self, req, requests_list, req_it, req_it_end, evicted): return req_it_end, False - # ---- Sorting ---- - - @staticmethod - def _lora_key(req: LlmRequest): - lora_id = getattr(req, "lora_task_id", None) - if lora_id is None: - return (0, 0) - return (1, lora_id) - - def _sort_requests(self, context_requests, generation_requests, has_chunks): - """Sort by LoRA task ID. Non-last chunks before last chunks.""" - if has_chunks: - not_last = [r for r in context_requests if not r.is_last_context_chunk] - last = [r for r in context_requests if r.is_last_context_chunk] - not_last.sort(key=self._lora_key) - last.sort(key=self._lora_key) - context_requests.clear() - context_requests.extend(not_last) - context_requests.extend(last) - else: - context_requests.sort(key=self._lora_key) - generation_requests.sort(key=self._lora_key) - # ---- can_schedule (PP dry-run) ---- def can_schedule(self, requests: RequestList) -> bool: diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/setup_mypyc.py b/tensorrt_llm/_torch/pyexecutor/scheduler/setup_mypyc.py new file mode 100644 index 000000000000..10afa44786b8 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/setup_mypyc.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setup script for compiling the scheduler module with mypyc. + +Usage (from tensorrt_llm/_torch/pyexecutor/): + python scheduler/setup_mypyc.py build_ext --inplace +""" + +import glob +import os +import shutil +import sys + +from mypyc.build import mypycify +from setuptools import setup + +# Set environment variables BEFORE importing mypyc +os.environ["MYPY_FORCE_COLOR"] = "0" + +# Write mypy.ini in the cwd (pyexecutor/) so mypy finds it before climbing to +# the repo-root pyproject.toml (whose [tool.mypy] lacks our error suppressions). +# This changes module resolution to the full path, so we manually copy .so files +# to scheduler/ in the finally block. +mypy_config_path = os.path.abspath("mypy.ini") +with open(mypy_config_path, "w") as f: + f.write("""[mypy] +# Critical: Don't follow any imports outside the specified files +follow_imports = skip +follow_imports_for_stubs = False + +# Ignore missing imports completely +ignore_missing_imports = True + +# Allow all untyped code +allow_untyped_calls = True +allow_untyped_defs = True +allow_incomplete_defs = True +allow_untyped_globals = True +check_untyped_defs = False + +# Disable all warnings that might cause errors +disallow_untyped_calls = False +disallow_untyped_defs = False +disallow_incomplete_defs = False +warn_return_any = False +warn_unused_ignores = False + +# Disable type errors that are safe at runtime: +# - valid-type: external types (nanobind objects) +# - union-attr: Optional[X].method() guarded by cached bool locals +# - arg-type: int|None passed to set.add() after None-checked getattr +# - attr-defined: external module attributes (StrEnum._to_pybind) +# - misc: relative imports beyond mypyc's resolution scope +# - operator: None arithmetic guarded by runtime checks +# - assignment: conditional None assignment to typed vars +# - annotation-unchecked: untyped function bodies (safe at runtime) +disable_error_code = valid-type, union-attr, arg-type, attr-defined, misc, operator, assignment, annotation-unchecked +""") + +# Compile only the unified_scheduler module (the hot path). +# Other scheduler files (scheduler.py, adp_router.py, waiting_queue.py) +# are thin wrappers or C++ bindings that don't benefit from compilation. +modules = [ + "scheduler/unified_scheduler.py", +] + +print(f"Compiling {len(modules)} modules with mypyc...") +print("") + +try: + ext_modules = mypycify( + modules, + opt_level="3", # Maximum optimization + multi_file=False, # Single module, no cross-file references needed + verbose=True, # Show what's being compiled + separate=False, # Compile into single .so + strip_asserts=False, # Keep assertions for debugging + ) + +except Exception as e: + print(f"Error during mypyc compilation: {e}") + sys.exit(1) +finally: + # Cleanup temp config + if os.path.exists(mypy_config_path): + try: + os.remove(mypy_config_path) + except OSError: + pass + + # Remove --config-file arguments from sys.argv before calling setup() + while "--config-file" in sys.argv: + idx = sys.argv.index("--config-file") + sys.argv.pop(idx) # Remove '--config-file' + if idx < len(sys.argv): # Remove the path that follows it + sys.argv.pop(idx) + +# mypy.ini in cwd causes mypyc to resolve the full module path, so --inplace +# tries to copy .so files to tensorrt_llm/_torch/pyexecutor/scheduler/ relative +# to cwd. Create that directory so --inplace succeeds, then copy to scheduler/. +_full_path_dir = os.path.join("tensorrt_llm", "_torch", "pyexecutor", "scheduler") +os.makedirs(_full_path_dir, exist_ok=True) + +setup( + name="scheduler_compiled", + ext_modules=ext_modules, + package_data={ + "scheduler": ["*.pyi", "**/*.pyi"], + }, + python_requires=">=3.8", +) + +# Copy .so files from the full path dir to scheduler/ and clean up +for so_file in glob.glob(os.path.join(_full_path_dir, "*.so")): + dest = os.path.join("scheduler", os.path.basename(so_file)) + shutil.copy2(so_file, dest) + print(f"Copied {so_file} -> {dest}") +shutil.rmtree("tensorrt_llm", ignore_errors=True) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/simple_scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler/simple_scheduler.py new file mode 100644 index 000000000000..e1c000510272 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/simple_scheduler.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SimpleScheduler: two-pass scheduling (capacity -> microbatch) using C++ bindings.""" + +from typing import Optional + +from strenum import StrEnum + +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestState +from tensorrt_llm.bindings import internal as tb_internal +from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy + +from .scheduler import RequestList, RequestScheduler, SchedulerOutput, ScheduleStepConfig + + +class BindCapacityScheduler: + def __init__( + self, + max_num_requests: int, + kv_cache_manager, + peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, + two_step_lookahead: bool = False, + ): + self.kv_cache_manager = kv_cache_manager + self.peft_cache_manager = peft_cache_manager + + self.impl = tb_internal.algorithms.CapacityScheduler( + max_num_requests=max_num_requests, + capacity_scheduler_policy=scheduler_policy._to_pybind(), + has_kv_cache_manager=kv_cache_manager is not None, + two_step_lookahead=two_step_lookahead, + no_schedule_until_state=LlmRequestState.CONTEXT_INIT, + no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, + ) + + def schedule_request( + self, active_requests: RequestList + ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: + return self.impl(active_requests, self.kv_cache_manager, self.peft_cache_manager) + + +class BindMicroBatchScheduler: + def __init__( + self, + max_batch_size: int, + max_num_tokens: int = None, + ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, + ) -> None: + self.max_batch_size = max_batch_size + self.max_num_tokens = max_num_tokens + + ctx_chunk_config_cpp = None + if ctx_chunk_config is not None: + policy = ctx_chunk_config[0] + ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig( + policy._to_pybind(), + ctx_chunk_config[1], # type: ignore[attr-defined] + ) + + self.impl = tb_internal.algorithms.MicroBatchScheduler(ctx_chunk_config_cpp, max_num_tokens) + + def schedule( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> tuple[list[LlmRequest], list[LlmRequest]]: + return self.impl( + active_requests, inflight_request_ids, self.max_batch_size, self.max_num_tokens + ) + + +class SimpleScheduler(RequestScheduler): + def __init__( + self, + capacity_scheduler, + micro_batch_scheduler, + schedule_step_config: Optional[ScheduleStepConfig] = None, + dist=None, + ): + super(SimpleScheduler, self).__init__(schedule_step_config=schedule_step_config, dist=dist) + self.capacity_scheduler = capacity_scheduler + self.micro_batch_scheduler = micro_batch_scheduler + + def schedule_request( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> SchedulerOutput: + fitting_requests, fitting_disagg_gen_init_requests, paused_requests = ( + self.capacity_scheduler.schedule_request(active_requests) + ) + + context_requests, generation_requests = self.micro_batch_scheduler.schedule( + fitting_requests, inflight_request_ids + ) + # Convert from binding type RequestVector to list[LlmRequest], + # so Python fields on LlmRequest won't be stripped away + return SchedulerOutput( + list(context_requests), + list(generation_requests), + list(paused_requests), + list(fitting_disagg_gen_init_requests), + len(fitting_requests), + ) + + def can_schedule(self, requests: RequestList) -> bool: + fitting_requests, _, _ = self.capacity_scheduler.schedule_request(requests) + return len(fitting_requests) == len(requests) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/unified_scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler/unified_scheduler.py new file mode 100644 index 000000000000..a516c94bf53d --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/unified_scheduler.py @@ -0,0 +1,1302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional, Set + +from strenum import StrEnum + +from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy +from tensorrt_llm.logger import logger + +from ..llm_request import LlmRequest, LlmRequestState +from .scheduler import ( + RequestList, + RequestScheduler, + SchedulerOutput, + ScheduleStepConfig, + compute_fcfs_context_chunk_size, + sort_requests_by_lora, +) + + +class ChunkingPolicy(Enum): + EQUAL_PROGRESS = 1 + FIRST_COME_FIRST_SERVED = 2 + + +@dataclasses.dataclass +class ContextChunkingConfig: + chunking_policy: ChunkingPolicy + chunk_unit_size: int + + +class TokenBudgetTracker: + """Fused capacity + microbatch token budget tracker. + + Integrates token-budget checks, request type classification, context + chunking, and sorting into the capacity policy loop — replacing the + old two-pass (capacity → microbatch) pipeline with a single pass. + + Provides type-specialized methods for hot paths: + - try_add_generation(): fast path for generation requests (no type checks) + - try_add_context(): fast path for context requests (no beam/gen checks) + - try_add(): generic path for mixed/unknown request types + """ + + __slots__ = [ + "max_batch_size", + "max_num_tokens", + "max_context_length", + "_has_token_limit", + "ctx_chunk_config", + "_inflight_ids", + "_batch_num_tokens", + "_scheduled_req_size", + "_scheduled_beam_width", + "_context_requests", + "_generation_requests", + "_contexts_to_be_chunked", + "_num_chunked_tokens", + "_all_context_requests_fit", + "_num_fitting", + "batch_full", + ] + + # Cache state values once at class level to avoid repeated .value access + _CONTEXT_INIT_VALUE = LlmRequestState.CONTEXT_INIT.value + _ENCODER_INIT_VALUE = LlmRequestState.ENCODER_INIT.value + _GEN_IN_PROGRESS_VALUE = LlmRequestState.GENERATION_IN_PROGRESS.value + _NO_SCHEDULE_UNTIL_VALUE = LlmRequestState.CONTEXT_INIT.value + _NO_SCHEDULE_AFTER_VALUE = LlmRequestState.GENERATION_TO_COMPLETE.value + + def __init__( + self, + max_batch_size: int, + max_num_tokens: Optional[int], + ctx_chunk_config: Optional[ContextChunkingConfig], + inflight_request_ids: object = None, + ): + self.max_batch_size = max_batch_size + self.max_num_tokens = max_num_tokens + self.max_context_length = max_num_tokens + self._has_token_limit = max_num_tokens is not None + self.ctx_chunk_config = ctx_chunk_config + # Accepts set[int] or C++ ReqIdsSet — only `in` operator is used. + self._inflight_ids = inflight_request_ids + self._batch_num_tokens = 0 + self._scheduled_req_size = 0 + self._scheduled_beam_width = 0 + self._context_requests: RequestList = [] + self._generation_requests: RequestList = [] + self._contexts_to_be_chunked: RequestList = [] + self._num_chunked_tokens = 0 + self._all_context_requests_fit = True + self._num_fitting = 0 + self.batch_full = False + + def try_add(self, req: LlmRequest) -> bool: + """Try to add a request to the batch. Returns False if token budget exceeded.""" + # Skip if in flight + if self._inflight_ids is not None and req.request_id in self._inflight_ids: + return True # don't reject, just skip token accounting + + # Disagg gen init requests bypass token accounting — they are + # classified separately by the capacity policy. + if req.is_disagg_generation_init_state: + return True + + req_state_value = req.state_value + + # Skip if not in schedulable state range + if not ( + req_state_value >= self._NO_SCHEDULE_UNTIL_VALUE + and req_state_value < self._NO_SCHEDULE_AFTER_VALUE + ): + return True # don't reject from capacity, just skip + + req_num_tokens = 0 + + # --- Encoder --- + if req_state_value == self._ENCODER_INIT_VALUE: + req_num_tokens = req.encoder_output_len + if self.max_context_length is not None: + assert req_num_tokens <= self.max_context_length, ( + f"The number of encoder tokens ({req_num_tokens}) exceeds " + f"the limit value ({self.max_context_length})" + ) + if self._has_token_limit and ( + self._batch_num_tokens + req_num_tokens > self.max_num_tokens + ): + return False + self._context_requests.append(req) + self._batch_num_tokens += req_num_tokens + + # --- Context --- + elif req_state_value == self._CONTEXT_INIT_VALUE: + if not self.ctx_chunk_config: + base_tokens = req.get_num_tokens(0) + draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 + req_num_tokens = base_tokens + draft_tokens + if self.max_context_length is not None: + assert req_num_tokens <= self.max_context_length, ( + f"Context tokens ({req_num_tokens}) exceeds " + f"limit ({self.max_context_length})" + ) + if self._has_token_limit and ( + self._batch_num_tokens + req_num_tokens > self.max_num_tokens + ): + return False + self._context_requests.append(req) + self._batch_num_tokens += req_num_tokens + else: + # Chunking: tentative schedule (finalized later) + req.context_chunk_size = req.context_remaining_length + draft_tokens = ( + req.num_draft_tokens + if (req.is_last_context_chunk and req.has_draft_tokens) + else 0 + ) + req_num_tokens = req.context_chunk_size + draft_tokens + if self.max_context_length is not None: + if self.max_context_length < req_num_tokens: + req_num_tokens = self.max_context_length + self._all_context_requests_fit = False + self._contexts_to_be_chunked.append(req) + self._num_chunked_tokens += req_num_tokens + + # --- Generation --- + else: + beam_width = req.get_beam_width_by_iter(for_next_iteration=False) + req_num_tokens = beam_width + req.num_draft_tokens + if self._has_token_limit and ( + self._batch_num_tokens + req_num_tokens > self.max_num_tokens + ): + return False + # Beam width consistency + if self._scheduled_beam_width == 0: + self._scheduled_beam_width = beam_width + elif self._scheduled_beam_width != beam_width: + return True # skip this request, don't reject from capacity + self._generation_requests.append(req) + self._batch_num_tokens += req_num_tokens + + self._scheduled_req_size += 1 + self._num_fitting += 1 + if self._scheduled_req_size >= self.max_batch_size: + self.batch_full = True + return True + + def try_add_generation(self, req: LlmRequest) -> int: + """Fast path for generation-in-progress requests. + + Returns: + 1: accepted, continue scheduling + 0: rejected (token budget exceeded) + -1: accepted but batch is full (stop scheduling) + """ + if self._inflight_ids is not None and req.request_id in self._inflight_ids: + return 1 # skip token accounting for inflight requests + beam_width = req.get_beam_width_by_iter(for_next_iteration=False) + req_num_tokens = beam_width + req.num_draft_tokens + if self._has_token_limit and ( + self._batch_num_tokens + req_num_tokens > self.max_num_tokens + ): + return 0 + # Beam width consistency + if self._scheduled_beam_width == 0: + self._scheduled_beam_width = beam_width + elif self._scheduled_beam_width != beam_width: + return 1 # skip, don't reject from capacity + self._generation_requests.append(req) + self._batch_num_tokens += req_num_tokens + self._scheduled_req_size += 1 + self._num_fitting += 1 + if self._scheduled_req_size >= self.max_batch_size: + self.batch_full = True + return -1 + return 1 + + def try_add_context(self, req: LlmRequest) -> bool: + """Fast path for context-init requests. + + Skips state-range, encoder, generation, and beam checks. + Called from the second pass of capacity policies where requests + are known to be context_init or disagg_generation_init. + """ + if self._inflight_ids is not None and req.request_id in self._inflight_ids: + return True # skip token accounting for inflight requests + if req.is_disagg_generation_init_state: + return True # classified separately by the capacity policy + + if not self.ctx_chunk_config: + base_tokens = req.get_num_tokens(0) + draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 + req_num_tokens = base_tokens + draft_tokens + if self.max_context_length is not None: + assert req_num_tokens <= self.max_context_length, ( + f"Context tokens ({req_num_tokens}) exceeds limit ({self.max_context_length})" + ) + if self._has_token_limit and ( + self._batch_num_tokens + req_num_tokens > self.max_num_tokens + ): + return False + self._context_requests.append(req) + self._batch_num_tokens += req_num_tokens + else: + # Chunking: tentative schedule (finalized later) + req.context_chunk_size = req.context_remaining_length + draft_tokens = ( + req.num_draft_tokens if (req.is_last_context_chunk and req.has_draft_tokens) else 0 + ) + req_num_tokens = req.context_chunk_size + draft_tokens + if self.max_context_length is not None: + if self.max_context_length < req_num_tokens: + req_num_tokens = self.max_context_length + self._all_context_requests_fit = False + self._contexts_to_be_chunked.append(req) + self._num_chunked_tokens += req_num_tokens + + self._scheduled_req_size += 1 + self._num_fitting += 1 + if self._scheduled_req_size >= self.max_batch_size: + self.batch_full = True + return True + + def finalize(self) -> tuple[RequestList, RequestList, int]: + """Apply chunking and sorting. Returns (context, generation, num_fitting). + + Note: num_fitting reflects requests admitted by try_add/try_add_* + before chunking. Requests with context_chunk_size == 0 are dropped + from context_requests below but num_fitting is not decremented. + See class docstring item 3 for the full semantics. + """ + # Verify chunking fits + if self._has_token_limit and self._num_chunked_tokens > ( + self.max_num_tokens - self._batch_num_tokens + ): + self._all_context_requests_fit = False + + # Apply chunking strategy + if not self._all_context_requests_fit and self._contexts_to_be_chunked: + remaining_capacity = ( + (self.max_num_tokens - self._batch_num_tokens) if self._has_token_limit else None + ) + self._set_ctx_requests_chunk_size(self._contexts_to_be_chunked, remaining_capacity) + + # Finalize chunked requests + for req in self._contexts_to_be_chunked: + if req.context_chunk_size > 0: + self._context_requests.append(req) + self._batch_num_tokens += req.context_chunk_size + + # Sort requests for consistency with C++ + sort_requests_by_lora( + self._context_requests, self._generation_requests, not self._all_context_requests_fit + ) + + return (self._context_requests, self._generation_requests, self._num_fitting) + + # ------------------------------------------------------------------ + # Context chunking and sorting helpers + # ------------------------------------------------------------------ + + def _set_ctx_requests_chunk_size(self, requests: RequestList, capacity: Optional[int]): + for req in requests: + req.context_chunk_size = 0 + + policy = self.ctx_chunk_config.chunking_policy + unit_size = self.ctx_chunk_config.chunk_unit_size + + if policy == ChunkingPolicy.EQUAL_PROGRESS: + self._chunk_equal_progress(requests, capacity, unit_size) + elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: + self._chunk_fcfs(requests, capacity, unit_size) + else: + raise ValueError(f"Invalid chunking policy: {policy}") + + self._fit_draft_tokens(requests, capacity, unit_size) + + def _chunk_equal_progress(self, requests: RequestList, capacity: Optional[int], unit_size: int): + num_ctx_tokens = 0 + num_tokens_single_loop = 1 + + while (capacity is None or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: + num_tokens_single_loop = 0 + for req in requests: + past_size = req.context_chunk_size + suggested_size = min(past_size + unit_size, req.context_remaining_length) + req.context_chunk_size = suggested_size + actual_size = req.context_chunk_size + actual_increment = actual_size - past_size + if capacity is not None and (num_ctx_tokens + actual_increment > capacity): + req.context_chunk_size = past_size + continue + if self.max_context_length is not None and actual_size > self.max_context_length: + req.context_chunk_size = past_size + continue + num_ctx_tokens += actual_increment + num_tokens_single_loop += actual_increment + + def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], unit_size: int): + current_capacity = capacity if capacity is not None else float("inf") + + for req in requests: + req.context_chunk_size = compute_fcfs_context_chunk_size( + req.context_remaining_length, + current_capacity if capacity is not None else None, + self.max_context_length, + unit_size, + ) + if capacity is not None: + current_capacity -= req.context_chunk_size + + def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], unit_size: int): + num_ctx_tokens = sum(req.context_chunk_size for req in requests) + + for req in requests: + if req.is_last_context_chunk and req.has_draft_tokens: + remainder = req.context_chunk_size % unit_size + remaining_space = 0 if remainder == 0 else unit_size - remainder + if self.max_context_length is not None: + remaining_context_len = self.max_context_length - req.context_chunk_size + remaining_space = min(remaining_space, remaining_context_len) + if capacity is not None: + remaining_space = min(remaining_space, capacity - num_ctx_tokens) + num_ctx_tokens += remaining_space + draft_discard = req.num_draft_tokens - remaining_space + if draft_discard > 0: + logger.debug(f"Discarding {draft_discard} draft tokens") + if hasattr(req, "discard_draft_tokens"): + req.discard_draft_tokens(draft_discard) + + +class SchedulerPolicyBase(ABC): + """ + Abstract base class for capacity scheduler policies. + Each policy implements its own scheduling logic. + """ + + @abstractmethod + def schedule( + self, + scheduler: "PyCapacityScheduler", + active_requests: RequestList, + token_tracker: Optional[TokenBudgetTracker] = None, + ) -> tuple[RequestList, RequestList, RequestList]: + """ + Schedule requests according to the policy. + + Args: + scheduler: The capacity scheduler instance (for accessing shared state) + active_requests: List of active requests to schedule + token_tracker: If provided, fuses token-budget checks into the + capacity loop for single-pass scheduling. + + Returns: + Tuple of (fitting_requests, fitting_disagg_gen_init_requests, + paused_requests) + """ + raise NotImplementedError + + +class MaxRequestsPolicy(SchedulerPolicyBase): + """ + MaxRequestsScheduler: Simple request count limiting without KV cache checks. + C++ reference: capacityScheduler.cpp:154-176 + """ + + def schedule( + self, + scheduler: "PyCapacityScheduler", + active_requests: RequestList, + token_tracker: Optional[TokenBudgetTracker] = None, + ) -> tuple[RequestList, RequestList, RequestList]: + scheduled_requests: RequestList = [] + + for req in active_requests: + if not scheduler._can_be_scheduled(req): + continue + + if len(scheduled_requests) >= scheduler.max_num_requests: + break + + if ( + req.is_encoder_init_state + or req.is_context_init_state + or req.is_generation_in_progress_state + ): + if token_tracker is not None: + if not token_tracker.try_add(req): + break + if token_tracker.batch_full: + scheduled_requests.append(req) + break + scheduled_requests.append(req) + + return scheduled_requests, [], [] + + +class GuaranteedNoEvictPolicy(SchedulerPolicyBase): + """ + GuaranteedNoEvictScheduler: Reserve blocks for requests to complete without eviction. + C++ reference: capacityScheduler.cpp:194-331 + """ + + def __init__(self, static_batch: bool = False): + self.static_batch = static_batch + + def schedule( + self, + scheduler: "PyCapacityScheduler", + active_requests: RequestList, + token_tracker: Optional[TokenBudgetTracker] = None, + ) -> tuple[RequestList, RequestList, RequestList]: + scheduled_requests: RequestList = [] + fitting_disagg: RequestList = [] + has_peft = scheduler.peft_cache_manager is not None + + skipping_is_relevant = scheduler._is_skipping_relevant() + + newly_contributed_context_blocks: Set = set() + newly_contributed_cross_context_blocks: Set = set() + if not self.static_batch and skipping_is_relevant: + newly_contributed_context_blocks, newly_contributed_cross_context_blocks = ( + scheduler._prefill_contributed_blocks(active_requests) + ) + + reserved_blocks = NoEvictScheduledBlocksManager(scheduler.kv_cache_manager) + reserved_cross_blocks: Optional[NoEvictScheduledBlocksManager] = None + if scheduler.cross_kv_cache_manager is not None: + reserved_cross_blocks = NoEvictScheduledBlocksManager(scheduler.cross_kv_cache_manager) + + # PEFT state - only used when has_peft + claimed_peft_pages = 0 + available_peft_pages = scheduler._get_max_peft_pages() if has_peft else 0 + uniq_task_ids: Optional[set[int]] = set() if has_peft else None + + pending_requests: RequestList = [] + pending_dis_gen_init_requests: RequestList = [] + + # Cache hot-path locals to avoid repeated attribute/method lookups + _has_tracker = token_tracker is not None + _until = scheduler._no_schedule_until_state_value + _after = scheduler._no_schedule_after_state_value + _max_num = scheduler.max_num_requests + _has_cross = reserved_cross_blocks is not None + _gen_in_progress = scheduler._gen_in_progress_state_value + _sched_append = scheduled_requests.append + _pending_append = pending_requests.append + _disagg_pending_append = pending_dis_gen_init_requests.append + num_scheduled = 0 + + # First pass: process in-progress generation and classify requests. + # Block decrements are deferred to batch_decrement_list after the loop + # for a single Python→C++ call (available_blocks is not read in first pass). + # At this point scheduled_requests contains only generation requests + # (context → pending_requests, disagg → pending_dis_gen_init_requests), + # so we pass it directly to batch_decrement_list instead of a separate list. + for req in active_requests: + # Inlined _can_be_scheduled_with_disagg_exception + is_disagg = req.is_disagg_generation_init_state + if not is_disagg: + sv = req.state_value + if not (_until <= sv < _after): + continue + + if num_scheduled >= _max_num: + break + + # sv is defined (set above on the non-disagg branch); + # replaces req.is_generation_in_progress_state to reuse cached state_value. + if not is_disagg and sv == _gen_in_progress: + # rc: 1=continue, 0=token budget exceeded, -1=batch full + rc = 1 + if _has_tracker: + rc = token_tracker.try_add_generation(req) + if rc == 0: + break # token budget exceeded + + _sched_append(req) + num_scheduled += 1 + + if has_peft: + lora_task_id, is_new_task, peft_pages = scheduler._get_peft_task_info( + req, uniq_task_ids + ) + if is_new_task: + claimed_peft_pages += peft_pages + uniq_task_ids.add(lora_task_id) + + if rc < 0: + break # batch full (after all bookkeeping) + + elif is_disagg: + _disagg_pending_append(req) + else: + _pending_append(req) + + # Batch-decrement blocks using C++ batch API (single boundary crossing) + reserved_blocks.batch_decrement_list(scheduled_requests) + if _has_cross: + reserved_cross_blocks.batch_decrement_list(scheduled_requests) + # Sync single-window scalar back to dict before second pass reads it + reserved_blocks.sync_to_dict() + if _has_cross: + reserved_cross_blocks.sync_to_dict() + + # Second pass: process pending requests + # Skip entirely if first pass already filled the batch + if (_has_tracker and token_tracker.batch_full) or num_scheduled >= _max_num: + return scheduled_requests, fitting_disagg, [] + if not self.static_batch or num_scheduled == 0: + if has_peft: + available_peft_pages -= claimed_peft_pages + + # Disagg requests: all are disagg_generation_init — skip + # beneficial_to_skip (never applies) and route directly to + # fitting_disagg. + _fitting_disagg_append = fitting_disagg.append + for req in pending_dis_gen_init_requests: + if num_scheduled >= _max_num: + break + + if not reserved_blocks.preview_reserve(req): + break + if _has_cross: + if not reserved_cross_blocks.preview_reserve(req): + break + + if has_peft: + lora_task_id, is_new_task, needed_peft_pages = scheduler._get_peft_task_info( + req, uniq_task_ids + ) + if needed_peft_pages > available_peft_pages: + continue + available_peft_pages -= needed_peft_pages + if is_new_task: + uniq_task_ids.add(lora_task_id) + + if _has_tracker: + if not token_tracker.try_add_context(req): + break + _fitting_disagg_append(req) + num_scheduled += 1 + reserved_blocks.commit_preview() + if _has_cross: + reserved_cross_blocks.commit_preview() + if _has_tracker and token_tracker.batch_full: + break + + # Context/encoder requests: none are disagg — skip disagg + # checks and route directly to scheduled_requests. + _skip_check = not self.static_batch and skipping_is_relevant + for req in pending_requests: + if _skip_check and scheduler._beneficial_to_skip( + req, + newly_contributed_context_blocks, + newly_contributed_cross_context_blocks, + ): + continue + + if num_scheduled >= _max_num: + break + + if req.is_context_init_state: + if not reserved_blocks.preview_reserve(req): + break + if _has_cross: + if not reserved_cross_blocks.preview_reserve(req): + break + + if has_peft: + lora_task_id, is_new_task, needed_peft_pages = ( + scheduler._get_peft_task_info(req, uniq_task_ids) + ) + if needed_peft_pages > available_peft_pages: + continue + available_peft_pages -= needed_peft_pages + if is_new_task: + uniq_task_ids.add(lora_task_id) + + if _has_tracker: + if not token_tracker.try_add_context(req): + break + _sched_append(req) + num_scheduled += 1 + reserved_blocks.commit_preview() + if _has_cross: + reserved_cross_blocks.commit_preview() + if _has_tracker and token_tracker.batch_full: + break + + return scheduled_requests, fitting_disagg, [] + + +class MaxUtilizationPolicy(SchedulerPolicyBase): + """ + MaxUtilizationScheduler: Maximize utilization, may pause started requests. + C++ reference: capacityScheduler.cpp:341-425 + """ + + def schedule( + self, + scheduler: "PyCapacityScheduler", + active_requests: RequestList, + token_tracker: Optional[TokenBudgetTracker] = None, + ) -> tuple[RequestList, RequestList, RequestList]: + scheduler.kv_cache_manager.start_scheduling() + + skipping_is_relevant = scheduler._is_skipping_relevant() + + scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( + scheduler.kv_cache_manager, scheduler.two_step_lookahead + ) + + num_scheduled_peft_pages = 0 + seen_task_ids: set[int] = set() + _max_peft_pages = scheduler._get_max_peft_pages() + + newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks(active_requests) + + # Cache hot-path locals + _until = scheduler._no_schedule_until_state_value + _after = scheduler._no_schedule_after_state_value + _max_num = scheduler.max_num_requests + _has_tracker = token_tracker is not None + # MaxUtilization doesn't pre-compute cross-context blocks (line 678 + # discards the second return). Use a reusable set cleared per + # iteration to avoid per-call allocation while matching the C++ + # semantics of not accumulating cross-context blocks across requests. + _cross_ctx_blocks: Set = set() + + def is_started_request(req: LlmRequest) -> bool: + sv = req.state_value + if not (_until <= sv < _after): + return False + return ( + req.is_context_init_state and not req.is_first_context_chunk + ) or req.is_generation_in_progress_state + + scheduled_requests: RequestList = [] + fitting_disagg: RequestList = [] + paused_requests: RequestList = [] + num_scheduled = 0 + + requests_list = list(active_requests) + req_it_end = len(requests_list) + req_it = 0 + + while req_it < req_it_end: + req = requests_list[req_it] + + # Inlined _can_be_scheduled_with_disagg_exception + if not req.is_disagg_generation_init_state: + sv = req.state_value + if not (_until <= sv < _after): + req_it += 1 + continue + + _cross_ctx_blocks.clear() + if skipping_is_relevant and scheduler._beneficial_to_skip( + req, newly_contributed_context_blocks, _cross_ctx_blocks + ): + req_it += 1 + continue + + result, num_scheduled_peft_pages = self._try_scheduling_request( + scheduler, + req, + scheduled_requests, + fitting_disagg, + scheduled_blocks_manager, + num_scheduled_peft_pages, + _max_peft_pages, + num_scheduled, + _max_num, + seen_task_ids, + token_tracker, + ) + + if result is True: + num_scheduled += 1 + if _has_tracker and token_tracker.batch_full: + break + req_it += 1 + elif result is None: + # Token budget exhausted — pausing won't help, stop. + break + else: + # Capacity failure — try pausing an older request. + last_started_idx = None + for i in range(req_it_end - 1, req_it - 1, -1): + if is_started_request(requests_list[i]): + last_started_idx = i + break + + if last_started_idx is not None: + paused_req = requests_list[last_started_idx] + scheduler.kv_cache_manager.scheduling_remove_sequence(paused_req.py_request_id) + paused_requests.append(paused_req) + # Don't decrement num_scheduled: the paused request is at + # index >= req_it (unprocessed), so it was never counted. + req_it_end = last_started_idx + else: + break + + return scheduled_requests, fitting_disagg, paused_requests + + def _try_scheduling_request( + self, + scheduler: "PyCapacityScheduler", + req: LlmRequest, + scheduled_requests: RequestList, + fitting_disagg: RequestList, + scheduled_blocks_manager: "MaxUtilizationScheduledBlocksManager", + num_scheduled_peft_pages: int, + max_peft_pages: int, + num_scheduled: int, + max_num_requests: int, + seen_task_ids: set[int], + token_tracker: Optional[TokenBudgetTracker] = None, + ) -> tuple[Optional[bool], int]: + """Try to schedule a request. + + Returns a tuple of (result, num_scheduled_peft_pages): + result is True: request scheduled successfully. + result is False: capacity failure (KV blocks, PEFT, max_num_requests) + — caller may pause an older request and retry. + result is None: token budget exhausted — caller should stop scheduling + (pausing won't help because it doesn't free token budget). + num_scheduled_peft_pages: updated running total of scheduled PEFT + pages (mirrors C++ pass-by-reference semantics in + capacityScheduler.cpp:429). + """ + if num_scheduled >= max_num_requests: + return False, num_scheduled_peft_pages + + if scheduled_blocks_manager.prepare_blocks_if_schedulable(req) is None: + return False, num_scheduled_peft_pages + + # PEFT check only when needed — compute required pages but do NOT + # commit (add to seen_task_ids / accumulate pages) until all checks + # pass. Matches C++ which commits atomically on success only. + _peft_lora_task_id = 0 + _peft_is_new_task = False + _peft_required_pages = 0 + if scheduler.peft_cache_manager is not None: + _peft_lora_task_id, _peft_is_new_task, _peft_required_pages = ( + scheduler._get_peft_task_info(req, seen_task_ids) + ) + if _peft_required_pages + num_scheduled_peft_pages > max_peft_pages: + return False, num_scheduled_peft_pages + + # Token budget check — return None (not False) so the caller + # does NOT enter the pause/backtrack path. Pausing frees KV + # blocks but not token budget, so retrying would fail again. + if token_tracker is not None: + if not token_tracker.try_add(req): + return None, num_scheduled_peft_pages + + # All checks passed — commit all state atomically. + scheduled_blocks_manager.update_scheduled_blocks() + if _peft_is_new_task: + seen_task_ids.add(_peft_lora_task_id) + num_scheduled_peft_pages += _peft_required_pages + if req.is_disagg_generation_init_state: + fitting_disagg.append(req) + else: + scheduled_requests.append(req) + return True, num_scheduled_peft_pages + + +class NoEvictScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. + Tracks available blocks per window size for GUARANTEED_NO_EVICT scheduling. + + Includes single-window fast path: when only one window size exists + (the common case), avoids dict iteration overhead. + + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:29-62 + """ + + def __init__(self, kv_cache_manager): + self.kv_cache_manager = kv_cache_manager + stats = kv_cache_manager.get_kv_cache_stats() + self.available_blocks: dict[int, int] = dict(stats.num_free_blocks_per_window_size) + self._preview_valid = False + # Single-window fast path: avoid dict iteration when only one window + if len(self.available_blocks) == 1: + ws, avail = next(iter(self.available_blocks.items())) + self._single_ws = ws + self._single_avail = avail + else: + self._single_ws = None + self._single_avail = 0 + + def batch_decrement_list(self, requests: RequestList) -> None: + """Batch-decrement blocks for a list of requests using C++ batch API. + + Uses get_remaining_blocks_to_completion_batch for a single Python→C++ + call instead of N individual calls. + """ + if not requests: + return + if self._single_ws is not None: + needed_list = self.kv_cache_manager.get_remaining_blocks_to_completion_batch( + requests, self._single_ws + ) + self._single_avail -= sum(needed_list) + else: + for ws in self.available_blocks: + needed_list = self.kv_cache_manager.get_remaining_blocks_to_completion_batch( + requests, ws + ) + self.available_blocks[ws] -= sum(needed_list) + + def sync_to_dict(self) -> None: + """Write single-window scalar back to dict. Call before dict is read.""" + if self._single_ws is not None: + self.available_blocks[self._single_ws] = self._single_avail + + def preview_reserve(self, req: LlmRequest) -> bool: + """Check if request fits (no mutation). Caches needed blocks for commit_preview. + + Call commit_preview() after all intermediate checks (PEFT, token) pass + to apply the cached decrement. This avoids a second C++ call. + """ + self._preview_valid = False + if self._single_ws is not None: + needed = self.kv_cache_manager.get_remaining_blocks_to_completion(req, self._single_ws) + if needed > self._single_avail: + return False + self._preview_needed_single = needed + self._preview_valid = True + return True + needed_per_ws = {} + for ws, avail in self.available_blocks.items(): + needed = self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) + if needed > avail: + return False + needed_per_ws[ws] = needed + self._preview_needed_multi = needed_per_ws + self._preview_valid = True + return True + + def commit_preview(self) -> None: + """Apply the cached decrement from the last preview_reserve call.""" + assert self._preview_valid, "commit_preview called without a successful preview_reserve" + self._preview_valid = False + if self._single_ws is not None: + self._single_avail -= self._preview_needed_single + return + for ws, needed in self._preview_needed_multi.items(): + self.available_blocks[ws] -= needed + + +class MaxUtilizationScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::MaxUtilizationScheduledBlocksManager. + Tracks scheduled blocks per window size for MAX_UTILIZATION scheduling. + + Includes single-window fast path for the common case. + + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:64-117 + """ + + def __init__(self, kv_cache_manager, two_steps_look_ahead: bool): + self.kv_cache_manager = kv_cache_manager + self.two_steps_look_ahead = two_steps_look_ahead + window_sizes = set(kv_cache_manager.max_attention_window_vec) + self.num_scheduled_blocks: dict[int, int] = {ws: 0 for ws in window_sizes} + self._pending_total: int = 0 + self._pending_blocks: dict[int, int] = {} + # Single-window fast path + if len(self.num_scheduled_blocks) == 1: + self._single_ws: Optional[int] = next(iter(self.num_scheduled_blocks)) + self._single_scheduled: int = 0 + else: + self._single_ws = None + self._single_scheduled = 0 + + def prepare_blocks_if_schedulable(self, req: LlmRequest) -> Optional[bool]: + """Check if request can be scheduled. Returns True or None (can't fit). + + For single-window: returns True and caches the scheduled_total + internally. Call update_scheduled_blocks() to commit. + For multi-window: returns a dict of new block counts (legacy). + """ + if self._single_ws is not None: + required = self.kv_cache_manager.get_needed_blocks_one_step( + req, self.two_steps_look_ahead, self._single_ws + ) + scheduled_total = self._single_scheduled + required + if not self.kv_cache_manager.scheduling_has_free_blocks( + scheduled_total, self._single_ws + ): + return None + self._pending_total = scheduled_total + return True + # Multi-window path + blocks_if_scheduled = {} + for window_size, num_scheduled in self.num_scheduled_blocks.items(): + required = self.kv_cache_manager.get_needed_blocks_one_step( + req, self.two_steps_look_ahead, window_size + ) + scheduled_total = num_scheduled + required + if not self.kv_cache_manager.scheduling_has_free_blocks(scheduled_total, window_size): + return None + blocks_if_scheduled[window_size] = scheduled_total + self._pending_blocks = blocks_if_scheduled + return True + + def update_scheduled_blocks(self) -> None: + """Commit the block counts from the last prepare_blocks_if_schedulable call.""" + if self._single_ws is not None: + self._single_scheduled = self._pending_total + self.num_scheduled_blocks[self._single_ws] = self._single_scheduled + return + for window_size, total in self._pending_blocks.items(): + self.num_scheduled_blocks[window_size] = total + + +class PyCapacityScheduler: + """KV cache capacity scheduler with optional fused token budget tracking. + + Python implementation based on the C++ CapacityScheduler. Core KV-block + logic follows cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp. + + Extension: accepts an optional TokenBudgetTracker via schedule_request() + to fuse token-budget checks into the capacity loop (single-pass + scheduling). When no tracker is provided, behaves identically to the + C++ implementation. + + Policies: + - MaxRequestsScheduler: No KV cache manager, simple request count limit + - GuaranteedNoEvictScheduler: Reserve blocks for completion, no eviction + - StaticBatchScheduler: Only schedule when no requests are active + - MaxUtilizationScheduler: Maximize utilization, may pause requests + """ + + def __init__( + self, + max_num_requests: int, + kv_cache_manager=None, + peft_cache_manager=None, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, + cross_kv_cache_manager=None, + two_step_lookahead: bool = False, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_COMPLETE, + ): + self.max_num_requests = max_num_requests + self.kv_cache_manager = kv_cache_manager + self.peft_cache_manager = peft_cache_manager + self.cross_kv_cache_manager = cross_kv_cache_manager + self.scheduler_policy = scheduler_policy + self.two_step_lookahead = two_step_lookahead + self.no_schedule_until_state = no_schedule_until_state + self.no_schedule_after_state = no_schedule_after_state + # Cache state values to avoid repeated .value access (optimization) + self._no_schedule_until_state_value = no_schedule_until_state.value + self._no_schedule_after_state_value = no_schedule_after_state.value + self._gen_in_progress_state_value = LlmRequestState.GENERATION_IN_PROGRESS.value + + # Initialize the appropriate policy + self._policy = self._create_policy() + + def _create_policy(self) -> SchedulerPolicyBase: + """Create the appropriate policy based on configuration.""" + if self.kv_cache_manager is None: + return MaxRequestsPolicy() + elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + return MaxUtilizationPolicy() + elif self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + return GuaranteedNoEvictPolicy(static_batch=False) + elif self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: + return GuaranteedNoEvictPolicy(static_batch=True) + else: + raise ValueError(f"Unsupported scheduler policy: {self.scheduler_policy}") + + def _can_be_scheduled(self, req: LlmRequest) -> bool: + """ + Check if request is within the schedulable state range. + Returns True if request has reached no_schedule_until_state + but has not yet reached no_schedule_after_state. + Optimized: use state_value property to avoid enum object creation + """ + # Use state_value property (returns int directly, avoids enum object creation) + state_value = req.state_value + # Inline comparison: must have reached until_state but not after_state + return ( + state_value >= self._no_schedule_until_state_value + and state_value < self._no_schedule_after_state_value + ) + + def _is_skipping_relevant(self) -> bool: + """ + Check if block reuse skip optimization is relevant. + Disabled for VSWA (Variable Sliding Window Attention). + C++ reference: capacityScheduler.cpp:207-208, 348 + """ + if self.kv_cache_manager is None: + return False + if self.kv_cache_manager.is_variable_window: + return False + if ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.is_variable_window + ): + return False + return True + + def _prefill_contributed_blocks(self, active_requests: RequestList) -> tuple[set, set]: + """ + Collect blocks contributed by chunked context requests already executing. + These blocks can be reused by later requests. + + C++ reference: capacityScheduler.cpp:34-68 (prefillWithChunkedContextsAlreadyExecuting) + """ + newly_contributed_context_blocks: Set = set() + newly_contributed_cross_context_blocks: Set = set() + + if self.kv_cache_manager is None: + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks + + enable_block_reuse = self.kv_cache_manager.enable_block_reuse + cross_enable_reuse = ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse + ) + + for req in active_requests: + # Check: isContextInitState() && !isFirstContextChunk() + if req.is_context_init_state and not req.is_first_context_chunk: + # Chunked context request already executing + if enable_block_reuse: + unique_tokens = req.get_unique_tokens(0) + block_key = self.kv_cache_manager.find_new_context_block(unique_tokens, req) + if block_key is not None: + newly_contributed_context_blocks.add(block_key) + + if cross_enable_reuse: + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + block_key = self.cross_kv_cache_manager.find_new_context_block( + encoder_unique_tokens, req + ) + if block_key is not None: + newly_contributed_cross_context_blocks.add(block_key) + + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks + + def _one_manager_beneficial_to_skip( + self, kv_cache_manager, unique_tokens, req: LlmRequest, newly_contributed_blocks: set + ) -> bool: + """ + Check if skipping is beneficial for one KV cache manager. + C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) + """ + new_context_block = kv_cache_manager.find_new_context_block(unique_tokens, req) + if new_context_block is not None: + if new_context_block in newly_contributed_blocks: + return True + newly_contributed_blocks.add(new_context_block) + return False + + def _beneficial_to_skip( + self, + req: LlmRequest, + newly_contributed_context_blocks: set, + newly_contributed_cross_context_blocks: set, + ) -> bool: + """ + Check if it's beneficial to skip this request. + A request should be skipped if it can reuse blocks contributed by + already scheduled context requests. + + C++ reference: capacityScheduler.cpp:97-123 (beneficialToSkip) + """ + if not (req.is_context_init_state and req.is_first_context_chunk): + return False + + if self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse: + unique_tokens = req.get_unique_tokens(0) + if self._one_manager_beneficial_to_skip( + self.kv_cache_manager, unique_tokens, req, newly_contributed_context_blocks + ): + return True + + if ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse + ): + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + if self._one_manager_beneficial_to_skip( + self.cross_kv_cache_manager, + encoder_unique_tokens, + req, + newly_contributed_cross_context_blocks, + ): + return True + + return False + + def _get_max_peft_pages(self) -> int: + """Get maximum PEFT cache pages.""" + if self.peft_cache_manager is None: + return 2**31 - 1 # INT_MAX equivalent + return self.peft_cache_manager.max_device_pages + + def _get_peft_pages_for_request(self, req: LlmRequest) -> int: + """Get PEFT pages needed for a request.""" + if self.peft_cache_manager is None: + return 0 + return self.peft_cache_manager.determine_num_pages(req) + + def _get_peft_task_info( + self, req: LlmRequest, seen_task_ids: set[int] + ) -> tuple[Optional[int], bool, int]: + """ + Get PEFT task information for a request. + Returns (lora_task_id, is_new_task, required_pages). + """ + lora_task_id = getattr(req, "lora_task_id", None) + is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids + required_pages = self._get_peft_pages_for_request(req) if is_new_task else 0 + return lora_task_id, is_new_task, required_pages + + def schedule_request( + self, + active_requests: RequestList, + token_tracker: Optional["TokenBudgetTracker"] = None, + ) -> tuple[RequestList, RequestList, RequestList]: + """ + Schedule requests based on the configured policy. + + Args: + active_requests: List of active requests to consider + token_tracker: If provided, fuses token-budget checks into + the capacity policy loop (single-pass scheduling). + + Returns: + Tuple of (fitting_requests, fitting_disagg_gen_init_requests, paused_requests) + + C++ reference: capacityScheduler.cpp:488-539 (CapacityScheduler::operator()) + """ + fitting_requests, fitting_disagg_gen_init_requests, paused = self._policy.schedule( + self, active_requests, token_tracker + ) + + logger.debug( + f"[Summary] Capacity scheduler allows {len(fitting_requests)} requests, " + f"pauses {len(paused)} requests" + ) + + return fitting_requests, fitting_disagg_gen_init_requests, paused + + +class UnifiedScheduler(RequestScheduler): + """Python-only scheduler — drop-in replacement for SimpleScheduler. + + Replaces the two-pass pipeline in SimpleScheduler (C++ bindings: + BindCapacityScheduler → BindMicroBatchScheduler) with a single-pass + fused approach. Gated by SchedulerConfig(use_python_scheduler=True). + + Implements the same schedule_request() interface as SimpleScheduler, + so py_executor.py uses the same code path for both schedulers. + + Key difference: capacity and token-budget checks run in one loop via + TokenBudgetTracker, instead of capacity first then microbatch second. + This eliminates one full iteration over fitting_requests and reduces + scheduler-side KV bookkeeping on requests that would be dropped by the + token budget. + """ + + def __init__( + self, + max_batch_size: int, + max_num_tokens: int, + kv_cache_manager, + peft_cache_manager, + scheduler_policy: CapacitySchedulerPolicy, + ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, + cross_kv_cache_manager=None, + two_step_lookahead: bool = False, + scheduler_capacity: Optional[int] = None, + schedule_step_config: Optional[ScheduleStepConfig] = None, + dist=None, + ): + super(UnifiedScheduler, self).__init__(schedule_step_config=schedule_step_config, dist=dist) + # Use scheduler_capacity if provided, otherwise fall back to max_batch_size + # scheduler_capacity may differ from max_batch_size (e.g., adjusted for attention_dp + disagg) + capacity = scheduler_capacity if scheduler_capacity is not None else max_batch_size + + self.capacity_scheduler = PyCapacityScheduler( + max_num_requests=capacity, + kv_cache_manager=kv_cache_manager, + peft_cache_manager=peft_cache_manager, + scheduler_policy=scheduler_policy, + cross_kv_cache_manager=cross_kv_cache_manager, + two_step_lookahead=two_step_lookahead, + ) + + self._max_batch_size = max_batch_size + self._max_num_tokens = max_num_tokens + self._ctx_chunk_config = None + if ctx_chunk_config: + input_policy = ctx_chunk_config[0] + if "EQUAL_PROGRESS" in str(input_policy): + policy_enum = ChunkingPolicy.EQUAL_PROGRESS + else: + policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED + self._ctx_chunk_config = ContextChunkingConfig(policy_enum, ctx_chunk_config[1]) + + def schedule_request( + self, active_requests: RequestList, inflight_request_ids: object = None + ) -> SchedulerOutput: + """Single-pass fused capacity + token budget scheduling. + + A TokenBudgetTracker is passed into the capacity policy loop so that + each request admission check includes both KV-block and token-budget + gates simultaneously. The tracker classifies requests into + context/generation and handles chunking/sorting. + """ + tracker = TokenBudgetTracker( + max_batch_size=self._max_batch_size, + max_num_tokens=self._max_num_tokens, + ctx_chunk_config=self._ctx_chunk_config, + inflight_request_ids=inflight_request_ids, + ) + _, fitting_disagg_gen_init, paused_requests = self.capacity_scheduler.schedule_request( + active_requests, tracker + ) + # num_fitting reflects requests passing both capacity AND token budget. + context_requests, generation_requests, num_fitting = tracker.finalize() + return SchedulerOutput( + context_requests=context_requests, + generation_requests=generation_requests, + paused_requests=paused_requests, + fitting_disagg_gen_init_requests=fitting_disagg_gen_init, + num_fitting_requests=num_fitting, + ) + + def can_schedule(self, requests: RequestList) -> bool: + fitting, _, _ = self.capacity_scheduler.schedule_request(requests) + return len(fitting) == len(requests) diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py index f35db02471b5..ea999487044d 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py @@ -3,10 +3,9 @@ from collections.abc import Iterable, Iterator from typing import Callable, Optional +from tensorrt_llm._torch.pyexecutor.executor_request_queue import RequestQueueItem from tensorrt_llm.llmapi.llm_args import WaitingQueuePolicy -from ..executor_request_queue import RequestQueueItem - class WaitingQueue(ABC): """Abstract base class for waiting queues.""" diff --git a/tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py b/tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py index 847bd361cc10..0856d2d8648a 100644 --- a/tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py +++ b/tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py @@ -44,6 +44,11 @@ # Format: "module.Class.method,module.Class.method2,..." LINE_PROFILER_FUNCTIONS_ENV_VAR = "TLLM_LINE_PROFILER_FUNCTIONS" +# Environment variable to disable default profile targets. +# When set (to any value), default targets are not loaded; +# only explicitly specified targets (via TLLM_LINE_PROFILER_FUNCTIONS or API) are used. +LINE_PROFILER_NO_DEFAULTS_ENV_VAR = "TLLM_LINE_PROFILER_NO_DEFAULTS" + @dataclass class ProfileTarget: @@ -152,10 +157,6 @@ def resolve(self) -> Optional[Callable]: "_select_generated_logits", "_sample_batched_by_strategy", ], - # Standalone module-level functions (use None as class_name) - None: [ - "_group_requests_by_strategy_key", - ], }, f"{_PYEXEC}.resource_manager": { "ResourceManager": ["prepare_resources", "update_resources", "free_resources"], @@ -164,15 +165,6 @@ def resolve(self) -> Optional[Callable]: f"{_PYEXEC}.scheduler": { "RequestScheduler": ["schedule_request"], }, - f"{_PYEXEC}.executor_request_queue": { - "ExecutorRequestQueue": [ - "_fetch_new_requests_attention_tp", - "_fetch_new_requests_attention_dp", - "_fetch_and_process_requests", - "_merge_requests", - "fetch_new_requests", - ], - }, } @@ -401,6 +393,10 @@ def __init__( self._line_profiler = None self._enabled = False + # When LINE_PROFILER_NO_DEFAULTS_ENV_VAR is set, disable default targets + if os.environ.get(LINE_PROFILER_NO_DEFAULTS_ENV_VAR): + use_defaults = False + # Add default targets if requested if use_defaults: self.targets.extend(DEFAULT_PROFILE_TARGETS) diff --git a/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py index a2488b8a1474..b914b6f4c87f 100644 --- a/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py +++ b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py @@ -252,6 +252,17 @@ def test_max_num_tokens_none(self): sched = make_scheduler(mgr, max_num_tokens=None) assert sched.max_num_tokens is None + def test_schedule_step_returns_step_result(self): + mgr = make_kv_cache_manager() + sched = make_scheduler(mgr) + req = make_ctx_request(0, context_remaining_length=16) + + step_result = sched.schedule_step([req], set()) + + assert ids(step_result.scheduled_requests.context_requests) == [0] + assert ids(step_result.scheduled_requests.generation_requests) == [] + assert step_result.num_fitting_requests == 1 + # =========================================================================== # Token Budget Limits diff --git a/tests/unittest/_torch/executor/test_py_scheduler.py b/tests/unittest/_torch/executor/test_py_scheduler.py index 493a92a1cb9f..aa695a56d9dc 100644 --- a/tests/unittest/_torch/executor/test_py_scheduler.py +++ b/tests/unittest/_torch/executor/test_py_scheduler.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Unit tests for Python scheduler implementations (PyMicroBatchScheduler, -PyCapacityScheduler, SimpleUnifiedScheduler). +Unit tests for Python scheduler implementations (TokenBudgetTracker, +PyCapacityScheduler, UnifiedScheduler). These tests validate the pure-Python scheduler logic using real LlmRequest objects (from C++ bindings) and mock KV cache managers, without requiring @@ -25,18 +25,55 @@ from dataclasses import dataclass, field from typing import List, Optional +from unittest.mock import Mock from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig from tensorrt_llm._torch.pyexecutor.scheduler.scheduler import ( + RequestScheduler, + SchedulerOutput, + ScheduleStepConfig, +) +from tensorrt_llm._torch.pyexecutor.scheduler.simple_scheduler import SimpleScheduler +from tensorrt_llm._torch.pyexecutor.scheduler.unified_scheduler import ( ChunkingPolicy, ContextChunkingConfig, PyCapacityScheduler, - PyMicroBatchScheduler, - SimpleUnifiedScheduler, + TokenBudgetTracker, + UnifiedScheduler, ) from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy +def _schedule_with_tracker( + max_batch_size: int, + max_num_tokens: Optional[int], + requests: list, + inflight_ids: set, + ctx_chunk_config=None, + max_context_length: Optional[int] = None, +) -> tuple[list, list]: + """Helper that mimics the old PyMicroBatchScheduler.schedule() interface. + + Creates a TokenBudgetTracker, adds each request via try_add(), then + finalizes to get (context_requests, generation_requests). + """ + tracker = TokenBudgetTracker( + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + ctx_chunk_config=ctx_chunk_config, + inflight_request_ids=inflight_ids if inflight_ids else None, + ) + if max_context_length is not None: + tracker.max_context_length = max_context_length + for req in requests: + if not tracker.try_add(req): + break + if tracker.batch_full: + break + ctx, gen, _ = tracker.finalize() + return ctx, gen + + def _make_request( request_id: int, prompt_len: int = 10, @@ -174,6 +211,9 @@ def get_max_resource_count(self) -> int: def get_needed_resource_to_completion(self, req) -> int: return self._blocks_per_request + def get_remaining_blocks_to_completion_batch(self, requests, window_size: int) -> list: + return [self._blocks_per_request for _ in requests] + class MockPeftCacheManager: def __init__(self, max_pages: int = 100, pages_per_request: int = 10): @@ -184,28 +224,58 @@ def determine_num_pages(self, req) -> int: return self._pages_per_request +class _FakeScheduler(RequestScheduler): + def __init__(self, output, schedule_step_config=None, dist=None): + super().__init__(schedule_step_config=schedule_step_config, dist=dist) + self._output = output + + def schedule_request(self, active_requests, inflight_request_ids): + return self._output + + def can_schedule(self, requests): + return True + + +class _StubCapacityScheduler: + def __init__(self, fitting_requests, disagg_requests=None, paused_requests=None): + self._fitting_requests = fitting_requests + self._disagg_requests = disagg_requests or [] + self._paused_requests = paused_requests or [] + + def schedule_request(self, active_requests): + return self._fitting_requests, self._disagg_requests, self._paused_requests + + +class _StubMicroBatchScheduler: + def __init__(self, context_requests, generation_requests): + self._context_requests = context_requests + self._generation_requests = generation_requests + + def schedule(self, active_requests, inflight_request_ids): + return self._context_requests, self._generation_requests + + # ############################################################################ # -# Part 1: PyMicroBatchScheduler Tests +# Part 1: TokenBudgetTracker Tests (formerly PyMicroBatchScheduler) # # ############################################################################ -class TestPyMicroBatchSchedulerBasic: +class TestTokenBudgetTrackerBasic: """ - Tests for PyMicroBatchScheduler — single-step scheduling decisions. + Tests for TokenBudgetTracker — single-step scheduling decisions. Aligned with C++ MicroBatchSchedulerTest in microBatchSchedulerTest.cpp. """ def test_simple_context_only(self): """All requests are context requests, batch size allows 2.""" - scheduler = PyMicroBatchScheduler(max_batch_size=2, max_num_tokens=None) requests = [ make_context_request(0, prompt_len=10), make_context_request(1, prompt_len=10), make_context_request(2, prompt_len=10), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) assert len(ctx) == 2 assert len(gen) == 0 assert ctx[0].request_id == 0 @@ -213,13 +283,12 @@ def test_simple_context_only(self): def test_simple_generation_only(self): """All requests are generation requests, batch size allows 2.""" - scheduler = PyMicroBatchScheduler(max_batch_size=2, max_num_tokens=None) requests = [ make_generation_request(0), make_generation_request(1), make_generation_request(2), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) assert len(ctx) == 0 assert len(gen) == 2 assert gen[0].request_id == 0 @@ -230,14 +299,13 @@ def test_context_generation_overlap(self): Mixed batch: context + generation requests. C++ ref: SimpleWithOverlap """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=None) requests = [ make_context_request(0, prompt_len=10), make_generation_request(1), make_context_request(2, prompt_len=10), make_generation_request(3), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, None, requests, set()) assert len(ctx) == 2 assert len(gen) == 2 assert {r.request_id for r in ctx} == {0, 2} @@ -248,13 +316,12 @@ def test_max_num_tokens_limits_context(self): max_num_tokens limits how many context tokens can be scheduled. C++ ref: SimpleNoOverlapMaxNumTokens """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=15) # Each context request has 10 tokens. Two would be 20 > 15. requests = [ make_context_request(0, prompt_len=10), make_context_request(1, prompt_len=10), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 15, requests, set()) # Only 1 fits within token budget assert len(ctx) == 1 assert ctx[0].request_id == 0 @@ -264,26 +331,24 @@ def test_max_num_tokens_allows_gen_after_context(self): After scheduling a context request, generation requests still fit if their token count (beam_width) fits in remaining budget. """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=12) requests = [ make_context_request(0, prompt_len=10), make_generation_request(1, beam_width=1), make_generation_request(2, beam_width=1), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 12, requests, set()) # context: 10 tokens, gen1: 1 token, gen2: 1 token => total 12 assert len(ctx) == 1 assert len(gen) == 2 def test_max_batch_size_limits_total(self): """Batch size limits total (context + generation).""" - scheduler = PyMicroBatchScheduler(max_batch_size=2, max_num_tokens=None) requests = [ make_context_request(0, prompt_len=5), make_generation_request(1), make_generation_request(2), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) # batch_size=2: should schedule context_0 + gen_1 assert len(ctx) + len(gen) == 2 @@ -292,14 +357,13 @@ def test_beam_width_1(self): Generation requests with beam_width=1 each cost 1 token. C++ ref: SimpleMaxNumTokensBW1 """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=12) requests = [ make_context_request(0, prompt_len=10, beam_width=1), make_generation_request(1, beam_width=1), make_generation_request(2, beam_width=1), make_generation_request(3, beam_width=1), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 12, requests, set()) # context: 10, gen: 1+1 = 12 total. Can't fit gen_3 (would be 13). assert len(ctx) == 1 assert len(gen) == 2 @@ -309,13 +373,12 @@ def test_beam_width_4(self): Generation requests with beam_width=4 each cost 4 tokens. C++ ref: SimpleMaxNumTokensBW4 """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=15) requests = [ make_context_request(0, prompt_len=10, beam_width=4), make_generation_request(1, beam_width=4), make_generation_request(2, beam_width=4), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 15, requests, set()) # context: 10, gen1: 4 = 14. gen2: +4 = 18 > 15. assert len(ctx) == 1 assert len(gen) == 1 @@ -325,13 +388,12 @@ def test_beam_width_mismatch_skipped(self): Generation requests with different beam widths are skipped. C++ ensures all gen requests in a batch have same beam_width. """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=None) requests = [ make_generation_request(0, beam_width=1), make_generation_request(1, beam_width=4), make_generation_request(2, beam_width=1), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, None, requests, set()) # gen_0 sets beam_width=1, gen_1 is skipped (beam_width=4), gen_2 fits assert len(gen) == 2 assert gen[0].request_id == 0 @@ -342,13 +404,12 @@ def test_draft_tokens_count_toward_budget(self): Draft tokens are added to the token count for both context and gen. C++ ref: DraftTokensMaxNumTokens """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=15) # Context request: 10 prompt + 3 draft = 13 tokens requests = [ make_context_request(0, prompt_len=10, draft_tokens_len=3), make_generation_request(1, draft_tokens_len=2), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 15, requests, set()) # context: 10+3=13, gen: 1+2=3, total=16 > 15 => only context fits assert len(ctx) == 1 assert len(gen) == 0 @@ -358,25 +419,23 @@ def test_gen_draft_tokens(self): Generation with draft tokens: cost = beam_width + num_draft_tokens. C++ ref: GenDraftTokensMaxNumTokens """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=10) requests = [ make_generation_request(0, beam_width=1, draft_tokens_len=3), make_generation_request(1, beam_width=1, draft_tokens_len=3), make_generation_request(2, beam_width=1, draft_tokens_len=3), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 10, requests, set()) # Each gen costs 1+3=4. Two fit (8), three don't (12 > 10). assert len(gen) == 2 def test_inflight_requests_excluded(self): """Requests already in flight are skipped.""" - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=None) requests = [ make_context_request(0, prompt_len=10), make_context_request(1, prompt_len=10), make_generation_request(2), ] - ctx, gen = scheduler.schedule(requests, {0, 2}) + ctx, gen = _schedule_with_tracker(4, None, requests, {0, 2}) # Only request 1 is not in flight assert len(ctx) == 1 assert ctx[0].request_id == 1 @@ -384,13 +443,12 @@ def test_inflight_requests_excluded(self): def test_completed_requests_filtered(self): """Requests in GENERATION_COMPLETE state are filtered out.""" - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=None) requests = [ make_context_request(0, prompt_len=10), make_completed_request(1), make_generation_request(2), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, None, requests, set()) # Completed request 1 is filtered by state gating assert len(ctx) == 1 assert len(gen) == 1 @@ -402,8 +460,6 @@ def test_simple_no_overlap(self): C++ ref: SimpleNoOverlap (multi-iteration; here we test single-step scheduling decisions that compose the same behavior). """ - scheduler = PyMicroBatchScheduler(max_batch_size=2, max_num_tokens=None) - # Step 1: 4 context requests, only 2 fit requests = [ make_context_request(0, prompt_len=10), @@ -411,7 +467,7 @@ def test_simple_no_overlap(self): make_context_request(2, prompt_len=10), make_context_request(3, prompt_len=10), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) assert len(ctx) == 2 assert len(gen) == 0 assert ctx[0].request_id == 0 @@ -425,7 +481,7 @@ def test_simple_no_overlap(self): make_context_request(2, prompt_len=10), make_context_request(3, prompt_len=10), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) assert len(gen) == 2 assert gen[0].request_id == 0 assert gen[1].request_id == 1 @@ -437,7 +493,7 @@ def test_simple_no_overlap(self): make_context_request(2, prompt_len=10), make_context_request(3, prompt_len=10), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, None, requests, set()) assert len(ctx) == 2 assert ctx[0].request_id == 2 assert ctx[1].request_id == 3 @@ -449,15 +505,12 @@ def test_simple_no_overlap_max_num_tokens(self): Req 0, 1: promptLen=12, maxNewTokens=5, maxNumTokens=7, chunkUnitSize=5 """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=2, max_num_tokens=7, ctx_chunk_config=config - ) # Step 1 (it=0): Only req0 gets a chunk of 5, req1 doesn't fit # C++: Req 0: (0,1,2,3,4), Req 1: () r0 = make_context_request(0, prompt_len=12) r1 = make_context_request(1, prompt_len=12) - ctx, gen = scheduler.schedule([r0, r1], set()) + ctx, gen = _schedule_with_tracker(2, 7, [r0, r1], set(), ctx_chunk_config=config) assert len(ctx) >= 1 # First request gets a chunk within budget req0 = next(r for r in ctx if r.request_id == 0) @@ -472,16 +525,13 @@ def test_simple_no_overlap_max_context_length(self): Requests with promptLen=10 and 17, maxContextLength=12, chunkUnitSize=5. """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=2, max_num_tokens=None, ctx_chunk_config=config - ) - # Override max_context_length (in C++ this is a separate constructor arg) - scheduler.max_context_length = 12 # Two requests with promptLen=10 fit within maxContextLength=12 r0 = make_context_request(0, prompt_len=10) r1 = make_context_request(1, prompt_len=10) - ctx, gen = scheduler.schedule([r0, r1], set()) + ctx, gen = _schedule_with_tracker( + 2, None, [r0, r1], set(), ctx_chunk_config=config, max_context_length=12 + ) assert len(ctx) == 2 # Each chunk should be at most max_context_length for r in ctx: @@ -489,7 +539,9 @@ def test_simple_no_overlap_max_context_length(self): # Request with promptLen=17 needs chunking (17 > 12) r3 = make_context_request(3, prompt_len=17) - ctx2, gen2 = scheduler.schedule([r3], set()) + ctx2, gen2 = _schedule_with_tracker( + 2, None, [r3], set(), ctx_chunk_config=config, max_context_length=12 + ) assert len(ctx2) == 1 assert ctx2[0].context_chunk_size <= 12 @@ -501,9 +553,9 @@ def test_simple_no_overlap_max_context_length(self): # ############################################################################ -class TestPyMicroBatchSchedulerChunking: +class TestTokenBudgetTrackerChunking: """ - Tests for context chunking logic in PyMicroBatchScheduler. + Tests for context chunking logic in TokenBudgetTracker. Aligned with C++ ContextChunkingTest in microBatchSchedulerTest.cpp. """ @@ -515,14 +567,11 @@ def test_equal_progress_basic(self): C++ ref: ContextChunkingTest with EQUAL_PROGRESS """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=10, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=20), make_context_request(1, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 10, requests, set(), ctx_chunk_config=config) assert len(ctx) == 2 # Each should get ~5 tokens (equal progress, unit=5, total=10) total_chunk = sum(r.context_chunk_size for r in ctx) @@ -535,14 +584,11 @@ def test_equal_progress_uneven_remaining(self): After chunking, sort puts not-last-chunk requests first. """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=15, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=3), # Only 3 tokens remaining make_context_request(1, prompt_len=20), # Lots remaining ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 15, requests, set(), ctx_chunk_config=config) assert len(ctx) == 2 # Look up by request_id since sort reorders (not-last-chunk first) req0 = next(r for r in ctx if r.request_id == 0) @@ -557,14 +603,11 @@ def test_fcfs_basic(self): FIRST_COME_FIRST_SERVED: first request gets as much as possible. """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=12, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=20), make_context_request(1, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 12, requests, set(), ctx_chunk_config=config) # FCFS: request 0 gets up to budget, request 1 gets remainder assert len(ctx) >= 1 # First request should get more tokens @@ -574,14 +617,11 @@ def test_fcfs_fills_first_request(self): """FCFS fills the first request completely if budget allows. After chunking, sort puts not-last-chunk requests first.""" config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=25, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=10), make_context_request(1, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 25, requests, set(), ctx_chunk_config=config) assert len(ctx) == 2 # Look up by request_id since sort reorders (not-last-chunk first) req0 = next(r for r in ctx if r.request_id == 0) @@ -597,15 +637,12 @@ def test_chunk_with_generation(self): Generation tokens reduce the available budget for context chunks. """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=15, ctx_chunk_config=config - ) requests = [ make_generation_request(0), # costs 1 token make_context_request(1, prompt_len=20), make_context_request(2, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 15, requests, set(), ctx_chunk_config=config) assert len(gen) == 1 # Remaining budget for context: 15 - 1 = 14 total_ctx_tokens = sum(r.context_chunk_size for r in ctx) @@ -617,14 +654,11 @@ def test_chunk_size_zero_not_scheduled(self): the scheduled context requests. """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=2, max_num_tokens=5, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=20), make_context_request(1, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(2, 5, requests, set(), ctx_chunk_config=config) # With budget 5, at most one request gets chunk_size=5, the other might get 0 for r in ctx: assert r.context_chunk_size > 0 @@ -635,13 +669,10 @@ def test_chunking_with_max_context_length(self): C++ ref: SimpleNoOverlapMaxContextLength """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=12, ctx_chunk_config=config - ) requests = [ make_context_request(0, prompt_len=20), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 12, requests, set(), ctx_chunk_config=config) assert len(ctx) == 1 # max_context_length = max_num_tokens = 12, so chunk <= 12 assert ctx[0].context_chunk_size <= 12 @@ -652,12 +683,9 @@ def test_continued_chunking(self): (context_position > 0) continues from where it left off. """ config = ContextChunkingConfig(ChunkingPolicy.EQUAL_PROGRESS, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=10, ctx_chunk_config=config - ) req = make_context_request(0, prompt_len=20, context_position=10) # remaining = 20 - 10 = 10 - ctx, gen = scheduler.schedule([req], set()) + ctx, gen = _schedule_with_tracker(4, 10, [req], set(), ctx_chunk_config=config) assert len(ctx) == 1 assert ctx[0].context_chunk_size <= 10 # remaining context @@ -668,13 +696,10 @@ def test_last_chunk_allows_draft_tokens(self): C++ ref: DraftTokensNoDiscard """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=10) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=20, ctx_chunk_config=config - ) # prompt_len=8, so chunk_size will be 8. Unit=10, remainder=2. # Draft tokens=2 fits in remainder. req = make_context_request(0, prompt_len=8, draft_tokens_len=2) - ctx, gen = scheduler.schedule([req], set()) + ctx, gen = _schedule_with_tracker(4, 20, [req], set(), ctx_chunk_config=config) assert len(ctx) == 1 assert req.is_last_context_chunk @@ -684,12 +709,9 @@ def test_draft_tokens_discarded_when_no_space(self): C++ ref: DraftTokensDiscard """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=5) - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=20, ctx_chunk_config=config - ) # prompt_len=5, chunk_size=5, unit=5, remainder=0. Draft=3 won't fit. req = make_context_request(0, prompt_len=5, draft_tokens_len=3) - ctx, gen = scheduler.schedule([req], set()) + ctx, gen = _schedule_with_tracker(4, 20, [req], set(), ctx_chunk_config=config) assert len(ctx) == 1 def test_chunked_context_draft_tokens_max_num_tokens(self): @@ -702,13 +724,8 @@ def test_chunked_context_draft_tokens_max_num_tokens(self): Each request's draft reduced from 8 to 7. """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=64) - scheduler = PyMicroBatchScheduler( - max_batch_size=64, - max_num_tokens=8192, - ctx_chunk_config=config, - ) requests = [make_context_request(i, prompt_len=2041, draft_tokens_len=8) for i in range(4)] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(64, 8192, requests, set(), ctx_chunk_config=config) assert len(ctx) == 4 for req in ctx: assert req.num_draft_tokens == 7 @@ -724,28 +741,21 @@ def test_chunked_context_draft_tokens_max_context_length(self): Draft reduced from 5 to 4. """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=64) - scheduler = PyMicroBatchScheduler( - max_batch_size=64, - max_num_tokens=8192, - ctx_chunk_config=config, - ) - scheduler.max_context_length = 10 requests = [ make_context_request(0, prompt_len=6, draft_tokens_len=5), make_context_request(1, prompt_len=6, draft_tokens_len=5), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker( + 64, 8192, requests, set(), ctx_chunk_config=config, max_context_length=10 + ) assert len(ctx) == 2 for req in ctx: assert req.num_draft_tokens == 4 def test_no_chunking_context_fits(self): """Without chunking, context is scheduled in full if it fits.""" - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=20, ctx_chunk_config=None - ) req = make_context_request(0, prompt_len=15) - ctx, gen = scheduler.schedule([req], set()) + ctx, gen = _schedule_with_tracker(4, 20, [req], set()) assert len(ctx) == 1 def test_no_chunking_context_exceeds_budget(self): @@ -753,14 +763,11 @@ def test_no_chunking_context_exceeds_budget(self): Each individual request must fit within max_context_length (== max_num_tokens), but the cumulative token count is checked against the budget. The first request that would push the total over the limit breaks the loop.""" - scheduler = PyMicroBatchScheduler( - max_batch_size=4, max_num_tokens=10, ctx_chunk_config=None - ) requests = [ make_context_request(0, prompt_len=8), make_context_request(1, prompt_len=8), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker(4, 10, requests, set()) # First request (8) fits (8 <= 10). Second (8+8=16 > 10) breaks the loop. assert len(ctx) == 1 assert ctx[0].request_id == 0 @@ -770,11 +777,10 @@ def test_sort_by_lora_task_id(self): Requests are sorted by lora_task_id for performance. C++ ref: sortRequests in inflightBatchingUtils.cpp """ - scheduler = PyMicroBatchScheduler(max_batch_size=4, max_num_tokens=None) r0 = _make_request(0, state=LlmRequestState.GENERATION_IN_PROGRESS, lora_task_id=5) r1 = _make_request(1, state=LlmRequestState.GENERATION_IN_PROGRESS) r2 = _make_request(2, state=LlmRequestState.GENERATION_IN_PROGRESS, lora_task_id=3) - ctx, gen = scheduler.schedule([r0, r1, r2], set()) + ctx, gen = _schedule_with_tracker(4, None, [r0, r1, r2], set()) # None < any value, so order should be: r1(None), r2(3), r0(5) assert gen[0].request_id == 1 assert gen[1].request_id == 2 @@ -804,11 +810,11 @@ def _run_context_chunking_test( For each policy (EQUAL_PROGRESS and FCFS), it: 1. Creates LlmRequests with given context_lengths and optional draft_lengths. - 2. Creates a PyMicroBatchScheduler with the right ContextChunkingConfig. - 3. If max_context_length is set, overrides scheduler.max_context_length. + 2. Creates a TokenBudgetTracker with the right ContextChunkingConfig. + 3. If max_context_length is set, overrides tracker.max_context_length. 4. For each iteration (each element in positions list): a. Filters requests where context_remaining_length > 0. - b. Calls scheduler._set_ctx_requests_chunk_size(active_reqs, ctx_tokens_capacity). + b. Calls tracker._set_ctx_requests_chunk_size(active_reqs, ctx_tokens_capacity). c. For each active req, calls req.move_to_next_context_chunk(). d. Verifies context position matches expected positions for ALL requests. 5. After all iterations, verifies final draft_lengths if specified. @@ -830,23 +836,23 @@ def _run_context_chunking_test( ) requests.append(req) - # Create scheduler + # Create tracker config = ContextChunkingConfig(policy, chunk_unit_size=chunk_unit_size) - scheduler = PyMicroBatchScheduler( + tracker = TokenBudgetTracker( max_batch_size=64, max_num_tokens=1000, # large enough not to limit ctx_chunk_config=config, ) if max_context_length is not None: - scheduler.max_context_length = max_context_length + tracker.max_context_length = max_context_length # Run iterations for iteration_idx, expected_positions in enumerate(positions_list): # Filter active requests (those with remaining context) active_reqs = [r for r in requests if r.context_remaining_length > 0] - scheduler._set_ctx_requests_chunk_size(active_reqs, ctx_tokens_capacity) + tracker._set_ctx_requests_chunk_size(active_reqs, ctx_tokens_capacity) # Move each active request to next chunk for req in active_reqs: @@ -1120,12 +1126,6 @@ def test_draft_tokens_greater_than_chunk_size(self): - Request 2: draftTokens = 5 (remaining budget) """ config = ContextChunkingConfig(ChunkingPolicy.FIRST_COME_FIRST_SERVED, chunk_unit_size=16) - scheduler = PyMicroBatchScheduler( - max_batch_size=64, - max_num_tokens=40, - ctx_chunk_config=config, - ) - scheduler.max_context_length = 64 requests = [ make_context_request(0, prompt_len=3, draft_tokens_len=17), @@ -1133,7 +1133,9 @@ def test_draft_tokens_greater_than_chunk_size(self): make_context_request(2, prompt_len=3, draft_tokens_len=17), ] - ctx, gen = scheduler.schedule(requests, set()) + ctx, gen = _schedule_with_tracker( + 64, 40, requests, set(), ctx_chunk_config=config, max_context_length=64 + ) assert len(ctx) == 3 req0 = next(r for r in ctx if r.request_id == 0) @@ -1464,7 +1466,7 @@ def test_generation_to_complete_scheduled(self): # ############################################################################ # -# Part 5: PyCapacityScheduler Advanced Tests +# Part 5: PyCapacityScheduler Advanced Tests (LoRA, Priority, Chunked, etc.) # # ############################################################################ @@ -1512,23 +1514,101 @@ def test_lora_doesnt_fit(self): # ############################################################################ # -# Part 6: SimpleUnifiedScheduler Integration Tests +# Part 6: schedule_step Integration Tests # # ############################################################################ -class TestSimpleUnifiedScheduler: +class TestScheduleStep: + def test_request_scheduler_schedule_step_passthrough(self): + ctx = make_context_request(0, prompt_len=10) + gen = make_generation_request(1) + scheduler = _FakeScheduler( + SchedulerOutput([ctx], [gen], [], [], 2), + schedule_step_config=ScheduleStepConfig(), + ) + + step_result = scheduler.schedule_step([ctx, gen], set()) + + assert step_result.scheduled_requests.context_requests == [ctx] + assert step_result.scheduled_requests.generation_requests == [gen] + assert step_result.fitting_disagg_gen_init_requests == [] + assert step_result.num_fitting_requests == 2 + + def test_request_scheduler_schedule_step_applies_batch_waiting(self): + ctx = make_context_request(0, prompt_len=10) + gen = make_generation_request(1) + scheduler = _FakeScheduler( + SchedulerOutput([ctx], [gen], [], [], 2), + schedule_step_config=ScheduleStepConfig( + batch_wait_timeout_iters=1, + batch_wait_max_tokens_ratio=0.5, + max_num_tokens=100, + ), + ) + + step_result = scheduler.schedule_step([ctx, gen], set()) + + assert step_result.scheduled_requests.context_requests == [] + assert step_result.scheduled_requests.generation_requests == [gen] + + def test_simple_scheduler_schedule_step_uses_base_flow(self): + ctx = make_context_request(0, prompt_len=10) + gen = make_generation_request(1) + scheduler = SimpleScheduler( + _StubCapacityScheduler([ctx, gen]), + _StubMicroBatchScheduler([ctx], [gen]), + schedule_step_config=ScheduleStepConfig(), + ) + + step_result = scheduler.schedule_step([ctx, gen], set()) + + assert step_result.scheduled_requests.context_requests == [ctx] + assert step_result.scheduled_requests.generation_requests == [gen] + assert step_result.num_fitting_requests == 2 + + def test_request_scheduler_schedule_step_applies_adp_balance(self): + ctx = make_context_request(0, prompt_len=10) + gen = make_generation_request(1) + dist = Mock() + dist.tp_allgather.return_value = [[1, 1], [1, 1]] + scheduler = _FakeScheduler( + SchedulerOutput([ctx], [gen], [], [], 2), + schedule_step_config=ScheduleStepConfig( + enable_attention_dp=True, + attention_dp_enable_balance=True, + attention_dp_batching_wait_iters=1, + max_batch_size=4, + ), + dist=dist, + ) + + step_result = scheduler.schedule_step([ctx, gen], set()) + + dist.tp_allgather.assert_called_once_with([1, 1]) + assert step_result.scheduled_requests.context_requests == [] + assert step_result.scheduled_requests.generation_requests == [gen] + + +# ############################################################################ +# +# Part 7: UnifiedScheduler Integration Tests +# +# ############################################################################ + + +class TestUnifiedScheduler: """ - Tests for the two-stage scheduling pipeline: - PyCapacityScheduler → PyMicroBatchScheduler + Tests for the fused scheduling pipeline: + PyCapacityScheduler + TokenBudgetTracker (single-pass via UnifiedScheduler) """ def test_capacity_then_microbatch(self): - """Capacity filters, then microbatch selects within token budget. + """Capacity filters, then token budget selects within token budget. max_batch_size is used as max_num_requests for capacity scheduler, so it must be large enough for all requests to pass capacity.""" kv = MockKVCacheManager(num_free_blocks=100, blocks_per_request=5) - scheduler = SimpleUnifiedScheduler( + scheduler = UnifiedScheduler( max_batch_size=4, max_num_tokens=15, kv_cache_manager=kv, @@ -1541,15 +1621,16 @@ def test_capacity_then_microbatch(self): make_generation_request(2), ] output = scheduler.schedule_request(requests, set()) - # Capacity: all 3 fit (plenty of blocks, max_num_requests=4) - # Microbatch: gen_2 (1) + context_0 (10) = 11 <= 15, context_1 (10) would be 21 > 15 - assert output.num_fitting_requests == 3 - assert len(output.context_requests) + len(output.generation_requests) <= 2 + # UnifiedScheduler fuses capacity + token budget in a single pass. + # num_fitting_requests counts requests passing BOTH gates (see docstring item 3). + # gen_2 (1) + context_0 (10) = 11 <= 15, context_1 (10) would be 21 > 15 + assert output.num_fitting_requests == 2 + assert len(output.context_requests) + len(output.generation_requests) == 2 def test_can_schedule_dry_run(self): """can_schedule() checks capacity without side effects.""" kv = MockKVCacheManager(num_free_blocks=100, blocks_per_request=5) - scheduler = SimpleUnifiedScheduler( + scheduler = UnifiedScheduler( max_batch_size=4, max_num_tokens=100, kv_cache_manager=kv, @@ -1565,7 +1646,7 @@ def test_can_schedule_dry_run(self): def test_can_schedule_returns_false(self): """can_schedule() returns False when capacity is insufficient.""" kv = MockKVCacheManager(num_free_blocks=3, blocks_per_request=5) - scheduler = SimpleUnifiedScheduler( + scheduler = UnifiedScheduler( max_batch_size=4, max_num_tokens=100, kv_cache_manager=kv, @@ -1582,7 +1663,7 @@ def test_can_schedule_returns_false(self): def test_full_pipeline_output_structure(self): """Verify SchedulerOutput has all expected fields.""" kv = MockKVCacheManager(num_free_blocks=100, blocks_per_request=5) - scheduler = SimpleUnifiedScheduler( + scheduler = UnifiedScheduler( max_batch_size=4, max_num_tokens=100, kv_cache_manager=kv, @@ -1607,7 +1688,7 @@ def test_full_pipeline_output_structure(self): def test_paused_requests_propagated(self): """Paused requests from capacity scheduler appear in output.""" kv = MockKVCacheManager(num_free_blocks=100, blocks_per_request=5) - scheduler = SimpleUnifiedScheduler( + scheduler = UnifiedScheduler( max_batch_size=4, max_num_tokens=100, kv_cache_manager=kv, @@ -1622,10 +1703,31 @@ def test_paused_requests_propagated(self): output = scheduler.schedule_request(requests, set()) assert isinstance(output.paused_requests, list) + def test_schedule_step_returns_scheduled_requests(self): + kv = MockKVCacheManager(num_free_blocks=100, blocks_per_request=5) + scheduler = UnifiedScheduler( + max_batch_size=4, + max_num_tokens=100, + kv_cache_manager=kv, + peft_cache_manager=None, + scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, + schedule_step_config=ScheduleStepConfig(max_num_tokens=100), + ) + requests = [ + make_context_request(0, prompt_len=10), + make_generation_request(1), + ] + + step_result = scheduler.schedule_step(requests, set()) + + assert len(step_result.scheduled_requests.context_requests) == 1 + assert len(step_result.scheduled_requests.generation_requests) == 1 + assert step_result.num_fitting_requests == 2 + # ############################################################################ # -# Part 7: Additional PyCapacityScheduler Tests +# Part 8: Additional PyCapacityScheduler Tests # # ############################################################################