Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# Disable UCC to WAR allgather issue before NGC PyTorch 25.12 upgrade.
os.environ["OMPI_MCA_coll_ucc_enable"] = "0"
os.environ["TLLM_USE_PYTHON_SCHEDULER"] = "1"


def _add_trt_llm_dll_directory():
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1

# Extract server_role from llm_args for scheduler selection
server_role = getattr(llm_args, 'server_role', None)

use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
if use_python_scheduler and not isinstance(kv_cache_manager,
KVCacheManagerV2):
Expand All @@ -873,7 +876,10 @@ def create_py_executor_instance(
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
scheduler_capacity=scheduler_capacity,
dist=dist,
max_num_active_requests=model_engine.get_max_num_sequences(),
server_role=server_role)
else:
if isinstance(kv_cache_manager, KVCacheManagerV2):
capacity_scheduler = KVCacheV2DummyScheduler(
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,13 @@ def __init__(
additional_outputs=additional_outputs)
self.child_requests = []

# Pre-validation cache for attention_dp optimization
# When a request passes simulation in GlobalCoordinator.can_accept_request(),
# we cache the estimated tokens/blocks to avoid recalculating in _fused_schedule_request()
self.py_pre_validated: bool = False
self.py_estimated_tokens: int = 0
self.py_estimated_blocks: int = 0

self._py_embedding_bias_1d: Optional[torch.Tensor] = None
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
# Pre-squeeze to 1D if needed (remove batch dimension)
Expand Down
251 changes: 246 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
SampleStateTensors, TRTLLMSampler)
from .scheduler import (RequestScheduler, ScheduledRequests,
SerializableSchedulerOutput)
SerializableSchedulerOutput, SimpleUnifiedScheduler)

