Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3639,6 +3639,8 @@ def _inner(folder):
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving current state to {output_dir}")

save_model_only = save_model_func_kwargs.pop("save_model_only", False)

if self.distributed_type == DistributedType.XLA:
# Finish running the previous step before checkpointing
xm.mark_step()
Expand All @@ -3664,23 +3666,25 @@ def _inner(folder):

# Save the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
optimizers = self._optimizers
if not save_model_only:
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
optimizers = self._optimizers

# Save the lr schedulers taking care of DeepSpeed nuances
schedulers = []
if self.distributed_type == DistributedType.DEEPSPEED:
for i, scheduler in enumerate(self._schedulers):
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
continue
schedulers.append(scheduler)
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers
if not save_model_only:
if self.distributed_type == DistributedType.DEEPSPEED:
for i, scheduler in enumerate(self._schedulers):
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
continue
schedulers.append(scheduler)
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers

# Save the samplers of the dataloaders
dataloaders = self._dataloaders
Expand All @@ -3701,6 +3705,7 @@ def _inner(folder):
self.scaler,
save_on_each_node=self.project_configuration.save_on_each_node,
safe_serialization=safe_serialization,
save_model_only=save_model_only
)
for i, obj in enumerate(self._custom_objects):
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def save_accelerator_state(
scaler: Optional[GradScaler] = None,
save_on_each_node: bool = False,
safe_serialization: bool = True,
save_model_only: bool = False,
):
"""
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
Expand Down Expand Up @@ -113,6 +114,10 @@ def save_accelerator_state(
output_model_file = output_dir.joinpath(weights_name)
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
logger.info(f"Model weights saved in {output_model_file}")

if save_model_only:
return output_dir

# Optimizer states
for i, opt in enumerate(optimizers):
state = opt.state_dict()
Expand Down
15 changes: 15 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,21 @@ def test_save_model(self, use_safetensors):
load_checkpoint_in_model(model, tmpdirname)
assert abs(model_signature - get_signature(model)) < 1e-3

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_state_model_only(self, use_safetensors):
accelerator = Accelerator()
model = torch.nn.Linear(10, 10)
model = accelerator.prepare(model)

model_signature = get_signature(model)
with tempfile.TemporaryDirectory() as tmpdirname:
accelerator.save_state(tmpdirname, safe_serialization=use_safetensors, save_model_only=True)
# make sure only the model was saved
assert os.listdir(tmpdirname) == ['model.safetensors' if use_safetensors else 'pytorch_model.bin']
# make sure loaded weights match
load_checkpoint_in_model(model, tmpdirname)
assert abs(model_signature - get_signature(model)) < 1e-3

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_sharded_model(self, use_safetensors):
accelerator = Accelerator()
Expand Down