From 96d13a9f7ea195abed128a3b0ac65e026b7e807a Mon Sep 17 00:00:00 2001 From: Tai An Date: Fri, 24 Apr 2026 15:18:38 -0700 Subject: [PATCH] fix(modeling): include named_buffers in module split expansion for infer_auto_device_map When infer_auto_device_map splits a module into its children (because it doesn't fit on one device), it expands the module into named_parameters(recurse=False) + named_children() but omits named_buffers(recurse=False). This means any buffer registered directly on a layer is never added to modules_to_treat and consequently never receives a device assignment. check_device_map then raises: ValueError: The device_map provided does not give any device for the following parameters: model.language_model.layers.8.layer_scalar Gemma-4 (google/gemma-4-E4B-it) exhibits this because its decoder layer registers layer_scalar via register_buffer, making it a buffer rather than a parameter. Fix: add list(module.named_buffers(recurse=False)) in the four expansion sites (fallback_allocate module-split, fallback_allocate parent-expansion, infer_auto_device_map tied-module-split, infer_auto_device_map main-module-split). --- src/accelerate/utils/modeling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 2eac9117dac..081b8ea5d7d 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1246,7 +1246,7 @@ def fallback_allocate( continue # split is possible, add the children to the list of modules to search - modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search if not module_found: @@ -1261,7 +1261,7 @@ def fallback_allocate( if parent_name in current_names: parent_module_idx = current_names.index(parent_name) _, parent_module = modules[parent_module_idx] - module_children = list(parent_module.named_parameters(recurse=False)) + list( + module_children = list(parent_module.named_parameters(recurse=False)) + list(parent_module.named_buffers(recurse=False)) + list( parent_module.named_children() ) modules = ( @@ -1462,7 +1462,7 @@ def infer_auto_device_map( if verbose: print(f"Splitting {tied_module_name}.") - tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = list(tied_module.named_parameters(recurse=False)) + list(tied_module.named_buffers(recurse=False)) + tied_module_children tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] @@ -1510,7 +1510,7 @@ def infer_auto_device_map( # -> split, we replace the module studied by its children + parameters if verbose: print(f"Splitting {name}.") - modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size(