Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions config/config_cerra_crps_decode.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# =====================================================================================
# CERRA CRPS decoding from the frozen pretrained encoder `yqo6nvmy-pretrain`.
#
# Goal: attach a CRPS decoder to the pretrained encoder and probe how well CERRA can be
# predicted at the *same* (forecast) step as the input window.
#
# - Encoder loaded from yqo6nvmy-pretrain and FROZEN (linear-probe style).
# - Inputs = the pretrain_multi_data_all_years streams, now all marked `forcing`.
# - CERRA = `diagnostic` output stream (contributes no input tokens), trained with CRPS.
# - Reconstruction / masking training mode; student-teacher (SSL) loss disabled.
#
# IMPORTANT: every architecture parameter below is copied verbatim from the
# yqo6nvmy-pretrain checkpoint config so that load_chkpt (strict=False) can match the
# encoder/forecast-engine weights. Do NOT change these unless you also retrain the
# encoder — a shape mismatch will crash load_state_dict.
# =====================================================================================


decoder_type: PerceiverIOCoordConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 0
num_register_tokens: 64

fe_num_blocks: 0
fe_layer_norm_after_blocks: []
fe_impute_latent_noise_std: 0.0
forecast_att_dense_rate: 1.0

healpix_level: 5

rope_2D: True

mlp_type: swiglu
qk_norm_type: RMSNorm
use_xsa: True
with_step_conditioning: True
noise_pre_predictor_std: 0

with_mixed_precision: True
with_flash_attention: True
compile_model: False
# The CRPS decoder decodes all M ensemble members in a SINGLE varlen call
# (model.py predict_decoders): each member's B*H query groups pair with its own KV groups,
# so the decoder params are used exactly once per forward — no per-member reuse. Because
# there is no reuse, per-block activation checkpointing is DDP-safe (it is the combination
# reuse + checkpointing that trips the reducer's "marked ready twice"), and it keeps peak
# memory low. The single call launches M*B*H attention groups, which must stay under the
# CUDA grid-dim cap (65535) — true for M<=4 at healpix-5 with B=1, hence members capped at 4.
# static_graph is NOT usable here: predict_decoders has data-dependent branches (P==0
# skip, NaN guard) so the participating-param set varies per iteration.
# (config_forecasting_crps.yml uses FSDP instead, which has none of these constraints.)
with_fsdp: False
decoder_checkpointing: True # per-block checkpoint on — safe: members batched, no reuse
ddp_find_unused_parameters: False # every trainable param must participate in the loss —
# freeze leftover modules rather than masking them
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

# ---- CRPS decoder (the merged work) ----------------------------------------------
# Ensemble members are produced by perturbing the latent tokens before decoding.
# kernel_crps requires >1 member. pred_head.ens_size stays 1 in the stream config.
latent_perturbation_num_members: 4 # capped at 4 so M× un-checkpointed decoder activations fit
latent_perturbation_sigma_init: 0.05 # ASSUMPTION: initial latent noise std
latent_perturbation_sigma_learnable: True

# ASSUMPTION: freeze the whole encoder (probe the frozen representation). Set to "" to
# fine-tune end-to-end, or narrow the regex to unfreeze parts.
freeze_modules: ".*encoder.*|.*deep_ssl_fusion.*" # |^(?!.*CERRA).*$"
# freeze_modules: "^(encoder|forecast_engine|latent_heads|latent_pre_norm|deep_ssl_fusion|deep_ssl_level_projections|embed_target_coords_(?!CERRA$).+|TargetPredictionEngine_(?!CERRA$).+|TargetPredictionEngineClassic_(?!CERRA$).+|BilinearDecoder_(?!CERRA$).+|EnsPredictionHead_(?!CERRA$).+)$"

norm_type: "LayerNorm"

#####################################

streams_directory: "./config/streams/pretrain_multi_data_all_years_cerra/"
streams: ???

