diff --git a/curriculum_learning.py b/curriculum_learning.py index 4b25fe25..6bd0d15a 100644 --- a/curriculum_learning.py +++ b/curriculum_learning.py @@ -1115,6 +1115,15 @@ def _train_stage( loss = self._get_model().compute_loss(batch) loss.backward() + # _get_model() unwraps DDP, so its gradient sync is bypassed; average grads manually. + if dist.is_initialized(): + for p in self._get_model().parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + # Handle gradient clipping for distributed training clip_grad_norm_(self._get_model().parameters(), GRAD_CLIP_NORM)