From a6955798ae6087544b843bf2962f1811d575d4a2 Mon Sep 17 00:00:00 2001 From: Rahul Solanki Date: Fri, 3 May 2024 01:02:20 +0000 Subject: [PATCH 1/5] handle weight sharing with init_on_device --- src/accelerate/big_modeling.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index bddcaa8a0cc..986f1c34cfc 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -132,7 +132,14 @@ def register_empty_parameter(module, name, param): param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + # When we have a case of tensor2 = tensor1, it would call the set_attr + # of param, which in turn would call the register_parameter API. + # In this case, the new param is already on meta-device, since it was moved + # previously when it was initialized. Hence, when resetting, you can + # directly assign that tensor instead of re-init. If you re-init you would + # lose the relationship. + module._parameters[name] = param if param.device == device else \ + param_cls(module._parameters[name].to(device), **kwargs) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent) From 892ba83647419a1fa8dc365ee5449385df20b25f Mon Sep 17 00:00:00 2001 From: "lanzongwei.lan" Date: Sat, 2 Aug 2025 15:17:28 +0800 Subject: [PATCH 2/5] ci: add unittest for tie-embedding empty_init --- tests/test_big_modeling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index 7c960745565..6d20f74eff8 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -188,6 +188,13 @@ def test_init_empty_weights(self): assert module.weight.device == torch.device("cpu") assert module.running_mean.device == torch.device("cpu") + def test_init_empty_weights_with_tie_embedding(self): + with init_empty_weights(): + module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)]) + # tie embedding + module[0].weight = module[1].weight + assert module[0].weight is module[1].weight + def test_init_empty_weights_very_large_model(self): # This is a 100 billion parameters model. with init_empty_weights(): From a8575f9dcf543a0504110477e9fd5f53449fa486 Mon Sep 17 00:00:00 2001 From: "lanzongwei.lan" Date: Thu, 11 Sep 2025 19:23:58 +0800 Subject: [PATCH 3/5] ci: add unittest for tie-embedding qwen2 --- tests/test_big_modeling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index 6d20f74eff8..3764ba2b98d 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -193,7 +193,12 @@ def test_init_empty_weights_with_tie_embedding(self): module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)]) # tie embedding module[0].weight = module[1].weight + + from transformers.models import Qwen2Config, Qwen2ForCausalLM + + qwen2 = Qwen2ForCausalLM(Qwen2Config(tie_word_embeddings=True)) assert module[0].weight is module[1].weight + assert qwen2.lm_head.weight is qwen2.model.embed_tokens.weight def test_init_empty_weights_very_large_model(self): # This is a 100 billion parameters model. From ac9f722e568499378d09f5bd7342534aed4e5b14 Mon Sep 17 00:00:00 2001 From: "lanzongwei.lan" Date: Thu, 18 Sep 2025 22:31:03 +0800 Subject: [PATCH 4/5] ci: no need tie_weights in test_infer_auto_device_map_on_t0pp --- tests/test_modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4857b3b5df2..68f3eae039c 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -677,7 +677,6 @@ def test_infer_auto_device_map_on_t0pp(self): config = AutoConfig.from_pretrained("bigscience/T0pp") with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) - model.tie_weights() special_dtypes = {n: torch.float32 for n, _ in model.named_parameters() if "wo" in n} max_memory = {0: 10**10, 1: 10**10, "cpu": 10**10} From 33a07f3923155ffca7f7e725e73761af3b037805 Mon Sep 17 00:00:00 2001 From: "lanzongwei.lan" Date: Fri, 19 Sep 2025 11:18:13 +0800 Subject: [PATCH 5/5] chore: lint fix --- src/accelerate/big_modeling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 986f1c34cfc..f025fefa92a 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -138,8 +138,9 @@ def register_empty_parameter(module, name, param): # previously when it was initialized. Hence, when resetting, you can # directly assign that tensor instead of re-init. If you re-init you would # lose the relationship. - module._parameters[name] = param if param.device == device else \ - param_cls(module._parameters[name].to(device), **kwargs) + module._parameters[name] = ( + param if param.device == device else param_cls(module._parameters[name].to(device), **kwargs) + ) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent)