Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def fallback_allocate(
continue

# split is possible, add the children to the list of modules to search
modules_children = list(module.named_parameters(recurse=False)) + modules_children
modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children
modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search

if not module_found:
Expand All @@ -1261,7 +1261,7 @@ def fallback_allocate(
if parent_name in current_names:
parent_module_idx = current_names.index(parent_name)
_, parent_module = modules[parent_module_idx]
module_children = list(parent_module.named_parameters(recurse=False)) + list(
module_children = list(parent_module.named_parameters(recurse=False)) + list(parent_module.named_buffers(recurse=False)) + list(
parent_module.named_children()
)
modules = (
Expand Down Expand Up @@ -1462,7 +1462,7 @@ def infer_auto_device_map(

if verbose:
print(f"Splitting {tied_module_name}.")
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
tied_module_children = list(tied_module.named_parameters(recurse=False)) + list(tied_module.named_buffers(recurse=False)) + tied_module_children
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]

Expand Down Expand Up @@ -1510,7 +1510,7 @@ def infer_auto_device_map(
# -> split, we replace the module studied by its children + parameters
if verbose:
print(f"Splitting {name}.")
modules_children = list(module.named_parameters(recurse=False)) + modules_children
modules_children = list(module.named_parameters(recurse=False)) + list(module.named_buffers(recurse=False)) + modules_children
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
# Update the max layer size.
max_layer_size, max_layer_names = get_max_layer_size(
Expand Down