From 725e2062d033a5c619631af8895fc477c7f9e099 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:16:24 +0800 Subject: [PATCH 1/7] add _prepare_and_schedule_batch function in PyExecutor Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 131 +++++++----------- 1 file changed, 50 insertions(+), 81 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 715a70139856..30e6c2dd0ad1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -800,6 +800,50 @@ def _executor_loop_pp(self): self.active_requests, previous_batch) + def _prepare_and_schedule_batch(self): + new_requests = self._fetch_new_requests() + if self.should_stop_processing: + return None, None + + if self.kv_cache_transceiver: + self._check_disagg_gen_transfer_status() + + iter_stats = None + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + len(new_requests), + self.executor_request_queue. + get_new_active_requests_queue_latency()) + + self._pad_attention_dp_dummy_request() + + if self.drafter is not None: + self._prepare_draft_requests(self.active_requests) + + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( + ) + + if self.kv_cache_transceiver: + # For requests that are fitting disagg gen init, also prepare resources for KV cache manager + self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests) + + if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests: + logger.warning( + "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" + ) + self.kv_cache_transceiver.check_context_transfer_status(1) + else: + assert scheduled_batch.batch_size > 0, ( + "fail to schedule any pending request, " + "probably run out of resource.") + + self.num_scheduled_requests = scheduled_batch.batch_size + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.context_requests)} context requests and ' + f'{len(scheduled_batch.generation_requests)} generation requests') + return scheduled_batch, iter_stats + def _executor_loop(self): torch.cuda.set_device(self.device_id) with self._profiler() as profile_step: @@ -810,48 +854,10 @@ def _executor_loop(self): profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() - new_requests = self._fetch_new_requests() - if self.should_stop_processing: - break - - if self.kv_cache_transceiver: - self._check_disagg_gen_transfer_status() - - if self.enable_iter_perf_stats: - iter_stats = self._get_init_iter_stats( - len(new_requests), - self.executor_request_queue. - get_new_active_requests_queue_latency()) - - self._pad_attention_dp_dummy_request() - - if self.drafter is not None: - self._prepare_draft_requests(self.active_requests) - - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( - ) - - if self.kv_cache_transceiver: - # For requests that are fitting disagg gen init, also prepare resources for KV cache manager - self._prepare_disagg_gen_init( - fitting_disagg_gen_init_requests) - if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests: - logger.warning( - "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" - ) - self.kv_cache_transceiver.check_context_transfer_status( - 1) - else: - assert scheduled_batch.batch_size > 0, ( - "fail to schedule any pending request, " - "probably run out of resource.") - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( - f'has {len(self.active_requests)} active_request, ' - f'scheduled {len(scheduled_batch.context_requests)} context requests and ' - f'{len(scheduled_batch.generation_requests)} generation requests' - ) + scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if scheduled_batch is None: + break self._pause_requests(scheduled_batch.paused_requests) @@ -954,47 +960,10 @@ def _executor_loop_overlap(self): profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() - new_requests = self._fetch_new_requests() - if self.should_stop_processing: - break - - if self.kv_cache_transceiver: - self._check_disagg_gen_transfer_status() - - if self.enable_iter_perf_stats: - iter_stats = self._get_init_iter_stats( - len(new_requests), - self.executor_request_queue. - get_new_active_requests_queue_latency()) - self._pad_attention_dp_dummy_request() - - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( - ) - - if self.kv_cache_transceiver: - - # For requests that are fitting disagg gen init, also prepare resources for KV cache manager - self._prepare_disagg_gen_init( - fitting_disagg_gen_init_requests) - - if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests: - logger.warning( - "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" - ) - self.kv_cache_transceiver.check_context_transfer_status( - 1) - else: - assert scheduled_batch.batch_size > 0, ( - "fail to schedule any pending request, " - "probably run out of resource.") - - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( - f'has {len(self.active_requests)} active_request, ' - f'scheduled {len(scheduled_batch.context_requests)} context requests and ' - f'{len(scheduled_batch.generation_requests)} generation requests' - ) + scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if scheduled_batch is None: + break self._pause_requests(scheduled_batch.paused_requests) From 96d004d80073bf58d7ddf3af4ae546d55d361314 Mon Sep 17 00:00:00 2001 From: Liana Koleva <43767763+lianakoleva@users.noreply.github.com> Date: Sat, 26 Jul 2025 08:27:10 -0700 Subject: [PATCH 2/7] doc: fix invalid link in llama 4 example documentation (#6340) Signed-off-by: Liana Koleva <43767763+lianakoleva@users.noreply.github.com> --- examples/models/core/llama4/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/core/llama4/README.md b/examples/models/core/llama4/README.md index 7e1644d5d94b..ff4fe4b69ff5 100644 --- a/examples/models/core/llama4/README.md +++ b/examples/models/core/llama4/README.md @@ -134,7 +134,7 @@ python -m tensorrt_llm.serve.scripts.benchmark_serving \ - `max_batch_size` and `max_num_tokens` can easily affect the performance. The default values for them are already carefully designed and should deliver good performance on overall cases, however, you may still need to tune it for peak performance. - `max_batch_size` should not be too low to bottleneck the throughput. Note with Attention DP, the the whole system's max_batch_size will be `max_batch_size*dp_size`. - CUDA grah `max_batch_size` should be same value as TensorRT-LLM server's `max_batch_size`. -- For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md). +- For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../../../../docs/source/performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md). ### Troubleshooting From d853811190378001a12a933bb7124ea1ee574607 Mon Sep 17 00:00:00 2001 From: Ziyi Xiong <219238287+ziyixiong-nv@users.noreply.github.com> Date: Sun, 27 Jul 2025 08:32:39 +0800 Subject: [PATCH 3/7] [https://nvbugs/5402719][fix]: Add cuda graph dummy requests to the spec_resource_manager (#6258) Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2875f19b5b4f..2ba4cafeda35 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -810,8 +810,11 @@ def _set_up_spec_metadata( is_draft_model=self.is_draft_model) return self.spec_metadata - def _get_padded_batch(self, scheduled_requests: ScheduledRequests, - kv_cache_manager) -> int: + def _get_padded_batch( + self, + scheduled_requests: ScheduledRequests, + kv_cache_manager, + spec_resource_manager: Optional[BaseResourceManager] = None) -> int: can_run_cuda_graph = scheduled_requests.can_run_cuda_graph batch_size = scheduled_requests.batch_size # The number of sequences in the batch is the number of prompts times the beam width. @@ -847,13 +850,17 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests, if available_blocks < 1: return 0 + cuda_graph_dummy_request_ids = [MAX_UINT64 - 1] self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests( - [MAX_UINT64 - 1], + cuda_graph_dummy_request_ids, is_gen=True, max_num_draft_tokens=self.max_draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width)[0] self.cuda_graph_dummy_request.is_cuda_graph_dummy = True + if spec_resource_manager is not None: + spec_resource_manager.add_dummy_requests( + request_ids=cuda_graph_dummy_request_ids) scheduled_requests.generation_requests.extend( [self.cuda_graph_dummy_request] * padding_size) @@ -861,8 +868,11 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests, return padding_size @contextlib.contextmanager - def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests, - kv_cache_manager): + def _maybe_pad_batch( + self, + scheduled_requests: ScheduledRequests, + kv_cache_manager, + spec_resource_manager: Optional[BaseResourceManager] = None): """ CUDA graphs can only be used for specific batch sizes. @@ -871,7 +881,8 @@ def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests, because the padded requests will be removed from scheduled requests. """ padding_size = self._get_padded_batch(scheduled_requests, - kv_cache_manager) + kv_cache_manager, + spec_resource_manager) try: yield scheduled_requests finally: @@ -2072,6 +2083,7 @@ def forward( spec_metadata.is_spec_dec_dynamic_tree, spec_metadata.max_draft_len) else: + spec_resource_manager = None spec_metadata = None moe_load_balancer = None @@ -2090,8 +2102,8 @@ def forward( with MoeLoadBalancerIterContext(moe_load_balancer): return self._forward_step(inputs, gather_ids, gather_context_logits) - with self._maybe_pad_batch(scheduled_requests, - kv_cache_manager) as scheduled_requests: + with self._maybe_pad_batch(scheduled_requests, kv_cache_manager, + spec_resource_manager) as scheduled_requests: maybe_graph = self._maybe_get_cuda_graph( scheduled_requests, spec_config=self.spec_config) if maybe_graph is not None: From 908f49a4adc533deaad970d64626d4ff9b2839f8 Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:01:10 +0800 Subject: [PATCH 4/7] [nvbug/5320234] fix: test_trtllm_bench_llmapi_launch (#6359) Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tests/integration/defs/test_e2e.py | 2 +- tests/integration/test_lists/waives.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index dfb0a1a0d1f9..82d828961b1f 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -551,7 +551,7 @@ def run_bench(self): if self.use_pytorch_backend: benchmark_cmd += " --backend pytorch" else: - benchmark_cmd += " --backend trt" + benchmark_cmd += " --backend tensorrt" if self.extra_llm_api_options: benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}" diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index f6a876ad01fd..224f56edbc68 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -396,7 +396,6 @@ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (http examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128) examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) -test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5320234) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5374145) examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5373451) examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) From 7503c0382b75658d17324f7acd4b977918fdc9cb Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:07:06 +0800 Subject: [PATCH 5/7] fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 30e6c2dd0ad1..2ccaf3ae493f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -832,10 +832,10 @@ def _prepare_and_schedule_batch(self): "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" ) self.kv_cache_transceiver.check_context_transfer_status(1) - else: - assert scheduled_batch.batch_size > 0, ( - "fail to schedule any pending request, " - "probably run out of resource.") + else: + assert scheduled_batch.batch_size > 0, ( + "fail to schedule any pending request, " + "probably run out of resource.") self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( From 2dd3186727adde77011a455780073372baf306e4 Mon Sep 17 00:00:00 2001 From: YueWeng <25103990+yweng0828@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:18:41 +0800 Subject: [PATCH 6/7] fix: remove cudaStreamSynchronize when using relaxed acceptance (#5262) Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com> --- tensorrt_llm/_torch/speculative/mtp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 3c783e1443f1..83eaf5458b50 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -67,7 +67,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): if req.is_first_context_chunk: slot_id = self.slot_manager.add_slot(req.request_id) if self.use_relaxed_acceptance_for_thinking: - self.mtp_relaxed_delta_pool[slot_id] = 0. + self.mtp_relaxed_delta_pool[slot_id].copy_( + 0, non_blocking=True) def update_resources(self, scheduled_batch: ScheduledRequests): pass @@ -75,7 +76,8 @@ def update_resources(self, scheduled_batch: ScheduledRequests): def free_resources(self, request: LlmRequest): free_slot_id = self.slot_manager.get_slot(request.request_id) if self.use_relaxed_acceptance_for_thinking: - self.mtp_relaxed_delta_pool[free_slot_id] = 0. + self.mtp_relaxed_delta_pool[free_slot_id].copy_(0, + non_blocking=True) self.slot_manager.remove_slot(request.request_id) def add_dummy_requests(self, request_ids: List[int]): From 93a0fd0a23b5881f2d7f4da765836f8b08049fa7 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Mon, 28 Jul 2025 09:36:26 +0800 Subject: [PATCH 7/7] [TRTLLM-6445] feat: Enable AllReduce-associated fusion patterns in Llama3/4. (#6205) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- .../allReduceFusionKernels.cu | 2 +- tensorrt_llm/_torch/models/modeling_llama.py | 233 +++++++++++++++--- 2 files changed, 203 insertions(+), 32 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu index 517acff4583f..27d041618e72 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu @@ -520,7 +520,7 @@ __global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport( } template -__global__ void allreduce_fusion_kernel_twoshot_sync( +__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_twoshot_sync( AllReduceFusionParams params, std::array begin_tokens, std::array token_num_per_ranks) { IndexHelper index_helper(params); diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 33dddfc784c4..4af9762d1808 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1,4 +1,5 @@ import copy +import os from typing import Dict, List, Optional, Tuple, Union import torch @@ -337,7 +338,7 @@ def forward( assert shared_output.size() == routed_output.size( ), f'unmatched tensor shape' final_hidden_states = shared_output + routed_output - if not self.enable_attention_dp and self.mapping.tp_size > 1: + if not self.enable_attention_dp and self.mapping.has_tp(): final_hidden_states = self.all_reduce( final_hidden_states, all_reduce_params=final_all_reduce_params) @@ -367,9 +368,6 @@ def __init__( self.fusion_config = EagerFusionConfig() # self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp( # ) - # TODO: re-enable these fusions - self.fusion_config.PRE_MOE_FUSION = False - self.fusion_config.POST_MLP_FUSION = False nope_layer = config.no_rope_layers[layer_idx] == 0 attention_chunk_size = getattr(config, "attention_chunk_size", @@ -387,6 +385,26 @@ def __init__( self.is_mlp_layer = (layer_idx + 1) % config.interleave_moe_layer_step != 0 + self.enable_fusion = os.environ.get( + "TRTLLM_LLAMA_EAGER_FUSION_DISABLED", "0") == "0" + + # MLP layer supports pre and post AR + Res + RMSNorm + NVFP4/FP8 + # MOE layer supports pre AR + Res + RMSNorm + # MOE layer supports post AR + Res + RMSNorm + QUANT + NVFP4/FP8 + self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + + # # Determine the pre and post feed forward fusion op based on the quant mode + if self.is_nvfp4: + self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 + self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 + elif self.is_fp8_quant: + self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8 + self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8 + + if not self.is_mlp_layer: + self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + if self.is_mlp_layer: self.feed_forward = GatedMLP( hidden_size=config.hidden_size, @@ -399,8 +417,10 @@ def __init__( layer_idx=layer_idx, ) - # self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp( - # ) + self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp( + ) and not self.enable_attention_dp and self.enable_fusion + self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp( + ) and not self.enable_attention_dp and self.enable_fusion else: self.feed_forward = Llama4MoE( num_experts=config.num_local_experts, @@ -413,8 +433,10 @@ def __init__( dtype=config.torch_dtype, layer_idx=layer_idx) - # self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp( - # ) + self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp( + ) and not self.enable_attention_dp and self.enable_fusion + self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp( + ) and not self.enable_attention_dp and self.enable_fusion self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -432,6 +454,15 @@ def __init__( self.moe_allreduce = MoEAllReduce(self.mapping) + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.fusion_config.PRE_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + self.disable_feed_forward_allreduce = ( + self.fusion_config.POST_MOE_FUSION + or self.fusion_config.POST_MLP_FUSION or self.mapping.tp_size == 1 + or self.enable_attention_dp) + def forward( self, position_ids: torch.IntTensor, @@ -461,34 +492,48 @@ def forward( position_ids=position_ids, hidden_states=hidden_states, attn_metadata=attn_metadata, - all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.PRE_MOE_FUSION or self.mapping.tp_size == 1 - or self.enable_attention_dp)), + all_reduce_params=AllReduceParams( + enable_allreduce=not self.disable_attn_allreduce), **kwargs, ) - if self.fusion_config.PRE_MOE_FUSION: - hidden_states, residual = self.all_reduce( + if self.fusion_config.PRE_MLP_FUSION or self.fusion_config.PRE_MOE_FUSION: + if self.is_mlp_layer and (self.is_nvfp4 or self.is_fp8_quant): + scale = self.feed_forward.gate_up_proj.input_scale + else: + scale = None + allreduce_output = self.all_reduce( hidden_states, all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + fusion_op=self.pre_feed_forward_fusion_op, residual=residual, norm_weight=self.post_attention_layernorm.weight, + scale=scale, eps=self.post_attention_layernorm.variance_epsilon, )) + + if self.is_mlp_layer and self.is_nvfp4: + act_fp4, act_sf, residual = allreduce_output + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + hidden_states, residual = allreduce_output else: - # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + # disable fusion for layers captured by spec_metadata + if spec_metadata is not None: + if spec_metadata.is_layer_capture(self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False + self.fusion_config.POST_MOE_FUSION = False + self.disable_feed_forward_allreduce = self.mapping.tp_size == 1 or self.enable_attention_dp + hidden_states = self.feed_forward( hidden_states, all_rank_num_tokens=attn_metadata.all_rank_num_tokens, all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MOE_FUSION - or self.fusion_config.POST_MLP_FUSION - or self.mapping.tp_size == 1 or self.enable_attention_dp)), + final_all_reduce_params=AllReduceParams( + enable_allreduce=not self.disable_feed_forward_allreduce), cutlass_min_latency_mode=cutlass_min_latency_mode, ) @@ -503,13 +548,23 @@ def forward( if (self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION ) and self.next_layer_layernorm is not None: + # Get the scale for the next allreduce fusion op + if self.next_attn is not None and (self.is_nvfp4 + or self.is_fp8_quant): + scale = self.next_attn.qkv_proj.input_scale + else: + # Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer + self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + scale = None + + # TODO: MIN_LATENCY_MODE is hardcoded to False if cutlass_min_latency_mode: shared_output = hidden_states[0] hidden_states_activated_experts = hidden_states[1] num_activated_experts_per_node = hidden_states[2] experts_to_token_score = hidden_states[3] - hidden_states, residual = self.moe_allreduce( + allreduce_output = self.moe_allreduce( residual, self.next_layer_layernorm.weight, device_num_experts=num_activated_experts_per_node, @@ -519,14 +574,22 @@ def forward( eps=self.next_layer_layernorm.variance_epsilon, ) else: - hidden_states, residual = self.all_reduce( + allreduce_output = self.all_reduce( hidden_states, all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + fusion_op=self.post_feed_forward_fusion_op, residual=residual, norm_weight=self.next_layer_layernorm.weight, + scale=scale, eps=self.next_layer_layernorm.variance_epsilon, )) + + # Unpack the allreduce output + if self.next_attn is not None and self.is_nvfp4: + act_fp4, act_sf, residual = allreduce_output + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + hidden_states, residual = allreduce_output elif self.next_layer_layernorm: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -544,6 +607,14 @@ def __init__( super().__init__() config = model_config.pretrained_config self.layer_idx = layer_idx + self.mapping = model_config.mapping + self.enable_attention_dp = model_config.mapping.enable_attention_dp + self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant( + ) + self.is_fp8_quant = self.is_quanted and model_config.quant_config.quant_mode.has_fp8_qdq( + ) + self.is_nvfp4 = self.is_quanted and model_config.quant_config.quant_mode.has_nvfp4( + ) self.self_attn = LlamaAttention( model_config, @@ -566,11 +637,42 @@ def __init__( eps=config.rms_norm_eps, dtype=config.torch_dtype) + self.all_reduce = AllReduce(mapping=model_config.mapping) + + self.next_layer_layernorm: RMSNorm = None + self.next_attn: LlamaAttention = None + self.attention_mask = PredefinedAttentionMask.CAUSAL # If the model is being used as an encoder model (prefill only) we use a full attention mask if not model_config.is_generation: self.attention_mask = PredefinedAttentionMask.FULL + self.enable_fusion = os.environ.get( + "TRTLLM_LLAMA_EAGER_FUSION_DISABLED", "0") == "0" + # Disable fusion for small models due to accuracy issues + self.enable_fusion &= config.hidden_size > 4096 + + self.PRE_MLP_FUSION = self.mapping.has_tp( + ) and not self.enable_attention_dp and self.enable_fusion + self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion + + if self.is_nvfp4: + self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 + self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 + elif self.is_fp8_quant: + self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8 + self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8 + else: + self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM + + self.disable_attn_allreduce = (self.PRE_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + self.disable_mlp_allreduce = (self.POST_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + def forward( self, position_ids: torch.IntTensor, @@ -583,9 +685,6 @@ def forward( if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) # Self Attention hidden_states = self.self_attn( @@ -593,20 +692,81 @@ def forward( hidden_states=hidden_states, attn_metadata=attn_metadata, attention_mask=self.attention_mask, + all_reduce_params=AllReduceParams( + enable_allreduce=not self.disable_attn_allreduce), **kwargs, ) - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states, **kwargs) + if self.PRE_MLP_FUSION: + if self.is_nvfp4 or self.is_fp8_quant: + scale = self.mlp.gate_up_proj.input_scale + else: + scale = None + all_reduce_output = self.all_reduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=self.pre_mlp_fusion_op, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=scale, + eps=self.post_attention_layernorm.variance_epsilon, + )) + if self.is_nvfp4: + act_fp4, act_sf, residual = all_reduce_output + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + hidden_states, residual = all_reduce_output + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # disable fusion for layers captured by spec_metadata + if spec_metadata is not None: + # how to know if is_layer_capture exists, if not do not call + if hasattr(spec_metadata, + "is_layer_capture") and spec_metadata.is_layer_capture( + self.layer_idx): + self.POST_MLP_FUSION = False + self.disable_mlp_allreduce = self.mapping.tp_size == 1 or self.enable_attention_dp + + hidden_states = self.mlp( + hidden_states, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not self.disable_mlp_allreduce), + **kwargs, + ) + if spec_metadata is not None: # We save the hidden states in the spec metadata here. In _prepare_draft_tokens, # PyExecutor will extract these from the model engine's spec metadata. # They will be passed to the draft model engine on the first draft iteration. # TODO: can we support multiple model outputs instead? + spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual) + if self.POST_MLP_FUSION and self.next_attn is not None: + if self.is_nvfp4 or self.is_fp8_quant: + scale = self.next_attn.qkv_proj.input_scale + else: + scale = None + all_reduce_output = self.all_reduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=self.post_mlp_fusion_op, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + scale=scale, + eps=self.next_layer_layernorm.variance_epsilon, + )) + if self.is_nvfp4: + act_fp4, act_sf, residual = all_reduce_output + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + hidden_states, residual = all_reduce_output + elif self.next_layer_layernorm: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + return hidden_states, residual @@ -729,7 +889,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): if self.has_custom_embed_tokens: with torch.no_grad(): - if model_config.mapping.tp_size > 1: + if model_config.mapping.has_tp(): weight = split_matrix_tp( weight, model_config.mapping.tp_size, @@ -777,7 +937,6 @@ def forward( lora_params=lora_params, ) - hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -790,6 +949,18 @@ def __init__( ): super().__init__(LlamaModel(model_config), model_config) + def load_weights(self, weights: Dict): + super().load_weights(weights) + + for idx, layer in enumerate( + self.model.layers[:self.config.num_hidden_layers]): + if idx == self.config.num_hidden_layers - 1: + layer.next_layer_layernorm = self.model.norm + else: + layer.next_layer_layernorm = self.model.layers[ + idx + 1].input_layernorm + layer.next_attn = self.model.layers[idx + 1].self_attn + class Llama4InputProcessor(InputProcessor):