Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def load_single_module(name, module):
for new_name in params_map[names[-1]]:
fw = filter_weights('.'.join(names[:-1] + [new_name]),
weights)
# tmp fixes to enable partial updates in old path
if not fw:
continue
if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [num_kv_heads
] * len(fw) if isinstance(
Expand All @@ -740,15 +743,18 @@ def load_single_module(name, module):
}

module_weights.append(fw)
module.load_weights(weights=module_weights)
if module_weights:
module.load_weights(weights=module_weights)

else:
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
if module_weights:
if hasattr(module, 'load_weights'):
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])

if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
False) in ["True", "true", "1", "yes", "y"]:
Expand Down
57 changes: 39 additions & 18 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,44 @@ def _prepare_draft_requests(self):
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)

def update_weights(self, weights):
# Load weights into the model
self.model_engine.model.load_weights(weights)
torch.cuda.synchronize()

# TODO: reset prefix cache

def update_weight_from_ipc_handles(self, handles):
"""
Update model weights from IPC handles.

Args:
ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.
{device_uuid: all_handles}
"""
from tensorrt_llm._torch.utils import get_device_uuid
device_uuid = get_device_uuid(self.device_id)

if device_uuid not in handles:
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")

try:
weights = {}
all_handles = handles[device_uuid]

for param_name, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id # Set target device
tensor = func(*list_args)
weights[param_name] = tensor

self.update_weights(weights)

except Exception as e:
logger.error(f"failed to update weights from ipc handles: {e}")
return False

def _sleep(self, sleep_request):
self.is_sleep_request = False
self._enqueue_responses({sleep_request.id: LlmResponse(request_id=sleep_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=sleep_request.id)})
Expand All @@ -1096,24 +1134,7 @@ def _update_weight(self, update_weight_request):
self.is_update_weight_request = False

try:
# Get handles for this device
device_uuid = get_device_uuid(self.device_id)
handles = update_weight_request.weight_ipc_handles[device_uuid]
weights = {}

# Process each handle to get the tensor
for name, handle in handles:
func, args = handle
list_args = list(args)
# Update device ID to match the current device
list_args[6] = self.device_id
tensor = func(*list_args)
weights[name] = tensor

# Load weights into the model
self.model_engine.model.load_weights(weights)

torch.cuda.synchronize()
self.update_weight_from_ipc_handles(update_weight_request.weight_ipc_handles)
update_weight_response = LlmResponse(request_id=update_weight_request.id, result=LlmResult(result=None, py_result=PyResult(0, 0, success=True), is_final=True), client_id=update_weight_request.id)
self._enqueue_responses({update_weight_request.id: update_weight_response})
except Exception as e:
Expand Down
Loading