diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index eb8b0f59d..ca6dbedce 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -173,6 +173,16 @@ training_config: type: LossPhysical, loss_fcts: { "mse": { }, }, }, + "cosine_matching": { + type: LossLatent, + weight: 1.0, + loss_fcts: { + "params": { + cosine_low: 0.68, + cosine_high: 0.78, + }, + }, + }, } model_input: { diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d8a30a722..6dc0eb978 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -39,6 +39,7 @@ ) from weathergen.model.layers import MLP, NamedLinear from weathergen.model.utils import get_num_parameters +from weathergen.train.loss_modules.utils import compute_cos_sim_to_prev from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype, is_stream_forcing @@ -325,6 +326,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.embed_target_coords = None self.encoder: EncoderModule | None = None self.forecast_engine: ForecastingEngine | IdentityEngine | None = None + self.compute_cos_sim_to_prev: bool = False self.pred_heads = None self.q_cells: torch.Tensor | None = None self.streams: dict[str, typing.Any] = cf.streams @@ -398,6 +400,10 @@ def create(self) -> "Model": v.type for _, v in cf.validation_config.losses.items() if v.get("enabled", True) ] + # cos_sim_to_prev is only needed by the latent loss; compute it + # (and keep prev_tokens around for it) only when that loss is configured. + self.compute_cos_sim_to_prev = "LossLatent" in loss_terms + if "LossPhysical" in loss_terms: for i_stream, (stream_name, si) in enumerate(self.streams.items()): # skip decoder if channels are empty @@ -699,8 +705,14 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): tokens = self.forecast_engine(tokens, step, model_params.rope_coords) continue - + prev_tokens = tokens if self.compute_cos_sim_to_prev else None tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + + # per-token cosine similarity between current and previous patch tokens + if self.compute_cos_sim_to_prev: + cos_sim_to_prev = compute_cos_sim_to_prev(tokens, prev_tokens, self.num_aux_tokens) + output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev) + # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 00a8b7b31..12e02239a 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -7,7 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from .loss_module_cosine_matching import LossLatent from .loss_module_physical import LossPhysical from .loss_module_ssl import LossLatentSSLStudentTeacher -__all__ = [LossPhysical, LossLatentSSLStudentTeacher] +__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatent] diff --git a/src/weathergen/train/loss_modules/loss_module_cosine_matching.py b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py new file mode 100644 index 000000000..f01357839 --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_cosine_matching.py @@ -0,0 +1,64 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +import torch.nn.functional as F +from omegaconf import DictConfig + +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatent(LossModuleBase): + """ + Band hinge on per-token cosine similarity between consecutive FE latent steps. + + Penalises tokens whose cosine similarity to the previous step falls outside + [cosine_low, cosine_high]. Both bounds are enforced as soft hinge losses so + the FE is free inside the sweet-spot and pays a quadratic penalty outside it. + + cos_sim_to_prev is computed in model.forward() per patch token and stored in + output.latent[step]["cos_sim_to_prev"] — no target calculator needed. + """ + + def __init__( + self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatent" + + params = next(iter(loss_fcts.values()), {}) if loss_fcts else {} + self.cosine_low = params.get("cosine_low", 0.68) + self.cosine_high = params.get("cosine_high", 0.78) + + def compute_loss(self, preds, targets, metadata, **kwargs) -> LossValues: + acc_loss = torch.tensor(0.0, device=self.device, requires_grad=True) + count = 0 + + for step_pred in preds.latent: + cos_sim = step_pred.get("cos_sim_to_prev", None) + if cos_sim is None: + continue + step_loss = ( + F.relu(cos_sim - self.cosine_high) ** 2 + F.relu(self.cosine_low - cos_sim) ** 2 + ).mean() + acc_loss = acc_loss + step_loss + count += 1 + + loss = acc_loss / count if count > 0 else acc_loss + return LossValues( + loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={} + ) diff --git a/src/weathergen/train/loss_modules/utils.py b/src/weathergen/train/loss_modules/utils.py new file mode 100644 index 000000000..12d59b324 --- /dev/null +++ b/src/weathergen/train/loss_modules/utils.py @@ -0,0 +1,34 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch +import torch.nn.functional as F + + +def compute_cos_sim_to_prev( + tokens: torch.Tensor, prev_tokens: torch.Tensor, num_aux_tokens: int +) -> torch.Tensor: + """ + Per-token cosine similarity between the current and previous patch tokens. + + Auxiliary tokens (register/class) are dropped and the remaining patch tokens are + flattened so each token contributes one similarity value. The previous tokens are + detached so the loss only flows through the current forecast step. + + Args: + tokens: Current latent tokens, shape ``(batch, num_tokens, dim)``. + prev_tokens: Previous-step latent tokens, same shape as ``tokens``. + num_aux_tokens: Number of leading auxiliary tokens to skip. + + Returns: + 1D tensor of cosine similarities, one per patch token. + """ + cur = tokens[:, num_aux_tokens:].reshape(-1, tokens.shape[-1]) + prv = prev_tokens[:, num_aux_tokens:].reshape(-1, prev_tokens.shape[-1]) + return F.cosine_similarity(cur, prv.detach(), dim=-1)