Skip to content

Fix int8 offload hook detachment statistics restoration#4044

Open
jiqing-feng wants to merge 1 commit into
huggingface:mainfrom
jiqing-feng:offload
Open

Fix int8 offload hook detachment statistics restoration#4044
jiqing-feng wants to merge 1 commit into
huggingface:mainfrom
jiqing-feng:offload

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented May 21, 2026

Summary

Fix AlignDevicesHook.detach_hook() so removing offload hooks restores bitsandbytes int8 weights together with their fp16 quantization statistics.

Sequential CPU offload stores module weights in weights_map and moves parameters to the meta device. For bnb int8 modules, the SCB statistics are saved as extra state-dict entries instead of regular parameters or buffers. pre_forward() already passed those statistics back into set_module_tensor_to_device(), but detach_hook() did not. As a result, removing hooks could restore the int8 weight while leaving weight.SCB on the meta device, causing the next offload pass to fail with:

NotImplementedError: Cannot copy out of meta tensor; no data!

This change centralizes the lookup of int8 fp16 statistics in AlignDevicesHook._get_fp16_statistics() and uses it from both pre_forward() and detach_hook(). The existing forward path keeps the same behavior, while hook removal now materializes int8 weights with their matching statistics.

Reproduction

This is a minimal reproduction of the failing sequence: enable sequential CPU offload, remove hooks recursively, then enable sequential CPU offload again.

import torch
from transformers import T5EncoderModel
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from accelerate.hooks import remove_hook_from_module

model_id = 'hf-internal-testing/flux.1-dev-int8-pkg'

t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder='text_encoder_2')
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder='transformer')
pipe = DiffusionPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    text_encoder_2=t5_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
)

pipe.enable_sequential_cpu_offload()

for _, component in pipe.components.items():
    if isinstance(component, torch.nn.Module) and hasattr(component, '_hf_hook'):
        remove_hook_from_module(component, recurse=True)

print('meta_after_remove')
for name, component in pipe.components.items():
    if isinstance(component, torch.nn.Module):
        meta = [n for n, p in component.state_dict().items() if p.device == torch.device('meta')]
        print(name, len(meta), meta[:3])

pipe.enable_sequential_cpu_offload()
print('reoffload_ok')

Expected output:

meta_after_remove
vae 0 []
text_encoder 0 []
text_encoder_2 0 []
transformer 0 []
reoffload_ok

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review this PR? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant