From 1a4baaf89272f805db7a4a1bb50f6bfcb6e130ed Mon Sep 17 00:00:00 2001 From: shuyix Date: Thu, 24 Jul 2025 02:29:11 -0700 Subject: [PATCH 1/2] enable partial load --- tensorrt_llm/_torch/models/modeling_utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index a8ce31bf2cef..17dc332eb2f9 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -724,6 +724,9 @@ def load_single_module(name, module): for new_name in params_map[names[-1]]: fw = filter_weights('.'.join(names[:-1] + [new_name]), weights) + # tmp fixes to enable partial updates in old path + if not fw: + continue if new_name in ['k_proj', 'v_proj']: num_kv_heads_list = [num_kv_heads ] * len(fw) if isinstance( @@ -740,15 +743,18 @@ def load_single_module(name, module): } module_weights.append(fw) - module.load_weights(weights=module_weights) + if module_weights: + module.load_weights(weights=module_weights) + else: module_weights = filter_weights(name, weights) - if hasattr(module, 'load_weights'): - module.load_weights(weights=[module_weights]) - else: - for n, p in module._parameters.items(): - if p is not None: - p.data.copy_(module_weights[n][:]) + if module_weights: + if hasattr(module, 'load_weights'): + module.load_weights(weights=[module_weights]) + else: + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n][:]) if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL", False) in ["True", "true", "1", "yes", "y"]: From 98359938c091be23561cb6c2cb55f1bb5be1aa8c Mon Sep 17 00:00:00 2001 From: shuyix Date: Thu, 24 Jul 2025 02:30:04 -0700 Subject: [PATCH 2/2] align interfaces with ray branch --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 30e2a8a16c3d..0d7382c9377d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1084,6 +1084,44 @@ def _prepare_draft_requests(self): logger.error(f"Encountered an error in decode: {error_msg}") self._handle_errors(error_msg) + def update_weights(self, weights): + # Load weights into the model + self.model_engine.model.load_weights(weights) + torch.cuda.synchronize() + + # TODO: reset prefix cache + + def update_weight_from_ipc_handles(self, handles): + """ + Update model weights from IPC handles. + + Args: + ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + {device_uuid: all_handles} + """ + from tensorrt_llm._torch.utils import get_device_uuid + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + try: + weights = {} + all_handles = handles[device_uuid] + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id # Set target device + tensor = func(*list_args) + weights[param_name] = tensor + + self.update_weights(weights) + + except Exception as e: + logger.error(f"failed to update weights from ipc handles: {e}") + return False + def _sleep(self, sleep_request): self.is_sleep_request = False self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)}) @@ -1096,24 +1134,7 @@ def _update_weight(self, update_weight_request): self.is_update_weight_request = False try: - # Get handles for this device - device_uuid = get_device_uuid(self.device_id) - handles = update_weight_request.weight_ipc_handles[device_uuid] - weights = {} - - # Process each handle to get the tensor - for name, handle in handles: - func, args = handle - list_args = list(args) - # Update device ID to match the current device - list_args[6] = self.device_id - tensor = func(*list_args) - weights[name] = tensor - - # Load weights into the model - self.model_engine.model.load_weights(weights) - - torch.cuda.synchronize() + self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles) update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id) self._enqueue_responses({update_weight_request.id: update_weight_response}) except Exception as e: