diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 2eac9117dac..081b8ea5d7d 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1246,7 +1246,7 @@ def fallback_allocate( continue # split is possible, add the children to the list of modules to search - modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search if not module_found: @@ -1261,7 +1261,7 @@ def fallback_allocate( if parent_name in current_names: parent_module_idx = current_names.index(parent_name) _, parent_module = modules[parent_module_idx] - module_children = list(parent_module.named_parameters(recurse=False)) + list( + module_children = list(parent_module.named_parameters(recurse=False)) + list(parent_module.named_buffers(recurse=False)) + list( parent_module.named_children() ) modules = ( @@ -1462,7 +1462,7 @@ def infer_auto_device_map( if verbose: print(f"Splitting {tied_module_name}.") - tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = list(tied_module.named_parameters(recurse=False)) + list(tied_module.named_buffers(recurse=False)) + tied_module_children tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] @@ -1510,7 +1510,7 @@ def infer_auto_device_map( # -> split, we replace the module studied by its children + parameters if verbose: print(f"Splitting {name}.") - modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size(