diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 6b28642b30f..3a6e3890471 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -342,6 +342,19 @@ def init_hook(self, module): return module + def _maybe_get_fp16_statistics(self, name, value): + # Some quantized weights keep scale statistics as separate state-dict entries rather than + # parameters or buffers. When materializing an int8 weight from `weights_map`, pass those + # statistics along so the restored parameter does not keep stale meta-device attributes. + if value is None or value.dtype != torch.int8 or "weight" not in name: + return None + + statistics_name = name.replace("weight", "SCB") + if statistics_name in self.weights_map: + return self.weights_map[statistics_name] + + return None + @_compiler_disable def pre_forward(self, module, *args, **kwargs): if self.io_same_device: @@ -355,11 +368,8 @@ def pre_forward(self, module, *args, **kwargs): recurse=self.place_submodules, remove_non_persistent=True, ): - fp16_statistics = None value = self.weights_map[name] - if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys(): - if value.dtype == torch.int8: - fp16_statistics = self.weights_map[name.replace("weight", "SCB")] + fp16_statistics = self._maybe_get_fp16_statistics(name, value) # In case we are using offloading with tied weights, we need to keep track of the offloaded weights # that are loaded on device at this point, as we will need to remove them as well from the dictionary @@ -424,7 +434,11 @@ def detach_hook(self, module): if self.offload: for name, device in self.original_devices.items(): if device != torch.device("meta"): - set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None)) + value = self.weights_map.get(name, None) + fp16_statistics = self._maybe_get_fp16_statistics(name, value) + set_module_tensor_to_device( + module, name, device, value=value, fp16_statistics=fp16_statistics + ) return module