Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions config/config_forecasting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ training_config:
type: LossPhysical,
loss_fcts: { "mse": { }, },
},
"cosine_matching": {
type: LossLatent,
weight: 1.0,
loss_fcts: {
"params": {
cosine_low: 0.68,
cosine_high: 0.78,
},
},
},
}

model_input: {
Expand Down
14 changes: 13 additions & 1 deletion src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from weathergen.model.layers import MLP, NamedLinear
from weathergen.model.utils import get_num_parameters
from weathergen.train.loss_modules.utils import compute_cos_sim_to_prev
from weathergen.utils.distributed import is_root
from weathergen.utils.utils import get_dtype, is_stream_forcing

Expand Down Expand Up @@ -325,6 +326,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
self.embed_target_coords = None
self.encoder: EncoderModule | None = None
self.forecast_engine: ForecastingEngine | IdentityEngine | None = None
self.compute_cos_sim_to_prev: bool = False
self.pred_heads = None
self.q_cells: torch.Tensor | None = None
self.streams: dict[str, typing.Any] = cf.streams
Expand Down Expand Up @@ -398,6 +400,10 @@ def create(self) -> "Model":
v.type for _, v in cf.validation_config.losses.items() if v.get("enabled", True)
]

# cos_sim_to_prev is only needed by the latent loss; compute it
# (and keep prev_tokens around for it) only when that loss is configured.
self.compute_cos_sim_to_prev = "LossLatent" in loss_terms

if "LossPhysical" in loss_terms:
for i_stream, (stream_name, si) in enumerate(self.streams.items()):
# skip decoder if channels are empty
Expand Down Expand Up @@ -699,8 +705,14 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
# Pushforward mode: advance tokens without grad; no decoding with torch.no_grad():
tokens = self.forecast_engine(tokens, step, model_params.rope_coords)
continue

prev_tokens = tokens if self.compute_cos_sim_to_prev else None
tokens = self.forecast_engine(tokens, step, model_params.rope_coords)

# per-token cosine similarity between current and previous patch tokens
if self.compute_cos_sim_to_prev:
cos_sim_to_prev = compute_cos_sim_to_prev(tokens, prev_tokens, self.num_aux_tokens)
output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev)

# decoder predictions
output = self.predict_decoders(model_params, step, tokens, batch, output)
# latent predictions (raw and with SSL heads)
Expand Down
3 changes: 2 additions & 1 deletion src/weathergen/train/loss_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from .loss_module_cosine_matching import LossLatent
from .loss_module_physical import LossPhysical
from .loss_module_ssl import LossLatentSSLStudentTeacher

__all__ = [LossPhysical, LossLatentSSLStudentTeacher]
__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatent]
64 changes: 64 additions & 0 deletions src/weathergen/train/loss_modules/loss_module_cosine_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import torch
import torch.nn.functional as F
from omegaconf import DictConfig

from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues
from weathergen.utils.train_logger import Stage

_logger = logging.getLogger(__name__)


class LossLatent(LossModuleBase):
"""
Band hinge on per-token cosine similarity between consecutive FE latent steps.

Penalises tokens whose cosine similarity to the previous step falls outside
[cosine_low, cosine_high]. Both bounds are enforced as soft hinge losses so
the FE is free inside the sweet-spot and pays a quadratic penalty outside it.

cos_sim_to_prev is computed in model.forward() per patch token and stored in
output.latent[step]["cos_sim_to_prev"] — no target calculator needed.
"""

def __init__(
self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts
):
LossModuleBase.__init__(self)
self.cf = cf
self.stage = stage
self.device = device
self.name = "LossLatent"

params = next(iter(loss_fcts.values()), {}) if loss_fcts else {}
self.cosine_low = params.get("cosine_low", 0.68)
self.cosine_high = params.get("cosine_high", 0.78)

def compute_loss(self, preds, targets, metadata, **kwargs) -> LossValues:
acc_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
count = 0

for step_pred in preds.latent:
cos_sim = step_pred.get("cos_sim_to_prev", None)
if cos_sim is None:
continue
step_loss = (
F.relu(cos_sim - self.cosine_high) ** 2 + F.relu(self.cosine_low - cos_sim) ** 2
).mean()
acc_loss = acc_loss + step_loss
count += 1

loss = acc_loss / count if count > 0 else acc_loss
return LossValues(
loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={}
)
34 changes: 34 additions & 0 deletions src/weathergen/train/loss_modules/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import torch
import torch.nn.functional as F


def compute_cos_sim_to_prev(
tokens: torch.Tensor, prev_tokens: torch.Tensor, num_aux_tokens: int
) -> torch.Tensor:
"""
Per-token cosine similarity between the current and previous patch tokens.

Auxiliary tokens (register/class) are dropped and the remaining patch tokens are
flattened so each token contributes one similarity value. The previous tokens are
detached so the loss only flows through the current forecast step.

Args:
tokens: Current latent tokens, shape ``(batch, num_tokens, dim)``.
prev_tokens: Previous-step latent tokens, same shape as ``tokens``.
num_aux_tokens: Number of leading auxiliary tokens to skip.

Returns:
1D tensor of cosine similarities, one per patch token.
"""
cur = tokens[:, num_aux_tokens:].reshape(-1, tokens.shape[-1])
prv = prev_tokens[:, num_aux_tokens:].reshape(-1, prev_tokens.shape[-1])
return F.cosine_similarity(cur, prv.detach(), dim=-1)
Loading