diff --git a/simclr.py b/simclr.py index 9a58f94a..0535bac5 100644 --- a/simclr.py +++ b/simclr.py @@ -145,6 +145,6 @@ def _validate(self, model, valid_loader): loss = self._step(model, xis, xjs, counter) valid_loss += loss.item() - valid_loss /= counter + valid_loss = valid_loss / (counter + 1) model.train() return valid_loss