diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..9c8185fc2 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -53,9 +53,9 @@ def __init__( logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}") # ensure that steps_decay has a reasonable value if self.n_steps_decay < int(0.2 * lr_steps): - self.n_steps_warmup = int(0.1 * lr_steps) - self.n_steps_cooldown = int(0.05 * lr_steps) - self.n_steps_decay = lr_steps - self.n_steps_warmup - self.n_steps_cooldown + self.n_steps_warmup = max(2, int(0.1 * lr_steps)) + self.n_steps_cooldown = max(1, int(0.05 * lr_steps)) + self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown) s = ( "cf.lr_steps_warmup and cf.lr_steps_cooldown", f" were larger than cf.lr_steps={lr_steps}", diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f396e611c..a31dd57a7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -86,6 +86,7 @@ def __init__(self, train_logging: Config): self.batch_size_test_per_gpu = -1 self.collapse_monitor: CollapseMonitor | None = None self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker() + self.t_training_start: float = 0 def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -321,7 +322,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) + beta2 = max(0.9, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2)) eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( @@ -361,7 +362,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: len_per_rank = ( - len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu) + max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu mini_epoch_base = int( self.cf.general.istep @@ -379,6 +380,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.validate_before_training() # training loop + self.t_training_start = time.time() for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): if is_root(): @@ -747,6 +749,7 @@ def _log(self, stage: Stage): self.train_logger.add_logs(stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: + elapsed_time = time.time() - self.t_training_start self.train_logger.add_logs( stage, samples, @@ -754,6 +757,7 @@ def _log(self, stage: Stage): stddev_all, avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), + elapsed_training_time_seconds=elapsed_time, ) loss_calculator.loss_hist = [] diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 53e4e551f..d7501a1cd 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -101,9 +101,10 @@ def add_logs( stddev_all: dict, avg_loss: list[float] = None, lr: float = None, + elapsed_training_time_seconds: float | None = None, ) -> None: """ - Log training or validation data + Log training or validation data. """ metrics: dict[str, float] = dict(num_samples=samples) @@ -112,6 +113,13 @@ def add_logs( metrics["loss_avg_mean"] = val metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) + if elapsed_training_time_seconds is not None: + metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds + metrics["average_samples_per_second"] = ( + samples / elapsed_training_time_seconds + if elapsed_training_time_seconds > 0 + else 0 + ) for key, value in losses_all.items(): val = np.nan if np.isnan(value).all() else np.nanmean(value)