Reduce EMA RAM usage and training overhead with local-shard EMA#15
Open
chijw wants to merge 1 commit intoTencentARC:mainfrom
Open
Reduce EMA RAM usage and training overhead with local-shard EMA#15chijw wants to merge 1 commit intoTencentARC:mainfrom
chijw wants to merge 1 commit intoTencentARC:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the FSDP EMA implementation to keep EMA weights shard-local during training (avoiding summon_full_params() in the hot path) and to export a full generator_ema state dict only at checkpoint save time via an FSDP full-state-dict gather.
Changes:
- Remove
summon_full_params()from EMA initialization, update, and copy logic to keep EMA shard-local. - Add
EMA_FSDP.full_state_dict()to swap in EMA shards, gather a full FSDP state dict, then restore live weights. - Update trainer checkpoint save paths to store
generator_emausingfull_state_dict(self.model.generator).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| utils/distributed.py | Updates EMA to operate on local parameter shards and adds a full-state-dict export path for checkpointing. |
| trainer/gan.py | Saves generator_ema via the new EMA full-state-dict gather path. |
| trainer/distillation.py | Saves generator_ema via the new EMA full-state-dict gather path. |
| trainer/diffusion.py | Saves generator_ema via the new EMA full-state-dict gather path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Comment on lines
115
to
+118
| def copy_to(self, fsdp_module): | ||
| # load EMA weights into an (unwrapped) copy of the generator | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| with FSDP.summon_full_params(fsdp_module, writeback=True): | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in self.shadow: | ||
| p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in self.shadow: | ||
| p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) |
Comment on lines
+125
to
+133
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in self.shadow: | ||
| p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
||
| checkpoint = fsdp_state_dict(fsdp_module) | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in live_state: | ||
| p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) | ||
|
|
| def full_state_dict(self, fsdp_module): | ||
| live_state = {} | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| live_state[n] = p.detach().clone() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.
The previous implementation used
summon_full_params()in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.To preserve the existing checkpoint format,
generator_emais still exported as a full state dict at save time. Since EMA is shard-local during training,full_state_dict()reuses the FSDP-wrapped module together withfsdp_state_dict()to gather the full checkpoint, instead of introducing a separate EMA-specific export path.Changes
EMA_FSDPshard-local during trainingsummon_full_params()from EMA init/update/copygenerator_emaonly at save timeself.generator_ema.full_state_dict(self.model.generator)to(dtype=..., device=...)when copying EMA tensors back, for better compatibility with newer PyTorch versions