diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f71c757ac33..12dbbbe03ec 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -906,6 +906,33 @@ def test_get_balanced_memory(self): max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100}) assert {0: 0, "cpu": 100} == max_memory + def test_get_balanced_memory_no_split_module_classes_set(self): + """Regression test: no_split_module_classes should accept a set without raising TypeError. + + In accelerate<=1.5.0, passing a set caused: + TypeError: unhashable type: 'set' + because the code did `set(no_split_module_classes)` without first + normalizing set inputs to a list. Fixed in #2345. + """ + model = ModelForTest() + # Pass no_split_module_classes as a set (not list/tuple) + # This should not raise "TypeError: unhashable type: 'set'" + max_memory = get_balanced_memory( + model, + max_memory={0: 300, 1: 300}, + no_split_module_classes={"Linear"}, + ) + assert isinstance(max_memory, dict) + + # Also verify infer_auto_device_map handles set input + from accelerate import infer_auto_device_map + device_map = infer_auto_device_map( + model, + max_memory={0: 300, 1: 300}, + no_split_module_classes={"Linear"}, + ) + assert isinstance(device_map, dict) + # Tests that get_module_size_with_ties returns the correct tied modules in # models with tied parameters whose parent modules share the same name prefix # See issue #3308: https://github.com/huggingface/accelerate/issues/3308