# type of zarr_store
zarr_store: "zip"

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

with_ddp: True

# zarr-v3's asyncio event loop is not fork-safe; forked DataLoader workers abort
# during CERRA validation. "spawn" gives each worker a fresh interpreter + event loop.
multiprocessing_method: "spawn"

desc: ""
run_id: ???
run_history: []

# logging frequency in the training loop (in number of batches)
train_logging:
terminal: 10
metrics: 20
checkpoint: 250

# parameters for data loading
data_loading :

num_workers: 12
rng_seed: ???
repeat_data_in_mini_epoch : False
memory_pinning: True
# keep spawned workers alive across mini-epochs / validation calls to avoid paying
# the (expensive) spawn startup + reader re-open every time.
persistent_workers: True


# config for training
training_config:

# reconstruction / masking mode (NOT student_teacher, NOT latent_loss)
training_mode: ["masking"]

# ASSUMPTION: training budget. Tune to taste; encoder is frozen so only the decoder
# (+ latent perturbation sigma) is optimised.
num_mini_epochs: 32
samples_per_mini_epoch: 4096
shuffle: True

# ASSUMPTION: CERRA is available 1985-2023; reuse the era5_cerra train/val split.
start_date: 1985-01-01T00:00
end_date: 2022-12-31T00:00

time_window_step: 06:00:00
time_window_len: 06:00:00

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 2e-6
lr_final: 0.0
num_steps_warmup: 256
num_steps_cooldown: 512
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw :
beta1 : 0.98125
beta2 : 0.9875
eps : 2e-08

# CERRA decoded with the CRPS (kernel_crps) loss; SSL student-teacher disabled.
losses : {
"student-teacher": {
enabled: False,
type: Disabled,
},
"physical": {
type: LossPhysical,
# "mse": null removes the mse loss inherited from default_config so only CRPS trains
loss_fcts: { "mse": null, "kernel_crps": { }, },
},
}

# NOTE: when launched via `train_continue --from-run-id yqo6nvmy-pretrain`, the SSL
# pretrain config is the base, so its `random_easy` model_input / `random_easy_target`
# target_input are inherited and MUST be disabled here. Otherwise model_input and
# target_input end up with mismatched lengths and mask building asserts. (Harmless on
# the plain `train` path, where these keys are absent.)
model_input: {
"random_easy" : {
enabled: False,
num_samples: 0,
},
"forecasting" : {
masking_strategy: "forecast",
},
}

target_input: {
"random_easy_target" : {
enabled: False,
},
}

# ASSUMPTION: "same forecast step" => offset 0 + num_steps 0, i.e. decode CERRA at the
# input window (auto-encoder/diagnostic), matching the pretrain encoder (num_steps=0).
forecast :
time_step: 06:00:00
offset: 0
num_steps: 0
policy: "fixed"


# validation config; full validation config is merge of training and validation config
validation_config:

samples_per_mini_epoch: 256
shuffle: False

start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00

# no SSL here, so no EMA teacher needed for validation
validate_with_ema:
enabled : False
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

output : {
num_samples: 0,
normalized_samples: False,
streams: null,
}

validate_before_training: False


# Tags for experiment tracking
wgtags:
org: null
issue: null
exp: "cerra_crps_decode"
grid: null
4 changes: 0 additions & 4 deletions config/config_forecasting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,13 @@ forecast_att_dense_rate: 1.0

healpix_level: 5

<<<<<<< HEAD
# Generalized RoPE selector.
rope_mode: none # one of: none, 2d, spherical
# Optional spherical harmonic band for spherical RoPE. If null, the model picks one
# conservative shared band that fits all spherical-RoPE attention modules.
rope_spherical_band: null
mlp_type: swiglu
use_xsa: True
=======
rope_2D: False
>>>>>>> origin/develop

with_mixed_precision: True
with_flash_attention: True
Expand Down
Loading
Loading