From 8fb099a7db8e7d85e8c234a10ffaecfa9dfc41d8 Mon Sep 17 00:00:00 2001 From: frozenleaves <914814442@qq.com> Date: Wed, 29 Apr 2026 15:01:59 +0800 Subject: [PATCH] fix train with fsdp2: AttributeError: 'Tensor' object has no attribute 'device_mesh' --- src/accelerate/utils/fsdp_utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index 464718e1a29..db070b72ce1 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -518,6 +518,15 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): f"Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys." ) full_param = full_sd[param_name] + # Persistent buffers are not sharded by fully_shard and remain plain Tensors. + # Broadcast them and keep as-is instead of trying to access DTensor attributes. + if not isinstance(sharded_param, DTensor): + full_param = full_param.detach().to(accelerator.device) + dist.broadcast(full_param, src=0, group=dist.group.WORLD) + if cpu_offload: + full_param = full_param.to("cpu") + sharded_sd[param_name] = full_param + continue device_mesh = sharded_param.device_mesh full_param = full_param.detach().to(device_mesh.device_type) if isinstance(full_param, DTensor): @@ -540,6 +549,15 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype): # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock else: for param_name, sharded_param in meta_sharded_sd.items(): + # Persistent buffers are not sharded by fully_shard and remain plain Tensors. + # Broadcast them and keep as-is instead of trying to access DTensor attributes. + if not isinstance(sharded_param, DTensor): + full_tensor = torch.empty(sharded_param.size(), device=accelerator.device, dtype=sharded_param.dtype) + dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) + if cpu_offload: + full_tensor = full_tensor.to("cpu") + sharded_sd[param_name] = full_tensor + continue device_mesh = sharded_param.device_mesh full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype) dist.broadcast(full_tensor, src=0, group=dist.group.WORLD) @@ -916,3 +934,4 @@ def get_parameters_from_modules( for module in modules: parameters.extend(list(module.parameters())) return set(parameters) +