From 12f3e1efcb4fe7b252f7e9336a3c5d201c29f5a7 Mon Sep 17 00:00:00 2001 From: 3manifold <22544721+3manifold@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:06:48 +0100 Subject: [PATCH] Save state model only --- src/accelerate/accelerator.py | 33 +++++++++++++++++++-------------- src/accelerate/checkpointing.py | 5 +++++ tests/test_accelerator.py | 15 +++++++++++++++ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a1cbfaa8d68..a5270b9e8bc 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -3639,6 +3639,8 @@ def _inner(folder): os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving current state to {output_dir}") + save_model_only = save_model_func_kwargs.pop("save_model_only", False) + if self.distributed_type == DistributedType.XLA: # Finish running the previous step before checkpointing xm.mark_step() @@ -3664,23 +3666,25 @@ def _inner(folder): # Save the optimizers taking care of FSDP and DeepSpeed nuances optimizers = [] - if self.distributed_type == DistributedType.FSDP: - for i, opt in enumerate(self._optimizers): - logger.info("Saving FSDP Optimizer") - save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i) - logger.info(f"FSDP Optimizer saved to output dir {output_dir}") - elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: - optimizers = self._optimizers + if not save_model_only: + if self.distributed_type == DistributedType.FSDP: + for i, opt in enumerate(self._optimizers): + logger.info("Saving FSDP Optimizer") + save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i) + logger.info(f"FSDP Optimizer saved to output dir {output_dir}") + elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + optimizers = self._optimizers # Save the lr schedulers taking care of DeepSpeed nuances schedulers = [] - if self.distributed_type == DistributedType.DEEPSPEED: - for i, scheduler in enumerate(self._schedulers): - if isinstance(scheduler, DeepSpeedSchedulerWrapper): - continue - schedulers.append(scheduler) - elif self.distributed_type not in [DistributedType.MEGATRON_LM]: - schedulers = self._schedulers + if not save_model_only: + if self.distributed_type == DistributedType.DEEPSPEED: + for i, scheduler in enumerate(self._schedulers): + if isinstance(scheduler, DeepSpeedSchedulerWrapper): + continue + schedulers.append(scheduler) + elif self.distributed_type not in [DistributedType.MEGATRON_LM]: + schedulers = self._schedulers # Save the samplers of the dataloaders dataloaders = self._dataloaders @@ -3701,6 +3705,7 @@ def _inner(folder): self.scaler, save_on_each_node=self.project_configuration.save_on_each_node, safe_serialization=safe_serialization, + save_model_only=save_model_only ) for i, obj in enumerate(self._custom_objects): save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node) diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 2b753e6e206..a52d953c64c 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -71,6 +71,7 @@ def save_accelerator_state( scaler: Optional[GradScaler] = None, save_on_each_node: bool = False, safe_serialization: bool = True, + save_model_only: bool = False, ): """ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. @@ -113,6 +114,10 @@ def save_accelerator_state( output_model_file = output_dir.joinpath(weights_name) save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization) logger.info(f"Model weights saved in {output_model_file}") + + if save_model_only: + return output_dir + # Optimizer states for i, opt in enumerate(optimizers): state = opt.state_dict() diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index ebab8a8c057..ee860d2d1de 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -296,6 +296,21 @@ def test_save_model(self, use_safetensors): load_checkpoint_in_model(model, tmpdirname) assert abs(model_signature - get_signature(model)) < 1e-3 + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) + def test_save_state_model_only(self, use_safetensors): + accelerator = Accelerator() + model = torch.nn.Linear(10, 10) + model = accelerator.prepare(model) + + model_signature = get_signature(model) + with tempfile.TemporaryDirectory() as tmpdirname: + accelerator.save_state(tmpdirname, safe_serialization=use_safetensors, save_model_only=True) + # make sure only the model was saved + assert os.listdir(tmpdirname) == ['model.safetensors' if use_safetensors else 'pytorch_model.bin'] + # make sure loaded weights match + load_checkpoint_in_model(model, tmpdirname) + assert abs(model_signature - get_signature(model)) < 1e-3 + @parameterized.expand([True, False], name_func=parameterized_custom_name_func) def test_save_sharded_model(self, use_safetensors): accelerator = Accelerator()