# Environment variable to specify iteration ranges for profiling start/stop.
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
Expand Down Expand Up @@ -354,6 +354,29 @@ def __init__(self,
self.max_input_len = max_input_len
# _executor_loop private data
self.max_num_active_requests = model_engine.get_max_num_sequences()

# Configure SimpleUnifiedScheduler and AttentionDPScheduler
if isinstance(scheduler, SimpleUnifiedScheduler):
# Configure batch waiting (for TP-only mode)
scheduler.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters
scheduler.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio
scheduler.enable_batch_waiting = (
scheduler.batch_wait_timeout_iters > 0
or scheduler.batch_wait_max_tokens_ratio > 0)

# If this is AttentionDPScheduler, configure attention_dp parameters
if hasattr(scheduler, 'enable_global_scheduling'
) and scheduler.enable_global_scheduling:
# Configure batching/waiting parameters for attention_dp
scheduler.global_coordinator.attention_dp_enable_balance = self.attention_dp_enable_balance
if self.attention_dp_enable_balance:
scheduler.global_coordinator.attention_dp_time_out_iters = self.attention_dp_time_out_iters
scheduler.global_coordinator.attention_dp_batching_wait_iters = self.attention_dp_batching_wait_iters

logger.info(
"Configured global scheduling for attention_dp (balance=%s)",
self.attention_dp_enable_balance)

self.active_requests: List[LlmRequest] = []
self.expected_num_active_requests = 0
self.async_transfer_manager = AsyncTransferManager(
Expand Down Expand Up @@ -1411,22 +1434,160 @@ def _can_queue(self, scheduled_batch):
return can_queue, can_queue_this_rank

def _prepare_and_schedule_batch(self):
new_requests = self._fetch_and_activate_new_requests()
"""Prepare and schedule batch for execution."""
# Use new unified scheduler interface if available
if isinstance(self.scheduler, SimpleUnifiedScheduler) and hasattr(
self.scheduler, 'schedule_iteration'):
return self._prepare_and_schedule_batch_v2()
else:
# Fallback to old path for backward compatibility
return self._prepare_and_schedule_batch_v1()

def _prepare_and_schedule_batch_v2(self):
"""
Prepare and schedule batch using new unified scheduler interface.

SIMPLIFIED: Most scheduling logic moved to scheduler.schedule_iteration()
"""
# Step 1: Check stop condition
if self.should_stop_processing:
return None, None

# Step 2: Check KV cache transfer status (disagg mode - executor responsibility)
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()

# Step 3: Calculate iter_stats (executor responsibility)
iter_stats = None
if self.enable_iter_perf_stats:
num_new_in_queue = len(self.waiting_queue)
queue_latency = self._get_new_active_requests_queue_latency()
iter_stats = self._get_init_iter_stats(num_new_in_queue,
queue_latency)

# Step 4: Fetch new requests from external queue (executor responsibility)
self._fetch_and_enqueue_requests(self.waiting_queue,
self.expected_num_active_requests)

# Step 5: Validate new requests (executor responsibility)
# This happens inside schedule_iteration via _activate_new_requests,
# but we need to extract and validate them
waiting_queue_size_before = len(self.waiting_queue)

# Step 6: Drafter pre-processing (calculate draft tokens for this iteration)
max_total_draft_tokens = 0
use_spec_decode = False

if self.drafter is not None:
# Calculate draft length based on batch size
if self.drafter.draft_len_schedule is not None:
batch_size_input = len(self.active_requests)
max_total_draft_tokens = self.drafter.get_draft_len_for_batch_size(
batch_size_input)
self.drafter.update_max_total_draft_tokens(
max_total_draft_tokens)

# Determine if spec decode should be used
if self.drafter.draft_len_schedule is not None and max_total_draft_tokens == 0:
use_spec_decode = False
elif getattr(self, 'speculation_permanently_disabled', False):
use_spec_decode = False
else:
use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.llm_args.max_num_tokens,
max_total_draft_tokens)

logger.debug(f"Use spec decode: {use_spec_decode}")
self.model_engine.enable_spec_decode = use_spec_decode
self.use_spec_decode = use_spec_decode
self.max_total_draft_tokens = max_total_draft_tokens

# Set up draft_tokens in active_requests for scheduler awareness
for request in self.active_requests:
if request.state not in (
LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
request.draft_tokens = [
0
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []

# Step 7: ✨ SINGLE CALL TO SCHEDULER ✨
scheduler_output = self.scheduler.schedule_iteration(
waiting_queue=self.waiting_queue,
active_requests=self.active_requests,
inflight_req_ids=self.inflight_req_ids,
drafter=self.drafter,
max_total_draft_tokens=max_total_draft_tokens,
use_spec_decode=use_spec_decode,
cp_config=self.dist.cp_config,
cp_rank=self.dist.cp_rank,
cp_size=self.dist.cp_size,
exclude_last_generation_logits=self.
_should_exclude_last_generation_logits(),
kv_cache_manager=self.kv_cache_manager,
resource_manager=self.resource_manager,
)

# Convert to ScheduledRequests
scheduled_batch = scheduler_output.to_scheduled_requests()

# Validate new requests that were activated
# (they're already in active_requests after schedule_iteration)
waiting_queue_size_after = len(self.waiting_queue)
num_new_activated = waiting_queue_size_before - waiting_queue_size_after

# Step 8: Disagg mode post-processing (executor responsibility)
if self.kv_cache_transceiver:
# Prepare disagg gen init resources
if scheduler_output.fitting_disagg_gen_init_requests:
self._prepare_disagg_gen_init(
scheduler_output.fitting_disagg_gen_init_requests)

# Warning if no requests fit
if scheduler_output.num_fitting_requests == 0 and not scheduler_output.fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self._check_disagg_ctx_cache_transfer_status(1)

# Step 9: Update state and return
self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_requests, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests')

return scheduled_batch, iter_stats

def _prepare_and_schedule_batch_v1(self):
"""
OLD PATH: Prepare and schedule batch using old interface.

Kept for backward compatibility with SimpleScheduler.
"""
# Step 1: Fetch and activate new requests
num_new_requests = self._fetch_and_activate_requests()
if self.should_stop_processing:
return None, None

# Step 2: Check KV cache transfer status
if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()

# Step 3: Calculate iter_stats
iter_stats = None
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self._get_new_active_requests_queue_latency())
num_new_requests, self._get_new_active_requests_queue_latency())

# Step 4: Pad dummy requests
self._pad_attention_dp_dummy_request()

# Step 5: Drafter logic
if self.drafter is not None:
# Honor permanent disable flag based on rolling acceptance first
if self.drafter.draft_len_schedule is not None:
Expand Down Expand Up @@ -1469,9 +1630,11 @@ def _prepare_and_schedule_batch(self):
# that speculation is about to happen.
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
# Step 6: Schedule batch
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule_batch(
)

# Step 7: Post-processing
if self.drafter is not None and not self.use_spec_decode:
for request in scheduled_batch.all_requests():
request.py_disable_speculative_decoding = True
Expand All @@ -1493,6 +1656,84 @@ def _prepare_and_schedule_batch(self):
f'{len(scheduled_batch.generation_requests)} generation requests')
return scheduled_batch, iter_stats

