From cd019c1687f439f98700e2d1b7193ef0a86e5b01 Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Wed, 10 Jun 2026 23:09:24 +0200 Subject: [PATCH 1/5] added cosine matching loss --- config/config_forecasting.yml | 11 ++++ src/weathergen/model/model.py | 10 ++- src/weathergen/train/loss_modules/__init__.py | 3 +- .../loss_module_cosine_matching.py | 62 +++++++++++++++++++ 4 files changed, 84 insertions(+), 2 deletions(-) create mode 100644 src/weathergen/train/loss_modules/loss_module_cosine_matching.py diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index eb8b0f59d..de94c7bcc 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -173,6 +173,17 @@ training_config: type: LossPhysical, loss_fcts: { "mse": { }, }, }, + "cosine_matching": { + type: LossLatentCosineMatching, + weight: 1.0, + target_and_aux_calc: "Physical", + loss_fcts: { + "params": { + cosine_low: 0.68, + cosine_high: 0.78, + }, + }, + }, } model_input: { diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d8a30a722..b5da13cfe 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -697,10 +697,18 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): + prev_tokens = tokens tokens = self.forecast_engine(tokens, step, model_params.rope_coords) continue - + prev_tokens = tokens tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + + # per-token cosine similarity between current and previous patch tokens + cur = tokens[:, self.num_aux_tokens:].reshape(-1, tokens.shape[-1]) + prv = prev_tokens[:, self.num_aux_tokens:].reshape(-1, tokens.shape[-1]) + cos_sim_to_prev = torch.nn.functional.cosine_similarity(cur, prv.detach(), dim=-1) + output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) + # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 00a8b7b31..5763eefe5 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -9,5 +9,6 @@ from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher +from .loss_module_cosine_matching import LossLatentCosineMatching -__all__ = [LossPhysical, LossLatentSSLStudentTeacher] +__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatentCosineMatching] diff --git a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py new file mode 100644 index 000000000..56741479c --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py @@ -0,0 +1,62 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig + +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatentCosineMatching(LossModuleBase): + """ + Band hinge on per-token cosine similarity between consecutive FE latent steps. + + Penalises tokens whose cosine similarity to the previous step falls outside + [cosine_low, cosine_high]. Both bounds are enforced as soft hinge losses so + the FE is free inside the sweet-spot and pays a quadratic penalty outside it. + + cos_sim_to_prev is computed in model.forward() per patch token and stored in + output.latent[step]["cos_sim_to_prev"] — no target calculator needed. + """ + + def __init__(self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatentCosineMatching" + + params = next(iter(loss_fcts.values()), {}) if loss_fcts else {} + self.cosine_low = params.get("cosine_low", 0.68) + self.cosine_high = params.get("cosine_high", 0.78) + + def compute_loss(self, preds, targets, metadata, **kwargs) -> LossValues: + acc_loss = torch.tensor(0.0, device=self.device, requires_grad=True) + count = 0 + + + for step_pred in preds.latent: + cos_sim = step_pred.get("cos_sim_to_prev", None) + if cos_sim is None: + continue + step_loss = ( + F.relu(cos_sim - self.cosine_high) ** 2 + + F.relu(self.cosine_low - cos_sim) ** 2 + ).mean() + acc_loss = acc_loss + step_loss + count += 1 + + loss = acc_loss / count if count > 0 else acc_loss + return LossValues(loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={}) From 03c71069b9f34acd4a8a91920b60d22e4944a0f6 Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Wed, 10 Jun 2026 23:15:29 +0200 Subject: [PATCH 2/5] lint check --- src/weathergen/model/model.py | 6 +++--- src/weathergen/train/loss_modules/__init__.py | 2 +- .../loss_modules/loss_module_cosine_matching.py | 12 +++++++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b5da13cfe..cca0441eb 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -702,10 +702,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: continue prev_tokens = tokens tokens = self.forecast_engine(tokens, step, model_params.rope_coords) - + # per-token cosine similarity between current and previous patch tokens - cur = tokens[:, self.num_aux_tokens:].reshape(-1, tokens.shape[-1]) - prv = prev_tokens[:, self.num_aux_tokens:].reshape(-1, tokens.shape[-1]) + cur = tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1]) + prv = prev_tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1]) cos_sim_to_prev = torch.nn.functional.cosine_similarity(cur, prv.detach(), dim=-1) output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 5763eefe5..f413688f3 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -7,8 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from .loss_module_cosine_matching import LossLatentCosineMatching from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher -from .loss_module_cosine_matching import LossLatentCosineMatching __all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatentCosineMatching] diff --git a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py index 56741479c..4d93867e7 100644 --- a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py +++ b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py @@ -31,7 +31,9 @@ class LossLatentCosineMatching(LossModuleBase): output.latent[step]["cos_sim_to_prev"] — no target calculator needed. """ - def __init__(self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts): + def __init__( + self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts + ): LossModuleBase.__init__(self) self.cf = cf self.stage = stage @@ -46,17 +48,17 @@ def compute_loss(self, preds, targets, metadata, **kwargs) -> LossValues: acc_loss = torch.tensor(0.0, device=self.device, requires_grad=True) count = 0 - for step_pred in preds.latent: cos_sim = step_pred.get("cos_sim_to_prev", None) if cos_sim is None: continue step_loss = ( - F.relu(cos_sim - self.cosine_high) ** 2 - + F.relu(self.cosine_low - cos_sim) ** 2 + F.relu(cos_sim - self.cosine_high) ** 2 + F.relu(self.cosine_low - cos_sim) ** 2 ).mean() acc_loss = acc_loss + step_loss count += 1 loss = acc_loss / count if count > 0 else acc_loss - return LossValues(loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={}) + return LossValues( + loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={} + ) From 50e5e321e3e0d27e01dc3dd4943bec0ea5d2d41c Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Tue, 16 Jun 2026 14:53:23 +0200 Subject: [PATCH 3/5] moved cosine calculation to loss_modules/utils.py structured the model forward function renmaed LossLatentCosineMatching to LossLatent removed target_aux_caluclator from the config --- config/config_forecasting.yml | 3 +- src/weathergen/model/model.py | 17 ++++++---- src/weathergen/train/loss_modules/__init__.py | 4 +-- .../loss_module_cosine_matching.py | 4 +-- src/weathergen/train/loss_modules/utils.py | 34 +++++++++++++++++++ 5 files changed, 50 insertions(+), 12 deletions(-) create mode 100644 src/weathergen/train/loss_modules/utils.py diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index de94c7bcc..ca6dbedce 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -174,9 +174,8 @@ training_config: loss_fcts: { "mse": { }, }, }, "cosine_matching": { - type: LossLatentCosineMatching, + type: LossLatent, weight: 1.0, - target_and_aux_calc: "Physical", loss_fcts: { "params": { cosine_low: 0.68, diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index cca0441eb..05dbd6092 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -39,6 +39,7 @@ ) from weathergen.model.layers import MLP, NamedLinear from weathergen.model.utils import get_num_parameters +from weathergen.train.loss_modules.utils import compute_cos_sim_to_prev from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype, is_stream_forcing @@ -398,6 +399,10 @@ def create(self) -> "Model": v.type for _, v in cf.validation_config.losses.items() if v.get("enabled", True) ] + # cos_sim_to_prev is only needed by the latent loss; compute it + # (and keep prev_tokens around for it) only when that loss is configured. + self.compute_cos_sim_to_prev = "LossLatent" in loss_terms + if "LossPhysical" in loss_terms: for i_stream, (stream_name, si) in enumerate(self.streams.items()): # skip decoder if channels are empty @@ -697,17 +702,17 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - prev_tokens = tokens tokens = self.forecast_engine(tokens, step, model_params.rope_coords) continue - prev_tokens = tokens + prev_tokens = tokens if self.compute_cos_sim_to_prev else None tokens = self.forecast_engine(tokens, step, model_params.rope_coords) # per-token cosine similarity between current and previous patch tokens - cur = tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1]) - prv = prev_tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1]) - cos_sim_to_prev = torch.nn.functional.cosine_similarity(cur, prv.detach(), dim=-1) - output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) + if self.compute_cos_sim_to_prev: + cos_sim_to_prev = compute_cos_sim_to_prev( + tokens, prev_tokens, self.num_aux_tokens + ) + output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index f413688f3..f5b4be31f 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -7,8 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from .loss_module_cosine_matching import LossLatentCosineMatching +from .loss_module_cosine_matching import LossLatent from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher -__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatentCosineMatching] +__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatent] \ No newline at end of file diff --git a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py index 4d93867e7..f01357839 100644 --- a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py +++ b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py @@ -19,7 +19,7 @@ _logger = logging.getLogger(__name__) -class LossLatentCosineMatching(LossModuleBase): +class LossLatent(LossModuleBase): """ Band hinge on per-token cosine similarity between consecutive FE latent steps. @@ -38,7 +38,7 @@ def __init__( self.cf = cf self.stage = stage self.device = device - self.name = "LossLatentCosineMatching" + self.name = "LossLatent" params = next(iter(loss_fcts.values()), {}) if loss_fcts else {} self.cosine_low = params.get("cosine_low", 0.68) diff --git a/src/weathergen/train/loss_modules/utils.py b/src/weathergen/train/loss_modules/utils.py new file mode 100644 index 000000000..12d59b324 --- /dev/null +++ b/src/weathergen/train/loss_modules/utils.py @@ -0,0 +1,34 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch +import torch.nn.functional as F + + +def compute_cos_sim_to_prev( + tokens: torch.Tensor, prev_tokens: torch.Tensor, num_aux_tokens: int +) -> torch.Tensor: + """ + Per-token cosine similarity between the current and previous patch tokens. + + Auxiliary tokens (register/class) are dropped and the remaining patch tokens are + flattened so each token contributes one similarity value. The previous tokens are + detached so the loss only flows through the current forecast step. + + Args: + tokens: Current latent tokens, shape ``(batch, num_tokens, dim)``. + prev_tokens: Previous-step latent tokens, same shape as ``tokens``. + num_aux_tokens: Number of leading auxiliary tokens to skip. + + Returns: + 1D tensor of cosine similarities, one per patch token. + """ + cur = tokens[:, num_aux_tokens:].reshape(-1, tokens.shape[-1]) + prv = prev_tokens[:, num_aux_tokens:].reshape(-1, prev_tokens.shape[-1]) + return F.cosine_similarity(cur, prv.detach(), dim=-1) From c94c1128967ce5e240b08971dad8a1d105ee9251 Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Tue, 16 Jun 2026 15:06:06 +0200 Subject: [PATCH 4/5] lint formatted --- src/weathergen/model/model.py | 4 +--- src/weathergen/train/loss_modules/__init__.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 05dbd6092..0e15de9b4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -709,9 +709,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # per-token cosine similarity between current and previous patch tokens if self.compute_cos_sim_to_prev: - cos_sim_to_prev = compute_cos_sim_to_prev( - tokens, prev_tokens, self.num_aux_tokens - ) + cos_sim_to_prev = compute_cos_sim_to_prev(tokens, prev_tokens, self.num_aux_tokens) output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) # decoder predictions diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index f5b4be31f..12e02239a 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -11,4 +11,4 @@ from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher -__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatent] \ No newline at end of file +__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatent] From 63792342adcc75c4b1451ada4452c9e124d7e98b Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Tue, 16 Jun 2026 15:19:29 +0200 Subject: [PATCH 5/5] lint check --- src/weathergen/model/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 0e15de9b4..6dc0eb978 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -326,6 +326,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.embed_target_coords = None self.encoder: EncoderModule | None = None self.forecast_engine: ForecastingEngine | IdentityEngine | None = None + self.compute_cos_sim_to_prev: bool = False self.pred_heads = None self.q_cells: torch.Tensor | None = None self.streams: dict[str, typing.Any] = cf.streams