Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,26 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>())
.def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion,
nb::call_guard<nb::gil_scoped_release>())
.def("get_remaining_blocks_to_completion_batch",
[](tbk::BaseKVCacheManager& self, nb::list const& pyRequests, SizeType32 windowSize)
{
// Extract C++ request pointers while GIL is held
std::vector<tb::LlmRequest const*> requests;
requests.reserve(nb::len(pyRequests));
for (auto const& item : pyRequests)
{
requests.push_back(&nb::cast<tb::LlmRequest const&>(item));
}
// Release GIL for the C++ computation
nb::gil_scoped_release release;
std::vector<SizeType32> result;
result.reserve(requests.size());
for (auto const* req : requests)
{
result.push_back(self.getRemainingBlocksToCompletion(*req, windowSize));
}
return result;
})
.def("add_token", &BaseKVCacheManager::addToken, nb::call_guard<nb::gil_scoped_release>())
.def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard<nb::gil_scoped_release>())
.def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard<nb::gil_scoped_release>())
Expand Down
31 changes: 31 additions & 0 deletions scripts/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,36 @@ def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False):
print("-- Done building kv_cache_manager_v2.")


def build_pyexecutor_scheduler(project_dir, venv_python, use_mypyc=False):
print("-- Building pyexecutor scheduler...")
scheduler_dir = project_dir / "tensorrt_llm/_torch/pyexecutor/scheduler"
pyexecutor_dir = project_dir / "tensorrt_llm/_torch/pyexecutor"

# Clean up any existing mypyc artifacts to prevent stale inclusion
if not use_mypyc:
for so_file in pyexecutor_dir.glob("*__mypyc*.so"):
print(f"Removing stale mypyc artifact: {so_file}")
so_file.unlink()

for so_file in scheduler_dir.glob("*.so"):
print(f"Removing stale artifact: {so_file}")
so_file.unlink()

if use_mypyc:
print("-- Building scheduler mypyc extensions...", end=" ")
setup_mypyc = scheduler_dir / "setup_mypyc.py"
build_run(f'"{venv_python}" "{setup_mypyc}" build_ext --inplace',
cwd=pyexecutor_dir)

# Verify that the shared library was generated
if not list(pyexecutor_dir.glob("*__mypyc*.so")) and not list(
scheduler_dir.glob("*.so")):
raise RuntimeError(
"Failed to build scheduler: no shared library generated.")
print("Done")
print("-- Done building pyexecutor scheduler.")


def main(*,
build_type: str = "Release",
generator: str = "",
Expand Down Expand Up @@ -969,6 +999,7 @@ def get_binding_lib(subdirectory, name):
binding_lib_file_name)

build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=mypyc)
build_pyexecutor_scheduler(project_dir, venv_python, use_mypyc=mypyc)

if not skip_building_wheel:
if dist_dir is None:
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def has_ext_modules(self):
'runtime/kv_cache_manager_v2/rawref/*.py',
'runtime/kv_cache_manager_v2/rawref/*.pyi',
'runtime/*__mypyc*.so',
'_torch/pyexecutor/scheduler/*.so',
'_torch/pyexecutor/scheduler/*.pyi',
'_torch/pyexecutor/*__mypyc*.so',
]

package_data += [
Expand Down Expand Up @@ -372,7 +375,10 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
package_data = [
p for p in package_data if p not in [
'runtime/kv_cache_manager_v2/*.so',
'runtime/kv_cache_manager_v2/**/*.so', 'runtime/*__mypyc*.so'
'runtime/kv_cache_manager_v2/**/*.so',
'runtime/*__mypyc*.so',
'_torch/pyexecutor/scheduler/*.so',
'_torch/pyexecutor/*__mypyc*.so',
]
]
# Ensure rawref is included
Expand Down
79 changes: 54 additions & 25 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
KVCacheV2Scheduler, SimpleScheduler,
SimpleUnifiedScheduler)
KVCacheV2Scheduler, ScheduleStepConfig, SimpleScheduler,
UnifiedScheduler)
from .seq_slot_manager import SeqSlotManager

