diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e85cd1abf..9c8185fc2 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -53,9 +53,9 @@ def __init__( logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}") # ensure that steps_decay has a reasonable value if self.n_steps_decay < int(0.2 * lr_steps): - self.n_steps_warmup = int(0.1 * lr_steps) - self.n_steps_cooldown = int(0.05 * lr_steps) - self.n_steps_decay = lr_steps - self.n_steps_warmup - self.n_steps_cooldown + self.n_steps_warmup = max(2, int(0.1 * lr_steps)) + self.n_steps_cooldown = max(1, int(0.05 * lr_steps)) + self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown) s = ( "cf.lr_steps_warmup and cf.lr_steps_cooldown", f" were larger than cf.lr_steps={lr_steps}", diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f396e611c..a31dd57a7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -86,6 +86,7 @@ def __init__(self, train_logging: Config): self.batch_size_test_per_gpu = -1 self.collapse_monitor: CollapseMonitor | None = None self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker() + self.t_training_start: float = 0 def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -321,7 +322,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1)) # aiming for beta2 = 0.95 at one node, ie B=4 - beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2) + beta2 = max(0.9, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2)) eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( @@ -361,7 +362,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): mini_epoch_base = int(self.cf.general.istep / len(self.data_loader)) else: len_per_rank = ( - len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu) + max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)) ) * self.batch_size_per_gpu mini_epoch_base = int( self.cf.general.istep @@ -379,6 +380,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.validate_before_training() # training loop + self.t_training_start = time.time() for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs): if is_root(): @@ -747,6 +749,7 @@ def _log(self, stage: Stage): self.train_logger.add_logs(stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: + elapsed_time = time.time() - self.t_training_start self.train_logger.add_logs( stage, samples, @@ -754,6 +757,7 @@ def _log(self, stage: Stage): stddev_all, avg_loss=avg_loss, lr=self.lr_scheduler.get_lr(), + elapsed_training_time_seconds=elapsed_time, ) loss_calculator.loss_hist = [] diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 5678cbbc2..059ea57bf 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -311,6 +311,7 @@ def plot_loss_avg( runs_data, runs_active, stage=TRAIN, + x_axis: str = "samples", x_scale_log=False, legend_outside: bool = False, legend_font_size: str = "x-small", @@ -322,10 +323,14 @@ def plot_loss_avg( _fig = plt.figure(figsize=(10, 7), dpi=PLOT_DPI_VALUE) + # x-axis label: "elapsed_training_time" -> "elapsed training time [s]", else "step" + x_label = "elapsed training time [s]" if "elapsed_training_time" in x_axis else "step" + legend_str = [] for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)): run_data_stage = run_data.train if stage == TRAIN else run_data.val - x_vals = np.array(run_data_stage["num_samples"]) + x_col = next(filter(lambda c: x_axis in c, run_data_stage.columns)) + x_vals = np.array(run_data_stage[x_col]) y_vals = np.array(run_data_stage["loss_avg_mean"]) mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals)) @@ -347,7 +352,7 @@ def plot_loss_avg( plt.xscale("log") plt.title("average loss") plt.ylabel("loss") - plt.xlabel("step") + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -379,10 +384,9 @@ def plot_loss_per_stream( channels: list[str], forecast_steps: list[int], x_axis: str = "samples", - x_type: str = "step", + x_scale_log: bool = False, x_lim: list[float] | None = None, y_lim: list[float] | None = None, - x_scale_log: bool = False, legend_outside: bool = False, legend_font_size: str = "x-small", legend_num_columns: int = 3, @@ -408,9 +412,7 @@ def plot_loss_per_stream( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") - x_type : str - x-axis type (options: "step", "reltime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -525,7 +527,13 @@ def plot_loss_per_stream( title_loss = ".".join(title_col.split(".")[:-1]) plt.title(title_loss + " (" + ", ".join(modes) + ")") plt.ylabel(err) - plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") + # x-axis label: "elapsed_training_time" -> friendly name, else use column as-is + x_label = ( + "elapsed training time [s]" + if "elapsed_training_time" in x_axis + else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -596,7 +604,7 @@ def plot_loss_per_run( errs : list list of errors to plot (e.g. mse, stddev) x_axis : str - x-axis strings used in the column names (options: "samples", "dtime") + x-axis column name used for the x-axis (options: "samples", "elapsed_training_time") x_scale_log : bool whether to use log scale for x-axis """ @@ -666,7 +674,10 @@ def plot_loss_per_run( plt.xscale("log") plt.grid(True, which="both", ls="-") plt.ylabel("loss") - plt.xlabel("samples") + x_label = ( + "elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis + ) + plt.xlabel(x_label) plt.tight_layout() _add_legend( legend_str, @@ -794,13 +805,13 @@ def plot_train(args=None): help="x-lim for per-stream plots", ) parser.add_argument( - "--x_type", + "--x-axis", "-x", - dest="x_type", - default="step", + dest="x_axis", + default="samples", type=str, - choices=["step", "reltime"], - help="Type of x-axis used in plots. Options: 'step' or 'reltime'", + choices=["samples", "elapsed_training_time"], + help="X-axis column for plots: 'samples' (default) or 'elapsed_training_time'", ) parser.add_argument( "--log-x", @@ -862,9 +873,7 @@ def plot_train(args=None): model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None out_dir = Path(args.output_dir) streams = list(args.streams) - x_types_valid = ["step"] # TODO: add "reltime" support when fix available - if args.x_type not in x_types_valid: - raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}") + x_axis = args.x_axis # Post-processing default logic for config from YAML-file if args.fd is None and args.fy is None: @@ -924,6 +933,7 @@ def plot_train(args=None): runs_data, runs_active, plot_dir=out_dir, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -937,6 +947,7 @@ def plot_train(args=None): runs_data, runs_active, stage=TRAIN, + x_axis=x_axis, legend_outside=args.legend_outside, legend_font_size=args.legend_font_size, legend_num_columns=args.legend_num_columns, @@ -953,7 +964,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -972,7 +983,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, @@ -991,7 +1002,7 @@ def plot_train(args=None): errs=args.metrics, channels=args.channels, forecast_steps=args.forecast_steps, - x_type=args.x_type, + x_axis=x_axis, x_scale_log=x_scale_log, x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 53e4e551f..d7501a1cd 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -101,9 +101,10 @@ def add_logs( stddev_all: dict, avg_loss: list[float] = None, lr: float = None, + elapsed_training_time_seconds: float | None = None, ) -> None: """ - Log training or validation data + Log training or validation data. """ metrics: dict[str, float] = dict(num_samples=samples) @@ -112,6 +113,13 @@ def add_logs( metrics["loss_avg_mean"] = val metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) + if elapsed_training_time_seconds is not None: + metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds + metrics["average_samples_per_second"] = ( + samples / elapsed_training_time_seconds + if elapsed_training_time_seconds > 0 + else 0 + ) for key, value in losses_all.items(): val = np.nan if np.isnan(value).all() else np.nanmean(value)