diff --git a/trainer/diffusion.py b/trainer/diffusion.py index 06f4e9e..26a9faf 100644 --- a/trainer/diffusion.py +++ b/trainer/diffusion.py @@ -145,7 +145,7 @@ def save(self): if self.config.ema_start_step < self.step: state_dict = { "generator": generator_state_dict, - "generator_ema": self.generator_ema.state_dict(), + "generator_ema": self.generator_ema.full_state_dict(self.model.generator), } else: state_dict = { diff --git a/trainer/distillation.py b/trainer/distillation.py index 6330118..52f2a08 100644 --- a/trainer/distillation.py +++ b/trainer/distillation.py @@ -184,7 +184,7 @@ def save(self): state_dict = { "generator": generator_state_dict, "critic": critic_state_dict, - "generator_ema": self.generator_ema.state_dict(), + "generator_ema": self.generator_ema.full_state_dict(self.model.generator), } else: state_dict = { diff --git a/trainer/gan.py b/trainer/gan.py index e632e81..485f0c1 100644 --- a/trainer/gan.py +++ b/trainer/gan.py @@ -216,7 +216,7 @@ def save(self): state_dict = { "generator": generator_state_dict, "critic": critic_state_dict, - "generator_ema": self.generator_ema.state_dict(), + "generator_ema": self.generator_ema.full_state_dict(self.model.generator), } else: state_dict = { diff --git a/utils/distributed.py b/utils/distributed.py index 4367ded..c2e8913 100644 --- a/utils/distributed.py +++ b/utils/distributed.py @@ -96,18 +96,14 @@ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999): @torch.no_grad() def _init_shadow(self, fsdp_module): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=False): - for n, p in fsdp_module.module.named_parameters(): - self.shadow[n] = p.detach().clone().float().cpu() + for n, p in fsdp_module.module.named_parameters(): + self.shadow[n] = p.detach().clone().float().cpu() @torch.no_grad() def update(self, fsdp_module): d = self.decay - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=False): - for n, p in fsdp_module.module.named_parameters(): - self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) + for n, p in fsdp_module.module.named_parameters(): + self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) # Optional helpers --------------------------------------------------- def state_dict(self): @@ -117,9 +113,22 @@ def load_state_dict(self, sd): self.shadow = {k: v.clone() for k, v in sd.items()} def copy_to(self, fsdp_module): - # load EMA weights into an (unwrapped) copy of the generator - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params(fsdp_module, writeback=True): - for n, p in fsdp_module.module.named_parameters(): - if n in self.shadow: - p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) + for n, p in fsdp_module.module.named_parameters(): + if n in self.shadow: + p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) + + @torch.no_grad() + def full_state_dict(self, fsdp_module): + live_state = {} + for n, p in fsdp_module.module.named_parameters(): + live_state[n] = p.detach().clone() + for n, p in fsdp_module.module.named_parameters(): + if n in self.shadow: + p.data.copy_(self.shadow[n].to(dtype=p.dtype, device=p.device)) + + checkpoint = fsdp_state_dict(fsdp_module) + for n, p in fsdp_module.module.named_parameters(): + if n in live_state: + p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) + + return checkpoint