Skip to content

Reduce EMA RAM usage and training overhead with local-shard EMA#15

Open
chijw wants to merge 1 commit intoTencentARC:mainfrom
chijw:main
Open

Reduce EMA RAM usage and training overhead with local-shard EMA#15
chijw wants to merge 1 commit intoTencentARC:mainfrom
chijw:main

Conversation

@chijw
Copy link

@chijw chijw commented Mar 15, 2026

Summary

This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.

The previous implementation used summon_full_params() in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.

To preserve the existing checkpoint format, generator_ema is still exported as a full state dict at save time. Since EMA is shard-local during training, full_state_dict() reuses the FSDP-wrapped module together with fsdp_state_dict() to gather the full checkpoint, instead of introducing a separate EMA-specific export path.

Changes

  • keep EMA_FSDP shard-local during training
  • remove summon_full_params() from EMA init/update/copy
  • export full generator_ema only at save time
  • switch trainer save paths to use self.generator_ema.full_state_dict(self.model.generator)
  • use to(dtype=..., device=...) when copying EMA tensors back, for better compatibility with newer PyTorch versions

Copilot AI review requested due to automatic review settings March 15, 2026 06:27
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the FSDP EMA implementation to keep EMA weights shard-local during training (avoiding summon_full_params() in the hot path) and to export a full generator_ema state dict only at checkpoint save time via an FSDP full-state-dict gather.

Changes:

  • Remove summon_full_params() from EMA initialization, update, and copy logic to keep EMA shard-local.
  • Add EMA_FSDP.full_state_dict() to swap in EMA shards, gather a full FSDP state dict, then restore live weights.
  • Update trainer checkpoint save paths to store generator_ema using full_state_dict(self.model.generator).

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
utils/distributed.py Updates EMA to operate on local parameter shards and adds a full-state-dict export path for checkpointing.
trainer/gan.py Saves generator_ema via the new EMA full-state-dict gather path.
trainer/distillation.py Saves generator_ema via the new EMA full-state-dict gather path.
trainer/diffusion.py Saves generator_ema via the new EMA full-state-dict gather path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines 115 to +118
def copy_to(self, fsdp_module):
# load EMA weights into an (unwrapped) copy of the generator
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(fsdp_module, writeback=True):
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device))
Comment on lines +125 to +133
for n, p in fsdp_module.module.named_parameters():
if n in self.shadow:
p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device))

checkpoint = fsdp_state_dict(fsdp_module)
for n, p in fsdp_module.module.named_parameters():
if n in live_state:
p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device))

def full_state_dict(self, fsdp_module):
live_state = {}
for n, p in fsdp_module.module.named_parameters():
live_state[n] = p.detach().clone()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants