Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,33 @@ def test_get_balanced_memory(self):
max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100})
assert {0: 0, "cpu": 100} == max_memory

def test_get_balanced_memory_no_split_module_classes_set(self):
"""Regression test: no_split_module_classes should accept a set without raising TypeError.

In accelerate<=1.5.0, passing a set caused:
TypeError: unhashable type: 'set'
because the code did `set(no_split_module_classes)` without first
normalizing set inputs to a list. Fixed in #2345.
"""
model = ModelForTest()
# Pass no_split_module_classes as a set (not list/tuple)
# This should not raise "TypeError: unhashable type: 'set'"
max_memory = get_balanced_memory(
model,
max_memory={0: 300, 1: 300},
no_split_module_classes={"Linear"},
)
assert isinstance(max_memory, dict)

# Also verify infer_auto_device_map handles set input
from accelerate import infer_auto_device_map
device_map = infer_auto_device_map(
model,
max_memory={0: 300, 1: 300},
no_split_module_classes={"Linear"},
)
assert isinstance(device_map, dict)

# Tests that get_module_size_with_ties returns the correct tied modules in
# models with tied parameters whose parent modules share the same name prefix
# See issue #3308: https://github.com/huggingface/accelerate/issues/3308
Expand Down