diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 7f4a25dd8386..6e2c5c4c1773 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -17,6 +17,7 @@ # Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade. os.environ["OMPI_MCA_coll_ucc_enable"] = "0" +os.environ["TLLM_USE_PYTHON_SCHEDULER"] = "1" def _add_trt_llm_dll_directory(): diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4fea1e0b4e67..9912d5834473 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -860,6 +860,9 @@ def create_py_executor_instance( if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager: scheduler_capacity += 1 + # Extract server_role from llm_args for scheduler selection + server_role = getattr(llm_args, 'server_role', None) + use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" if use_python_scheduler and not isinstance(kv_cache_manager, KVCacheManagerV2): @@ -873,7 +876,10 @@ def create_py_executor_instance( scheduler_policy=scheduler_config.capacity_scheduler_policy, ctx_chunk_config=ctx_chunk_config, two_step_lookahead=mapping.has_pp(), - scheduler_capacity=scheduler_capacity) + scheduler_capacity=scheduler_capacity, + dist=dist, + max_num_active_requests=model_engine.get_max_num_sequences(), + server_role=server_role) else: if isinstance(kv_cache_manager, KVCacheManagerV2): capacity_scheduler = KVCacheV2DummyScheduler( diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index f48d724658ae..8112fed78757 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -583,6 +583,13 @@ def __init__( additional_outputs=additional_outputs) self.child_requests = [] + # Pre-validation cache for attention_dp optimization + # When a request passes simulation in GlobalCoordinator.can_accept_request(), + # we cache the estimated tokens/blocks to avoid recalculating in _fused_schedule_request() + self.py_pre_validated: bool = False + self.py_estimated_tokens: int = 0 + self.py_estimated_blocks: int = 0 + self._py_embedding_bias_1d: Optional[torch.Tensor] = None if hasattr(self, 'embedding_bias') and self.embedding_bias is not None: # Pre-squeeze to 1D if needed (remove batch dimension) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2c0593d65105..d3603215d2e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -61,7 +61,7 @@ from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, SampleStateTensors, TRTLLMSampler) from .scheduler import (RequestScheduler, ScheduledRequests, - SerializableSchedulerOutput) + SerializableSchedulerOutput, SimpleUnifiedScheduler) # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -354,6 +354,29 @@ def __init__(self, self.max_input_len = max_input_len # _executor_loop private data self.max_num_active_requests = model_engine.get_max_num_sequences() + + # Configure SimpleUnifiedScheduler and AttentionDPScheduler + if isinstance(scheduler, SimpleUnifiedScheduler): + # Configure batch waiting (for TP-only mode) + scheduler.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters + scheduler.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio + scheduler.enable_batch_waiting = ( + scheduler.batch_wait_timeout_iters > 0 + or scheduler.batch_wait_max_tokens_ratio > 0) + + # If this is AttentionDPScheduler, configure attention_dp parameters + if hasattr(scheduler, 'enable_global_scheduling' + ) and scheduler.enable_global_scheduling: + # Configure batching/waiting parameters for attention_dp + scheduler.global_coordinator.attention_dp_enable_balance = self.attention_dp_enable_balance + if self.attention_dp_enable_balance: + scheduler.global_coordinator.attention_dp_time_out_iters = self.attention_dp_time_out_iters + scheduler.global_coordinator.attention_dp_batching_wait_iters = self.attention_dp_batching_wait_iters + + logger.info( + "Configured global scheduling for attention_dp (balance=%s)", + self.attention_dp_enable_balance) + self.active_requests: List[LlmRequest] = [] self.expected_num_active_requests = 0 self.async_transfer_manager = AsyncTransferManager( @@ -1411,22 +1434,160 @@ def _can_queue(self, scheduled_batch): return can_queue, can_queue_this_rank def _prepare_and_schedule_batch(self): - new_requests = self._fetch_and_activate_new_requests() + """Prepare and schedule batch for execution.""" + # Use new unified scheduler interface if available + if isinstance(self.scheduler, SimpleUnifiedScheduler) and hasattr( + self.scheduler, 'schedule_iteration'): + return self._prepare_and_schedule_batch_v2() + else: + # Fallback to old path for backward compatibility + return self._prepare_and_schedule_batch_v1() + + def _prepare_and_schedule_batch_v2(self): + """ + Prepare and schedule batch using new unified scheduler interface. + + SIMPLIFIED: Most scheduling logic moved to scheduler.schedule_iteration() + """ + # Step 1: Check stop condition + if self.should_stop_processing: + return None, None + + # Step 2: Check KV cache transfer status (disagg mode - executor responsibility) + if self.kv_cache_transceiver: + self._check_disagg_gen_transfer_status() + self._check_kv_transfer_timeout() + + # Step 3: Calculate iter_stats (executor responsibility) + iter_stats = None + if self.enable_iter_perf_stats: + num_new_in_queue = len(self.waiting_queue) + queue_latency = self._get_new_active_requests_queue_latency() + iter_stats = self._get_init_iter_stats(num_new_in_queue, + queue_latency) + + # Step 4: Fetch new requests from external queue (executor responsibility) + self._fetch_and_enqueue_requests(self.waiting_queue, + self.expected_num_active_requests) + + # Step 5: Validate new requests (executor responsibility) + # This happens inside schedule_iteration via _activate_new_requests, + # but we need to extract and validate them + waiting_queue_size_before = len(self.waiting_queue) + + # Step 6: Drafter pre-processing (calculate draft tokens for this iteration) + max_total_draft_tokens = 0 + use_spec_decode = False + + if self.drafter is not None: + # Calculate draft length based on batch size + if self.drafter.draft_len_schedule is not None: + batch_size_input = len(self.active_requests) + max_total_draft_tokens = self.drafter.get_draft_len_for_batch_size( + batch_size_input) + self.drafter.update_max_total_draft_tokens( + max_total_draft_tokens) + + # Determine if spec decode should be used + if self.drafter.draft_len_schedule is not None and max_total_draft_tokens == 0: + use_spec_decode = False + elif getattr(self, 'speculation_permanently_disabled', False): + use_spec_decode = False + else: + use_spec_decode = self.drafter.should_use_spec_decode( + self.active_requests, self.max_batch_size, + self.model_engine.llm_args.max_num_tokens, + max_total_draft_tokens) + + logger.debug(f"Use spec decode: {use_spec_decode}") + self.model_engine.enable_spec_decode = use_spec_decode + self.use_spec_decode = use_spec_decode + self.max_total_draft_tokens = max_total_draft_tokens + + # Set up draft_tokens in active_requests for scheduler awareness + for request in self.active_requests: + if request.state not in ( + LlmRequestState.GENERATION_IN_PROGRESS, + LlmRequestState.DISAGG_GENERATION_INIT): + continue + request.draft_tokens = [ + 0 + ] * max_total_draft_tokens if max_total_draft_tokens > 0 else [] + + # Step 7: ✨ SINGLE CALL TO SCHEDULER ✨ + scheduler_output = self.scheduler.schedule_iteration( + waiting_queue=self.waiting_queue, + active_requests=self.active_requests, + inflight_req_ids=self.inflight_req_ids, + drafter=self.drafter, + max_total_draft_tokens=max_total_draft_tokens, + use_spec_decode=use_spec_decode, + cp_config=self.dist.cp_config, + cp_rank=self.dist.cp_rank, + cp_size=self.dist.cp_size, + exclude_last_generation_logits=self. + _should_exclude_last_generation_logits(), + kv_cache_manager=self.kv_cache_manager, + resource_manager=self.resource_manager, + ) + + # Convert to ScheduledRequests + scheduled_batch = scheduler_output.to_scheduled_requests() + + # Validate new requests that were activated + # (they're already in active_requests after schedule_iteration) + waiting_queue_size_after = len(self.waiting_queue) + num_new_activated = waiting_queue_size_before - waiting_queue_size_after + + # Step 8: Disagg mode post-processing (executor responsibility) + if self.kv_cache_transceiver: + # Prepare disagg gen init resources + if scheduler_output.fitting_disagg_gen_init_requests: + self._prepare_disagg_gen_init( + scheduler_output.fitting_disagg_gen_init_requests) + + # Warning if no requests fit + if scheduler_output.num_fitting_requests == 0 and not scheduler_output.fitting_disagg_gen_init_requests: + logger.warning( + "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" + ) + self._check_disagg_ctx_cache_transfer_status(1) + + # Step 9: Update state and return + self.num_scheduled_requests = scheduled_batch.batch_size + logger.debug( + f'has {len(self.active_requests)} active_requests, ' + f'scheduled {len(scheduled_batch.context_requests)} context requests and ' + f'{len(scheduled_batch.generation_requests)} generation requests') + + return scheduled_batch, iter_stats + + def _prepare_and_schedule_batch_v1(self): + """ + OLD PATH: Prepare and schedule batch using old interface. + + Kept for backward compatibility with SimpleScheduler. + """ + # Step 1: Fetch and activate new requests + num_new_requests = self._fetch_and_activate_requests() if self.should_stop_processing: return None, None + # Step 2: Check KV cache transfer status if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() self._check_kv_transfer_timeout() + # Step 3: Calculate iter_stats iter_stats = None if self.enable_iter_perf_stats: iter_stats = self._get_init_iter_stats( - len(new_requests), - self._get_new_active_requests_queue_latency()) + num_new_requests, self._get_new_active_requests_queue_latency()) + # Step 4: Pad dummy requests self._pad_attention_dp_dummy_request() + # Step 5: Drafter logic if self.drafter is not None: # Honor permanent disable flag based on rolling acceptance first if self.drafter.draft_len_schedule is not None: @@ -1469,9 +1630,11 @@ def _prepare_and_schedule_batch(self): # that speculation is about to happen. self._prepare_draft_requests() - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( + # Step 6: Schedule batch + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule_batch( ) + # Step 7: Post-processing if self.drafter is not None and not self.use_spec_decode: for request in scheduled_batch.all_requests(): request.py_disable_speculative_decoding = True @@ -1493,6 +1656,84 @@ def _prepare_and_schedule_batch(self): f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats + def _validate_new_requests( + self, new_requests: List[LlmRequest]) -> List[LlmRequest]: + """ + Validate new requests and handle errors for invalid ones. + + Args: + new_requests: List of new requests to validate + + Returns: + List of validated requests (invalid ones are removed and errors are handled) + """ + validated_requests = [] + for request in new_requests: + try: + self._validate_request(request) + validated_requests.append(request) + except Exception as e: + self._handle_errors(str(e), requests=[request]) + return validated_requests + + def _fetch_and_activate_requests(self): + """ + Fetch and activate new requests. + + Returns: + int: Number of newly activated requests + """ + if isinstance(self.scheduler, SimpleUnifiedScheduler): + # SimpleUnifiedScheduler path: Works for both attention_dp and TP-only modes + # Fetch and enqueue requests from executor queue + self._fetch_and_enqueue_requests(self.waiting_queue, + self.expected_num_active_requests) + + # Activate new requests through scheduler + # Scheduler returns LlmRequests (already converted, only once) + # For attention_dp: expected_num_active_requests is max across all ranks + # For TP-only: expected_num_active_requests is local count + new_llm_requests, self.expected_num_active_requests = \ + self.scheduler._activate_new_requests( + self.active_requests, + self.waiting_queue, + self.dist.cp_config, + self.dist.cp_rank, + self.dist.cp_size, + self._should_exclude_last_generation_logits() + ) + + # Validate and add new requests to active_requests + validated_new_requests = self._validate_new_requests( + new_llm_requests) + self.active_requests.extend(validated_new_requests) + + return len(validated_new_requests) + else: + # SimpleScheduler path + new_requests = self._fetch_and_activate_new_requests() + return len(new_requests) + + def _schedule_batch(self): + """ + Schedule the batch using the appropriate scheduler. + + Returns: + tuple: (scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs) + """ + if isinstance(self.scheduler, SimpleUnifiedScheduler): + # SimpleUnifiedScheduler path: Works for both attention_dp and TP-only modes + # - For attention_dp: Batching done during activation via _apply_batching_filter() + # - For TP-only: Batching done during scheduling via _apply_batch_waiting() + scheduler_output = self.scheduler.schedule_request( + self.active_requests, self.inflight_req_ids) + return (scheduler_output.to_scheduled_requests(), + scheduler_output.fitting_disagg_gen_init_requests, + scheduler_output.num_fitting_requests) + else: + # SimpleScheduler path + return self._schedule() + def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 6631057251f3..e3801487c696 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,9 +1,11 @@ +import copy import dataclasses +import itertools from abc import ABC, abstractmethod -from collections import namedtuple -from dataclasses import dataclass +from collections import deque, namedtuple +from dataclasses import dataclass, field from enum import Enum -from typing import Optional +from typing import Dict, List, Optional, Set from strenum import StrEnum @@ -11,23 +13,77 @@ from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy from tensorrt_llm.logger import logger -# Assuming these imports exist in your environment from .llm_request import LlmRequest, LlmRequestState +from .request_utils import merge_requests +from .resource_manager import ResourceManagerType RequestList = list[LlmRequest] +# Standard scheduler output (used by both SimpleScheduler and SimpleUnifiedScheduler) SchedulerOutput = namedtuple("SchedulerOutput", [ "context_requests", "generation_requests", "paused_requests", "fitting_disagg_gen_init_requests", "num_fitting_requests" ]) +@dataclass +class UnifiedSchedulerOutput: + """ + Extended scheduler output for SimpleUnifiedScheduler with global coordination. + + Includes standard scheduling fields plus updated_active_requests for attention_dp mode. + """ + context_requests: RequestList + generation_requests: RequestList + paused_requests: RequestList + fitting_disagg_gen_init_requests: RequestList + num_fitting_requests: int + + # Optional: Only populated when global coordination is used (attention_dp) + updated_active_requests: Optional[RequestList] = None + + def to_scheduler_output(self) -> SchedulerOutput: + """Convert to standard SchedulerOutput (for backward compatibility).""" + return SchedulerOutput( + context_requests=self.context_requests, + generation_requests=self.generation_requests, + paused_requests=self.paused_requests, + fitting_disagg_gen_init_requests=self. + fitting_disagg_gen_init_requests, + num_fitting_requests=self.num_fitting_requests, + ) + + def to_scheduled_requests(self) -> 'ScheduledRequests': + """Convert to ScheduledRequests (used by PyExecutor).""" + return ScheduledRequests.from_lists( + context_requests=self.context_requests, + generation_requests=self.generation_requests, + paused_requests=self.paused_requests, + ) + + class ScheduledRequests: # to be aligned with ScheduledRequests in cpp/tensorrt_llm/batch_manager/common.h def __init__(self): self.context_requests: RequestList = [] self.generation_requests: RequestList = [] self.paused_requests: RequestList = [] + self.disagg_gen_init_requests: RequestList = [] + + @staticmethod + def from_lists( + context_requests: RequestList, + generation_requests: RequestList, + paused_requests: RequestList, + disagg_gen_init_requests: Optional[RequestList] = None, + ) -> 'ScheduledRequests': + """Factory method to create ScheduledRequests from lists.""" + scheduled = ScheduledRequests() + scheduled.context_requests = context_requests + scheduled.generation_requests = generation_requests + scheduled.paused_requests = paused_requests + scheduled.disagg_gen_init_requests = disagg_gen_init_requests if disagg_gen_init_requests is not None else [] + return scheduled @property def is_generation_only(self) -> bool: @@ -124,6 +180,36 @@ def to_scheduler_result( return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests +@dataclass +class RankResourceState: + """ + Snapshot of a single rank's resources for global coordination. + + This dataclass captures all information needed to simulate resource + allocation decisions without actually allocating resources. + Used by SimpleUnifiedScheduler for attention_dp global scheduling. + """ + + rank_id: int + + # === Constraints (Safety) === + free_kv_blocks: int # From CapacityScheduler.get_kv_cache_stats() + max_kv_blocks: int # Total KV cache capacity + current_batch_tokens: int # Current token load + max_token_budget: float # From MicroBatchScheduler.max_num_tokens (can be float('inf')) + current_batch_size: int # Number of active requests + max_batch_size: int # From MicroBatchScheduler.max_batch_size + + # === Load Metrics (Balancing) === + num_active_gen_reqs: int # Generation requests in progress + num_active_ctx_reqs: int # Context requests in progress + + # === PEFT/LoRA (Optional - reserved for future use) === + active_lora_task_ids: Set[int] = field( + default_factory=set) # For LoRA co-location + available_peft_pages: int = 0 # PEFT cache capacity + + class CapacityScheduler(ABC): @abstractmethod @@ -311,1053 +397,2860 @@ class ContextChunkingConfig: chunk_unit_size: int -class MicroBatchScheduler: - """Base class to match structure.""" +class NoEvictScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. + Tracks available blocks per window size for GUARANTEED_NO_EVICT scheduling. + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:29-62 + """ -class PyMicroBatchScheduler(MicroBatchScheduler): + def __init__(self, kv_cache_manager): + """ + Initialize with free blocks from KVCacheManager. + C++ equivalent: mAvailableBlocks = mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() + """ + self.kv_cache_manager = kv_cache_manager + stats = kv_cache_manager.get_kv_cache_stats() + self.available_blocks: dict[int, int] = dict( + stats.num_free_blocks_per_window_size) - def __init__( - self, - max_batch_size: int, - max_num_tokens: Optional[int] = None, - ctx_chunk_config: Optional[ContextChunkingConfig] = None, - no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState. - GENERATION_TO_COMPLETE, - ): - super().__init__() - self.max_batch_size = max_batch_size - self.max_num_tokens = max_num_tokens - self.ctx_chunk_config = ctx_chunk_config - self.max_context_length = max_num_tokens - # Match C++ MicroBatchScheduler defaults (see algorithms.cpp line 68-70) - self.no_schedule_until_state = no_schedule_until_state - self.no_schedule_after_state = no_schedule_after_state - # Cache state values to avoid repeated .value access (optimization) - self._no_schedule_until_state_value = no_schedule_until_state.value - self._no_schedule_after_state_value = no_schedule_after_state.value - self._context_init_state_value = LlmRequestState.CONTEXT_INIT.value - self._encoder_init_state_value = LlmRequestState.ENCODER_INIT.value + def decrement_reserved_blocks(self, req: LlmRequest) -> None: + """ + Decrement available blocks by the blocks needed to complete this request. + C++ reference: scheduledBlocksManager.h:40-46 + """ + for window_size in self.available_blocks: + needed = self.kv_cache_manager.get_remaining_blocks_to_completion( + req, window_size) + self.available_blocks[window_size] -= needed - def _can_be_scheduled(self, req: LlmRequest) -> bool: + def enough_available_blocks(self, req: LlmRequest) -> bool: """ - Check if request is within the schedulable state range. - C++ reference: microBatchScheduler.cpp line 192-195 - Optimized: use state_value property to avoid enum object creation + Check if there are enough available blocks for this request across all window sizes. + C++ reference: scheduledBlocksManager.h:48-57 """ - # Use state_value property (returns int directly, avoids enum object creation) - state_value = req.state_value - # Inline comparison: must have reached until_state but not after_state - return (state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value) + return all( + self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= + avail for ws, avail in self.available_blocks.items()) - def schedule( - self, active_requests: RequestList, - inflight_request_ids: set[int]) -> tuple[RequestList, RequestList]: - context_requests: RequestList = [] - generation_requests: RequestList = [] +class MaxUtilizationScheduledBlocksManager: + """ + Python equivalent of C++ kv_cache_manager::MaxUtilizationScheduledBlocksManager. + Tracks scheduled blocks per window size for MAX_UTILIZATION scheduling. - # Current total tokens in the scheduled batch (Generation + Context) - batch_num_tokens = 0 - scheduled_req_size = 0 - scheduled_beam_width = 0 + Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:64-117 + """ - contexts_to_be_chunked: RequestList = [] - # Total tokens required by chunked requests (calculated tentatively) - num_chunked_tokens = 0 - all_context_requests_fit = True + def __init__(self, kv_cache_manager, two_steps_look_ahead: bool): + """ + Initialize scheduled blocks count per window size. + C++ equivalent: iterate windowSizes and set mNumScheduledBlocks[windowSize] = 0 + """ + self.kv_cache_manager = kv_cache_manager + self.two_steps_look_ahead = two_steps_look_ahead + window_sizes = set(kv_cache_manager.max_attention_window_vec) + self.num_scheduled_blocks: dict[int, int] = { + ws: 0 + for ws in window_sizes + } - # Cache instance attributes as locals for faster access in loop - max_batch_size = self.max_batch_size - max_num_tokens = self.max_num_tokens - max_context_length = self.max_context_length - ctx_chunk_config = self.ctx_chunk_config + def prepare_blocks_if_schedulable( + self, req: LlmRequest) -> Optional[dict[int, int]]: + """ + Check if request can be scheduled and return new block counts if so. + Returns None if request cannot fit. + C++ reference: scheduledBlocksManager.h:80-100 + """ + blocks_if_scheduled = {} + for window_size, num_scheduled in self.num_scheduled_blocks.items(): + required = self.kv_cache_manager.get_needed_blocks_one_step( + req, self.two_steps_look_ahead, window_size) + logger.debug( + f"MaxUtilizationScheduler: request ID {req.request_id} " + f"required blocks {required} for {window_size} window size") + scheduled_total = num_scheduled + required + has_free = self.kv_cache_manager.scheduling_has_free_blocks( + scheduled_total, window_size) + if not has_free: + return None + blocks_if_scheduled[window_size] = scheduled_total + return blocks_if_scheduled - # 1. Main Scheduling Loop - for req in active_requests: - req_state_value = req.state_value - # Skip requests already in flight (should be filtered by caller, but C++ checks) - if req.request_id in inflight_request_ids: - continue + def update_scheduled_blocks(self, blocks: dict[int, int]) -> None: + """ + Update the scheduled blocks after successfully scheduling a request. + C++ reference: scheduledBlocksManager.h:102-110 + """ + assert len(blocks) == len(self.num_scheduled_blocks), \ + f"Block count mismatch: {len(blocks)} vs {len(self.num_scheduled_blocks)}" + for window_size, blocks_if_scheduled in blocks.items(): + logger.debug( + f"MaxUtilizationScheduler: scheduled blocks {blocks_if_scheduled} " + f"for window size {window_size}") + self.num_scheduled_blocks[window_size] = blocks_if_scheduled - # Skip if request cannot be scheduled yet or should no longer be scheduled, manually inline the condition to reuse req.state_value - if not (req_state_value >= self._no_schedule_until_state_value - and req_state_value < self._no_schedule_after_state_value): - continue - req_num_tokens = 0 +class PeftHelper: + """ + Helper class for PEFT/LoRA resource management. - # --- A. Encoder Request Handling --- - if req_state_value == self._encoder_init_state_value: - req_num_tokens = req.encoder_output_len + Encapsulates all PEFT-related logic including page calculation, + task tracking, and capacity management. + """ - assert max_context_length is None or req_num_tokens <= max_context_length, \ - f"The number of encoder tokens ({req_num_tokens}) exceeds the limit value ({max_context_length})" + def __init__(self, peft_cache_manager): + """ + Initialize PEFT helper. - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): - break + Args: + peft_cache_manager: PEFT cache manager instance (or None if PEFT disabled) + """ + self.peft_cache_manager = peft_cache_manager - logger.debug(f"encoder request scheduled: ID {req.request_id}") - context_requests.append(req) - batch_num_tokens += req_num_tokens + def get_max_pages(self) -> int: + """Get maximum PEFT cache pages available.""" + if self.peft_cache_manager is None: + return 2**31 - 1 # INT_MAX equivalent + return self.peft_cache_manager.max_device_pages - # --- B. Context Request Handling --- - elif req_state_value == self._context_init_state_value: - if not ctx_chunk_config: - # No Chunking: Schedule full context - # C++ uses getNumTokens(beam=0) which is tokens.size() - numPreDecodedTokens - base_tokens = req.get_num_tokens(0) - draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 - req_num_tokens = base_tokens + draft_tokens - - assert max_context_length is None or req_num_tokens <= max_context_length, \ - f"The number of context tokens ({req_num_tokens}) exceeds the limit value ({max_context_length})" - - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): - break + def get_pages_for_request(self, req: LlmRequest) -> int: + """Get number of PEFT pages needed for a request.""" + if self.peft_cache_manager is None: + return 0 + return self.peft_cache_manager.determine_num_pages(req) - logger.debug( - f"context request scheduled: ID {req.request_id}") - context_requests.append(req) - batch_num_tokens += req_num_tokens - else: - # Chunking Enabled: Tentative schedule - req.context_chunk_size = req.context_remaining_length + def get_task_info( + self, req: LlmRequest, + seen_task_ids: set[int]) -> tuple[Optional[int], bool, int]: + """ + Get PEFT task information for a request. - draft_tokens = req.num_draft_tokens if ( - req.is_last_context_chunk - and req.has_draft_tokens) else 0 - req_num_tokens = req.context_chunk_size + draft_tokens + Args: + req: Request to check + seen_task_ids: Set of task IDs already seen/allocated - if max_context_length is not None: - if max_context_length < req_num_tokens: - req_num_tokens = max_context_length - all_context_requests_fit = False + Returns: + Tuple of (lora_task_id, is_new_task, required_pages) + """ + lora_task_id = getattr(req, 'lora_task_id', None) + is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids + required_pages = self.get_pages_for_request(req) if is_new_task else 0 + return lora_task_id, is_new_task, required_pages - logger.debug( - f"contexts-to-be-chunked request scheduled: ID {req.request_id}" - ) - contexts_to_be_chunked.append(req) - num_chunked_tokens += req_num_tokens - # --- C. Generation Request Handling --- - else: - # C++ uses getBeamWidthByIter() which returns dynamic beam width - # during beam search (1->2->3->...->beamWidth) - beam_width = req.get_beam_width_by_iter( - for_next_iteration=False) - req_num_tokens = beam_width + req.num_draft_tokens +class GlobalCoordinator: + """ + Handles global request coordination for attention_dp mode. + + This class encapsulates all the logic for coordinating request scheduling + across multiple TP ranks using a single allgather operation. It implements + a deterministic water-filling algorithm that all ranks execute identically (SPMD). + + Responsibilities: + - Build local rank resource state + - Gather states from all ranks via single allgather + - Simulate global scheduling with water-filling algorithm + - Calculate assignment scores for load balancing + - Apply batching filters for context request coordination + """ - if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): - break + def __init__(self, scheduler, dist, max_num_active_requests: int): + """ + Initialize global coordinator. - # Beam Width Consistency Check - if scheduled_beam_width == 0: - scheduled_beam_width = beam_width - elif scheduled_beam_width != beam_width: - logger.debug( - f"generation request skipped: ID {req.request_id} since its " - f"beam width ({beam_width}) is different from scheduled ones " - f"({scheduled_beam_width})") - continue - generation_requests.append(req) - batch_num_tokens += req_num_tokens + Args: + scheduler: Reference to parent SimpleUnifiedScheduler (for estimation methods) + dist: Distributed communication object + max_num_active_requests: Maximum number of active requests across all ranks + """ + self.scheduler = scheduler + self.dist = dist + self.max_num_active_requests = max_num_active_requests + + # Attention DP balancing/batching state + self.attention_dp_enable_balance = False + self.attention_dp_time_out_iters = 0 + self.attention_dp_batching_wait_iters = 0 + self.adp_ctx_waiting_iters_count = 0 + self.adp_ctx_batching_wait_iters_count = 0 + + def _estimate_next_iteration_growth_tokens(self, + request: LlmRequest) -> int: + """ + Estimate how many additional tokens a request will consume in the NEXT iteration. - # --- Batch Size Limit Check --- - scheduled_req_size += 1 - if scheduled_req_size >= max_batch_size: - break + This is critical for accurate simulation: old active requests will grow + (generate tokens, process next chunk, etc.) before new requests are scheduled. - # 2. Verify Chunking Fits - if max_num_tokens is not None and num_chunked_tokens > ( - max_num_tokens - batch_num_tokens): - all_context_requests_fit = False + Args: + request: Active request to estimate growth for + + Returns: + int: Estimated additional tokens for next iteration + """ + state_value = request.state_value + + # Context requests: Check for chunking + if state_value == self.scheduler._context_init_state_value: + if not request.is_last_context_chunk: + # Will process another chunk in next iteration + remaining_length = request.context_remaining_length + if remaining_length > 0: + # Estimate next chunk size + max_chunk = self.scheduler.max_num_tokens if self.scheduler.max_num_tokens else 2048 + return min(remaining_length, max_chunk) + return 0 # Last chunk, no growth + + # Generation requests: Will generate more tokens + elif state_value != self.scheduler._encoder_init_state_value: + # Get beam width for next iteration + beam_width = request.get_beam_width_by_iter(for_next_iteration=True) + # Add draft tokens if applicable + draft_tokens = request.num_draft_tokens if request.has_draft_tokens else 0 + return beam_width + draft_tokens + + # Encoder requests: No growth (single-shot) + return 0 + + def _estimate_next_iteration_growth_blocks(self, + request: LlmRequest) -> int: + """ + Estimate how many additional KV cache blocks a request will need in the NEXT iteration. - # 3. Apply Chunking Strategy if needed - if not all_context_requests_fit and contexts_to_be_chunked: - assert ctx_chunk_config is not None, \ - "If chunking is not enabled, context scheduling should be completed." - remaining_capacity = ( - max_num_tokens - - batch_num_tokens) if max_num_tokens is not None else None + Args: + request: Active request to estimate growth for - self._set_ctx_requests_chunk_size(contexts_to_be_chunked, - remaining_capacity) + Returns: + int: Estimated additional blocks for next iteration + """ + if self.scheduler.kv_cache_manager is None: + return 0 - # 4. Finalize Chunked Requests - for req in contexts_to_be_chunked: - if req.context_chunk_size > 0: - context_requests.append(req) - batch_num_tokens += req.context_chunk_size - logger.debug(f"context request scheduled: ID {req.request_id}, " - f"chunk size {req.context_chunk_size}") + # Estimate growth tokens first + growth_tokens = self._estimate_next_iteration_growth_tokens(request) + if growth_tokens == 0: + return 0 - # Sort requests for consistency with C++ - # C++ reference: utils::sortRequests in inflightBatchingUtils.cpp - self._sort_requests(context_requests, generation_requests, - not all_context_requests_fit) + # Get current sequence length and blocks + current_length = request.get_num_tokens(0) + + # For VSWA, use worst-case across window sizes + if hasattr(self.scheduler.kv_cache_manager, 'is_variable_window') and \ + self.scheduler.kv_cache_manager.is_variable_window: + max_growth_blocks = 0 + for window_size_key in self.scheduler.kv_cache_manager.get_window_size_keys( + ): + current_blocks = self.scheduler.kv_cache_manager.get_num_required_blocks( + request, window_size_key) + + # Estimate blocks after growth (approximate) + # This is conservative: assume each token might need a new block + tokens_per_block = getattr(self.scheduler.kv_cache_manager, + 'tokens_per_block', 64) + future_length = current_length + growth_tokens + future_blocks = (future_length + tokens_per_block - + 1) // tokens_per_block + + growth_blocks = max(0, future_blocks - current_blocks) + max_growth_blocks = max(max_growth_blocks, growth_blocks) + + return max_growth_blocks + else: + # Standard case: estimate block growth + tokens_per_block = getattr(self.scheduler.kv_cache_manager, + 'tokens_per_block', 64) - # Summary logs - logger.debug(f"batchSize (num ctx/enc requests + num gen requests): " - f"{len(context_requests) + len(generation_requests)}") - logger.debug(f"batchNumTokens / maxNumTokens: {batch_num_tokens} / " - f"{max_num_tokens or 0}") + # Current blocks + current_blocks = self.scheduler.kv_cache_manager.get_num_required_blocks( + request) - return context_requests, generation_requests + # Future blocks after growth + future_length = current_length + growth_tokens + future_blocks = (future_length + tokens_per_block - + 1) // tokens_per_block - def _sort_requests(self, context_requests: RequestList, - generation_requests: RequestList, - chunks_present: bool) -> None: - """ - Sort requests for consistency with C++. - C++ reference: utils::sortRequests in inflightBatchingUtils.cpp + return max(0, future_blocks - current_blocks) - 1. If chunks are present, move context requests that reached the last - context chunk to the end of the vector. - 2. Sort all requests by lora task id for performance. + def build_local_state( + self, + active_requests: List[LlmRequest], + ) -> RankResourceState: """ + Build snapshot of local rank's current state. - def get_lora_task_id(req: LlmRequest): - # C++ uses std::optional comparison where nullopt < any_value - # So requests without LoRA (nullopt) should come first - lora_id = getattr(req, 'lora_task_id', None) - if lora_id is None: - return (0, 0) # (has_value=False, value=0) - comes first - return (1, lora_id) # (has_value=True, value) - sorted by value + ENHANCEMENT: Includes predicted growth of active requests for next iteration. + This makes simulation more accurate by accounting for resources that old requests + will consume before new requests are scheduled. - if chunks_present: - # Partition: non-last-chunk first, last-chunk at end - not_last_chunk = [ - r for r in context_requests if not r.is_last_context_chunk - ] - last_chunk = [ - r for r in context_requests if r.is_last_context_chunk - ] - # Sort each group by lora_task_id - not_last_chunk.sort(key=get_lora_task_id) - last_chunk.sort(key=get_lora_task_id) - # Rebuild the list in-place - context_requests.clear() - context_requests.extend(not_last_chunk) - context_requests.extend(last_chunk) + This captures all information needed for global coordination without + modifying any actual resources. + + Args: + active_requests: Currently active requests on this rank + + Returns: + RankResourceState: Snapshot of current rank state (including predicted growth) + """ + # Get KV cache stats + if self.scheduler.kv_cache_manager is not None: + stats = self.scheduler.kv_cache_manager.get_kv_cache_stats() + # For VSWA (Variable Sliding Window), we track per window size + if hasattr(stats, 'num_free_blocks_per_window_size'): + free_blocks_per_ws = dict(stats.num_free_blocks_per_window_size) + # Use the primary window size (0 or first key) + primary_ws = 0 if 0 in free_blocks_per_ws else next( + iter(free_blocks_per_ws), 0) + free_blocks = free_blocks_per_ws.get(primary_ws, 0) + else: + # Fallback for non-VSWA + free_blocks = getattr(stats, 'free_num_blocks', 0) + max_blocks = getattr(self.scheduler.kv_cache_manager, + 'max_num_blocks', 0) else: - context_requests.sort(key=get_lora_task_id) + free_blocks = 0 + max_blocks = 0 - generation_requests.sort(key=get_lora_task_id) + # Get token budget + max_token_budget = self.scheduler.max_num_tokens if self.scheduler.max_num_tokens is not None else float( + 'inf') - def _set_ctx_requests_chunk_size(self, requests: RequestList, - capacity: Optional[int]): - # C++: Resets all chunk sizes to 0 at start - for req in requests: - req.context_chunk_size = 0 + # Calculate current token load + current_tokens = self.scheduler._calculate_current_token_load( + active_requests) - policy = self.ctx_chunk_config.chunking_policy - unit_size = self.ctx_chunk_config.chunk_unit_size + # ENHANCEMENT: Predict growth for next iteration + # This accounts for old requests consuming more resources before new requests schedule + predicted_growth_tokens = 0 + predicted_growth_blocks = 0 - if policy == ChunkingPolicy.EQUAL_PROGRESS: - self._chunk_equal_progress(requests, capacity, unit_size) - elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: - self._chunk_fcfs(requests, capacity, unit_size) - else: - raise ValueError(f"Invalid chunking policy: {policy}") + for req in active_requests: + growth_tokens = self._estimate_next_iteration_growth_tokens(req) + predicted_growth_tokens += growth_tokens + + if growth_tokens > 0: + growth_blocks = self._estimate_next_iteration_growth_blocks(req) + predicted_growth_blocks += growth_blocks + + # Count active requests by type + num_active_gen = sum(1 for r in active_requests + if not r.is_context_init_state) + num_active_ctx = sum(1 for r in active_requests + if r.is_context_init_state) + + # Reserve resources for predicted growth + # This makes simulation conservative but accurate + return RankResourceState( + rank_id=self.dist.rank, + free_kv_blocks=max(0, free_blocks - + predicted_growth_blocks), # Reserve for growth + max_kv_blocks=max_blocks, + current_batch_tokens=current_tokens + + predicted_growth_tokens, # Include growth + max_token_budget=max_token_budget, + current_batch_size=len(active_requests), + max_batch_size=self.scheduler.max_batch_size, + num_active_gen_reqs=num_active_gen, + num_active_ctx_reqs=num_active_ctx, + ) + + def gather_all_states( + self, local_state: RankResourceState) -> List[RankResourceState]: + """ + THE SINGLE COMMUNICATION POINT. + Gather RankResourceState from all TP ranks via tp_allgather. - self._fit_draft_tokens(requests, capacity, unit_size) + This is the ONLY synchronization point in the unified scheduler, + replacing the 3+ tp_allgather calls in the old architecture. - def _chunk_equal_progress(self, requests: RequestList, - capacity: Optional[int], unit_size: int): - num_ctx_tokens = 0 - num_tokens_single_loop = 1 + Args: + local_state: This rank's resource state - # C++ Loop: while ((!capacity || numCtxTokens < capacity) && numTokensSingleLoop) - while (capacity is None - or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: - num_tokens_single_loop = 0 - for req in requests: - past_size = req.context_chunk_size + Returns: + List[RankResourceState]: States from all ranks + """ + # Serialize to dict for communication (dataclasses are not directly serializable) + local_dict = { + 'rank_id': local_state.rank_id, + 'free_kv_blocks': local_state.free_kv_blocks, + 'max_kv_blocks': local_state.max_kv_blocks, + 'current_batch_tokens': local_state.current_batch_tokens, + 'max_token_budget': local_state.max_token_budget, + 'current_batch_size': local_state.current_batch_size, + 'max_batch_size': local_state.max_batch_size, + 'num_active_gen_reqs': local_state.num_active_gen_reqs, + 'num_active_ctx_reqs': local_state.num_active_ctx_reqs, + 'active_lora_task_ids': list(local_state.active_lora_task_ids), + 'available_peft_pages': local_state.available_peft_pages, + } - # C++ logic: suggested = past + unit - suggested_size = past_size + unit_size + # THE SINGLE tp_allgather + all_dicts = self.dist.tp_allgather(local_dict) - # Ensure we don't exceed what the request actually needs - remaining_total = req.context_remaining_length - suggested_size = min(suggested_size, remaining_total) + # Deserialize back to RankResourceState objects + result = [] + for d in all_dicts: + # Convert active_lora_task_ids back to set + d['active_lora_task_ids'] = set(d.get('active_lora_task_ids', [])) + result.append(RankResourceState(**d)) - req.context_chunk_size = suggested_size + return result - actual_size = req.context_chunk_size - actual_increment = actual_size - past_size - - # Check Constraints - # 1. Capacity - if capacity is not None and (num_ctx_tokens + actual_increment - > capacity): - req.context_chunk_size = past_size # Revert - continue - - # 2. Max Context Length - if self.max_context_length is not None and actual_size > self.max_context_length: - req.context_chunk_size = past_size # Revert - continue + def calculate_assignment_score( + self, + rank_state: RankResourceState, + ) -> float: + """ + Calculate assignment score for a rank. + Higher score = better assignment. - num_ctx_tokens += actual_increment - num_tokens_single_loop += actual_increment + PRIMARY GOAL: Balance TOKEN WORKLOAD across ranks. + Token balance directly correlates with execution time balance, + which is the true performance metric. Request count balance + is secondary since requests vary significantly in token count. - def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], - unit_size: int): - current_capacity = capacity if capacity is not None else float('inf') + Scoring components (by priority): + 1. Token load penalty (PRIMARY): Heavily penalize high token load + 2. Context request penalty (SECONDARY): Balance context vs generation + 3. Generation request penalty (SECONDARY): Minor load factor - for req in requests: - suggested_size = req.context_remaining_length - actual_size = suggested_size + Args: + rank_state: Current state of the candidate rank - if current_capacity < actual_size: - actual_size = current_capacity + Returns: + float: Assignment score (higher is better) + """ + score = 0.0 + + # === PRIMARY: Token Load Balance (Most Important) === + # This is the dominant factor because token count directly determines + # computation time. A rank with 2000 tokens takes 40x longer than + # a rank with 50 tokens, even if both have the same request count. + if rank_state.max_token_budget > 0 and rank_state.max_token_budget != float( + 'inf'): + load_ratio = rank_state.current_batch_tokens / rank_state.max_token_budget + score -= load_ratio * 1000.0 # Heavily penalize high token load + + # === SECONDARY: Request Count Balance (Less Important) === + # Context requests still matter because they block generation, + # but their weight is much lower relative to token load. + # This ensures token balance takes precedence over request count balance. + score -= rank_state.num_active_ctx_reqs * 10.0 + score -= rank_state.num_active_gen_reqs * 5.0 + + return score + + def can_accept_request( + self, + request: LlmRequest, + rank_state: RankResourceState, + ) -> bool: + """ + Check if rank can accept this request based on resource constraints. + This is the SIMULATION of capacity and token budget checks. - if self.max_context_length is not None: - actual_size = min(self.max_context_length, actual_size) + OPTIMIZATION: If the request can be accepted, cache the estimated tokens/blocks + to avoid recalculation in _fused_schedule_request(). - # Round down to unit size if we had to truncate - if actual_size < suggested_size: - actual_size = (int(actual_size) // unit_size) * unit_size + Args: + request: The request to check + rank_state: Current state of the candidate rank - req.context_chunk_size = int(actual_size) + Returns: + bool: True if rank can accept the request + """ + # Check batch size limit + if rank_state.current_batch_size >= rank_state.max_batch_size: + return False - # C++: ctxTokensCapacity = ctxTokensCapacity - actualChunkSize - if capacity is not None: - current_capacity -= req.context_chunk_size + # Check token budget limit + tokens_needed = self.scheduler._estimate_tokens_needed(request) + if rank_state.max_token_budget != float('inf'): + if rank_state.current_batch_tokens + tokens_needed > rank_state.max_token_budget: + return False - def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], - unit_size: int): - # Calculate tokens already taken by the batch so far - num_ctx_tokens = sum(req.context_chunk_size for req in requests) + # Check KV cache capacity + blocks_needed = self.scheduler._estimate_blocks_needed(request) + if rank_state.free_kv_blocks < blocks_needed: + return False - for req in requests: - if req.is_last_context_chunk and req.has_draft_tokens: - remainder = req.context_chunk_size % unit_size - remaining_space = 0 if remainder == 0 else unit_size - remainder + # OPTIMIZATION: Cache estimates for later use in _fused_schedule_request() + # This avoids ~50% duplicate work for newly activated requests + request.py_pre_validated = True + request.py_estimated_tokens = tokens_needed + request.py_estimated_blocks = blocks_needed - if self.max_context_length is not None: - remaining_context_len = self.max_context_length - req.context_chunk_size - remaining_space = min(remaining_space, - remaining_context_len) + return True - if capacity is not None: - remaining_space = min(remaining_space, - capacity - num_ctx_tokens) - num_ctx_tokens += remaining_space + def update_rank_state_after_assignment( + self, + rank_state: RankResourceState, + request: LlmRequest, + ) -> None: + """ + Update simulated rank state after assigning a request. + This modifies the state IN PLACE during simulation. - draft_discard = req.num_draft_tokens - remaining_space - if draft_discard > 0: - logger.debug(f"Discarding {draft_discard} draft tokens") - if hasattr(req, "discard_draft_tokens"): - req.discard_draft_tokens(draft_discard) + Args: + rank_state: The rank state to update (modified in place) + request: The request that was assigned + """ + # Decrement resources + tokens_needed = self.scheduler._estimate_tokens_needed(request) + rank_state.current_batch_tokens += tokens_needed + rank_state.current_batch_size += 1 + blocks_needed = self.scheduler._estimate_blocks_needed(request) + rank_state.free_kv_blocks -= blocks_needed -class SchedulerPolicyBase(ABC): - """ - Abstract base class for capacity scheduler policies. - Each policy implements its own scheduling logic. - """ + # Update request counters + if request.is_context_init_state: + rank_state.num_active_ctx_reqs += 1 + else: + rank_state.num_active_gen_reqs += 1 - @abstractmethod - def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: + def simulate_global_schedule( + self, + candidate_requests: + List, # List[RequestQueueItem] but avoid circular import + all_rank_states: List[RankResourceState], + ) -> Dict[int, List[int]]: """ - Schedule requests according to the policy. + Deterministic water-filling algorithm. + ALL RANKS RUN THIS IDENTICALLY (SPMD). + + This is the core scheduling algorithm that assigns requests to ranks + based on resource availability and optimization criteria. Args: - scheduler: The capacity scheduler instance (for accessing shared state) - active_requests: List of active requests to schedule + candidate_requests: List of candidate requests to assign + all_rank_states: Current states of all ranks Returns: - Tuple of (scheduled_requests, paused_requests) + Dict mapping rank_id -> [assigned_request_ids] """ - raise NotImplementedError - - -class MaxRequestsPolicy(SchedulerPolicyBase): - """ - MaxRequestsScheduler: Simple request count limiting without KV cache checks. - C++ reference: capacityScheduler.cpp:154-176 - """ + # Deep copy to avoid modifying original states + sim_states = copy.deepcopy(all_rank_states) + + # Initialize assignments + assignments = {state.rank_id: [] for state in sim_states} + + # Sort candidates deterministically (all ranks must see same order!) + # Priority: non-relaxed first, then by request_id for determinism + sorted_candidates = sorted( + candidate_requests, + key=lambda item: ( + # Check if request has attention_dp_relax flag + (getattr(item, 'llm_request', None) and getattr( + item.llm_request, 'py_scheduling_params', None) and getattr( + item.llm_request.py_scheduling_params, + 'attention_dp_relax', False)) or False, + # Secondary sort by id for determinism (RequestQueueItem.id) + item.id, + )) + + # Water-filling algorithm + for req_item in sorted_candidates: + if not hasattr(req_item, 'llm_request') or not req_item.llm_request: + continue - def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: - scheduled_requests: RequestList = [] + req = req_item.llm_request - for req in active_requests: - if not scheduler._can_be_scheduled(req): - continue + # Score all ranks for this request + best_rank_id = -1 + best_score = -float('inf') - if len(scheduled_requests) >= scheduler.max_num_requests: - break + for rank_state in sim_states: + # Feasibility check + if not self.can_accept_request(req, rank_state): + continue - if (req.is_encoder_init_state or req.is_context_init_state - or req.is_generation_in_progress_state): - scheduled_requests.append(req) + # Calculate score + score = self.calculate_assignment_score(rank_state) - return scheduled_requests, [] + if score > best_score: + best_score = score + best_rank_id = rank_state.rank_id + # Assign to best rank (if any rank can accept) + if best_rank_id != -1: + assignments[best_rank_id].append(req.request_id) -class GuaranteedNoEvictPolicy(SchedulerPolicyBase): - """ - GuaranteedNoEvictScheduler: Reserve blocks for requests to complete without eviction. - C++ reference: capacityScheduler.cpp:194-331 - """ + # Update simulated state + target_state = sim_states[best_rank_id] + self.update_rank_state_after_assignment(target_state, req) - def __init__(self, static_batch: bool = False): - self.static_batch = static_batch + return assignments - def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: - scheduled_requests: RequestList = [] - has_peft = scheduler.peft_cache_manager is not None + def apply_batching_filter( + self, + assignments: Dict[int, List[int]], + candidate_requests: List, + ) -> Dict[int, List[int]]: + """ + Apply batching filter to assignments based on waiting logic. - skipping_is_relevant = scheduler._is_skipping_relevant() + If we should wait for all ranks to have context requests, this method + filters out context requests but keeps generation requests. - newly_contributed_context_blocks: Set = set() - newly_contributed_cross_context_blocks: Set = set() - if not self.static_batch and skipping_is_relevant: - newly_contributed_context_blocks, newly_contributed_cross_context_blocks = \ - scheduler._prefill_contributed_blocks(active_requests) + Args: + assignments: Dict mapping rank_id -> [assigned_request_ids] + candidate_requests: List of candidate requests - reserved_blocks = NoEvictScheduledBlocksManager( - scheduler.kv_cache_manager) - reserved_cross_blocks: Optional[NoEvictScheduledBlocksManager] = None - if scheduler.cross_kv_cache_manager is not None: - reserved_cross_blocks = NoEvictScheduledBlocksManager( - scheduler.cross_kv_cache_manager) + Returns: + Dict[int, List[int]]: Filtered assignments + """ + # Check if we should wait + should_wait = self.should_wait_for_context_batching( + assignments, candidate_requests) + if not should_wait: + return assignments + + # Build request ID to request mapping + req_id_to_req = {} + for req_item in candidate_requests: + if hasattr(req_item, 'llm_request') and req_item.llm_request: + req = req_item.llm_request + req_id_to_req[req.request_id] = req + + # Filter out context requests, keep generation requests + filtered_assignments = {} + for rank_id in assignments: + filtered_req_ids = [] + for req_id in assignments[rank_id]: + if req_id in req_id_to_req: + req = req_id_to_req[req_id] + # Keep only generation requests, remove context requests + if not req.is_context_init_state: + filtered_req_ids.append(req_id) + else: + # Unknown request (shouldn't happen but keep for safety) + filtered_req_ids.append(req_id) + filtered_assignments[rank_id] = filtered_req_ids - # PEFT state - only used when has_peft - claimed_peft_pages = 0 - available_peft_pages = scheduler._get_max_peft_pages( - ) if has_peft else 0 - uniq_task_ids: set[int] = set() if has_peft else None + return filtered_assignments - pending_requests: RequestList = [] - pending_dis_gen_init_requests: RequestList = [] + def should_wait_for_context_batching( + self, + assignments: Dict[int, List[int]], + candidate_requests: List, + ) -> bool: + """ + Check if we should wait for all ranks to have context requests (attention_dp batching). - # First pass: process in-progress generation and classify requests - for req in active_requests: - if not scheduler._can_be_scheduled_with_disagg_exception(req): - continue + This implements the same logic as _balance_adp_requests to ensure: + 1. All ranks have context requests before scheduling (avoid load imbalance) + 2. Batch context requests together when possible + 3. Timeout mechanism to avoid deadlock - if len(scheduled_requests) >= scheduler.max_num_requests: - break + Args: + assignments: Dict mapping rank_id -> [assigned_request_ids] + candidate_requests: List of candidate requests - if req.is_generation_in_progress_state: - scheduled_requests.append(req) - reserved_blocks.decrement_reserved_blocks(req) - if reserved_cross_blocks is not None: - reserved_cross_blocks.decrement_reserved_blocks(req) + Returns: + bool: True if we should wait (clear context requests), False if we should proceed + """ + if not self.attention_dp_enable_balance: + return False - if has_peft: - lora_task_id, is_new_task, peft_pages = scheduler._get_peft_task_info( - req, uniq_task_ids) - if is_new_task: - claimed_peft_pages += peft_pages - uniq_task_ids.add(lora_task_id) + # Build request ID to request mapping + req_id_to_req = {} + for req_item in candidate_requests: + if hasattr(req_item, 'llm_request') and req_item.llm_request: + req = req_item.llm_request + req_id_to_req[req.request_id] = req + + # Count context and generation requests per rank + rank_ctx_counts = {} + rank_gen_counts = {} + for rank_id, assigned_req_ids in assignments.items(): + ctx_count = 0 + gen_count = 0 + for req_id in assigned_req_ids: + if req_id in req_id_to_req: + req = req_id_to_req[req_id] + if req.is_context_init_state: + ctx_count += 1 + else: + gen_count += 1 + rank_ctx_counts[rank_id] = ctx_count + rank_gen_counts[rank_id] = gen_count + + # Check conditions (same as _balance_adp_requests) + all_ranks_have_ctx_requests = all(count > 0 + for count in rank_ctx_counts.values()) + all_ranks_have_gen_requests = all(count > 0 + for count in rank_gen_counts.values()) + + # Note: We don't check free_ctx_slots here because global coordination already handles capacity in can_accept_request + + if all_ranks_have_ctx_requests: + # All ranks have context requests + self.adp_ctx_waiting_iters_count = 0 + + # Check if we should batch (wait for more context requests) + if all_ranks_have_gen_requests: + if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters: + self.adp_ctx_batching_wait_iters_count += 1 + return True # Wait for batching + else: + self.adp_ctx_batching_wait_iters_count = 0 + return False # Proceed with scheduling + else: + return False # Proceed (no generation requests to compete with) + else: + # Not all ranks have context requests + self.adp_ctx_waiting_iters_count += 1 - elif req.is_disagg_generation_init_state: - pending_dis_gen_init_requests.append(req) + timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters + if timeout_reached or not all_ranks_have_gen_requests: + # Timeout or no generation requests - proceed anyway + self.adp_ctx_waiting_iters_count = 0 + return False else: - pending_requests.append(req) - - # Second pass: process pending requests - if not self.static_batch or len(scheduled_requests) == 0: - if has_peft: - available_peft_pages -= claimed_peft_pages - - for requests in [pending_dis_gen_init_requests, pending_requests]: - for req in requests: - if (not self.static_batch and skipping_is_relevant - and not req.is_disagg_generation_init_state - and scheduler._beneficial_to_skip( - req, newly_contributed_context_blocks, - newly_contributed_cross_context_blocks)): - continue - - if len(scheduled_requests) >= scheduler.max_num_requests: - break + # Wait for all ranks to get context requests + return True - if req.is_context_init_state or req.is_disagg_generation_init_state: - enough_blocks = reserved_blocks.enough_available_blocks( - req) - enough_cross_blocks = True - if reserved_cross_blocks is not None: - enough_cross_blocks = reserved_cross_blocks.enough_available_blocks( - req) - if not enough_blocks or not enough_cross_blocks: - break +class CapacityChecker: + """ + Helper class for KV cache capacity checking. - # PEFT check only when needed - if has_peft: - lora_task_id, is_new_task, needed_peft_pages = scheduler._get_peft_task_info( - req, uniq_task_ids) - if needed_peft_pages > available_peft_pages: - continue - available_peft_pages -= needed_peft_pages - if is_new_task: - uniq_task_ids.add(lora_task_id) + Encapsulates all logic related to checking if requests fit in KV cache, + including block reuse optimization and policy-specific capacity checks. + """ - scheduled_requests.append(req) - reserved_blocks.decrement_reserved_blocks(req) - if reserved_cross_blocks is not None: - reserved_cross_blocks.decrement_reserved_blocks(req) + def __init__(self, kv_cache_manager, cross_kv_cache_manager, + scheduler_policy: CapacitySchedulerPolicy, + no_schedule_until_state_value, no_schedule_after_state_value): + """ + Initialize capacity checker. - return scheduled_requests, [] + Args: + kv_cache_manager: KV cache manager instance + cross_kv_cache_manager: Cross-attention KV cache manager (or None) + scheduler_policy: Capacity scheduling policy + no_schedule_until_state_value: Minimum state value for scheduling + no_schedule_after_state_value: Maximum state value for scheduling + """ + self.kv_cache_manager = kv_cache_manager + self.cross_kv_cache_manager = cross_kv_cache_manager + self.scheduler_policy = scheduler_policy + self._no_schedule_until_state_value = no_schedule_until_state_value + self._no_schedule_after_state_value = no_schedule_after_state_value + def can_be_scheduled(self, req: LlmRequest) -> bool: + """ + Check if request can be scheduled based on state value. + """ + # Use cached state values for performance + state_value = req.state_value + return (state_value >= self._no_schedule_until_state_value + and state_value < self._no_schedule_after_state_value) -class MaxUtilizationPolicy(SchedulerPolicyBase): - """ - MaxUtilizationScheduler: Maximize utilization, may pause started requests. - C++ reference: capacityScheduler.cpp:341-425 - """ + def is_skipping_relevant(self) -> bool: + """ + Check if block reuse skip optimization is relevant. + Disabled for VSWA (Variable Sliding Window Attention). + """ + if self.kv_cache_manager is None: + return False + if self.kv_cache_manager.is_variable_window: + return False + if (self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.is_variable_window): + return False + return True - def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: - scheduler.kv_cache_manager.start_scheduling() + def prefill_contributed_blocks( + self, active_requests: RequestList) -> tuple[set, set]: + """ + Collect blocks contributed by chunked context requests already executing. + These blocks can be reused by later requests. - skipping_is_relevant = scheduler._is_skipping_relevant() + Args: + active_requests: Currently active requests - scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( - scheduler.kv_cache_manager, scheduler.two_step_lookahead) + Returns: + Tuple of (context_blocks, cross_context_blocks) that can be reused + """ + newly_contributed_context_blocks: set = set() + newly_contributed_cross_context_blocks: set = set() - num_scheduled_peft_pages = 0 - seen_task_ids: set[int] = set() + if self.kv_cache_manager is None: + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks( - active_requests) + enable_block_reuse = self.kv_cache_manager.enable_block_reuse + cross_enable_reuse = (self.cross_kv_cache_manager is not None and + self.cross_kv_cache_manager.enable_block_reuse) - def is_started_request(req: LlmRequest) -> bool: - if not scheduler._can_be_scheduled(req): - return False - return ((req.is_context_init_state - and not req.is_first_context_chunk) - or req.is_generation_in_progress_state) + for req in active_requests: + # Check: isContextInitState() && !isFirstContextChunk() + if req.is_context_init_state and not req.is_first_context_chunk: + # Chunked context request already executing + if enable_block_reuse: + unique_tokens = req.get_unique_tokens(0) + block_key = self.kv_cache_manager.find_new_context_block( + unique_tokens, req) + if block_key is not None: + newly_contributed_context_blocks.add(block_key) - scheduled_requests: RequestList = [] - paused_requests: RequestList = [] + if cross_enable_reuse: + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + block_key = self.cross_kv_cache_manager.find_new_context_block( + encoder_unique_tokens, req) + if block_key is not None: + newly_contributed_cross_context_blocks.add( + block_key) - requests_list = list(active_requests) - req_it_end = len(requests_list) - req_it = 0 + return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - while req_it < req_it_end: - req = requests_list[req_it] - logger.debug( - f"MaxUtilizationScheduler: scheduling request ID {req.request_id}" - ) + def _one_manager_beneficial_to_skip(self, kv_cache_manager, unique_tokens, + req: LlmRequest, + newly_contributed_blocks: set) -> bool: + """Check if skipping is beneficial for one KV cache manager.""" + new_context_block = kv_cache_manager.find_new_context_block( + unique_tokens, req) + if new_context_block is not None: + if new_context_block in newly_contributed_blocks: + return True + return False - if not scheduler._can_be_scheduled_with_disagg_exception(req): - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - "cannot / should not be scheduled") - req_it += 1 - continue + def beneficial_to_skip(self, req: LlmRequest, + newly_contributed_context_blocks: set, + newly_contributed_cross_context_blocks: set) -> bool: + """ + Check if it's beneficial to skip this request. + A request should be skipped if it can reuse blocks contributed by + already scheduled context requests. - if (skipping_is_relevant and scheduler._beneficial_to_skip( - req, newly_contributed_context_blocks, set())): - req_it += 1 - continue + Args: + req: Request to check + newly_contributed_context_blocks: Blocks from active context requests + newly_contributed_cross_context_blocks: Cross-attention blocks from active requests - was_scheduled = self._try_scheduling_request( - scheduler, req, scheduled_requests, scheduled_blocks_manager, - num_scheduled_peft_pages, seen_task_ids) + Returns: + True if request should be skipped for block reuse optimization + """ + if not (req.is_context_init_state and req.is_first_context_chunk): + return False - if was_scheduled: - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} -> start" - ) - req_it += 1 - else: - last_started_idx = None - for i in range(req_it_end - 1, req_it - 1, -1): - if is_started_request(requests_list[i]): - last_started_idx = i - break + if (self.kv_cache_manager is not None + and self.kv_cache_manager.enable_block_reuse): + unique_tokens = req.get_unique_tokens(0) + if self._one_manager_beneficial_to_skip( + self.kv_cache_manager, unique_tokens, req, + newly_contributed_context_blocks): + return True - if last_started_idx is not None: - paused_req = requests_list[last_started_idx] - scheduler.kv_cache_manager.scheduling_remove_sequence( - paused_req.py_request_id) - paused_requests.append(paused_req) - logger.debug( - f"MaxUtilizationScheduler: request ID {paused_req.request_id} -> pause" - ) - req_it_end = last_started_idx - else: - break + if (self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse): + encoder_unique_tokens = req.get_encoder_unique_tokens() + if encoder_unique_tokens is not None: + if self._one_manager_beneficial_to_skip( + self.cross_kv_cache_manager, encoder_unique_tokens, req, + newly_contributed_cross_context_blocks): + return True - return scheduled_requests, paused_requests + return False - def _try_scheduling_request( - self, scheduler: 'PyCapacityScheduler', req: LlmRequest, - scheduled_requests: RequestList, - scheduled_blocks_manager: 'MaxUtilizationScheduledBlocksManager', - num_scheduled_peft_pages: int, seen_task_ids: set[int]) -> bool: - if len(scheduled_requests) >= scheduler.max_num_requests: - return False + def check_kv_capacity( + self, + req: LlmRequest, + scheduled_blocks_manager, + reserved_blocks, + reserved_cross_blocks, + simulation_mode: bool, + ) -> bool: + """ + Check if request fits in KV cache capacity. + Uses the appropriate block manager based on the scheduling policy. - blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable( - req) - if blocks_if_scheduled is None: - return False + Args: + req: Request to check + scheduled_blocks_manager: MaxUtilizationScheduledBlocksManager (or None) + reserved_blocks: NoEvictScheduledBlocksManager (or None) + reserved_cross_blocks: NoEvictScheduledBlocksManager for cross-attention (or None) + simulation_mode: If True, don't update block manager state - # PEFT check only when needed - if scheduler.peft_cache_manager is not None: - lora_task_id, is_new_task, num_required_peft_pages = scheduler._get_peft_task_info( - req, seen_task_ids) - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required peft pages: {num_required_peft_pages}") - max_peft_pages = scheduler._get_max_peft_pages() - if num_required_peft_pages + num_scheduled_peft_pages > max_peft_pages: + Returns: + True if request fits, False otherwise + """ + if self.scheduler_policy == CapacitySchedulerPolicy.MAX_REQUESTS: + # No KV cache manager, always fits + return True + elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + # Use MaxUtilizationScheduledBlocksManager + blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable( + req) + if blocks_if_scheduled is None: return False - logger.debug( - f"MaxUtilizationScheduler: scheduled peft pages: {num_required_peft_pages}" - ) - if is_new_task: - seen_task_ids.add(lora_task_id) - - scheduled_blocks_manager.update_scheduled_blocks(blocks_if_scheduled) - scheduled_requests.append(req) - return True + # Update state if not in simulation mode + if not simulation_mode: + scheduled_blocks_manager.update_scheduled_blocks( + blocks_if_scheduled) + return True + else: + # Use NoEvictScheduledBlocksManager (GUARANTEED_NO_EVICT or STATIC_BATCH) + if req.is_context_init_state: + enough_blocks = reserved_blocks.enough_available_blocks(req) + enough_cross_blocks = True + if reserved_cross_blocks is not None: + enough_cross_blocks = reserved_cross_blocks.enough_available_blocks( + req) + return enough_blocks and enough_cross_blocks + else: + # Generation requests always fit (blocks already reserved) + return True -class NoEvictScheduledBlocksManager: +class ChunkingManager: """ - Python equivalent of C++ kv_cache_manager::NoEvictScheduledBlocksManager. - Tracks available blocks per window size for GUARANTEED_NO_EVICT scheduling. + Helper class for context chunking management. - Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:29-62 + Encapsulates all logic related to chunking context requests to fit within + token budgets, including sorting, chunk size calculation, and draft token fitting. """ - def __init__(self, kv_cache_manager): - """ - Initialize with free blocks from KVCacheManager. - C++ equivalent: mAvailableBlocks = mKvCacheManager.getBlockManager().getNumFreeBlocksPerWindowSize() + def __init__(self, ctx_chunk_config, max_context_length): """ - self.kv_cache_manager = kv_cache_manager - stats = kv_cache_manager.get_kv_cache_stats() - self.available_blocks: dict[int, int] = dict( - stats.num_free_blocks_per_window_size) + Initialize chunking manager. - def decrement_reserved_blocks(self, req: LlmRequest) -> None: - """ - Decrement available blocks by the blocks needed to complete this request. - C++ reference: scheduledBlocksManager.h:40-46 + Args: + ctx_chunk_config: Context chunking configuration (policy + unit_size) + max_context_length: Maximum context length per request """ - for window_size in self.available_blocks: - needed = self.kv_cache_manager.get_remaining_blocks_to_completion( - req, window_size) - self.available_blocks[window_size] -= needed + self.ctx_chunk_config = ctx_chunk_config + self.max_context_length = max_context_length - def enough_available_blocks(self, req: LlmRequest) -> bool: + def sort_requests(self, context_requests: RequestList, + generation_requests: RequestList, + chunks_present: bool) -> None: """ - Check if there are enough available blocks for this request across all window sizes. - C++ reference: scheduledBlocksManager.h:48-57 + Sort requests for consistency with C++. + + 1. If chunks are present, move context requests that reached the last + context chunk to the end of the vector. + 2. Sort all requests by lora task id for performance. + + Args: + context_requests: Context requests list (modified in-place) + generation_requests: Generation requests list (modified in-place) + chunks_present: Whether chunking is active """ - return all( - self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= - avail for ws, avail in self.available_blocks.items()) + def get_lora_task_id(req: LlmRequest): + # C++ uses std::optional comparison where nullopt < any_value + # So requests without LoRA (nullopt) should come first + lora_id = getattr(req, 'lora_task_id', None) + if lora_id is None: + return (0, 0) # (has_value=False, value=0) - comes first + return (1, lora_id) # (has_value=True, value) - sorted by value + + if chunks_present: + # Partition: non-last-chunk first, last-chunk at end + not_last_chunk = [ + r for r in context_requests if not r.is_last_context_chunk + ] + last_chunk = [ + r for r in context_requests if r.is_last_context_chunk + ] + # Sort each group by lora_task_id + not_last_chunk.sort(key=get_lora_task_id) + last_chunk.sort(key=get_lora_task_id) + # Rebuild the list in-place + context_requests.clear() + context_requests.extend(not_last_chunk) + context_requests.extend(last_chunk) + else: + context_requests.sort(key=get_lora_task_id) + + generation_requests.sort(key=get_lora_task_id) + + def apply_chunking(self, requests: RequestList, + capacity: Optional[int]) -> None: + """ + Apply chunking to context requests based on chunking policy. + + Args: + requests: Context requests to chunk (modified in-place) + capacity: Available token capacity + """ + # C++: Resets all chunk sizes to 0 at start + for req in requests: + req.context_chunk_size = 0 + + policy = self.ctx_chunk_config.chunking_policy + unit_size = self.ctx_chunk_config.chunk_unit_size + + if policy == ChunkingPolicy.EQUAL_PROGRESS: + self._chunk_equal_progress(requests, capacity, unit_size) + elif policy == ChunkingPolicy.FIRST_COME_FIRST_SERVED: + self._chunk_fcfs(requests, capacity, unit_size) + else: + raise ValueError(f"Invalid chunking policy: {policy}") + + self._fit_draft_tokens(requests, capacity, unit_size) + + def _chunk_equal_progress(self, requests: RequestList, + capacity: Optional[int], unit_size: int): + """Apply equal progress chunking strategy.""" + num_ctx_tokens = 0 + num_tokens_single_loop = 1 + + # C++ Loop: while ((!capacity || numCtxTokens < capacity) && numTokensSingleLoop) + while (capacity is None + or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: + num_tokens_single_loop = 0 + for req in requests: + past_size = req.context_chunk_size + + # C++ logic: suggested = past + unit + suggested_size = past_size + unit_size + + # Ensure we don't exceed what the request actually needs + remaining_total = req.context_remaining_length + suggested_size = min(suggested_size, remaining_total) + + req.context_chunk_size = suggested_size + + actual_size = req.context_chunk_size + actual_increment = actual_size - past_size + + # Check Constraints + # 1. Capacity + if capacity is not None and (num_ctx_tokens + actual_increment + > capacity): + req.context_chunk_size = past_size # Revert + continue + + # 2. Max Context Length + if self.max_context_length is not None and actual_size > self.max_context_length: + req.context_chunk_size = past_size # Revert + continue + + num_ctx_tokens += actual_increment + num_tokens_single_loop += actual_increment + + def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], + unit_size: int): + """Apply first-come-first-served chunking strategy.""" + current_capacity = capacity if capacity is not None else float('inf') + + for req in requests: + suggested_size = req.context_remaining_length + actual_size = suggested_size + + # Apply unit size constraint + if unit_size > 0: + actual_size = (actual_size // unit_size) * unit_size + + # Apply capacity constraint + if actual_size > current_capacity: + actual_size = (int(current_capacity) // unit_size) * unit_size + + # Apply max context length constraint + if self.max_context_length is not None and actual_size > self.max_context_length: + actual_size = (self.max_context_length // unit_size) * unit_size + + req.context_chunk_size = actual_size + current_capacity -= actual_size + + if current_capacity <= 0: + break + + def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], + unit_size: int): + """Fit draft tokens into remaining capacity for chunked requests.""" + # Calculate tokens already taken by the batch so far + num_ctx_tokens = sum(req.context_chunk_size for req in requests) + + for req in requests: + if req.is_last_context_chunk and req.has_draft_tokens: + remainder = req.context_chunk_size % unit_size + remaining_space = 0 if remainder == 0 else unit_size - remainder + + if self.max_context_length is not None: + remaining_context_len = self.max_context_length - req.context_chunk_size + remaining_space = min(remaining_space, + remaining_context_len) + + if capacity is not None: + remaining_space = min(remaining_space, + capacity - num_ctx_tokens) + num_ctx_tokens += remaining_space + + draft_discard = req.num_draft_tokens - remaining_space + if draft_discard > 0: + logger.debug(f"Discarding {draft_discard} draft tokens") + + +@dataclass +class SchedulingState: + """ + State container for scheduling loop in _fused_schedule_request. + + Groups all state variables together to reduce parameter passing + and make the code more maintainable. + """ + # Block reuse optimization + skipping_is_relevant: bool + newly_contributed_context_blocks: set + newly_contributed_cross_context_blocks: set + + # PEFT/LoRA tracking + has_peft: bool + claimed_peft_pages: int + available_peft_pages: int + uniq_task_ids: set + + # Batch tracking + batch_num_tokens: int + scheduled_req_size: int + scheduled_beam_width: int + + # Output lists + context_requests: RequestList + generation_requests: RequestList + paused_requests: RequestList + fitting_disagg_gen_init: RequestList + + # Chunking state + contexts_to_be_chunked: RequestList + num_chunked_tokens: int + all_context_requests_fit: bool + + # Cached configuration (for faster access) + max_batch_size: int + max_num_tokens: Optional[int] + max_context_length: Optional[int] + ctx_chunk_config: Optional['ContextChunkingConfig'] + + +class AttentionDPMixin: + """ + Mixin providing optional attention_dp coordination functionality. + + Attention_dp is enabled if `dist` parameter is provided, disabled otherwise. + This makes it composable with any scheduler that extends SimpleUnifiedScheduler. + + When enabled, provides: + - Global coordination across TP ranks (GATHER → SIMULATE → COMMIT) + - Token-based load balancing via GlobalCoordinator + - Dummy request padding for collective operations + - Batching filters for synchronized scheduling + + When disabled (dist=None): + - Falls back to base scheduler implementation (TP-only) + - No performance overhead + + This mixin can be composed with disagg schedulers to provide + optional attention_dp support across disaggregated servers. + """ + + def __init__(self, + *args, + dist=None, + max_num_active_requests: Optional[int] = None, + **kwargs): + super().__init__(*args, **kwargs) + self.dist = dist + self.max_num_active_requests = max_num_active_requests + self.enable_global_scheduling = dist is not None and max_num_active_requests is not None + + if self.enable_global_scheduling: + self.global_coordinator = GlobalCoordinator( + self, dist, max_num_active_requests) + else: + self.global_coordinator = None + + def _activate_new_requests( + self, + active_requests: RequestList, + waiting_queue: Optional[deque], + cp_config: dict, + cp_rank: int, + cp_size: int, + exclude_last_generation_logits: bool, + ) -> tuple[RequestList, int]: + """Override to use global coordination if attention_dp enabled.""" + if waiting_queue is None or len(waiting_queue) == 0: + return [], len(active_requests) + + if self.enable_global_scheduling: + return self._activate_with_global_coordination( + active_requests, waiting_queue, cp_config, cp_rank, cp_size, + exclude_last_generation_logits) + else: + return super()._activate_new_requests( + active_requests, waiting_queue, cp_config, cp_rank, cp_size, + exclude_last_generation_logits) + + def schedule_iteration( + self, + waiting_queue: Optional[deque], + active_requests: RequestList, + inflight_req_ids: set[int], + drafter=None, + max_total_draft_tokens: int = 0, + use_spec_decode: bool = False, + cp_config: Optional[dict] = None, + cp_rank: int = 0, + cp_size: int = 1, + exclude_last_generation_logits: bool = False, + kv_cache_manager=None, + resource_manager=None, + ) -> UnifiedSchedulerOutput: + """Override to add dummy padding step if attention_dp enabled.""" + if waiting_queue: + new_requests, expected_num_active = self._activate_new_requests( + active_requests, waiting_queue, cp_config, cp_rank, cp_size, + exclude_last_generation_logits) + active_requests.extend(new_requests) + + if self.global_coordinator is not None: + self.global_coordinator.expected_num_active_requests = expected_num_active + + if self.enable_global_scheduling and kv_cache_manager is not None: + self._pad_dummy_requests( + active_requests=active_requests, + kv_cache_manager=kv_cache_manager, + resource_manager=resource_manager, + max_total_draft_tokens=max_total_draft_tokens, + ) + + return super().schedule_iteration( + waiting_queue=deque(), + active_requests=active_requests, + inflight_req_ids=inflight_req_ids, + drafter=drafter, + max_total_draft_tokens=max_total_draft_tokens, + use_spec_decode=use_spec_decode, + cp_config=cp_config, + cp_rank=cp_rank, + cp_size=cp_size, + exclude_last_generation_logits=exclude_last_generation_logits, + kv_cache_manager=kv_cache_manager, + resource_manager=resource_manager, + ) + + def _activate_with_global_coordination( + self, + active_requests: RequestList, + waiting_queue: deque, + cp_config: dict, + cp_rank: int, + cp_size: int, + exclude_last_generation_logits: bool, + ) -> tuple[RequestList, int]: + """ + Activate new requests using global coordination (attention_dp). + + This performs the full GATHER → SIMULATE → COMMIT flow to assign + new requests to ranks, then extracts assigned requests from waiting_queue. + + Args: + active_requests: Currently active requests + waiting_queue: Queue of waiting RequestQueueItems + cp_config: CP configuration dict + cp_rank: Current CP rank + cp_size: Total number of CP ranks + exclude_last_generation_logits: Whether to exclude last generation logits + + Returns: + Tuple of (new_llm_requests, expected_num_active_requests) + """ + # === PHASE 1: GATHER === + # Gather states first to know total active requests across all ranks + local_state = self.global_coordinator.build_local_state(active_requests) + all_rank_states = self.global_coordinator.gather_all_states(local_state) + + # Calculate total active requests across all ranks + total_num_active_requests = sum(state.current_batch_size + for state in all_rank_states) + + # Calculate how many new candidates we can accept + total_capacity = self.dist.tp_size * self.max_num_active_requests + num_new_candidates = max( + 0, + min(total_capacity - total_num_active_requests, len(waiting_queue))) + + if num_new_candidates == 0: + # No capacity for new requests + expected_num_active_requests = max(state.current_batch_size + for state in all_rank_states) + return [], expected_num_active_requests + + # Extract candidate requests + candidate_requests = list( + itertools.islice(waiting_queue, num_new_candidates)) + + # Convert candidate RequestQueueItems to LlmRequests ONCE + # These will be used for simulation AND execution (no recreation) + candidate_llm_requests = merge_requests( + candidate_requests, + cp_config=cp_config, + cp_rank=cp_rank, + cp_size=cp_size, + exclude_last_generation_logits=exclude_last_generation_logits) + + # Attach llm_request back to RequestQueueItem for simulation + # Note: merge_requests may create child requests, we need to map them back + llm_req_map = {} # request_id -> LlmRequest + for llm_req in candidate_llm_requests: + llm_req_map[llm_req.request_id] = llm_req + + for req_item in candidate_requests: + if req_item.id in llm_req_map: + req_item.llm_request = llm_req_map[req_item.id] + + # === PHASE 2: SIMULATE === + assignments = self.global_coordinator.simulate_global_schedule( + candidate_requests, all_rank_states) + + # === PHASE 2.5: BATCHING CHECK === + assignments = self.global_coordinator.apply_batching_filter( + assignments, candidate_requests) + + # Calculate expected_num_active_requests (max across all ranks after assignment) + # This uses data we already have from the allgather, no extra communication needed + expected_num_active_requests = max( + all_rank_states[rank_id].current_batch_size + + len(assignments[rank_id]) + for rank_id in range(len(all_rank_states))) + + # === PHASE 3: EXTRACT ASSIGNED LLMREQUESTS === + my_assigned_req_ids = set(assignments[self.dist.rank]) + assigned_llm_requests = [] + + # Convert to list to allow safe modification of waiting_queue + items_to_process = list(waiting_queue) + waiting_queue.clear() + + for req_item in items_to_process: + if (hasattr(req_item, 'llm_request') and req_item.llm_request + and req_item.llm_request.request_id in my_assigned_req_ids): + # Reuse the LlmRequest we created earlier ✅ (created only once!) + assigned_llm_requests.append(req_item.llm_request) + # Also add child requests if they exist + if req_item.llm_request.child_requests: + assigned_llm_requests.extend( + req_item.llm_request.child_requests) + else: + # Put back unassigned items + waiting_queue.append(req_item) + + return assigned_llm_requests, expected_num_active_requests + + def _pad_dummy_requests( + self, + active_requests: RequestList, + kv_cache_manager, + resource_manager, + max_total_draft_tokens: int, + ) -> None: + """ + Pad active_requests with dummy requests for attention_dp mode. + + This ensures all ranks have the same number of active requests for collective operations. + Only pads if this rank has zero non-disagg requests but others have active requests. + + Args: + active_requests: List of currently active requests (modified in-place) + kv_cache_manager: KV cache manager for creating dummy requests + resource_manager: Resource manager for spec decode support + max_total_draft_tokens: Maximum draft tokens for dummy requests + """ + if not self.enable_global_scheduling: + # Only for attention_dp mode + return + + # Get expected number of active requests from global coordinator + if not hasattr(self.global_coordinator, 'expected_num_active_requests'): + return + + expected_num = self.global_coordinator.expected_num_active_requests + + # Count non-disagg active requests + # Disagg requests in certain states don't count toward the active limit + num_active = len([ + req for req in active_requests + if not (req.is_disagg_generation_init_state + or req.is_disagg_generation_transmission_in_progress) + ]) + + # Pad only if this rank has 0 requests but others have requests + if expected_num - num_active > 0 and num_active == 0: + # Create one dummy generation request + dummy_req = kv_cache_manager.add_dummy_requests( + request_ids=[0], + is_gen=True, + prepare_resource=True, + max_num_draft_tokens=max_total_draft_tokens, + )[0] + dummy_req.is_attention_dp_dummy = True + + # Add to spec resource manager if present + if resource_manager is not None: + spec_resource_mgr = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) + if spec_resource_mgr is not None: + spec_resource_mgr.add_dummy_requests([0]) + + active_requests.append(dummy_req) + + +class SimpleUnifiedScheduler(RequestScheduler): + """ + Base unified scheduler with FUSED single-pass scheduling (TP-only). + + This scheduler combines capacity (KV cache) and micro-batch (token budget) + checks into a single efficient loop, eliminating the double work of the + traditional two-pass approach. + + Operational mode: + - TP-only mode: Local scheduling on this rank only + - Supports batch waiting optimization + - Uses fused single-pass scheduling + + For attention_dp support, use AttentionDPScheduler instead. + For disaggregated setups, use DisaggGenerationScheduler or DisaggContextScheduler. + + Fused Scheduling Architecture: + - Single loop checks both KV cache AND token budget together + - Direct resource access (no wrapper schedulers) + - Reuses block manager infrastructure + - Supports all capacity policies: MAX_UTILIZATION, GUARANTEED_NO_EVICT, STATIC_BATCH, MAX_REQUESTS + - Supports chunking: EQUAL_PROGRESS and FIRST_COME_FIRST_SERVED + + Performance benefits: + - Faster: Single-pass vs two-pass (30-50% speedup) + - Simpler: Eliminates PyCapacityScheduler and PyMicroBatchScheduler + - More correct: No simulation/execution divergence bugs + - Less memory: No duplicate state tracking + + Internal Dispatch: + - Automatically dispatches to AttentionDPScheduler if dist parameter is provided + - This allows callers to use SimpleUnifiedScheduler as the entry point + - The appropriate implementation is selected based on parameters + """ + + def __new__(cls, + *args, + dist=None, + max_num_active_requests=None, + server_role=None, + **kwargs): + """ + Factory method that dispatches to the appropriate scheduler implementation. + + Dispatch rules: + 1. If server_role == ServerRole.CONTEXT → DisaggContextScheduler + 2. If server_role == ServerRole.GENERATION → DisaggGenerationScheduler + 3. If dist is not None and max_num_active_requests is not None → AttentionDPScheduler + 4. Otherwise → SimpleUnifiedScheduler (base) + + This allows callers to use SimpleUnifiedScheduler as the entry point + without needing to know about the different implementations. + """ + # If being called on a subclass, use normal instantiation + if cls is not SimpleUnifiedScheduler: + return super().__new__(cls) + + # Import ServerRole for comparison + from tensorrt_llm.llmapi.disagg_utils import ServerRole + + # Dispatch based on server_role (disagg mode takes precedence) + if server_role == ServerRole.CONTEXT: + return object.__new__(DisaggContextScheduler) + elif server_role == ServerRole.GENERATION: + return object.__new__(DisaggGenerationScheduler) + # Dispatch based on attention_dp + elif dist is not None and max_num_active_requests is not None: + # Return AttentionDPScheduler instance + return object.__new__(AttentionDPScheduler) + else: + # Return base SimpleUnifiedScheduler instance (TP-only) + return super().__new__(cls) + + def __init__( + self, + max_batch_size: int, + max_num_tokens: int, + kv_cache_manager, + peft_cache_manager, + scheduler_policy: CapacitySchedulerPolicy, + ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, + cross_kv_cache_manager=None, + two_step_lookahead: bool = False, + scheduler_capacity: Optional[int] = None, + dist=None, # For internal dispatch (ignored in base, used by AttentionDPMixin) + max_num_active_requests: + Optional[ + int] = None, # For internal dispatch (ignored in base, used by AttentionDPMixin) + server_role=None, # For internal dispatch (ignored in base) + ): + # Note: dist and max_num_active_requests are accepted for API compatibility + # but ignored in the base class. They are used by AttentionDPMixin when + # __new__ returns an AttentionDPScheduler instance. + + # Use scheduler_capacity if provided, otherwise fall back to max_batch_size + # scheduler_capacity may differ from max_batch_size (e.g., adjusted for attention_dp + disagg) + capacity = scheduler_capacity if scheduler_capacity is not None else max_batch_size + + # Parse chunking config + py_chunk_config = None + if ctx_chunk_config: + # Fix: Use string comparison to identify the policy. + # This works regardless of whether the input is a Python Enum, C++ Binding Enum, or String. + input_policy = ctx_chunk_config[0] + + if "EQUAL_PROGRESS" in str(input_policy): + policy_enum = ChunkingPolicy.EQUAL_PROGRESS + else: + # Default to FCFS for FIRST_COME_FIRST_SERVED or others + policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED + + py_chunk_config = ContextChunkingConfig(policy_enum, + ctx_chunk_config[1]) + + # FUSED PATH: Single-pass scheduling (TP-only) + # Store resources directly for single-pass scheduling + # This eliminates the double work of capacity + micro-batch scheduling + self.kv_cache_manager = kv_cache_manager + self.cross_kv_cache_manager = cross_kv_cache_manager + self.peft_cache_manager = peft_cache_manager + self.max_batch_size = max_batch_size + self.max_num_tokens = max_num_tokens + self.max_num_requests = capacity + self.ctx_chunk_config = py_chunk_config + self.max_context_length = max_num_tokens + self.scheduler_policy = scheduler_policy + self.two_step_lookahead = two_step_lookahead + + # Cache state values for performance + self._no_schedule_until_state_value = LlmRequestState.CONTEXT_INIT.value + self._no_schedule_after_state_value = LlmRequestState.GENERATION_TO_COMPLETE.value + self._context_init_state_value = LlmRequestState.CONTEXT_INIT.value + self._encoder_init_state_value = LlmRequestState.ENCODER_INIT.value + + # Helper components + self.peft_helper = PeftHelper(peft_cache_manager) + + self.capacity_checker = CapacityChecker( + kv_cache_manager, cross_kv_cache_manager, scheduler_policy, + self._no_schedule_until_state_value, + self._no_schedule_after_state_value) + + self.chunking_manager = ChunkingManager( + py_chunk_config, max_num_tokens) if py_chunk_config else None + + # Batch waiting state (for TP-only mode) + # These track the waiting logic for batch waiting in TP-only mode + # Will be configured by PyExecutor if needed + self.batch_wait_timeout_iters = 0 + self.batch_wait_max_tokens_ratio = 0.0 + self.enable_batch_waiting = False + self.batch_wait_iters_count = 0 + + def schedule_iteration( + self, + waiting_queue: Optional[deque], + active_requests: RequestList, + inflight_req_ids: set[int], + # Drafter integration + drafter=None, + max_total_draft_tokens: int = 0, + use_spec_decode: bool = False, + # Context parallelism + cp_config: Optional[dict] = None, + cp_rank: int = 0, + cp_size: int = 1, + exclude_last_generation_logits: bool = False, + # Resource managers (for padding dummy requests) + kv_cache_manager=None, + resource_manager=None, + ) -> UnifiedSchedulerOutput: + """ + Complete end-to-end scheduling for one iteration (TP-only). + + This is the unified entry point that handles: + 1. Activating new requests from waiting queue + 2. Drafter integration (speculative decoding) + 3. Batch scheduling + 4. Post-processing + + For attention_dp support with dummy padding, use AttentionDPScheduler instead. + """ + + # Step 1: Activate new requests (local only in base) + if waiting_queue: + new_requests, _ = self._activate_new_requests( + active_requests, waiting_queue, cp_config, cp_rank, cp_size, + exclude_last_generation_logits) + active_requests.extend(new_requests) + + # Step 2: Integrate drafter (if provided) + if drafter is not None: + self._integrate_drafter(active_requests, max_total_draft_tokens, + use_spec_decode) + + # Step 3: Schedule batch + output = self.schedule_request(active_requests, inflight_req_ids) + + # Step 4: Post-process drafter + if drafter is not None and not use_spec_decode: + self._postprocess_drafter(output) + + return output + + def _activate_new_requests( + self, + active_requests: RequestList, + waiting_queue: Optional[deque], + cp_config: dict, + cp_rank: int, + cp_size: int, + exclude_last_generation_logits: bool, + ) -> tuple[RequestList, int]: + """ + Activate new requests from waiting queue (TP-only mode). + + Args: + active_requests: Currently active requests + waiting_queue: Queue of waiting RequestQueueItems + cp_config: CP configuration dict + cp_rank: Current CP rank + cp_size: Total number of CP ranks + exclude_last_generation_logits: Whether to exclude last generation logits + + Returns: + Tuple of (new_llm_requests, expected_num_active_requests) + """ + if waiting_queue is None or len(waiting_queue) == 0: + return [], len(active_requests) + + # TP-only mode: Activate requests locally + return self._activate_local(active_requests, waiting_queue, cp_config, + cp_rank, cp_size, + exclude_last_generation_logits) + + def _schedule_generation_only_during_waiting( + self, + active_requests: RequestList, + inflight_request_ids: set[int], + ) -> Optional[UnifiedSchedulerOutput]: + """ + Proactive optimization: Schedule only generation requests when in waiting mode. + + This avoids expensive context request scheduling when we're already waiting + for more generation requests to accumulate. + + Args: + active_requests: Currently active requests + inflight_request_ids: Set of inflight request IDs + + Returns: + UnifiedSchedulerOutput if still waiting (with empty context_requests), + None if should exit waiting mode and run normal scheduling + """ + # Split requests by type + generation_requests_only = [ + r for r in active_requests if not r.is_context_init_state + ] + + # Check if we have generation requests to avoid dead waiting + if len(generation_requests_only) == 0: + # No generation requests, stop waiting to avoid dead lock + self.batch_wait_iters_count = 0 + return None # Exit to normal path + + # Only schedule generation requests (skip expensive context scheduling) + # Use fused scheduler + result = self._fused_schedule_request(generation_requests_only, + inflight_request_ids) + + # Check if we should stop waiting + num_gen_tokens = sum( + self._estimate_tokens_needed(gen_req) + for gen_req in result.generation_requests) + + max_num_tokens = self.max_num_tokens + if max_num_tokens is not None: + # Check if we've timed out or have enough generation tokens + should_stop_waiting = ( + self.batch_wait_iters_count >= self.batch_wait_timeout_iters + or num_gen_tokens + >= self.batch_wait_max_tokens_ratio * max_num_tokens) + + if should_stop_waiting: + # Stop waiting, next iteration will schedule context requests + self.batch_wait_iters_count = 0 + return None # Exit to normal path + else: + # Continue waiting + self.batch_wait_iters_count += 1 + else: + # No token budget limit, stop waiting + self.batch_wait_iters_count = 0 + return None # Exit to normal path + + # Return with empty context requests (still waiting) + # DESIGN NOTE: Context requests in active_requests are intentionally NOT + # scheduled and NOT added to paused_requests. They remain in active_requests + # and will be scheduled in a future iteration when batch waiting stops. + # This deferred scheduling is the core of batch waiting optimization: + # we delay expensive context processing to accumulate more generation requests. + return UnifiedSchedulerOutput( + context_requests=[], # Intentionally empty (deferred, not paused) + generation_requests=result.generation_requests, + paused_requests=result.paused_requests, # Only paused generation + fitting_disagg_gen_init_requests=result. + fitting_disagg_gen_init_requests, + num_fitting_requests=result.num_fitting_requests, + updated_active_requests=None, + ) + + def _apply_batch_waiting( + self, + context_requests: RequestList, + generation_requests: RequestList, + ) -> RequestList: + """ + Apply batch waiting logic for TP-only mode. + + Return an empty list if scheduled requests fulfill the waiting conditions, + otherwise return the original context requests. + + Waiting conditions: + - The number of scheduled tokens (both context and generation) is smaller than + `self.batch_wait_max_tokens_ratio * self.max_num_tokens` + - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters` + + Args: + context_requests: Scheduled context requests + generation_requests: Scheduled generation requests + + Returns: + Empty list if should wait, otherwise original context_requests + """ + # Skip if batch waiting is not enabled + if not self.enable_batch_waiting: + return context_requests + + # Skip if no context requests to wait for + if len(context_requests) == 0: + return context_requests + + # Skip if no generation requests (to avoid dead waiting) + if len(generation_requests) == 0: + self.batch_wait_iters_count = 0 + return context_requests + + # Calculate scheduled tokens + num_scheduled_ctx_tokens = sum( + self._estimate_tokens_needed(ctx_req) + for ctx_req in context_requests) + num_scheduled_gen_tokens = sum( + self._estimate_tokens_needed(gen_req) + for gen_req in generation_requests) + num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + # Get max_num_tokens + max_num_tokens = self.max_num_tokens + if max_num_tokens is None: + # No token budget limit, cannot apply batch waiting + return context_requests + + # Check waiting conditions + should_waiting = (self.batch_wait_iters_count + < self.batch_wait_timeout_iters + and num_scheduled_tokens + < self.batch_wait_max_tokens_ratio * max_num_tokens) + + if should_waiting: + self.batch_wait_iters_count += 1 + return [] + + self.batch_wait_iters_count = 0 + return context_requests + + def schedule_request( + self, + active_requests: RequestList, + inflight_request_ids: set[int], + ) -> UnifiedSchedulerOutput: + """ + Schedule requests for execution. + + This method handles capacity scheduling (KV cache allocation) and + micro-batch scheduling (token budget + chunking). + + For TP-only mode, applies batch waiting logic if enabled. + For attention_dp support, use AttentionDPScheduler which handles batching during activation. + + Args: + active_requests: Currently active requests + inflight_request_ids: Set of inflight request IDs + + Returns: + UnifiedSchedulerOutput with scheduled requests + """ + # FUSED PATH: Single-pass scheduling (TP-only) + # Proactive optimization: + # If we're already in waiting mode, skip context scheduling to save computation + if (self.enable_batch_waiting and self.batch_wait_iters_count > 0): + # Try generation-only scheduling (optimization path) + result = self._schedule_generation_only_during_waiting( + active_requests, inflight_request_ids) + if result is not None: + # Still waiting, return early with empty context + return result + # Otherwise, exit waiting mode and fall through to normal path + + # Use fused single-pass scheduling + result = self._fused_schedule_request(active_requests, + inflight_request_ids) + + # Apply batch waiting (TP-only) + result.context_requests = self._apply_batch_waiting( + result.context_requests, result.generation_requests) + + return result + + # ========== Helper methods for schedule_iteration ========== + + def _integrate_drafter( + self, + active_requests: RequestList, + max_total_draft_tokens: int, + use_spec_decode: bool, + ) -> None: + """ + Integrate drafter (speculative decoding) with active requests. + + Sets up draft tokens for generation requests based on whether + speculation is enabled. + + Args: + active_requests: List of currently active requests (modified in-place) + max_total_draft_tokens: Maximum draft tokens for this iteration + use_spec_decode: Whether speculation is enabled + """ + try: + for req in active_requests: + # Only process generation requests + if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS, + LlmRequestState.DISAGG_GENERATION_INIT): + continue + + # Save previous draft tokens + req.py_last_draft_tokens = req.py_draft_tokens + + # Set up draft tokens based on spec decode state + if (max_total_draft_tokens > 0 and use_spec_decode + and not req.py_disable_speculative_decoding): + req.py_draft_tokens = [0] * max_total_draft_tokens + req.py_draft_pages_allocated = max_total_draft_tokens + else: + req.py_draft_tokens = [] + req.py_draft_pages_allocated = 0 + + except Exception as e: + logger.error(f"Error in _integrate_drafter: {e}") + raise + + def _postprocess_drafter(self, + scheduled_output: UnifiedSchedulerOutput) -> None: + """ + Post-process scheduled batch for drafter. + + If spec decode is disabled, mark all scheduled requests accordingly. + + Args: + scheduled_output: Scheduled output to post-process (modified in-place) + """ + # Mark all requests as having spec decode disabled + all_requests = (scheduled_output.context_requests + + scheduled_output.generation_requests) + for request in all_requests: + request.py_disable_speculative_decoding = True + + # ========== Helper methods for _fused_schedule_request ========== + + def _initialize_block_managers( + self, simulation_mode: bool + ) -> tuple[Optional['MaxUtilizationScheduledBlocksManager'], + Optional['NoEvictScheduledBlocksManager'], + Optional['NoEvictScheduledBlocksManager']]: + """ + Initialize block managers based on scheduling policy. + + Args: + simulation_mode: If True, skip start_scheduling call + + Returns: + Tuple of (scheduled_blocks_manager, reserved_blocks, reserved_cross_blocks) + """ + scheduled_blocks_manager = None + reserved_blocks = None + reserved_cross_blocks = None + + if self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: + if not simulation_mode: + self.kv_cache_manager.start_scheduling() + scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( + self.kv_cache_manager, self.two_step_lookahead) + elif self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT or \ + self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: + reserved_blocks = NoEvictScheduledBlocksManager( + self.kv_cache_manager) + if self.cross_kv_cache_manager is not None: + reserved_cross_blocks = NoEvictScheduledBlocksManager( + self.cross_kv_cache_manager) + + return scheduled_blocks_manager, reserved_blocks, reserved_cross_blocks + + def _initialize_scheduling_state(self, active_requests: RequestList, + has_peft: bool) -> SchedulingState: + """ + Initialize scheduling state for _fused_schedule_request. + + Args: + active_requests: Currently active requests + has_peft: Whether PEFT is enabled + + Returns: + SchedulingState with initialized values + """ + # Block reuse optimization + skipping_is_relevant = self.capacity_checker.is_skipping_relevant() + newly_contributed_context_blocks: set = set() + newly_contributed_cross_context_blocks: set = set() + if skipping_is_relevant: + newly_contributed_context_blocks, newly_contributed_cross_context_blocks = \ + self.capacity_checker.prefill_contributed_blocks(active_requests) + + # PEFT/LoRA state + claimed_peft_pages = 0 + available_peft_pages = self.peft_helper.get_max_pages( + ) if has_peft else 0 + uniq_task_ids: set[int] = set() if has_peft else None + + return SchedulingState( + skipping_is_relevant=skipping_is_relevant, + newly_contributed_context_blocks=newly_contributed_context_blocks, + newly_contributed_cross_context_blocks= + newly_contributed_cross_context_blocks, + has_peft=has_peft, + claimed_peft_pages=claimed_peft_pages, + available_peft_pages=available_peft_pages, + uniq_task_ids=uniq_task_ids, + batch_num_tokens=0, + scheduled_req_size=0, + scheduled_beam_width=0, + context_requests=[], + generation_requests=[], + paused_requests=[], + fitting_disagg_gen_init=[], + contexts_to_be_chunked=[], + num_chunked_tokens=0, + all_context_requests_fit=True, + max_batch_size=self.max_batch_size, + max_num_tokens=self.max_num_tokens, + max_context_length=self.max_context_length, + ctx_chunk_config=self.ctx_chunk_config, + ) + + def _schedule_in_progress_generation( + self, + active_requests: RequestList, + state: SchedulingState, + reserved_blocks: 'NoEvictScheduledBlocksManager', + reserved_cross_blocks: Optional['NoEvictScheduledBlocksManager'], + simulation_mode: bool, + ) -> None: + """ + Schedule in-progress generation requests (GUARANTEED_NO_EVICT policy only). + + These must be scheduled first to free up reserved blocks. + Updates state in-place. + + Args: + active_requests: All active requests + state: Current scheduling state (modified in-place) + reserved_blocks: Reserved blocks manager + reserved_cross_blocks: Reserved cross-attention blocks manager (or None) + simulation_mode: If True, skip block updates + """ + for req in active_requests: + if not self.capacity_checker.can_be_scheduled_with_disagg_exception( + req): + continue + + if len(state.context_requests) + len( + state.generation_requests) >= self.max_num_requests: + break + + if req.is_generation_in_progress_state: + # Check token budget + req_num_tokens = self._estimate_tokens_needed(req) + + if state.max_num_tokens is not None and ( + state.batch_num_tokens + req_num_tokens + > state.max_num_tokens): + state.paused_requests.append(req) + continue + + # Fits! Schedule it + state.generation_requests.append(req) + state.batch_num_tokens += req_num_tokens + state.scheduled_req_size += 1 + + if not simulation_mode: + reserved_blocks.decrement_reserved_blocks(req) + if reserved_cross_blocks is not None: + reserved_cross_blocks.decrement_reserved_blocks(req) + + # Track PEFT + if state.has_peft: + lora_task_id, is_new_task, peft_pages = self.peft_helper.get_task_info( + req, state.uniq_task_ids) + if is_new_task: + state.claimed_peft_pages += peft_pages + state.uniq_task_ids.add(lora_task_id) + + # Update available PEFT pages + if state.has_peft: + state.available_peft_pages -= state.claimed_peft_pages + + def _should_schedule_request( + self, req: LlmRequest, inflight_request_ids: set[int], + reserved_blocks: Optional['NoEvictScheduledBlocksManager']) -> bool: + """ + Check if request should be considered for scheduling. + + Args: + req: Request to check + inflight_request_ids: Set of already in-flight request IDs + reserved_blocks: Reserved blocks manager (or None) + + Returns: + True if request should be processed, False if should skip + """ + # Skip inflight requests + if req.request_id in inflight_request_ids: + return False + + # Skip requests not in schedulable state range + req_state_value = req.state_value + if not (req_state_value >= self._no_schedule_until_state_value + and req_state_value < self._no_schedule_after_state_value): + return False + + # Skip in-progress generation (already handled for GUARANTEED_NO_EVICT) + if reserved_blocks is not None and req.is_generation_in_progress_state: + return False + + return True + + def _check_batch_limits(self, state: SchedulingState) -> bool: + """ + Check if batch limits are reached. + + Args: + state: Current scheduling state + + Returns: + True if can continue scheduling, False if limits reached + """ + # Check batch size limit + if state.scheduled_req_size >= state.max_batch_size: + return False + + # Check request count limit + if len(state.context_requests) + len( + state.generation_requests) >= self.max_num_requests: + return False + + return True + + def _finalize_chunking(self, state: SchedulingState) -> None: + """ + Apply chunking to queued context requests and finalize. + + Updates state in-place by moving chunked requests to context_requests. + + Args: + state: Current scheduling state (modified in-place) + """ + if not state.contexts_to_be_chunked: + return + + # Verify chunking fits + if state.max_num_tokens is not None and state.num_chunked_tokens > ( + state.max_num_tokens - state.batch_num_tokens): + state.all_context_requests_fit = False + + # Apply chunking + remaining_capacity = (state.max_num_tokens - state.batch_num_tokens + ) if state.max_num_tokens is not None else None + self.chunking_manager.apply_chunking(state.contexts_to_be_chunked, + remaining_capacity) + + # Finalize chunked requests + for req in state.contexts_to_be_chunked: + if req.context_chunk_size > 0: + state.context_requests.append(req) + draft_tokens = req.num_draft_tokens if ( + req.is_last_context_chunk and req.has_draft_tokens) else 0 + state.batch_num_tokens += req.context_chunk_size + draft_tokens + else: + state.paused_requests.append(req) + + def _build_scheduler_output( + self, state: SchedulingState) -> UnifiedSchedulerOutput: + """ + Build final scheduler output from scheduling state. + + Args: + state: Final scheduling state + + Returns: + UnifiedSchedulerOutput with scheduled requests + """ + # Sort requests for consistency + chunks_present = state.ctx_chunk_config is not None + if self.chunking_manager and chunks_present: + self.chunking_manager.sort_requests(state.context_requests, + state.generation_requests, + chunks_present) + + # Return results + num_fitting = len(state.context_requests) + len( + state.generation_requests) + return UnifiedSchedulerOutput( + context_requests=state.context_requests, + generation_requests=state.generation_requests, + paused_requests=state.paused_requests, + fitting_disagg_gen_init_requests=state.fitting_disagg_gen_init, + num_fitting_requests=num_fitting, + updated_active_requests=None, + ) + + def _check_peft_capacity(self, req, state) -> bool: + """ + Check if request can fit in PEFT cache and update state if it can. + + Args: + req: Request to check + state: Current scheduling state + + Returns: + True if can fit (or no PEFT), False if insufficient PEFT capacity + """ + if not state.has_peft: + return True # No PEFT, always fits + + lora_task_id, is_new_task, needed_peft_pages = self.peft_helper.get_task_info( + req, state.uniq_task_ids) + + if needed_peft_pages > state.available_peft_pages: + return False # Insufficient PEFT capacity + + # Sufficient capacity - update state + if is_new_task: + state.available_peft_pages -= needed_peft_pages + state.uniq_task_ids.add(lora_task_id) + + return True + + def _should_skip_for_block_reuse(self, req, state) -> bool: + """ + Check if request should be skipped for block reuse optimization. + + Can be overridden by subclasses to add additional skip conditions. + + Args: + req: Request to check + state: Current scheduling state + + Returns: + True if request should be skipped, False otherwise + """ + return (state.skipping_is_relevant + and self.capacity_checker.beneficial_to_skip( + req, state.newly_contributed_context_blocks, + state.newly_contributed_cross_context_blocks)) + + def _try_handle_special_request(self, req, state, scheduled_blocks_manager, + reserved_blocks, reserved_cross_blocks, + simulation_mode) -> tuple[bool, bool]: + """ + Hook for subclasses to handle special request types. + + Base implementation does nothing (no special requests). + Subclasses can override to add custom request handling (e.g., disagg_generation_init). + + Args: + req: Request to potentially handle + state: Current scheduling state + scheduled_blocks_manager: Block manager for scheduling + reserved_blocks: Reserved blocks (for GUARANTEED_NO_EVICT) + reserved_cross_blocks: Reserved cross attention blocks + simulation_mode: Whether in simulation mode + + Returns: + Tuple of (handled, should_break): + - handled: True if request was handled by this method, False otherwise + - should_break: True if scheduling loop should break, False if should continue + """ + return False, False # Base implementation: no special requests + + def _fused_schedule_request( + self, + active_requests: RequestList, + inflight_request_ids: set[int], + simulation_mode: bool = False, + ) -> UnifiedSchedulerOutput: + """ + Fused single-pass scheduling combining capacity and micro-batch checks. + + This method merges the two-pass approach (capacity → micro-batch) into a single + loop that checks both KV cache capacity and token budget together. This eliminates + redundant work and improves performance for global coordination mode. + + Args: + active_requests: Currently active requests to schedule + inflight_request_ids: Set of request IDs already in flight + simulation_mode: If True, only check feasibility without allocating blocks + (used for global coordination simulation) + + Returns: + UnifiedSchedulerOutput with scheduled requests + + Note: + Subclasses can customize behavior by overriding hook methods: + - _should_skip_for_block_reuse(): Customize skip optimization logic + - _try_handle_special_request(): Add handling for custom request types + """ + # Initialize block managers + scheduled_blocks_manager, reserved_blocks, reserved_cross_blocks = \ + self._initialize_block_managers(simulation_mode) + + # Initialize scheduling state + state = self._initialize_scheduling_state( + active_requests, self.peft_cache_manager is not None) + + # For GUARANTEED_NO_EVICT: Schedule in-progress generation first + if reserved_blocks is not None: + self._schedule_in_progress_generation(active_requests, state, + reserved_blocks, + reserved_cross_blocks, + simulation_mode) + + # MAIN SCHEDULING LOOP: Fused capacity + token budget checking + for req in active_requests: + req_state_value = req.state_value + + # Filtering checks + if not self._should_schedule_request(req, inflight_request_ids, + reserved_blocks): + continue + + # Batch limit checks + if not self._check_batch_limits(state): + state.paused_requests.append(req) + break + + # Try special request handling first (hook for subclasses) + handled, should_break = self._try_handle_special_request( + req, state, scheduled_blocks_manager, reserved_blocks, + reserved_cross_blocks, simulation_mode) + if handled: + if should_break: + break + else: + continue + + # Block reuse skip optimization + if self._should_skip_for_block_reuse(req, state): + continue + + # --- A. Encoder Request Handling --- + if req_state_value == self._encoder_init_state_value: + req_num_tokens = self._estimate_tokens_needed(req) + + assert state.max_context_length is None or req_num_tokens <= state.max_context_length, \ + f"The number of encoder tokens ({req_num_tokens}) exceeds the limit value ({state.max_context_length})" + + # Check token budget + if state.max_num_tokens is not None and ( + state.batch_num_tokens + req_num_tokens + > state.max_num_tokens): + state.paused_requests.append(req) + break + + # Check KV cache capacity + can_fit_kv = self.capacity_checker.check_kv_capacity( + req, scheduled_blocks_manager, reserved_blocks, + reserved_cross_blocks, simulation_mode) + if not can_fit_kv: + state.paused_requests.append(req) + break + + # Check PEFT capacity + if not self._check_peft_capacity(req, state): + state.paused_requests.append(req) + continue # Continue to next request (not break) + + # Fits! Schedule it + state.context_requests.append(req) + state.batch_num_tokens += req_num_tokens + state.scheduled_req_size += 1 + + # --- B. Context Request Handling --- + elif req_state_value == self._context_init_state_value: + if not state.ctx_chunk_config: + # No chunking: schedule full context + req_num_tokens = self._estimate_tokens_needed(req) + + assert state.max_context_length is None or req_num_tokens <= state.max_context_length, \ + f"The number of context tokens ({req_num_tokens}) exceeds the limit value ({state.max_context_length})" + + # Check token budget + if state.max_num_tokens is not None and ( + state.batch_num_tokens + req_num_tokens + > state.max_num_tokens): + state.paused_requests.append(req) + break + + # Check KV cache capacity + can_fit_kv = self.capacity_checker.check_kv_capacity( + req, scheduled_blocks_manager, reserved_blocks, + reserved_cross_blocks, simulation_mode) + if not can_fit_kv: + state.paused_requests.append(req) + break + + # Check PEFT capacity + if not self._check_peft_capacity(req, state): + state.paused_requests.append(req) + continue # Continue to next request (not break) + + # Fits! Schedule it + state.context_requests.append(req) + state.batch_num_tokens += req_num_tokens + state.scheduled_req_size += 1 + else: + # Chunking enabled: tentative schedule + # Check KV cache capacity first + can_fit_kv = self.capacity_checker.check_kv_capacity( + req, scheduled_blocks_manager, reserved_blocks, + reserved_cross_blocks, simulation_mode) + if not can_fit_kv: + state.paused_requests.append(req) + break + + # Check PEFT capacity + if not self._check_peft_capacity(req, state): + state.paused_requests.append(req) + continue # Continue to next request (not break) + + # Add to chunking queue + req.context_chunk_size = req.context_remaining_length + + draft_tokens = req.num_draft_tokens if ( + req.is_last_context_chunk + and req.has_draft_tokens) else 0 + req_num_tokens = req.context_chunk_size + draft_tokens + + if state.max_context_length is not None: + if state.max_context_length < req_num_tokens: + req_num_tokens = state.max_context_length + state.all_context_requests_fit = False + + state.contexts_to_be_chunked.append(req) + state.num_chunked_tokens += req_num_tokens + state.scheduled_req_size += 1 + + # --- C. Generation Request Handling --- + else: + # Regular generation request + req_num_tokens = self._estimate_tokens_needed(req) + beam_width = req.get_beam_width_by_iter( + for_next_iteration=False) + + # Check token budget + if state.max_num_tokens is not None and ( + state.batch_num_tokens + req_num_tokens + > state.max_num_tokens): + state.paused_requests.append(req) + break -class MaxUtilizationScheduledBlocksManager: - """ - Python equivalent of C++ kv_cache_manager::MaxUtilizationScheduledBlocksManager. - Tracks scheduled blocks per window size for MAX_UTILIZATION scheduling. + # Check PEFT capacity + if not self._check_peft_capacity(req, state): + state.paused_requests.append(req) + continue # Continue to next request (not break) - Reference: cpp/tensorrt_llm/batch_manager/scheduledBlocksManager.h:64-117 - """ + # Beam width consistency check + if state.scheduled_beam_width == 0: + state.scheduled_beam_width = beam_width + elif state.scheduled_beam_width != beam_width: + logger.debug( + f"generation request skipped: ID {req.request_id} since its " + f"beam width ({beam_width}) is different from scheduled ones " + f"({state.scheduled_beam_width})") + continue - def __init__(self, kv_cache_manager, two_steps_look_ahead: bool): - """ - Initialize scheduled blocks count per window size. - C++ equivalent: iterate windowSizes and set mNumScheduledBlocks[windowSize] = 0 - """ - self.kv_cache_manager = kv_cache_manager - self.two_steps_look_ahead = two_steps_look_ahead - window_sizes = set(kv_cache_manager.max_attention_window_vec) - self.num_scheduled_blocks: dict[int, int] = { - ws: 0 - for ws in window_sizes - } + # Fits! Schedule it + state.generation_requests.append(req) + state.batch_num_tokens += req_num_tokens + state.scheduled_req_size += 1 - def prepare_blocks_if_schedulable( - self, req: LlmRequest) -> Optional[dict[int, int]]: + # Apply chunking if needed + if state.contexts_to_be_chunked: + self._finalize_chunking(state) + + # Build and return output + return self._build_scheduler_output(state) + + def can_schedule(self, requests: RequestList) -> bool: """ - Check if request can be scheduled and return new block counts if so. - Returns None if request cannot fit. - C++ reference: scheduledBlocksManager.h:80-100 + Check if all requests can be scheduled (dry run). + Uses fused scheduler in simulation mode. """ - blocks_if_scheduled = {} - for window_size, num_scheduled in self.num_scheduled_blocks.items(): - required = self.kv_cache_manager.get_needed_blocks_one_step( - req, self.two_steps_look_ahead, window_size) - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required blocks {required} for {window_size} window size") - scheduled_total = num_scheduled + required - has_free = self.kv_cache_manager.scheduling_has_free_blocks( - scheduled_total, window_size) - if not has_free: - return None - blocks_if_scheduled[window_size] = scheduled_total - return blocks_if_scheduled - - def update_scheduled_blocks(self, blocks: dict[int, int]) -> None: + # Use fused scheduler in simulation mode + result = self._fused_schedule_request(requests, + set(), + simulation_mode=True) + scheduled_count = len(result.context_requests) + len( + result.generation_requests) + len( + result.fitting_disagg_gen_init_requests) + return scheduled_count == len(requests) + + # ========== Estimation methods for global coordination ========== + # These methods provide resource estimation for global coordination, + # working with both fused and traditional scheduling paths + + def _estimate_tokens_needed(self, request: LlmRequest) -> int: """ - Update the scheduled blocks after successfully scheduling a request. - C++ reference: scheduledBlocksManager.h:102-110 + Estimate how many tokens this request will consume in the next step. + + OPTIMIZATION: For pre-validated requests (passed simulation in GlobalCoordinator), + use cached estimate to avoid recalculation (~30-40% speedup for new requests). + + Args: + request: The request to estimate for + + Returns: + int: Number of tokens needed for next iteration """ - assert len(blocks) == len(self.num_scheduled_blocks), \ - f"Block count mismatch: {len(blocks)} vs {len(self.num_scheduled_blocks)}" - for window_size, blocks_if_scheduled in blocks.items(): - logger.debug( - f"MaxUtilizationScheduler: scheduled blocks {blocks_if_scheduled} " - f"for window size {window_size}") - self.num_scheduled_blocks[window_size] = blocks_if_scheduled + # Fast path: Use cached estimate if available + if request.py_pre_validated and request.py_estimated_tokens > 0: + return request.py_estimated_tokens + # Slow path: Calculate from scratch + state_value = request.state_value -class PyCapacityScheduler: - """ - Python implementation of the C++ CapacityScheduler. - Aligned 1:1 with C++ logic in cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp. - Supports Multiple Window Sizes (VSWA), block reuse optimization, and all policies. + # Encoder tokens + if state_value == self._encoder_init_state_value: + return request.encoder_output_len - Policies: - - MaxRequestsScheduler: No KV cache manager, simple request count limit - - GuaranteedNoEvictScheduler: Reserve blocks for completion, no eviction - - StaticBatchScheduler: Only schedule when no requests are active - - MaxUtilizationScheduler: Maximize utilization, may pause requests + # Context tokens + elif state_value == self._context_init_state_value: + base_tokens = request.get_num_tokens(0) + draft_tokens = request.num_draft_tokens if request.has_draft_tokens else 0 + return base_tokens + draft_tokens - Reference: cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h - """ + # Generation tokens + else: + beam_width = request.get_beam_width_by_iter( + for_next_iteration=False) + draft_tokens = request.num_draft_tokens if request.has_draft_tokens else 0 + return beam_width + draft_tokens - def __init__( - self, - max_num_requests: int, - kv_cache_manager=None, - peft_cache_manager=None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. - GUARANTEED_NO_EVICT, - cross_kv_cache_manager=None, - two_step_lookahead: bool = False, - no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState. - GENERATION_COMPLETE, - ): + def _estimate_blocks_needed(self, request: LlmRequest) -> int: """ - Initialize the capacity scheduler. + Estimate how many KV cache blocks this request will consume in the next step. + + OPTIMIZATION: For pre-validated requests (passed simulation in GlobalCoordinator), + use cached estimate to avoid recalculation (~30-40% speedup for new requests). Args: - max_num_requests: Maximum number of requests to schedule - kv_cache_manager: KV cache manager (None for MaxRequestsScheduler) - peft_cache_manager: PEFT/LoRA cache manager (optional) - scheduler_policy: Scheduling policy - cross_kv_cache_manager: Cross-attention KV cache manager for encoder-decoder - two_step_lookahead: Enable two-step lookahead for MAX_UTILIZATION - no_schedule_until_state: Don't schedule until this state is reached - no_schedule_after_state: Don't schedule after this state is reached - """ - self.max_num_requests = max_num_requests - self.kv_cache_manager = kv_cache_manager - self.peft_cache_manager = peft_cache_manager - self.cross_kv_cache_manager = cross_kv_cache_manager - self.scheduler_policy = scheduler_policy - self.two_step_lookahead = two_step_lookahead - self.no_schedule_until_state = no_schedule_until_state - self.no_schedule_after_state = no_schedule_after_state - # Cache state values to avoid repeated .value access (optimization) - self._no_schedule_until_state_value = no_schedule_until_state.value - self._no_schedule_after_state_value = no_schedule_after_state.value + request: The request to estimate for - # Initialize the appropriate policy - self._policy = self._create_policy() + Returns: + int: Number of blocks needed (worst-case for VSWA) + """ + # Fast path: Use cached estimate if available + if request.py_pre_validated and request.py_estimated_blocks > 0: + return request.py_estimated_blocks - def _create_policy(self) -> SchedulerPolicyBase: - """Create the appropriate policy based on configuration.""" + # Slow path: Calculate from scratch if self.kv_cache_manager is None: - return MaxRequestsPolicy() - elif self.scheduler_policy == CapacitySchedulerPolicy.MAX_UTILIZATION: - return MaxUtilizationPolicy() - elif self.scheduler_policy == CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: - return GuaranteedNoEvictPolicy(static_batch=False) - elif self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: - return GuaranteedNoEvictPolicy(static_batch=True) + return 0 + + # For VSWA, check all window sizes and return worst-case (maximum) + if hasattr(self.kv_cache_manager, 'is_variable_window' + ) and self.kv_cache_manager.is_variable_window: + max_blocks = 0 + for window_size_key in self.kv_cache_manager.get_window_size_keys(): + blocks = self.kv_cache_manager.get_num_required_blocks( + request, window_size_key) + max_blocks = max(max_blocks, blocks) + return max_blocks else: - raise ValueError( - f"Unsupported scheduler policy: {self.scheduler_policy}") + # Standard case: single window size + return self.kv_cache_manager.get_num_required_blocks(request) - def _can_be_scheduled(self, req: LlmRequest) -> bool: - """ - Check if request is within the schedulable state range. - Returns True if request has reached no_schedule_until_state - but has not yet reached no_schedule_after_state. - Optimized: use state_value property to avoid enum object creation + def _calculate_current_token_load(self, + active_requests: RequestList) -> int: """ - # Use state_value property (returns int directly, avoids enum object creation) - state_value = req.state_value - # Inline comparison: must have reached until_state but not after_state - return (state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value) + Calculate total tokens consumed by current active requests. + + Args: + active_requests: List of currently active requests - def _is_skipping_relevant(self) -> bool: + Returns: + int: Total token count """ - Check if block reuse skip optimization is relevant. - Disabled for VSWA (Variable Sliding Window Attention). - C++ reference: capacityScheduler.cpp:207-208, 348 + total_tokens = 0 + for req in active_requests: + # Only count schedulable requests + state_value = req.state_value + if (state_value >= self._no_schedule_until_state_value + and state_value < self._no_schedule_after_state_value): + total_tokens += self._estimate_tokens_needed(req) + return total_tokens + + def _activate_local( + self, + active_requests: RequestList, + waiting_queue: deque, + cp_config: dict, + cp_rank: int, + cp_size: int, + exclude_last_generation_logits: bool, + ) -> tuple[RequestList, int]: """ - if self.kv_cache_manager is None: - return False - if self.kv_cache_manager.is_variable_window: - return False - if (self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.is_variable_window): - return False - return True + Activate new requests locally (TP-only mode, no global coordination). - def _prefill_contributed_blocks( - self, active_requests: RequestList) -> tuple[set, set]: - """ - Collect blocks contributed by chunked context requests already executing. - These blocks can be reused by later requests. + This method handles request activation for the base SimpleUnifiedScheduler + which operates in TP-only mode without attention_dp. + + Args: + active_requests: Currently active requests on this rank + waiting_queue: Queue of waiting RequestQueueItems + cp_config: CP configuration dict + cp_rank: Current CP rank + cp_size: Total number of CP ranks + exclude_last_generation_logits: Whether to exclude last generation logits - C++ reference: capacityScheduler.cpp:34-68 (prefillWithChunkedContextsAlreadyExecuting) + Returns: + Tuple of (new_llm_requests, expected_num_active_requests) """ - newly_contributed_context_blocks: Set = set() - newly_contributed_cross_context_blocks: Set = set() + # Calculate local capacity + # Use max_num_requests as fallback when max_num_active_requests is unset + max_active = self.max_num_active_requests if self.max_num_active_requests is not None else self.max_num_requests + max_new_requests = max(0, max_active - len(active_requests)) + + if max_new_requests == 0: + return [], len(active_requests) + + # Pop requests from waiting queue (local capacity only) + new_request_items = [] + for _ in range(min(max_new_requests, len(waiting_queue))): + if len(waiting_queue) == 0: + break + new_request_items.append(waiting_queue.popleft()) - if self.kv_cache_manager is None: - return newly_contributed_context_blocks, newly_contributed_cross_context_blocks + if len(new_request_items) == 0: + return [], len(active_requests) - enable_block_reuse = self.kv_cache_manager.enable_block_reuse - cross_enable_reuse = (self.cross_kv_cache_manager is not None and - self.cross_kv_cache_manager.enable_block_reuse) + # Convert RequestQueueItems to LlmRequests (ONLY ONCE) + new_llm_requests = merge_requests( + new_request_items, + cp_config=cp_config, + cp_rank=cp_rank, + cp_size=cp_size, + exclude_last_generation_logits=exclude_last_generation_logits) - for req in active_requests: - # Check: isContextInitState() && !isFirstContextChunk() - if req.is_context_init_state and not req.is_first_context_chunk: - # Chunked context request already executing - if enable_block_reuse: - unique_tokens = req.get_unique_tokens(0) - block_key = self.kv_cache_manager.find_new_context_block( - unique_tokens, req) - if block_key is not None: - newly_contributed_context_blocks.add(block_key) + # For TP-only mode, expected_num_active_requests is local count + expected_num_active_requests = len(active_requests) + len( + new_llm_requests) - if cross_enable_reuse: - encoder_unique_tokens = req.get_encoder_unique_tokens() - if encoder_unique_tokens is not None: - block_key = self.cross_kv_cache_manager.find_new_context_block( - encoder_unique_tokens, req) - if block_key is not None: - newly_contributed_cross_context_blocks.add( - block_key) + return new_llm_requests, expected_num_active_requests - return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - def _one_manager_beneficial_to_skip(self, kv_cache_manager, unique_tokens, - req: LlmRequest, - newly_contributed_blocks: set) -> bool: - """ - Check if skipping is beneficial for one KV cache manager. - C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) - """ - new_context_block = kv_cache_manager.find_new_context_block( - unique_tokens, req) - if new_context_block is not None: - if new_context_block in newly_contributed_blocks: - return True - newly_contributed_blocks.add(new_context_block) - return False +class AttentionDPScheduler(AttentionDPMixin, SimpleUnifiedScheduler): + """ + Unified scheduler with attention_dp support. + + Use this for: + - Single server with attention_dp across TP ranks + - No disaggregation + + Parameters: + - dist: Required - DistributionManager for attention_dp + - max_num_active_requests: Required - Maximum active requests across all ranks + - ... (other SimpleUnifiedScheduler parameters) + + Example: + scheduler = AttentionDPScheduler( + max_batch_size=64, + max_num_tokens=2048, + kv_cache_manager=kv_cache_mgr, + peft_cache_manager=None, + scheduler_policy=CapacitySchedulerPolicy.MAX_UTILIZATION, + dist=dist, # Required + max_num_active_requests=128, # Required + ) + """ + + def __init__(self, *args, dist, max_num_active_requests, **kwargs): + if dist is None: + raise ValueError("AttentionDPScheduler requires dist parameter") + if max_num_active_requests is None: + raise ValueError( + "AttentionDPScheduler requires max_num_active_requests parameter" + ) + super().__init__(*args, + dist=dist, + max_num_active_requests=max_num_active_requests, + **kwargs) + + +class DisaggGenerationScheduler(AttentionDPMixin, SimpleUnifiedScheduler): + """ + Unified scheduler for disaggregated generation server. + + Use this for: + - Generation/decoding server in disaggregated serving + - Handles generation requests and disagg_generation_init requests only + - Optional attention_dp (controlled by `dist` parameter) + + Examples: + # With attention_dp + scheduler = DisaggGenerationScheduler( + ..., + dist=dist, + max_num_active_requests=128, + kv_cache_transceiver=transceiver, + ) + + # Without attention_dp (TP-only) + scheduler = DisaggGenerationScheduler( + ..., + dist=None, + kv_cache_transceiver=transceiver, + ) + """ - def _beneficial_to_skip( - self, req: LlmRequest, newly_contributed_context_blocks: set, - newly_contributed_cross_context_blocks: set) -> bool: + def _should_schedule_request(self, req, inflight_request_ids: set[int], + reserved_blocks) -> bool: """ - Check if it's beneficial to skip this request. - A request should be skipped if it can reuse blocks contributed by - already scheduled context requests. + Filter requests for generation server. + + Generation server only handles: + - Generation requests (in-progress decoding) + - Disagg generation init requests (receiving KV from context server) - C++ reference: capacityScheduler.cpp:97-123 (beneficialToSkip) + Skips: + - Encoder requests + - Context requests """ - if not (req.is_context_init_state and req.is_first_context_chunk): + # Skip inflight requests + if req.request_id in inflight_request_ids: return False - if (self.kv_cache_manager is not None - and self.kv_cache_manager.enable_block_reuse): - unique_tokens = req.get_unique_tokens(0) - if self._one_manager_beneficial_to_skip( - self.kv_cache_manager, unique_tokens, req, - newly_contributed_context_blocks): - return True + # Skip in-progress generation (already handled for GUARANTEED_NO_EVICT) + if reserved_blocks is not None and req.is_generation_in_progress_state: + return False - if (self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.enable_block_reuse): - encoder_unique_tokens = req.get_encoder_unique_tokens() - if encoder_unique_tokens is not None: - if self._one_manager_beneficial_to_skip( - self.cross_kv_cache_manager, encoder_unique_tokens, req, - newly_contributed_cross_context_blocks): - return True + # Generation server specific filtering + req_state_value = req.state_value - return False + # Allow disagg generation init (bypass normal state range check) + if req.is_disagg_generation_init_state: + return True - def _get_max_peft_pages(self) -> int: - """Get maximum PEFT cache pages.""" - if self.peft_cache_manager is None: - return 2**31 - 1 # INT_MAX equivalent - return self.peft_cache_manager.max_device_pages + # Allow generation requests (within normal state range) + if req_state_value == self._generation_in_progress_state_value: + # Check normal state range + if (req_state_value >= self._no_schedule_until_state_value + and req_state_value < self._no_schedule_after_state_value): + return True - def _get_peft_pages_for_request(self, req: LlmRequest) -> int: - """Get PEFT pages needed for a request.""" - if self.peft_cache_manager is None: - return 0 - return self.peft_cache_manager.determine_num_pages(req) + # Skip encoder and context requests (handled by context server) + return False - def _get_peft_task_info( - self, req: LlmRequest, - seen_task_ids: set[int]) -> tuple[Optional[int], bool, int]: + def _finalize_chunking(self, state) -> None: """ - Get PEFT task information for a request. - Returns (lora_task_id, is_new_task, required_pages). + No-op for generation server - chunking only happens on context server. """ - lora_task_id = getattr(req, 'lora_task_id', None) - is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids - required_pages = self._get_peft_pages_for_request( - req) if is_new_task else 0 - return lora_task_id, is_new_task, required_pages - def _can_be_scheduled_with_disagg_exception(self, req: LlmRequest) -> bool: + def _should_skip_for_block_reuse(self, req, state) -> bool: """ - Check if request can be scheduled, with exception for disagg generation init state. - Disagg generation init requests bypass the normal state gating. + Override to exclude disagg_generation_init from skip optimization. + + Disagg gen init requests should always be considered for scheduling. """ if req.is_disagg_generation_init_state: - return True - return self._can_be_scheduled(req) + return False + return super()._should_skip_for_block_reuse(req, state) - def schedule_request( - self, active_requests: RequestList - ) -> tuple[RequestList, RequestList, RequestList]: + def _try_handle_special_request(self, req, state, scheduled_blocks_manager, + reserved_blocks, reserved_cross_blocks, + simulation_mode) -> tuple[bool, bool]: """ - Schedule requests based on the configured policy. + Handle disagg_generation_init requests. Args: - active_requests: List of active requests to consider + req: Request to potentially handle + state: Current scheduling state + scheduled_blocks_manager: Block manager for scheduling + reserved_blocks: Reserved blocks (for GUARANTEED_NO_EVICT) + reserved_cross_blocks: Reserved cross attention blocks + simulation_mode: Whether in simulation mode Returns: - Tuple of (fitting_requests, fitting_disagg_gen_init_requests, paused_requests) - - C++ reference: capacityScheduler.cpp:488-539 (CapacityScheduler::operator()) + Tuple of (handled, should_break): + - handled: True if disagg_gen_init request was processed + - should_break: True if scheduling should stop (capacity issue) """ - scheduled, paused = self._policy.schedule(self, active_requests) - - fitting_requests, fitting_disagg_gen_init_requests = self._classify_output( - scheduled) - - logger.debug( - f"[Summary] Capacity scheduler allows {len(fitting_requests)} requests, " - f"pauses {len(paused)} requests") + # Handle disagg_generation_init requests + if req.is_disagg_generation_init_state: + # Check KV cache capacity + can_fit_kv = self.capacity_checker.check_kv_capacity( + req, scheduled_blocks_manager, reserved_blocks, + reserved_cross_blocks, simulation_mode) + if not can_fit_kv: + state.paused_requests.append(req) + return True, True # Handled, should break + + # Check PEFT capacity (using shared helper) + if not self._check_peft_capacity(req, state): + state.paused_requests.append(req) + return True, False # Handled, continue to next request + + # Fits! Add to disagg gen init list + state.fitting_disagg_gen_init.append(req) + return True, False # Handled, continue to next request + + return False, False # Not a special request, use normal handling + + def _check_batch_limits(self, state: 'SchedulingState') -> bool: + """ + Check if batch limits are reached (with disagg_gen_init included). - return fitting_requests, fitting_disagg_gen_init_requests, paused + Args: + state: Current scheduling state - def _classify_output( - self, - scheduled_requests: RequestList) -> tuple[RequestList, RequestList]: - """ - Separate scheduled requests into normal requests and disagg gen init requests. - C++ reference: capacityScheduler.cpp:522-534 + Returns: + True if can continue scheduling, False if limits reached """ - fitting_requests: RequestList = [] - fitting_disagg_gen_init_requests: RequestList = [] - for req in scheduled_requests: - if req.is_disagg_generation_init_state: - fitting_disagg_gen_init_requests.append(req) - else: - fitting_requests.append(req) - return fitting_requests, fitting_disagg_gen_init_requests + # Check batch size limit + if state.scheduled_req_size >= state.max_batch_size: + return False + # Check request count limit (include disagg_gen_init for generation server) + if len(state.context_requests) + len(state.generation_requests) + len( + state.fitting_disagg_gen_init) >= self.max_num_requests: + return False -class SimpleUnifiedScheduler(RequestScheduler): + return True - def __init__( - self, - max_batch_size: int, - max_num_tokens: int, - kv_cache_manager, - peft_cache_manager, - scheduler_policy: CapacitySchedulerPolicy, - ctx_chunk_config: Optional[tuple[StrEnum, int]] = None, - cross_kv_cache_manager=None, - two_step_lookahead: bool = False, - scheduler_capacity: Optional[int] = None, - ): - # Use scheduler_capacity if provided, otherwise fall back to max_batch_size - # scheduler_capacity may differ from max_batch_size (e.g., adjusted for attention_dp + disagg) - capacity = scheduler_capacity if scheduler_capacity is not None else max_batch_size + def _build_scheduler_output( + self, state: 'SchedulingState') -> 'UnifiedSchedulerOutput': + """ + Build final scheduler output with disagg_generation_init requests included. - # 1. Initialize Python Capacity Scheduler - # Now fully aligned with C++ CapacityScheduler - self.capacity_scheduler = PyCapacityScheduler( - max_num_requests=capacity, - kv_cache_manager=kv_cache_manager, - peft_cache_manager=peft_cache_manager, - scheduler_policy=scheduler_policy, - cross_kv_cache_manager=cross_kv_cache_manager, - two_step_lookahead=two_step_lookahead) + Overrides base method to correctly count disagg_gen_init requests in num_fitting. - # 2. Initialize Python MicroBatch Scheduler - py_chunk_config = None - if ctx_chunk_config: - # Fix: Use string comparison to identify the policy. - # This works regardless of whether the input is a Python Enum, C++ Binding Enum, or String. - input_policy = ctx_chunk_config[0] + Args: + state: Final scheduling state - if "EQUAL_PROGRESS" in str(input_policy): - policy_enum = ChunkingPolicy.EQUAL_PROGRESS - else: - # Default to FCFS for FIRST_COME_FIRST_SERVED or others - policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED + Returns: + UnifiedSchedulerOutput with scheduled requests including disagg + """ + # Sort requests for consistency + chunks_present = state.ctx_chunk_config is not None + if self.chunking_manager and chunks_present: + self.chunking_manager.sort_requests(state.context_requests, + state.generation_requests, + chunks_present) + + # Return results (include disagg_gen_init in num_fitting count) + num_fitting = len(state.context_requests) + len( + state.generation_requests) + len(state.fitting_disagg_gen_init) + return UnifiedSchedulerOutput( + context_requests=state.context_requests, + generation_requests=state.generation_requests, + paused_requests=state.paused_requests, + fitting_disagg_gen_init_requests=state.fitting_disagg_gen_init, + num_fitting_requests=num_fitting, + updated_active_requests=None, + ) + + +class DisaggContextScheduler(AttentionDPMixin, SimpleUnifiedScheduler): + """ + Unified scheduler for disaggregated context server. + + Use this for: + - Context/prefill server in disaggregated serving + - Handles encoder and context requests only + - Supports chunking + - Optional attention_dp (controlled by `dist` parameter) + + Examples: + # With attention_dp + scheduler = DisaggContextScheduler( + ..., + dist=dist, + max_num_active_requests=128, + kv_cache_transceiver=transceiver, + ) + + # Without attention_dp (TP-only) + scheduler = DisaggContextScheduler( + ..., + dist=None, + kv_cache_transceiver=transceiver, + ) + """ - py_chunk_config = ContextChunkingConfig(policy_enum, - ctx_chunk_config[1]) + def _should_schedule_request(self, req, inflight_request_ids: set[int], + reserved_blocks) -> bool: + """ + Filter requests for context server. - self.micro_batch_scheduler = PyMicroBatchScheduler( - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - ctx_chunk_config=py_chunk_config) + Context server only handles: + - Encoder requests (e.g., vision encoder) + - Context requests (prefill/prompt processing) - def schedule_request(self, active_requests: RequestList, - inflight_request_ids: set[int]) -> SchedulerOutput: - # Step 1: Capacity Check (Who fits in memory?) - fitting_requests, fitting_disagg_gen_init, paused_requests = \ - self.capacity_scheduler.schedule_request(active_requests) + Skips: + - Generation requests + - Disagg generation init requests (handled by generation server) + """ + # First, apply base filtering + if not super()._should_schedule_request(req, inflight_request_ids, + reserved_blocks): + return False - # Step 2: MicroBatch Check (Who fits in token budget? + Chunking) - context_requests, generation_requests = \ - self.micro_batch_scheduler.schedule(fitting_requests, inflight_request_ids) + # Context server specific filtering + req_state_value = req.state_value - return SchedulerOutput( - context_requests=context_requests, - generation_requests=generation_requests, - paused_requests=paused_requests, - fitting_disagg_gen_init_requests=fitting_disagg_gen_init, - num_fitting_requests=len(fitting_requests)) + # Allow encoder requests + if req_state_value == self._encoder_init_state_value: + return True - def can_schedule(self, requests: RequestList) -> bool: - # Dry run capacity check - fitting, _, _ = self.capacity_scheduler.schedule_request(requests) - return len(fitting) == len(requests) + # Allow context requests + if req_state_value == self._context_init_state_value: + return True + + # Skip generation and disagg_generation_init requests (handled by generation server) + return False diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 76cbde9646f8..331b1533b9f5 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -684,6 +684,9 @@ def serve( raise ValueError(f"Invalid server role: {server_role}. " \ f"Must be one of: {', '.join([role.name for role in ServerRole])}") + # Pass server_role to LLM args for scheduler selection + llm_args['server_role'] = server_role + # Parse media_io_kwargs from JSON string to dict if provided parsed_media_io_kwargs = None if media_io_kwargs is not None: