Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
644273a
feat: add timing metrics (startup, training, overall)
florianscheidl Apr 17, 2026
92120b8
docs: add agent structure with skills, tasks, and docs
florianscheidl Apr 17, 2026
df69dcb
docs: add skills review cycle for periodic compactification
florianscheidl Apr 17, 2026
7f0648f
configs
florianscheidl Apr 17, 2026
8fe45a0
Merge branch 'feature/timing-metrics' into ekfs/scaling-plots-20260417
florianscheidl Apr 17, 2026
dd55fb0
Remove hermes tool tracking for now
florianscheidl Apr 17, 2026
09b6e82
Try duration metrics
florianscheidl Apr 17, 2026
da3c29b
Update metrics, store after each mini-epoch
florianscheidl Apr 17, 2026
fc9a111
Refactor configs/streams
florianscheidl Apr 20, 2026
cfc4c62
Extract scaling data
florianscheidl Apr 20, 2026
82b503a
Script to generate scaling plots
florianscheidl Apr 20, 2026
70053b1
Script update
florianscheidl Apr 20, 2026
0c2df97
Repeat data in mini epoch
florianscheidl Apr 20, 2026
2c79d28
corrected time window length
florianscheidl Apr 20, 2026
6374986
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 20, 2026
b5d70f6
Lower to 512 samples per mini epoch
florianscheidl Apr 20, 2026
f46828c
Updated extraction script
florianscheidl Apr 20, 2026
89ac519
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 20, 2026
7cad6b5
Log time more often
florianscheidl Apr 21, 2026
30ac102
Fix training start scope
florianscheidl Apr 21, 2026
5e7f63e
Minimal validation
florianscheidl Apr 22, 2026
2be95c6
Increase samples_per_mini_epoch to 1024
florianscheidl Apr 22, 2026
93b203b
Final training duration and terminal/metric logging
florianscheidl Apr 23, 2026
2b708e3
log metrics after mini-epoch
florianscheidl Apr 23, 2026
0d8407d
Log metrics after mini-epoch, change schema
florianscheidl Apr 23, 2026
422fc60
MEtric typo
florianscheidl Apr 23, 2026
f63cba9
Logging refactor
florianscheidl Apr 23, 2026
b596c14
Update extraction script
florianscheidl Apr 23, 2026
42ba646
NNode extraction
florianscheidl Apr 23, 2026
c9fa64d
Logs path
florianscheidl Apr 23, 2026
ccfbc64
Wait until all training complete and wait with validation until logs …
florianscheidl Apr 23, 2026
c0f96b7
Log seconds rather than hours
florianscheidl Apr 23, 2026
701eb00
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 23, 2026
e6475e9
Measure dataset advancement time
florianscheidl Apr 23, 2026
6fd001f
LR scheduler lower bounds
florianscheidl Apr 23, 2026
313cec6
At least two warmup steps
florianscheidl Apr 23, 2026
aa4d399
Len per rank at least 1 to avoid zero division error
florianscheidl Apr 24, 2026
7956c52
Write csv for easier viewing
florianscheidl Apr 24, 2026
cf659e1
Extraction and plotting
florianscheidl Apr 24, 2026
177df79
Remove parent dir creation
florianscheidl Apr 24, 2026
8a4bc56
more detailed extraction script
florianscheidl Apr 24, 2026
bca6d3d
Remove overall time logging
florianscheidl Apr 24, 2026
9436811
Cleanup trainer
florianscheidl Apr 24, 2026
21c1575
Metrics extraction and plot generation scripts
florianscheidl Apr 24, 2026
e67616a
Add efficiency factor in plot
florianscheidl Apr 24, 2026
b1e4ea4
RM checkpoint and log metrics at last iteration
florianscheidl Apr 27, 2026
c89fe20
Detailed metrics
florianscheidl Apr 27, 2026
c14b749
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Apr 27, 2026
b0bc6c2
Remove barrier and extra logging on last batch
florianscheidl Apr 27, 2026
133ee4c
trainer code cleanup
florianscheidl Apr 27, 2026
ec665da
Lower bound beta2 in adam
florianscheidl Apr 27, 2026
2088311
update script for scaling plots, loss as separate entry point
florianscheidl Apr 28, 2026
5bd88d9
specify nodes in scaling data script
florianscheidl Apr 29, 2026
7e1ae1c
Update extract scaling data
florianscheidl Apr 29, 2026
b42432c
Add pyarrow
florianscheidl Apr 30, 2026
6d5683b
Update script for scaling plots
florianscheidl Apr 30, 2026
0396290
Update to generating scaling plots
florianscheidl May 4, 2026
86093d7
Merge branch 'develop' into ekfs/scaling-plots-20260417
florianscheidl May 4, 2026
6800262
Move scaling scripts to package
florianscheidl May 4, 2026
c5af276
init refactor
florianscheidl May 4, 2026
e760c13
Setup and linting
florianscheidl May 4, 2026
556106e
Updated plot generation script
florianscheidl May 4, 2026
b93813d
Update readme
florianscheidl May 4, 2026
b2fe866
Fewer diffs
florianscheidl May 4, 2026
c574514
no gitignore changes
florianscheidl May 4, 2026
dad5462
Refactor logging and move time for mini epoch logging outside loop
florianscheidl May 4, 2026
4f11519
Formatting and style fixes
florianscheidl May 4, 2026
b02b38f
Update config
florianscheidl May 4, 2026
55d8219
Avoid duplicate metrics
florianscheidl May 4, 2026
904713d
Fix lint issues
florianscheidl May 4, 2026
9ecd544
t_training in __init__
florianscheidl May 4, 2026
9f02dc1
Renamed metric
florianscheidl May 8, 2026
bfd5424
Merge branch 'develop' into ekfs/scaling-plots-20260417
clessig Jun 8, 2026
0785e3b
mv performance package
florianscheidl Jun 12, 2026
e982776
Plot losses against elapsed training time via --x-axis flag
florianscheidl Jun 12, 2026
367454d
Fewer changes
florianscheidl Jun 12, 2026
d0f851c
rm configs
florianscheidl Jun 12, 2026
2f25ee8
Remove startup time
florianscheidl Jun 12, 2026
198b542
Remove startup time
florianscheidl Jun 12, 2026
6f7ff39
Merge branch 'develop' into flo/plot-training-time-axis
florianscheidl Jun 12, 2026
0de8660
Formatting and removed time per epoch
florianscheidl Jun 12, 2026
11bcd2b
Merge branch 'flo/plot-training-time-axis' into ekfs/scaling-plots-20…
florianscheidl Jun 12, 2026
39f8075
Undo pyproject change
florianscheidl Jun 12, 2026
41b1fbf
Merge branch 'develop' into ekfs/scaling-plots-20260417
florianscheidl Jun 12, 2026
61ad9cf
ploting changes wip
florianscheidl Jun 12, 2026
e099f84
Undo pyproject changes
florianscheidl Jun 12, 2026
a45c255
Merge branch 'ekfs/scaling-plots-20260417' of github.com:florianschei…
florianscheidl Jun 12, 2026
68c0428
Merge branch 'ekfs/scaling-plots-20260417' into flo/plot-training-tim…
florianscheidl Jun 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/weathergen/train/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(
logger.debug(f"steps_decay={self.n_steps_decay} lr_steps={lr_steps}")
# ensure that steps_decay has a reasonable value
if self.n_steps_decay < int(0.2 * lr_steps):
self.n_steps_warmup = int(0.1 * lr_steps)
self.n_steps_cooldown = int(0.05 * lr_steps)
self.n_steps_decay = lr_steps - self.n_steps_warmup - self.n_steps_cooldown
self.n_steps_warmup = max(2, int(0.1 * lr_steps))
self.n_steps_cooldown = max(1, int(0.05 * lr_steps))
self.n_steps_decay = max(1, lr_steps - self.n_steps_warmup - self.n_steps_cooldown)
s = (
"cf.lr_steps_warmup and cf.lr_steps_cooldown",
f" were larger than cf.lr_steps={lr_steps}",
Expand Down
8 changes: 6 additions & 2 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, train_logging: Config):
self.batch_size_test_per_gpu = -1
self.collapse_monitor: CollapseMonitor | None = None
self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker()
self.t_training_start: float = 0

def get_batch_size_total(self, batch_size_per_gpu) -> int:
"""
Expand Down Expand Up @@ -321,7 +322,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None):
# aiming for beta1 = 0.9 at one node, ie kappa=B=4
beta1 = max(0.5, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta1))
# aiming for beta2 = 0.95 at one node, ie B=4
beta2 = 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2)
beta2 = max(0.9, 1.0 - kappa * (1.0 - self.training_cfg.optimizer.adamw.beta2))
eps = self.training_cfg.optimizer.adamw.get("eps", 2e-08) / np.sqrt(kappa)

self.optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -361,7 +362,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None):
mini_epoch_base = int(self.cf.general.istep / len(self.data_loader))
else:
len_per_rank = (
len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu)
max(1, len(self.dataset) // (self.world_size_original * self.batch_size_per_gpu))
) * self.batch_size_per_gpu
mini_epoch_base = int(
self.cf.general.istep
Expand All @@ -379,6 +380,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None):
self.validate_before_training()

# training loop
self.t_training_start = time.time()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For plot_train, timing should start here. This would also avoid that run_train is modified.


for mini_epoch in range(mini_epoch_base, self.training_cfg.num_mini_epochs):
if is_root():
Expand Down Expand Up @@ -747,13 +749,15 @@ def _log(self, stage: Stage):
self.train_logger.add_logs(stage, samples, losses_all, stddev_all)

elif self.cf.general.istep >= 0:
elapsed_time = time.time() - self.t_training_start
self.train_logger.add_logs(
stage,
samples,
losses_all,
stddev_all,
avg_loss=avg_loss,
lr=self.lr_scheduler.get_lr(),
elapsed_training_time_seconds=elapsed_time,
)

loss_calculator.loss_hist = []
Expand Down
53 changes: 32 additions & 21 deletions src/weathergen/utils/plot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def plot_loss_avg(
runs_data,
runs_active,
stage=TRAIN,
x_axis: str = "samples",
x_scale_log=False,
legend_outside: bool = False,
legend_font_size: str = "x-small",
Expand All @@ -322,10 +323,14 @@ def plot_loss_avg(

_fig = plt.figure(figsize=(10, 7), dpi=PLOT_DPI_VALUE)

# x-axis label: "elapsed_training_time" -> "elapsed training time [s]", else "step"
x_label = "elapsed training time [s]" if "elapsed_training_time" in x_axis else "step"

legend_str = []
for i_run, (run_id, run_data) in enumerate(zip(runs_ids, runs_data, strict=False)):
run_data_stage = run_data.train if stage == TRAIN else run_data.val
x_vals = np.array(run_data_stage["num_samples"])
x_col = next(filter(lambda c: x_axis in c, run_data_stage.columns))
x_vals = np.array(run_data_stage[x_col])
y_vals = np.array(run_data_stage["loss_avg_mean"])

mask = np.logical_and(~np.isnan(x_vals), ~np.isnan(y_vals))
Expand All @@ -347,7 +352,7 @@ def plot_loss_avg(
plt.xscale("log")
plt.title("average loss")
plt.ylabel("loss")
plt.xlabel("step")
plt.xlabel(x_label)
plt.tight_layout()
_add_legend(
legend_str,
Expand Down Expand Up @@ -379,10 +384,9 @@ def plot_loss_per_stream(
channels: list[str],
forecast_steps: list[int],
x_axis: str = "samples",
x_type: str = "step",
x_scale_log: bool = False,
x_lim: list[float] | None = None,
y_lim: list[float] | None = None,
x_scale_log: bool = False,
legend_outside: bool = False,
legend_font_size: str = "x-small",
legend_num_columns: int = 3,
Expand All @@ -408,9 +412,7 @@ def plot_loss_per_stream(
errs : list
list of errors to plot (e.g. mse, stddev)
x_axis : str
x-axis strings used in the column names (options: "samples", "dtime")
x_type : str
x-axis type (options: "step", "reltime")
x-axis column name used for the x-axis (options: "samples", "elapsed_training_time")
x_scale_log : bool
whether to use log scale for x-axis
"""
Expand Down Expand Up @@ -525,7 +527,13 @@ def plot_loss_per_stream(
title_loss = ".".join(title_col.split(".")[:-1])
plt.title(title_loss + " (" + ", ".join(modes) + ")")
plt.ylabel(err)
plt.xlabel(x_axis if x_type == "step" else "rel. time [h]")
# x-axis label: "elapsed_training_time" -> friendly name, else use column as-is
x_label = (
"elapsed training time [s]"
if "elapsed_training_time" in x_axis
else x_axis
)
plt.xlabel(x_label)
plt.tight_layout()
_add_legend(
legend_str,
Expand Down Expand Up @@ -596,7 +604,7 @@ def plot_loss_per_run(
errs : list
list of errors to plot (e.g. mse, stddev)
x_axis : str
x-axis strings used in the column names (options: "samples", "dtime")
x-axis column name used for the x-axis (options: "samples", "elapsed_training_time")
x_scale_log : bool
whether to use log scale for x-axis
"""
Expand Down Expand Up @@ -666,7 +674,10 @@ def plot_loss_per_run(
plt.xscale("log")
plt.grid(True, which="both", ls="-")
plt.ylabel("loss")
plt.xlabel("samples")
x_label = (
"elapsed training time [s]" if "elapsed_training_time" in x_axis else x_axis
)
plt.xlabel(x_label)
plt.tight_layout()
_add_legend(
legend_str,
Expand Down Expand Up @@ -794,13 +805,13 @@ def plot_train(args=None):
help="x-lim for per-stream plots",
)
parser.add_argument(
"--x_type",
"--x-axis",
"-x",
dest="x_type",
default="step",
dest="x_axis",
default="samples",
type=str,
choices=["step", "reltime"],
help="Type of x-axis used in plots. Options: 'step' or 'reltime'",
choices=["samples", "elapsed_training_time"],
help="X-axis column for plots: 'samples' (default) or 'elapsed_training_time'",
)
parser.add_argument(
"--log-x",
Expand Down Expand Up @@ -862,9 +873,7 @@ def plot_train(args=None):
model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None
out_dir = Path(args.output_dir)
streams = list(args.streams)
x_types_valid = ["step"] # TODO: add "reltime" support when fix available
if args.x_type not in x_types_valid:
raise ValueError(f"x_type must be one of {x_types_valid}, but got {args.x_type}")
x_axis = args.x_axis

# Post-processing default logic for config from YAML-file
if args.fd is None and args.fy is None:
Expand Down Expand Up @@ -924,6 +933,7 @@ def plot_train(args=None):
runs_data,
runs_active,
plot_dir=out_dir,
x_axis=x_axis,
legend_outside=args.legend_outside,
legend_font_size=args.legend_font_size,
legend_num_columns=args.legend_num_columns,
Expand All @@ -937,6 +947,7 @@ def plot_train(args=None):
runs_data,
runs_active,
stage=TRAIN,
x_axis=x_axis,
legend_outside=args.legend_outside,
legend_font_size=args.legend_font_size,
legend_num_columns=args.legend_num_columns,
Expand All @@ -953,7 +964,7 @@ def plot_train(args=None):
errs=args.metrics,
channels=args.channels,
forecast_steps=args.forecast_steps,
x_type=args.x_type,
x_axis=x_axis,
x_scale_log=x_scale_log,
x_lim=args.per_stream_x_lim,
y_lim=args.per_stream_y_lim,
Expand All @@ -972,7 +983,7 @@ def plot_train(args=None):
errs=args.metrics,
channels=args.channels,
forecast_steps=args.forecast_steps,
x_type=args.x_type,
x_axis=x_axis,
x_scale_log=x_scale_log,
x_lim=args.per_stream_x_lim,
y_lim=args.per_stream_y_lim,
Expand All @@ -991,7 +1002,7 @@ def plot_train(args=None):
errs=args.metrics,
channels=args.channels,
forecast_steps=args.forecast_steps,
x_type=args.x_type,
x_axis=x_axis,
x_scale_log=x_scale_log,
x_lim=args.per_stream_x_lim,
y_lim=args.per_stream_y_lim,
Expand Down
10 changes: 9 additions & 1 deletion src/weathergen/utils/train_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def add_logs(
stddev_all: dict,
avg_loss: list[float] = None,
lr: float = None,
elapsed_training_time_seconds: float | None = None,
) -> None:
"""
Log training or validation data
Log training or validation data.
"""
metrics: dict[str, float] = dict(num_samples=samples)

Expand All @@ -112,6 +113,13 @@ def add_logs(
metrics["loss_avg_mean"] = val
metrics["learning_rate"] = lr
metrics["num_samples"] = int(samples)
if elapsed_training_time_seconds is not None:
metrics["elapsed_training_time_seconds"] = elapsed_training_time_seconds
metrics["average_samples_per_second"] = (
samples / elapsed_training_time_seconds
if elapsed_training_time_seconds > 0
else 0
)

for key, value in losses_all.items():
val = np.nan if np.isnan(value).all() else np.nanmean(value)
Expand Down
Loading