GB = 1 << 30
Expand Down Expand Up @@ -1268,6 +1268,21 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1

adp_cfg = llm_args.attention_dp_config
adp_enable_balance = adp_cfg is not None and adp_cfg.enable_balance
schedule_step_config = ScheduleStepConfig(
enable_attention_dp=mapping.enable_attention_dp,
attention_dp_enable_balance=adp_enable_balance,
attention_dp_time_out_iters=adp_cfg.timeout_iters
if adp_enable_balance else 0,
attention_dp_batching_wait_iters=(adp_cfg.batching_wait_iters
if adp_enable_balance else 0),
batch_wait_timeout_iters=llm_args.batch_wait_timeout_iters,
batch_wait_max_tokens_ratio=llm_args.batch_wait_max_tokens_ratio,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
)

if isinstance(kv_cache_manager, KVCacheManagerV2):
# V2: interleaved scheduler handles both capacity and budget
draft_kv_cache_manager = resources.get(
Expand All @@ -1285,31 +1300,45 @@ def create_py_executor_instance(
if peft_cache_manager is not None else None,
scheduler_capacity=scheduler_capacity,
draft_kv_cache_manager=draft_kv_cache_manager,
schedule_step_config=schedule_step_config,
dist=dist,
)
elif (scheduler_config is not None
and scheduler_config.use_python_scheduler):
scheduler = SimpleUnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())

mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
use_python_scheduler = (scheduler_config.use_python_scheduler
if scheduler_config is not None else False)
if use_python_scheduler:
scheduler = UnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity,
schedule_step_config=schedule_step_config,
dist=dist,
)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())

