From 95ca0cc93ced028c7d7db52a66ac251e2b89419b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 21 May 2026 09:38:13 +0800 Subject: [PATCH 1/2] fix hook for bnb Signed-off-by: jiqing-feng --- src/accelerate/hooks.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 6b28642b30f..396f6f27a57 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -342,6 +342,19 @@ def init_hook(self, module): return module + def _get_fp16_statistics(self, name, value): + # Some quantized weights keep scale statistics as separate state-dict entries rather than + # parameters or buffers. When materializing an int8 weight from `weights_map`, pass those + # statistics along so the restored parameter does not keep stale meta-device attributes. + if value is None or value.dtype != torch.int8 or "weight" not in name: + return None + + statistics_name = name.replace("weight", "SCB") + if statistics_name in self.weights_map: + return self.weights_map[statistics_name] + + return None + @_compiler_disable def pre_forward(self, module, *args, **kwargs): if self.io_same_device: @@ -355,11 +368,8 @@ def pre_forward(self, module, *args, **kwargs): recurse=self.place_submodules, remove_non_persistent=True, ): - fp16_statistics = None value = self.weights_map[name] - if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys(): - if value.dtype == torch.int8: - fp16_statistics = self.weights_map[name.replace("weight", "SCB")] + fp16_statistics = self._get_fp16_statistics(name, value) # In case we are using offloading with tied weights, we need to keep track of the offloaded weights # that are loaded on device at this point, as we will need to remove them as well from the dictionary @@ -424,7 +434,11 @@ def detach_hook(self, module): if self.offload: for name, device in self.original_devices.items(): if device != torch.device("meta"): - set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None)) + value = self.weights_map.get(name, None) + fp16_statistics = self._get_fp16_statistics(name, value) + set_module_tensor_to_device( + module, name, device, value=value, fp16_statistics=fp16_statistics + ) return module From 77817bb30cb078829d84efdc32d4ae7154126d02 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 28 May 2026 09:23:58 +0800 Subject: [PATCH 2/2] fix name Signed-off-by: jiqing-feng --- src/accelerate/hooks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 396f6f27a57..3a6e3890471 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -342,7 +342,7 @@ def init_hook(self, module): return module - def _get_fp16_statistics(self, name, value): + def _maybe_get_fp16_statistics(self, name, value): # Some quantized weights keep scale statistics as separate state-dict entries rather than # parameters or buffers. When materializing an int8 weight from `weights_map`, pass those # statistics along so the restored parameter does not keep stale meta-device attributes. @@ -369,7 +369,7 @@ def pre_forward(self, module, *args, **kwargs): remove_non_persistent=True, ): value = self.weights_map[name] - fp16_statistics = self._get_fp16_statistics(name, value) + fp16_statistics = self._maybe_get_fp16_statistics(name, value) # In case we are using offloading with tied weights, we need to keep track of the offloaded weights # that are loaded on device at this point, as we will need to remove them as well from the dictionary @@ -435,7 +435,7 @@ def detach_hook(self, module): for name, device in self.original_devices.items(): if device != torch.device("meta"): value = self.weights_map.get(name, None) - fp16_statistics = self._get_fp16_statistics(name, value) + fp16_statistics = self._maybe_get_fp16_statistics(name, value) set_module_tensor_to_device( module, name, device, value=value, fp16_statistics=fp16_statistics )