From e7c666a51dcaabd434f35f9b2c80d4c341521611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=A0=E7=9A=84GitHub=E5=90=8D=E7=A8=B1?= <你的GitHub信箱> Date: Mon, 25 May 2026 05:46:51 +0000 Subject: [PATCH] test: add regression test for no_split_module_classes accepting set type get_balanced_memory and infer_auto_device_map should accept a set for no_split_module_classes without raising 'TypeError: unhashable type: set'. Bug was present in <=1.5.0, fixed in #2345, this test prevents regression. --- tests/test_modeling_utils.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f71c757ac33..12dbbbe03ec 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -906,6 +906,33 @@ def test_get_balanced_memory(self): max_memory = get_balanced_memory(model, max_memory={0: 0, "cpu": 100}) assert {0: 0, "cpu": 100} == max_memory + def test_get_balanced_memory_no_split_module_classes_set(self): + """Regression test: no_split_module_classes should accept a set without raising TypeError. + + In accelerate<=1.5.0, passing a set caused: + TypeError: unhashable type: 'set' + because the code did `set(no_split_module_classes)` without first + normalizing set inputs to a list. Fixed in #2345. + """ + model = ModelForTest() + # Pass no_split_module_classes as a set (not list/tuple) + # This should not raise "TypeError: unhashable type: 'set'" + max_memory = get_balanced_memory( + model, + max_memory={0: 300, 1: 300}, + no_split_module_classes={"Linear"}, + ) + assert isinstance(max_memory, dict) + + # Also verify infer_auto_device_map handles set input + from accelerate import infer_auto_device_map + device_map = infer_auto_device_map( + model, + max_memory={0: 300, 1: 300}, + no_split_module_classes={"Linear"}, + ) + assert isinstance(device_map, dict) + # Tests that get_module_size_with_ties returns the correct tied modules in # models with tied parameters whose parent modules share the same name prefix # See issue #3308: https://github.com/huggingface/accelerate/issues/3308