mb_scheduler = BindMicroBatchScheduler(max_batch_size,
max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(
capacity_scheduler,
mb_scheduler,
schedule_step_config=schedule_step_config,
dist=dist,
)

config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
Expand Down
130 changes: 11 additions & 119 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,7 @@ def __init__(
self.stream_interval = self.llm_args.stream_interval
self.perf_manager = PerfMetricsManager(
enabled=getattr(self.llm_args, 'return_perf_metrics', False))
self.attention_dp_enable_balance = (
self.llm_args.attention_dp_config is not None
and self.llm_args.attention_dp_config.enable_balance)
if self.attention_dp_enable_balance:
self.attention_dp_time_out_iters = self.llm_args.attention_dp_config.timeout_iters
self.attention_dp_batching_wait_iters = self.llm_args.attention_dp_config.batching_wait_iters
self.batch_wait_timeout_ms = self.llm_args.batch_wait_timeout_ms
self.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters
self.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0

self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
Expand Down Expand Up @@ -451,9 +442,6 @@ def __init__(

self.is_shutdown = False
self.max_batch_size = max_batch_size
self.adp_ctx_waiting_iters_count = 0
self.adp_ctx_batching_wait_iters_count = 0
self.batch_wait_iters_count = 0

def on_detected():
self._handle_errors(
Expand Down Expand Up @@ -1173,8 +1161,11 @@ def _pp_schedule_and_propagate(self, microbatch_id: int):
is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp
if self.dist.rank == 0 or (self.dist.is_first_pp_rank
and is_dp_broadcast):
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
step_result = self.scheduler.schedule_step(self.active_requests,
self.inflight_req_ids)
scheduled_batch = step_result.scheduled_requests
fitting_disagg_gen_init_requests = step_result.fitting_disagg_gen_init_requests
num_fitting_reqs = step_result.num_fitting_requests
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
scheduled_batch, fitting_disagg_gen_init_requests,
num_fitting_reqs)
Expand Down Expand Up @@ -1291,7 +1282,7 @@ def _executor_loop_pp(self):
if self.dist.rank != 0:
# Retry until current rank can run first PP's schedule result.
self._pp_retry_until_can_schedule(scheduled_batch)
# Run scheduler locally because scheduler may change llm requests' state.
# Replay scheduler-local request state changes only.
self.scheduler.schedule_request(self.active_requests,
self.inflight_req_ids)

Expand Down Expand Up @@ -1780,8 +1771,11 @@ def _prepare_and_schedule_batch(self):
# that speculation is about to happen.
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
step_result = self.scheduler.schedule_step(self.active_requests,
self.inflight_req_ids)
scheduled_batch = step_result.scheduled_requests
fitting_disagg_gen_init_requests = step_result.fitting_disagg_gen_init_requests
num_fitting_reqs = step_result.num_fitting_requests

if self.drafter is not None and not self.use_spec_decode:
for request in scheduled_batch.all_requests():
Expand Down Expand Up @@ -2662,108 +2656,6 @@ def _add_kv_cache_events(self):
# to be transferred to main thread when user needs them.
kv_cache_manager.flush_iteration_events()

def _balance_adp_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
balanced_context_requests = context_requests
num_scheduled_context_requests = len(context_requests)
num_scheduled_generation_requests = len(generation_requests)
num_scheduled_tokens = sum(
[len(req.get_tokens(0))
for req in context_requests]) + num_scheduled_generation_requests
# Note: We use tp_allgather instead of tp_cp_allgather because we want to
# balance the requests across DP ranks; not CP ranks within those DP ranks.
responses_list = self.dist.tp_allgather([
num_scheduled_context_requests, num_scheduled_generation_requests,
num_scheduled_tokens
])
all_ranks_num_scheduled_context_requests = [
response[0] for response in responses_list
]
all_ranks_num_scheduled_generation_requests = [
response[1] for response in responses_list
]
all_ranks_have_free_ctx_slots = all([
num_gen < self.max_batch_size
for num_gen in all_ranks_num_scheduled_generation_requests
])
all_ranks_have_ctx_requests = all([
num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests
])
all_ranks_have_gen_requests = all([
num_gen > 0
for num_gen in all_ranks_num_scheduled_generation_requests
])

if self.attention_dp_enable_balance:
# wait for all ranks have context requests
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
self.adp_ctx_waiting_iters_count = 0
# balance number of context requests across ranks
if all_ranks_have_gen_requests:
if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters:
self.adp_ctx_batching_wait_iters_count += 1
balanced_context_requests = []
else:
self.adp_ctx_batching_wait_iters_count = 0
else:
self.adp_ctx_waiting_iters_count += 1
balanced_context_requests = []
timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters
if timeout_reached or not all_ranks_have_gen_requests:
self.adp_ctx_waiting_iters_count = 0
balanced_context_requests = context_requests
return balanced_context_requests

def _waiting_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
"""
Return an empty list if scheduled requests fulfill the waiting conditions, otherwise return the original context requests.
Waiting conditions:
- The number of scheduled tokens (both context and generation) is smaller than `self.batch_wait_max_tokens_ratio * self.max_num_tokens`
- The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`.
"""

num_scheduled_ctx_tokens = sum(
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens

should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens
if should_waiting:
self.batch_wait_iters_count += 1
return []

self.batch_wait_iters_count = 0
return context_requests

@nvtx_range("_schedule")
def _schedule(self):
scheduler_output = self.scheduler.schedule_request(
self.active_requests, self.inflight_req_ids)

scheduled_context_requests = scheduler_output.context_requests
if self.enable_attention_dp and self.attention_dp_enable_balance:
scheduled_context_requests = self._balance_adp_requests(
scheduler_output.context_requests,
scheduler_output.generation_requests)

# If no generation requests, no need to wait, to avoid dead waiting
should_check_waiting = not self.enable_attention_dp and self.enable_batch_waiting and len(
scheduler_output.context_requests) > 0 and len(
scheduler_output.generation_requests) > 0
if should_check_waiting:
scheduled_context_requests = self._waiting_requests(
scheduler_output.context_requests,
scheduler_output.generation_requests)

scheduled_requests = ScheduledRequests()
scheduled_requests.reset_context_requests(scheduled_context_requests)
scheduled_requests.generation_requests = scheduler_output.generation_requests
scheduled_requests.paused_requests = scheduler_output.paused_requests

return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

@nvtx_range("_check_disagg_gen_transfer_status")
def _check_disagg_gen_transfer_status(self):

Expand Down
Loading
Loading