def _validate_new_requests(
self, new_requests: List[LlmRequest]) -> List[LlmRequest]:
"""
Validate new requests and handle errors for invalid ones.

Args:
new_requests: List of new requests to validate

Returns:
List of validated requests (invalid ones are removed and errors are handled)
"""
validated_requests = []
for request in new_requests:
try:
self._validate_request(request)
validated_requests.append(request)
except Exception as e:
self._handle_errors(str(e), requests=[request])
return validated_requests

def _fetch_and_activate_requests(self):
"""
Fetch and activate new requests.

Returns:
int: Number of newly activated requests
"""
if isinstance(self.scheduler, SimpleUnifiedScheduler):
# SimpleUnifiedScheduler path: Works for both attention_dp and TP-only modes
# Fetch and enqueue requests from executor queue
self._fetch_and_enqueue_requests(self.waiting_queue,
self.expected_num_active_requests)

# Activate new requests through scheduler
# Scheduler returns LlmRequests (already converted, only once)
# For attention_dp: expected_num_active_requests is max across all ranks
# For TP-only: expected_num_active_requests is local count
new_llm_requests, self.expected_num_active_requests = \
self.scheduler._activate_new_requests(
self.active_requests,
self.waiting_queue,
self.dist.cp_config,
self.dist.cp_rank,
self.dist.cp_size,
self._should_exclude_last_generation_logits()
)

# Validate and add new requests to active_requests
validated_new_requests = self._validate_new_requests(
new_llm_requests)
self.active_requests.extend(validated_new_requests)

return len(validated_new_requests)
else:
# SimpleScheduler path
new_requests = self._fetch_and_activate_new_requests()
return len(new_requests)

def _schedule_batch(self):
"""
Schedule the batch using the appropriate scheduler.

Returns:
tuple: (scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs)
"""
if isinstance(self.scheduler, SimpleUnifiedScheduler):
# SimpleUnifiedScheduler path: Works for both attention_dp and TP-only modes
# - For attention_dp: Batching done during activation via _apply_batching_filter()
# - For TP-only: Batching done during scheduling via _apply_batch_waiting()
scheduler_output = self.scheduler.schedule_request(
self.active_requests, self.inflight_req_ids)
return (scheduler_output.to_scheduled_requests(),
scheduler_output.fitting_disagg_gen_init_requests,
scheduler_output.num_fitting_requests)
else:
# SimpleScheduler path
return self._schedule()

def _kv_connector_start_batch(self, scheduled_batch):
if self.kv_connector_manager:
self.kv_connector_manager.take_scheduled_requests_pending_load(
Expand Down
Loading
Loading