diff --git a/config/config_jepa_multi_data_all_years.yml b/config/config_jepa_multi_data_all_years.yml index 25e169bb5..aac193800 100644 --- a/config/config_jepa_multi_data_all_years.yml +++ b/config/config_jepa_multi_data_all_years.yml @@ -54,7 +54,7 @@ num_class_tokens: 0 num_register_tokens: 64 # noise before predictor in JEPA -noise_pre_predictor_std: 0 +noise_pre_predictor_std: 1e-3 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder @@ -171,7 +171,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] + training_mode: ["masking","student_teacher"] # Deep self-supervision (V-JEPA 2.1 style): compute SSL loss at multiple encoder depths. # When enabled, intermediate encoder representations are tapped and used as additional @@ -209,8 +209,8 @@ training_config: lr_max: 1e-4 lr_final_decay: 1e-6 lr_final: 0.0 - num_steps_warmup: 512 - num_steps_cooldown: 1024 + num_steps_warmup: 1024 + num_steps_cooldown: 16384 policy_warmup: "cosine" policy_decay: "constant" policy_cooldown: "linear" @@ -218,14 +218,20 @@ training_config: optimizer: grad_clip: 0.25 - weight_decay: 0.05 + weight_decay: 0.1 adamw : # parameters are scaled by number of DDP workers beta1 : 0.9875 # at B=8 beta1 =0.9 - beta2 : 0.994 # at B=8 beta1 approx 0.95 + beta2 : 0.99875 # at B=8 beta1 approx 0.99 eps : 2e-08 losses : { + "physical": { + enabled: True, + type: LossPhysical, + weight: 1.0, + loss_fcts: { "mse": { target_source_correspondence: {0 : {0 : "subset"} },}, }, + }, "student-teacher": { enabled: True, type: LossLatentSSLStudentTeacher, @@ -233,8 +239,8 @@ training_config: loss_fcts : { "JEPA": { 'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer, - "num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 1024, - "dropout_rate": 0.1, + "num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 2048, + "dropout_rate": 0.2, target_source_correspondence: {0 : {0 : "subset"} }, }, }, @@ -280,6 +286,8 @@ training_config: # validation config; full validation config is merge of training and validation config validation_config: + time_window_step: 06:00:00 + samples_per_mini_epoch: 256 shuffle: False diff --git a/config/config_jepa_multi_data_all_years_ft.yml b/config/config_jepa_multi_data_all_years_ft.yml index 870a29fd8..08d9df126 100644 --- a/config/config_jepa_multi_data_all_years_ft.yml +++ b/config/config_jepa_multi_data_all_years_ft.yml @@ -1,5 +1,5 @@ -streams_directory: "./config/streams/pretrain_multi_data_od/" +streams_directory: "./config/streams/pretrain_multi_data_all_years/" general: diff --git a/config/config_jepa_multi_data_ft_forecast.yml b/config/config_jepa_multi_data_ft_forecast.yml index f824a9147..0c71120e6 100644 --- a/config/config_jepa_multi_data_ft_forecast.yml +++ b/config/config_jepa_multi_data_ft_forecast.yml @@ -1,11 +1,40 @@ -streams_directory: "./config/streams/jepa_forecast_multi_data_od/" +streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/" + +freeze_modules: "^(?!.*ERA5)(?=.*(?:encoder|latent_pre_norm|latent_heads)).*$" general: # mutable parameters + istep: 0 rank: ??? world_size: ??? training_config: - num_mini_epochs: 32 + start_date: 2016-01-01T00:00 + # OND-2022 carved out of training so the heldout-train-years extra validation + # set below is genuinely held out (was 2022-12-31) + end_date: 2022-09-30T00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 1e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 32768 + + num_mini_epochs: 6 samples_per_mini_epoch: 8192 + +# # extra validation sets, evaluated each mini-epoch and logged as stage "val_"; +# # each entry overrides the primary validation_config +# extra_validation_configs: +# # held-out slice inside the training years (excluded from training via the +# # end_date above), season-matched to the OND-2023 primary val window; +# # memorization probe: if its loss tracks val, the train/val gap is memorization, +# # if it stays well below val, the gap is distribution shift +# heldout-train-years: +# start_date: 2022-10-01T00:00 +# end_date: 2022-12-31T00:00 +# shuffle: True +# samples_per_mini_epoch: 256 diff --git a/config/config_jepa_multi_data_ft_forecast_all_years.yml b/config/config_jepa_multi_data_ft_forecast_all_years.yml index c3f941850..b1bfd955a 100644 --- a/config/config_jepa_multi_data_ft_forecast_all_years.yml +++ b/config/config_jepa_multi_data_ft_forecast_all_years.yml @@ -107,7 +107,7 @@ training_config: enabled: false - num_mini_epochs: 24 + num_mini_epochs: 32 samples_per_mini_epoch: 8192 shuffle: True @@ -119,11 +119,11 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 8e-5 + lr_max: 4e-5 lr_final_decay: 2e-6 - lr_final: 0.0 + lr_final: 2e-6 num_steps_warmup: 256 - num_steps_cooldown: 512 + num_steps_cooldown: 16384 policy_warmup: "cosine" policy_decay: "constant" policy_cooldown: "linear" @@ -136,7 +136,7 @@ training_config: adamw : # parameters are scaled by number of DDP workers beta1 : 0.9875 # at B=8 beta1 =0.9 - beta2 : 0.994 # at B=8 beta1 approx 0.95 + beta2 : 0.99875 # at B=8 beta1 approx 0.95 eps : 1e-08 losses : { @@ -144,10 +144,10 @@ training_config: enabled: False, type: Disabled, }, - # "physical": { - # enabled: False, - # type: Disabled, - # }, + "physical": { + enabled: False, + type: Disabled, + }, "forecast": { type: LossPhysical, loss_fcts: { "mse": { }, }, @@ -173,7 +173,7 @@ training_config: forecast : time_step: 06:00:00 - num_steps: 3 + num_steps: 2 offset: 1 policy: "fixed" @@ -183,6 +183,7 @@ validation_config: samples_per_mini_epoch: 256 shuffle: False + time_window_step: 6:00:00 start_date: 2023-10-01T00:00 end_date: 2023-12-31T00:00 @@ -205,6 +206,19 @@ validation_config: # run validation before training starts (mainly for model development) validate_before_training: False +# # extra validation sets, evaluated each mini-epoch and logged as stage "val_"; +# # each entry overrides the primary validation_config +# extra_validation_configs: +# # held-out slice inside the training years (excluded from training via the +# # end_date above), season-matched to the OND-2023 primary val window; +# # memorization probe: if its loss tracks val, the train/val gap is memorization, +# # if it stays well below val, the gap is distribution shift +# heldout-train-years: +# start_date: 2022-10-01T00:00 +# end_date: 2022-12-31T00:00 +# shuffle: True +# samples_per_mini_epoch: 256 + # test config; full test config is merge of validation and test config test_config: diff --git a/config/config_jepa_multi_data_ft_forecast_all_years_noise.yml b/config/config_jepa_multi_data_ft_forecast_all_years_noise.yml new file mode 100644 index 000000000..e51082d82 --- /dev/null +++ b/config/config_jepa_multi_data_ft_forecast_all_years_noise.yml @@ -0,0 +1,270 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 5e-5 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +# performance +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +ddp_find_unused_parameters: False #True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: ".*encoder.*|.*latent_pre_norm.*|.*latent_heads.*" + +norm_type: "LayerNorm" + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +##################################### + +streams_directory: "./config/streams/jepa_forecast_multi_data_all_years/" +streams: ??? + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "spawn" # "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 50 + metrics: 20 + checkpoint: 500 + log_grad_norms: False + +# parameters for data loading +data_loading : + + num_workers: 6 + rng_seed: ??? + repeat_data_in_mini_epoch : False + persistent_workers: True + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + # training_mode: ["student_teacher"] + training_mode: ["masking"] + + context_loss: + enabled: false + + # Deep self-supervision (V-JEPA 2.1 style): compute SSL loss at multiple encoder depths. + # When enabled, intermediate encoder representations are tapped and used as additional + # targets/predictions for the SSL loss. Disabled by default (empty tap_after). + deep_ssl: + enabled: false + + + num_mini_epochs: 32 + samples_per_mini_epoch: 8192 + shuffle: True + + start_date: 1980-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 01:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 4e-5 + lr_final_decay: 2e-6 + lr_final: 2e-6 + num_steps_warmup: 256 + num_steps_cooldown: 16384 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 0.25 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.9875 # at B=8 beta1 =0.9 + beta2 : 0.99875 # at B=8 beta1 approx 0.95 + eps : 1e-08 + + losses : { + "student-teacher": { + enabled: False, + type: Disabled, + }, + # "physical": { + # enabled: False, + # type: Disabled, + # }, + "forecast": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "random_easy" : { + enabled: False, + num_samples: 0, + }, + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + target_input: { + "random_easy_target" : { + enabled: False, + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 2 + offset: 1 + policy: "fixed" + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + time_window_step: 6:00:00 + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 600 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + +# # extra validation sets, evaluated each mini-epoch and logged as stage "val_"; +# # each entry overrides the primary validation_config +# extra_validation_configs: +# # held-out slice inside the training years (excluded from training via the +# # end_date above), season-matched to the OND-2023 primary val window; +# # memorization probe: if its loss tracks val, the train/val gap is memorization, +# # if it stays well below val, the gap is distribution shift +# heldout-train-years: +# start_date: 2022-10-01T00:00 +# end_date: 2022-12-31T00:00 +# shuffle: True +# samples_per_mini_epoch: 256 + +# test config; full test config is merge of validation and test config +test_config: + + samples_per_mini_epoch: 128 + shuffle: False + + start_date: 2023-06-01T00:00 + end_date: 2023-08-31T00:00 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 128, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null + + + diff --git a/config/config_jepa_self_flow_multi_data_all_years.yml b/config/config_jepa_self_flow_multi_data_all_years.yml index 09f366711..e6204ef50 100644 --- a/config/config_jepa_self_flow_multi_data_all_years.yml +++ b/config/config_jepa_self_flow_multi_data_all_years.yml @@ -59,7 +59,7 @@ num_class_tokens: 0 num_register_tokens: 64 # noise before predictor in JEPA -noise_pre_predictor_std: 0 +noise_pre_predictor_std: 1e-4 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder @@ -176,7 +176,7 @@ data_loading : training_config: # training_mode: "masking", "student_teacher", "latent_loss" - training_mode: ["student_teacher"] + training_mode: ["student_teacher", "masking"] # Deep self-supervision (V-JEPA 2.1 style): compute SSL loss at multiple encoder depths. # When enabled, intermediate encoder representations are tapped and used as additional @@ -211,7 +211,7 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 - lr_max: 1e-4 + lr_max: 8e-5 lr_final_decay: 1e-6 lr_final: 0.0 num_steps_warmup: 512 @@ -223,14 +223,20 @@ training_config: optimizer: grad_clip: 0.25 - weight_decay: 0.05 + weight_decay: 0.1 adamw : # parameters are scaled by number of DDP workers beta1 : 0.9875 # at B=8 beta1 =0.9 - beta2 : 0.994 # at B=8 beta1 approx 0.95 + beta2 : 0.99875 # at B=8 beta1 approx 0.99 eps : 2e-08 losses : { + "physical": { + enabled: True, + type: LossPhysical, + weight: 1.0, + loss_fcts: { "mse": { target_source_correspondence: {0 : {0 : "identity"} },}, }, + }, "student-teacher": { enabled: True, type: LossLatentSSLStudentTeacher, @@ -238,14 +244,14 @@ training_config: loss_fcts : { "JEPA": { 'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer, - "num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 1024, + "num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 2048, "dropout_rate": 0.1, target_source_correspondence: {0 : {0 : "identity"} }, }, }, target_and_aux_calc: { "EMATeacher" : { ema_ramp_up_ratio : null, - ema_halflife_in_thousands: 400000, + ema_halflife_in_thousands: 300000, model_param_overrides : { training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} }, diff --git a/config/config_smoke_extra_val.yml b/config/config_smoke_extra_val.yml new file mode 100644 index 000000000..609daedcc --- /dev/null +++ b/config/config_smoke_extra_val.yml @@ -0,0 +1,56 @@ +# Minimal smoke-test config for the extra_validation_configs feature: +# tiny model (1 block per engine, small latent dims), few samples, ERA5 only. +# Run: +# uv run train --config ./config/config_smoke_extra_val.yml +# Then: +# uv run plot_train -fd "{: [0, 'smoke extra val']}" + +streams_directory: "./config/streams/era5_1deg/" + +ae_local_dim_embed: 64 +ae_local_num_blocks: 1 +ae_local_num_heads: 4 +ae_adapter_num_heads: 4 +ae_adapter_embed: 32 + +ae_global_dim_embed: 64 +ae_global_num_blocks: 1 +ae_global_num_heads: 4 + +fe_num_blocks: 1 +fe_num_heads: 4 + +data_loading: + num_workers: 4 + +train_logging: + terminal: 2 + metrics: 2 + checkpoint: 250 + +training_config: + num_mini_epochs: 2 + samples_per_mini_epoch: 32 + + # tiny run: keep scheduler phases shorter than total steps + learning_rate_scheduling: + num_steps_warmup: 4 + num_steps_cooldown: 4 + + # end of training range carved to keep OND-2022 held out + start_date: 2016-01-01T00:00 + end_date: 2022-09-30T00:00 + +validation_config: + samples_per_mini_epoch: 4 + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + +extra_validation_configs: + # held-out slice inside the training years, season-matched to the primary val window + heldout-train-years: + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 + shuffle: True + samples_per_mini_epoch: 4 diff --git a/config/evaluate/eval_config_best.yml b/config/evaluate/eval_config_best.yml index 59e83e158..90a4fe02f 100644 --- a/config/evaluate/eval_config_best.yml +++ b/config/evaluate/eval_config_best.yml @@ -9,9 +9,10 @@ default_streams: - u_850 - v_850 - q_850 + - q_200 evaluation: ensemble: all - forecast_step: all + forecast_step: [0,1,2,3,4] sample: all plotting: forecast_step: all @@ -51,8 +52,8 @@ run_ids: # label: Inference of k47tobdl-ft32-with-noise # type: json # # results_base_dir: ./results/bmbywqn6/ - ay98rzxp: - label: Inference of e8bd7ejo-2-nodes - type: json - # results_base_dir: ./results/ay98rzxp/ + i3p7ybfn: + label: Inference of xi0mujc2-finetune-forecast (SSL multidata) + # type: json + results_base_dir: ./results/i3p7ybfn/ diff --git a/config/experiments/combined_overfitting_pipeline.yml b/config/experiments/combined_overfitting_pipeline.yml new file mode 100644 index 000000000..2ec1b3eae --- /dev/null +++ b/config/experiments/combined_overfitting_pipeline.yml @@ -0,0 +1,81 @@ +# Combined overfitting-followup pipeline: shared pretrain, then BRANCH into the +# three experiments so the expensive upstream runs ONCE. +# +# pretrain ──► pretrain-ft ──┬─► finetune-forecast ──┬─► exp1 lr-decay-fix +# │ (ERA5 frozen, stage3)└─► exp3 unfreeze-lowlr-reg +# └─► exp2 era5-size-control (alternative stage 3) +# +# Branching is native to launch-slurm-sophie.py: each stage's `from:` names its +# parent, multiple stages may share a parent, and the cleanup job waits on all +# leaves (exp1, exp2, exp3). All stages share one run_id; a child continues from +# -, so exp1/exp3 continue from the SAME finetune-forecast +# checkpoint and exp2 from the SAME pretrain-ft checkpoint. +# +# Budget: 6 stages (<=8) and 2+2+4+4+3+3 = 18 chained jobs (<=24). Streams dirs +# unchanged (referenced inside each stage config). +# +# Usage: +# ../WeatherGenerator-private/hpc/launch-slurm-sophie.py \ +# config/experiments/combined_overfitting_pipeline.yml # dry-run +# ../WeatherGenerator-private/hpc/launch-slurm-sophie.py --no-dry-run \ +# config/experiments/combined_overfitting_pipeline.yml # submit + +nodes: 1 +register: true +cleanup-scripts: yes +results-dir: shared +mini-epoch: -1 + +stages: + # ---- shared upstream (runs once) ---- + - name: pretrain + command: train + base-config: ./config/config_jepa_multi_data_all_years.yml + config: ./config/config_jepa_multi_data_all_years.yml + options: [] + chain: 2 + nodes: 2 + + - name: pretrain-ft + command: train + from: pretrain + config: ./config/config_jepa_multi_data_all_years_ft.yml + options: [] + chain: 2 + nodes: 2 + + # ERA5 frozen-encoder forecast finetune (the good forecaster exp1/exp3 build on) + - name: finetune-forecast + command: train-continue + from: pretrain-ft + config: ./config/config_jepa_multi_data_ft_forecast_all_years.yml + chain: 4 + nodes: 2 + + # ---- branches ---- + # EXP2 — H1 control: alternative stage 3, ERA5 restricted to the oper-sized + # window. Branches from pretrain-ft (sibling of finetune-forecast), NOT a + # continuation of it. No oper stage 4: the signal is at this stage. + - name: exp2-era5-size-control + command: train-continue + from: pretrain-ft + config: ./config/experiments/exp2_era5_size_control/ft_forecast_era5_2yr.yml + chain: 4 + nodes: 2 + + # EXP1 — H3 fix: oper forecast finetune with decaying LR. From finetune-forecast. + - name: exp1-lr-decay-fix + command: train-continue + from: finetune-forecast + config: ./config/experiments/exp1_lr_decay_fix/ft_forecast_lr_decay.yml + chain: 3 + nodes: 2 + + # EXP3 — H1 amplifier: full-unfreeze oper finetune at sane LR + reg. From + # finetune-forecast (same parent as exp1 -> direct frozen-vs-unfrozen comparison). + - name: exp3-unfreeze-lowlr-reg + command: train-continue + from: finetune-forecast + config: ./config/experiments/exp3_unfreeze_lowlr_reg/ft_forecast_unfreeze_lowlr.yml + chain: 3 + nodes: 2 diff --git a/config/experiments/exp1_lr_decay_fix/ft_forecast_lr_decay.yml b/config/experiments/exp1_lr_decay_fix/ft_forecast_lr_decay.yml new file mode 100644 index 000000000..de1f69181 --- /dev/null +++ b/config/experiments/exp1_lr_decay_fix/ft_forecast_lr_decay.yml @@ -0,0 +1,54 @@ +# EXP1 — lr-decay-fix (targets H3, the optimization amplifier) +# +# Stage-4 (operational-analysis forecast finetune) override. Identical to the +# baseline config/config_jepa_multi_data_ft_forecast.yml EXCEPT: +# - policy_decay: constant -> linear (LR now actually decays instead of being +# pinned at lr_max for the whole body; see lr_scheduler.py:119 vs :213) +# - lr_max: 1e-5 -> 5e-6 (lower peak, within the notes' 5e-6..1e-5) +# - extra_validation_configs enabled (free H2 kill-shot: heldout 2022-OND) +# +# Everything else (window, frozen non-ERA5 encoder, streams dir) is unchanged so +# the only knobs that move are the LR schedule. Expected: a later, lower val +# minimum and a smaller train/val separation than the baseline stage 4. + +streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/" + +freeze_modules: "^(?!.*ERA5)(?=.*(?:encoder|latent_pre_norm|latent_heads)).*$" + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + +training_config: + start_date: 2016-01-01T00:00 + # OND-2022 carved out of training so the heldout-train-years extra validation + # set below is genuinely held out + end_date: 2022-09-30T00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-6 # was 1e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 32768 + policy_decay: "linear" # was inherited "constant" -> THIS is the H3 fix + + num_mini_epochs: 6 + samples_per_mini_epoch: 8192 + +# extra validation sets, evaluated each mini-epoch and logged as stage "val_"; +# each entry overrides the primary validation_config. +# heldout slice inside the training years (excluded from training via end_date +# above), season-matched to the OND-2023 primary val window. Memorization probe: +# if its loss tracks val -> the train/val gap is memorization (H1); if it stays +# well below val -> the gap is distribution shift (H2). +extra_validation_configs: + heldout-train-years: + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 + shuffle: True + samples_per_mini_epoch: 256 diff --git a/config/experiments/exp1_lr_decay_fix/pipeline.yml b/config/experiments/exp1_lr_decay_fix/pipeline.yml new file mode 100644 index 000000000..6b03a97f6 --- /dev/null +++ b/config/experiments/exp1_lr_decay_fix/pipeline.yml @@ -0,0 +1,46 @@ +# EXP1 — lr-decay-fix pipeline (targets H3, optimization amplifier). +# +# Identical to config/jepa_multi_data_pipeline.yml except the final stage uses +# the decaying-LR override. Self-contained / launchable as-is. +# +# COST NOTE: stages 1-3 are byte-for-byte the baseline. If your launcher can +# seed a train-continue from the existing baseline `finetune-forecast` +# checkpoint (xi0mujc2-finetune-forecast), delete stages 1-3 and keep only the +# last stage to avoid re-pretraining. + +nodes: 1 +register: true +cleanup-scripts: yes +results-dir: shared +mini-epoch: -1 + +stages: + - name: pretrain + command: train + base-config: ./config/config_jepa_multi_data_all_years.yml + config: ./config/config_jepa_multi_data_all_years.yml + options: [] + chain: 2 + nodes: 2 + + - name: pretrain-ft + command: train + config: ./config/config_jepa_multi_data_all_years_ft.yml + options: [] + chain: 2 + nodes: 2 + + - name: finetune-forecast + command: train-continue + from: pretrain-ft + config: ./config/config_jepa_multi_data_ft_forecast_all_years.yml + chain: 4 + nodes: 2 + + # Stage 4: operational-analysis forecast finetune with the LR-decay fix + - name: finetune-forecast-ft + command: train-continue + from: finetune-forecast + config: ./config/experiments/exp1_lr_decay_fix/ft_forecast_lr_decay.yml + chain: 3 + nodes: 2 diff --git a/config/experiments/exp2_era5_size_control/ft_forecast_era5_2yr.yml b/config/experiments/exp2_era5_size_control/ft_forecast_era5_2yr.yml new file mode 100644 index 000000000..5ffddb7d0 --- /dev/null +++ b/config/experiments/exp2_era5_size_control/ft_forecast_era5_2yr.yml @@ -0,0 +1,211 @@ +# EXP2 — era5-size-control (targets H1, the primary cause: small-data memorization) +# +# This is a COPY of config/config_jepa_multi_data_ft_forecast_all_years.yml (the +# baseline stage-3 frozen-encoder ERA5 forecast finetune) with ONLY two changes: +# 1. training window 1980-01-01..2022-12-31 -> 2016-01-01..2022-09-30 +# i.e. ERA5 restricted to the SAME ~7-yr window as the operational stage 4 +# (~10k distinct 6-hourly states, ~6x smaller than the full ERA5 set). +# 2. an extra heldout-train-years (2022-OND) validation slice added. +# +# Source stays ERA5, encoder stays frozen, LR stays constant 4e-5 — so the ONLY +# variable that moves vs the baseline stage 3 is the amount of training data. +# Baseline stage 3 (full ERA5, frozen) does NOT overfit. If this size-matched +# ERA5 run now overfits like the oper stage 4, the train/val gap is driven by +# data volume (H1), independent of the oper source -> H1 confirmed. If it stays +# clean, size alone is not sufficient and the oper source / non-stationarity +# (H2) is implicated. Decisive, cheap control. +# +# No stage 4 in this pipeline: the diagnostic signal is entirely at this stage. + +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] +fe_impute_latent_noise_std: 0 +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +# performance +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +ddp_find_unused_parameters: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: ".*encoder.*|.*latent_pre_norm.*|.*latent_heads.*" + +norm_type: "LayerNorm" + +# type of zarr_store +zarr_store: "zip" + +##################################### + +# UNCHANGED streams directory (ERA5 all-years forecast streams) +streams_directory: "./config/streams/jepa_forecast_multi_data_all_years/" +streams: ??? + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + multiprocessing_method: "spawn" + + desc: "" + run_id: ??? + run_history: [] + +train_logging: + terminal: 50 + metrics: 20 + checkpoint: 500 + log_grad_norms: False + +data_loading : + + num_workers: 6 + rng_seed: ??? + repeat_data_in_mini_epoch : False + persistent_workers: True + memory_pinning: True + + +training_config: + + training_mode: ["masking"] + + context_loss: + enabled: false + + deep_ssl: + enabled: false + + num_mini_epochs: 24 + samples_per_mini_epoch: 8192 + shuffle: True + + # *** ONLY FUNCTIONAL CHANGE vs baseline stage 3: size-matched to the oper window *** + start_date: 2016-01-01T00:00 + end_date: 2022-09-30T00:00 # OND-2022 carved out for the heldout probe below + + time_window_step: 01:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 4e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 16384 + policy_warmup: "cosine" + policy_decay: "constant" # kept constant on purpose: only data size varies + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 0.25 + weight_decay: 0.1 + log_grad_norms: False + adamw : + beta1 : 0.9875 + beta2 : 0.99875 + eps : 1e-08 + + losses : { + "student-teacher": { + enabled: False, + type: Disabled, + }, + "forecast": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "random_easy" : { + enabled: False, + num_samples: 0, + }, + "forecasting" : { + masking_strategy: "forecast", + }, + } + + target_input: { + "random_easy_target" : { + enabled: False, + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 3 + offset: 1 + policy: "fixed" + +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 600 + + output : { + num_samples: 0, + normalized_samples: False, + streams: null, + } + + validate_before_training: False + +# heldout slice INSIDE the (now restricted) training years, season-matched to the +# OND-2023 primary val window. Direct memorization gap measurement for H1. +extra_validation_configs: + heldout-train-years: + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 + shuffle: True + samples_per_mini_epoch: 256 + +test_config: + + samples_per_mini_epoch: 128 + shuffle: False + + start_date: 2023-06-01T00:00 + end_date: 2023-08-31T00:00 + + output : { + num_samples: 128, + normalized_samples: False, + streams: null, + } + +wgtags: + org: null + issue: null + exp: null + grid: null diff --git a/config/experiments/exp2_era5_size_control/pipeline.yml b/config/experiments/exp2_era5_size_control/pipeline.yml new file mode 100644 index 000000000..f7e366fe6 --- /dev/null +++ b/config/experiments/exp2_era5_size_control/pipeline.yml @@ -0,0 +1,40 @@ +# EXP2 — era5-size-control pipeline (targets H1, the primary cause). +# +# Same pretrain chain as baseline, but the forecast-finetune stage is the +# size-matched ERA5 control (frozen encoder, ERA5 source, ~7-yr window) and +# there is NO operational stage 4 — the diagnostic signal is entirely at the +# size-controlled stage 3. +# +# COST NOTE: stages 1-2 are byte-for-byte the baseline. If your launcher can seed +# a train-continue from the existing baseline `pretrain-ft` checkpoint, delete +# stages 1-2 and keep only the size-control stage. + +nodes: 1 +register: true +cleanup-scripts: yes +results-dir: shared +mini-epoch: -1 + +stages: + - name: pretrain + command: train + base-config: ./config/config_jepa_multi_data_all_years.yml + config: ./config/config_jepa_multi_data_all_years.yml + options: [] + chain: 2 + nodes: 2 + + - name: pretrain-ft + command: train + config: ./config/config_jepa_multi_data_all_years_ft.yml + options: [] + chain: 2 + nodes: 2 + + # Stage 3 (control): ERA5 forecast finetune restricted to the oper-sized window + - name: finetune-forecast-era5-control + command: train-continue + from: pretrain-ft + config: ./config/experiments/exp2_era5_size_control/ft_forecast_era5_2yr.yml + chain: 4 + nodes: 2 diff --git a/config/experiments/exp3_unfreeze_lowlr_reg/ft_forecast_unfreeze_lowlr.yml b/config/experiments/exp3_unfreeze_lowlr_reg/ft_forecast_unfreeze_lowlr.yml new file mode 100644 index 000000000..a442e8fa2 --- /dev/null +++ b/config/experiments/exp3_unfreeze_lowlr_reg/ft_forecast_unfreeze_lowlr.yml @@ -0,0 +1,60 @@ +# EXP3 — unfreeze-lowlr-reg (targets H1's amplifier: full-model plasticity) +# +# Stage-4 (operational-analysis forecast finetune) override. Unlike the existing +# config_jepa_multi_data_ft_forecast_unfreeze_encoder.yml (which unfroze at the +# bad constant 8e-5 LR), this unfreezes the WHOLE model but at a sane, decaying +# LR and with stronger regularization, so the comparison to EXP1 is fair: +# +# EXP1 (frozen non-ERA5 encoder, lr fix) vs EXP3 (full unfreeze, same lr fix) +# => isolates how much unfreezing actually buys once LR/reg are sane. +# +# Deltas vs baseline config_jepa_multi_data_ft_forecast.yml: +# - freeze_modules: "" (fully unfrozen) +# - fe_dropout_rate: 0.1 -> 0.2 (regularization, top-level model param) +# - policy_decay: constant -> linear, lr_max 1e-5 -> 5e-6 (H3-safe unfreeze) +# - optimizer.weight_decay: 0.1 -> 0.2 (regularization) +# - extra_validation_configs enabled (memorization probe / early-stop read) + +# top-level model param: bump forecast-engine dropout for regularization +fe_dropout_rate: 0.2 + +streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/" + +freeze_modules: "" + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + +training_config: + start_date: 2016-01-01T00:00 + # OND-2022 carved out so the heldout-train-years probe below is genuinely held out + end_date: 2022-09-30T00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-6 # was 1e-5; lower because the whole model is now plastic + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 32768 + policy_decay: "linear" # was inherited "constant" + + optimizer: + weight_decay: 0.2 # was 0.1; stronger reg against memorization + + num_mini_epochs: 6 + samples_per_mini_epoch: 8192 + +# Memorization probe / early-stopping read-out (see EXP1 for rationale): ft-5 in +# the analysis showed a real val minimum at ~epoch 3, so watch this curve to pick +# the early-stop checkpoint. +extra_validation_configs: + heldout-train-years: + start_date: 2022-10-01T00:00 + end_date: 2022-12-31T00:00 + shuffle: True + samples_per_mini_epoch: 256 diff --git a/config/experiments/exp3_unfreeze_lowlr_reg/pipeline.yml b/config/experiments/exp3_unfreeze_lowlr_reg/pipeline.yml new file mode 100644 index 000000000..3543ee8f0 --- /dev/null +++ b/config/experiments/exp3_unfreeze_lowlr_reg/pipeline.yml @@ -0,0 +1,45 @@ +# EXP3 — unfreeze-lowlr-reg pipeline (targets H1 amplifier: full-model plasticity). +# +# Identical to config/jepa_multi_data_pipeline.yml except the final stage fully +# unfreezes the model at a sane decaying LR + stronger regularization. +# +# COST NOTE: stages 1-3 are byte-for-byte the baseline. If your launcher can seed +# a train-continue from the existing baseline `finetune-forecast` checkpoint +# (xi0mujc2-finetune-forecast), delete stages 1-3 and keep only the last stage. + +nodes: 1 +register: true +cleanup-scripts: yes +results-dir: shared +mini-epoch: -1 + +stages: + - name: pretrain + command: train + base-config: ./config/config_jepa_multi_data_all_years.yml + config: ./config/config_jepa_multi_data_all_years.yml + options: [] + chain: 2 + nodes: 2 + + - name: pretrain-ft + command: train + config: ./config/config_jepa_multi_data_all_years_ft.yml + options: [] + chain: 2 + nodes: 2 + + - name: finetune-forecast + command: train-continue + from: pretrain-ft + config: ./config/config_jepa_multi_data_ft_forecast_all_years.yml + chain: 4 + nodes: 2 + + # Stage 4: full-unfreeze operational finetune at sane LR + stronger reg + - name: finetune-forecast-unfreeze + command: train-continue + from: finetune-forecast + config: ./config/experiments/exp3_unfreeze_lowlr_reg/ft_forecast_unfreeze_lowlr.yml + chain: 3 + nodes: 2 diff --git a/config/jepa_multi_data_pipeline.yml b/config/jepa_multi_data_pipeline.yml index f03e8a092..2e1afc13b 100644 --- a/config/jepa_multi_data_pipeline.yml +++ b/config/jepa_multi_data_pipeline.yml @@ -34,10 +34,10 @@ stages: # Stage 1: JEPA pre-training with EMA teacher + deep SSL - name: pretrain-ft - command: train + command: train-continue config: ./config/config_jepa_multi_data_all_years_ft.yml options: [] - chain: 2 + chain: 1 nodes: 2 # # Stage 1.1: JEPA pre-training with EMA teacher + deep SSL @@ -58,13 +58,21 @@ stages: nodes: 2 # Stage 3: Forecasting finetuning (freezes encoder) - - name: finetune-forecast-ft + - name: finetune-forecast-noise command: train-continue - from: finetune-forecast - config: ./config/config_jepa_multi_data_ft_forecast.yml - chain: 3 + from: pretrain-ft + config: ./config/config_jepa_multi_data_ft_forecast_all_years_noise.yml + chain: 4 nodes: 2 + # # Stage 3: Forecasting finetuning (freezes encoder) + # - name: finetune-forecast-ft + # command: train-continue + # from: finetune-forecast + # config: ./config/config_jepa_multi_data_ft_forecast.yml + # chain: 2 + # nodes: 2 + # # Stage 3: Forecasting finetuning (unfreezes encoder) # - name: finetune-forecast-2 # command: train-continue diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index 1460e3d0c..480e17db2 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -17,7 +17,7 @@ ERA5 : loss_weight : 1. location_weight : cosine_latitude token_size : 8 - tokenize_spacetime : True + tokenize_spacetime : False max_num_targets: 20000 frequency : 06:00:00 embed : diff --git a/config/streams/jepa_forecast_multi_data_all_years/analysis.yml b/config/streams/jepa_forecast_multi_data_all_years/analysis.yml index 9f9c4ff16..e7e1763aa 100644 --- a/config/streams/jepa_forecast_multi_data_all_years/analysis.yml +++ b/config/streams/jepa_forecast_multi_data_all_years/analysis.yml @@ -3,7 +3,7 @@ # forcing so the teacher metadata mask remains populated for SSL targets. ERA5_in: - type: anemoi + type: anemoi_operan filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] # filenames: # ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] stream_id: 0 @@ -14,14 +14,13 @@ ERA5_in: forcing: True location_weight: cosine_latitude token_size: 8 - tokenize_spacetime: True - max_num_targets: -1 + tokenize_spacetime : False frequency: 06:00:00 nominal_time_mapping : "0" : 5 # 04:30:00 - "6" : 9 # 09:00:00 + "6" : 11 # 09:00:00 "12" : 17 #16:30:00 - "18" : 21 #21:00:00 + "18" : 23 #21:00:00 embed: net: transformer num_tokens: 1 diff --git a/config/streams/jepa_forecast_multi_data_all_years/era5.yml b/config/streams/jepa_forecast_multi_data_all_years/era5.yml index 5f9268163..0b8a34463 100644 --- a/config/streams/jepa_forecast_multi_data_all_years/era5.yml +++ b/config/streams/jepa_forecast_multi_data_all_years/era5.yml @@ -6,10 +6,10 @@ ERA5: type: anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] # filenames: # ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] - stream_id: 0 + stream_id: 42 source: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] target: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] - geoinfo_channels: ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] + geoinfo_channels: ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight: 1.0 diagnostic: True location_weight: cosine_latitude @@ -32,4 +32,75 @@ ERA5: pred_head: ens_size: 1 num_layers: 1 + channel_weights : + q_10: 0.2 + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_10: 0.2 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_10: 0.2 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_10: 0.2 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_10: 0.2 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 diff --git a/config/streams/jepa_forecast_multi_data_od/analysis.yml b/config/streams/jepa_forecast_multi_data_od/analysis.yml index 4ecad74da..bae4e5876 100644 --- a/config/streams/jepa_forecast_multi_data_od/analysis.yml +++ b/config/streams/jepa_forecast_multi_data_od/analysis.yml @@ -3,7 +3,7 @@ # forcing so the teacher metadata mask remains populated for SSL targets. ERA5_in: - type: anemoi + type: anemoi_operan # filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] stream_id: 0 @@ -14,14 +14,14 @@ ERA5_in: forcing: True location_weight: cosine_latitude token_size: 8 - tokenize_spacetime: True + tokenize_spacetime: False max_num_targets: -1 frequency: 06:00:00 nominal_time_mapping : "0" : 5 # 04:30:00 - "6" : 9 # 09:00:00 + "6" : 11 # 09:00:00 "12" : 17 #16:30:00 - "18" : 21 #21:00:00 + "18" : 23 #21:00:00 embed: net: transformer num_tokens: 1 diff --git a/config/streams/jepa_forecast_multi_data_od/era5.yml b/config/streams/jepa_forecast_multi_data_od/era5.yml index 7c316c49c..5c1b5e4e5 100644 --- a/config/streams/jepa_forecast_multi_data_od/era5.yml +++ b/config/streams/jepa_forecast_multi_data_od/era5.yml @@ -32,5 +32,76 @@ ERA5: pred_head: ens_size: 1 num_layers: 1 + channel_weights : + q_10: 0.2 + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_10: 0.2 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_10: 0.2 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_10: 0.2 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_10: 0.2 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 diff --git a/config/streams/pretrain_multi_data_all_years/synop.yml b/config/streams/jepa_forecast_multi_data_od/synop.yml similarity index 65% rename from config/streams/pretrain_multi_data_all_years/synop.yml rename to config/streams/jepa_forecast_multi_data_od/synop.yml index 6501bcf7e..4833ace3c 100644 --- a/config/streams/pretrain_multi_data_all_years/synop.yml +++ b/config/streams/jepa_forecast_multi_data_od/synop.yml @@ -6,10 +6,10 @@ SurfaceCombined : type : obs stream_id : 2 - filenames : ['observations-ea-ofb-0001-1979-2022-combined-surface-v5-fixed-land-spatial80-min1km-lat60S70N-lsm09-min10.zarr', 'observations-ea-ofb-0001-2023-combined-surface-v5-fixed-land-heldout20-min1km-lat60S70N-lsm09-min10.zarr'] + filenames : ['observations-ea-ofb-0001-1979-2025-combined-surface-v5.zarr'] geoinfos: ['reportype', 'stalt', 'lsm'] - loss_weight : 0.5 - token_size : 64 + forcing: True + token_size : 512 tokenize_spacetime : False max_num_targets: -1 embed : @@ -27,4 +27,3 @@ SurfaceCombined : pred_head : ens_size : 1 num_layers : 1 - diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml new file mode 100644 index 000000000..21a139313 --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml @@ -0,0 +1,39 @@ +# ERA5 input stream for temporal JEPA. +# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked +# forcing so the teacher metadata mask remains populated for SSL targets. + +ERA5_in: + type: anemoi_operan + # filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] + stream_id: 0 + source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z'] + loss_weight: 1.0 + forcing: True + location_weight: cosine_latitude + token_size: 8 + tokenize_spacetime: False + max_num_targets: -1 + frequency: 06:00:00 + nominal_time_mapping : + "0" : 5 # 04:30:00 + "6" : 11 # 09:00:00 + "12" : 17 #16:30:00 + "18" : 23 #21:00:00 + embed: + net: transformer + num_tokens: 1 + num_heads: 8 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml new file mode 100644 index 000000000..14f2589e4 --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml @@ -0,0 +1,28 @@ +# Polar-orbiting observation stream for temporal JEPA. + +METOP_ABC_AVHRR_IASI: + type: obs + stream_id: 20 + filenames: ['observations-ea-ofb-0001-2007-2021-metop-a-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2013-2023-metop-b-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2019-2023-metop-c-iasi-radiances-v1.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] + source: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383'] + target: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383'] + forcing: True + loss_weight: 1.0 + token_size: 512 + tokenize_spacetime: False + embed: + net: transformer + num_tokens: 1 + num_heads: 2 + dim_embed: 256 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml new file mode 100644 index 000000000..700d414ee --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml @@ -0,0 +1,107 @@ +# ERA5 input stream for temporal JEPA. +# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked +# forcing so the teacher metadata mask remains populated for SSL targets. + +ERA5: + type: anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] + stream_id: 42 + source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z'] + loss_weight: 1.0 + diagnostic: True + location_weight: cosine_latitude + token_size: 8 + tokenize_spacetime: False + max_num_targets: -1 + frequency: 01:00:00 + embed: + net: transformer + num_tokens: 1 + num_heads: 8 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + channel_weights : + q_10: 0.2 + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_10: 0.2 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_10: 0.2 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_10: 0.2 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_10: 0.2 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 + + diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml new file mode 100644 index 000000000..04f9abb41 --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml @@ -0,0 +1,140 @@ +# Geostationary observation streams for temporal JEPA. +# These mirror era5_georing_avhrr_forecast_random_inputs/geos.yml, with forcing +# removed so teacher masks are available to the SSL loss. + +METEOSAT_SEVIRI_IR: + type: obs + stream_id: 10 + filenames: ['observations-file-2014-2024-seviri-o256-wegen-v3.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +GOES_ABI_IR: + type: obs + stream_id: 11 + filenames: ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +HIMAWARI_AHI_IR: + type: obs + stream_id: 12 + filenames: ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr', 'observations-file-2022-2024-himawari9-IR-o256-v1.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +GOES_ABI_VIS: + type: obs + stream_id: 13 + filenames: ['observations-file-2017-2024-abi-goes16-VIS-o256-v2.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +HIMAWARI_AHI_VIS: + type: obs + stream_id: 14 + filenames: ['observations-file-2015-2022-himawari8-VIS-o256-v1.zarr', 'observations-file-2022-2024-himawari9-VIS-o256-v1.zarr'] + geoinfo_channels: ['zenith', 'noise_time'] + # geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/streams/pretrain_multi_data_all_years/analysis.yml b/config/streams/pretrain_multi_data_all_years/analysis.yml index 1d7164660..d3fa4d0b1 100644 --- a/config/streams/pretrain_multi_data_all_years/analysis.yml +++ b/config/streams/pretrain_multi_data_all_years/analysis.yml @@ -3,7 +3,7 @@ # forcing so the teacher metadata mask remains populated for SSL targets. ERA5_in: - type: anemoi + type: anemoi_operan filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] # filenames: # ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] stream_id: 0 @@ -18,9 +18,9 @@ ERA5_in: frequency: 06:00:00 nominal_time_mapping : "0" : 5 # 04:30:00 - "6" : 9 # 09:00:00 + "6" : 11 # 09:00:00 "12" : 17 #16:30:00 - "18" : 21 #21:00:00 + "18" : 23 #21:00:00 embed: net: transformer num_tokens: 1 diff --git a/config/streams/pretrain_multi_data_all_years/avhrr.yml b/config/streams/pretrain_multi_data_all_years/avhrr.yml index 17cbbcd23..97827c01e 100644 --- a/config/streams/pretrain_multi_data_all_years/avhrr.yml +++ b/config/streams/pretrain_multi_data_all_years/avhrr.yml @@ -7,6 +7,8 @@ METOP_ABC_AVHRR_IASI: geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] loss_weight: 1.0 token_size: 512 + reconstruct: False + tokenize_spacetime: False embed: net: transformer num_tokens: 1 diff --git a/config/streams/pretrain_multi_data_all_years/geos.yml b/config/streams/pretrain_multi_data_all_years/geos.yml index 25b66646e..607bd4209 100644 --- a/config/streams/pretrain_multi_data_all_years/geos.yml +++ b/config/streams/pretrain_multi_data_all_years/geos.yml @@ -11,6 +11,7 @@ METEOSAT_SEVIRI_IR: token_size: 1024 tokenize_spacetime: False max_num_targets: 262144 + reconstruct: False embed: net: transformer num_tokens: 1 @@ -36,6 +37,7 @@ GOES_ABI_IR: token_size: 1024 tokenize_spacetime: False max_num_targets: 262144 + reconstruct: False embed: net: transformer num_tokens: 1 @@ -61,6 +63,7 @@ HIMAWARI_AHI_IR: token_size: 1024 tokenize_spacetime: False max_num_targets: 262144 + reconstruct: False embed: net: transformer num_tokens: 1 @@ -86,6 +89,7 @@ GOES_ABI_VIS: token_size: 1024 tokenize_spacetime: False max_num_targets: 262144 + reconstruct: False embed: net: transformer num_tokens: 1 @@ -112,6 +116,7 @@ HIMAWARI_AHI_VIS: token_size: 1024 tokenize_spacetime: False max_num_targets: 262144 + reconstruct: False embed: net: transformer num_tokens: 1 diff --git a/config/streams/pretrain_multi_data_od/analysis.yml b/config/streams/pretrain_multi_data_od/analysis.yml index 26e81ae5e..c8022d966 100644 --- a/config/streams/pretrain_multi_data_od/analysis.yml +++ b/config/streams/pretrain_multi_data_od/analysis.yml @@ -3,7 +3,7 @@ # forcing so the teacher metadata mask remains populated for SSL targets. ERA5_in: - type: anemoi + type: anemoi_operan filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] stream_id: 0 source: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] @@ -15,6 +15,11 @@ ERA5_in: tokenize_spacetime: True max_num_targets: -1 frequency: 06:00:00 + nominal_time_mapping : + "0" : 5 # 04:30:00 + "6" : 11 # 09:00:00 + "12" : 17 #16:30:00 + "18" : 23 #21:00:00 embed: net: transformer num_tokens: 1 diff --git a/config/streams/pretrain_multi_data_od/avhrr.yml b/config/streams/pretrain_multi_data_od/avhrr.yml index 17cbbcd23..aeed22714 100644 --- a/config/streams/pretrain_multi_data_od/avhrr.yml +++ b/config/streams/pretrain_multi_data_od/avhrr.yml @@ -7,6 +7,7 @@ METOP_ABC_AVHRR_IASI: geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] loss_weight: 1.0 token_size: 512 + tokenize_spacetime: False embed: net: transformer num_tokens: 1 diff --git a/config/streams/pretrain_multi_data_od/synop.yml b/config/streams/pretrain_multi_data_od/synop.yml index fb85b297f..bbd911b94 100644 --- a/config/streams/pretrain_multi_data_od/synop.yml +++ b/config/streams/pretrain_multi_data_od/synop.yml @@ -6,10 +6,9 @@ SurfaceCombined : type : obs stream_id : 2 - filenames : ['observations-ea-ofb-0001-1979-2022-combined-surface-v5-fixed-land-spatial80-min1km-lat60S70N-lsm09-min10.zarr', 'observations-ea-ofb-0001-2023-combined-surface-v5-fixed-land-heldout20-min1km-lat60S70N-lsm09-min10.zarr'] + filenames : ['observations-ea-ofb-0001-1979-2025-combined-surface-v5.zarr'] geoinfos: ['reportype', 'stalt', 'lsm'] - loss_weight : 0.5 - token_size : 64 + token_size : 512 tokenize_spacetime : False max_num_targets: -1 embed : @@ -27,5 +26,3 @@ SurfaceCombined : pred_head : ens_size : 1 num_layers : 1 - - diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index e18de025a..53198abba 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -132,6 +132,10 @@ def _sanitize_time_keys(conf: Config) -> Config: _sanitize_delta_time_keys(conf.test_config) _sanitize_start_end_time_keys(conf.test_config) + for extra_conf in (conf.get("extra_validation_configs") or {}).values(): + _sanitize_delta_time_keys(extra_conf) + _sanitize_start_end_time_keys(extra_conf) + return conf @@ -354,6 +358,7 @@ def _convert_interpolation(cfg, key): config.get("training_config"), config.get("test_config"), config.get("validation_config"), + *(config.get("extra_validation_configs") or {}).values(), ] for subconf in subconfs: diff --git a/src/weathergen/datasets/data_reader_anemoi.py b/src/weathergen/datasets/data_reader_anemoi.py index 6d4237fd3..4fc3c496d 100644 --- a/src/weathergen/datasets/data_reader_anemoi.py +++ b/src/weathergen/datasets/data_reader_anemoi.py @@ -259,28 +259,43 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6 channels = self.stream_info.get(ch_type) channels_exclude = self.stream_info.get(ch_type + "_exclude", []) + stream_name = self.stream_info["name"] # sanity check - is_empty = len(channels) == 0 if channels is not None else False - if is_empty: - stream_name = self.stream_info["name"] + if channels is not None and len(channels) == 0: _logger.warning(f"No channel for {stream_name} for {ch_type}.") - chs_idx = np.sort( - [ - ds0.name_to_index[k] - for (k, v) in ds0.typed_variables.items() - if ( - not v.is_computed_forcing - and not v.is_constant_in_time - and ( - np.array([f == k for f in channels]).any() if channels is not None else True - ) - and not np.array([f == k for f in channels_exclude]).any() + # Variables eligible for selection: physical fields only (no computed forcings, + # no constants-in-time), and not explicitly excluded. + eligible = { + k + for k, v in ds0.typed_variables.items() + if not v.is_computed_forcing + and not v.is_constant_in_time + and not any(ex == k for ex in channels_exclude) + } + + if channels is not None: + # Respect the order given in the stream config so the channel layout is identical + # across datasets that share channels (e.g. ERA5 vs operational analysis), + # regardless of each dataset's on-disk variable order. + seen: set[str] = set() + ordered = [] + for k in channels: + if k in eligible and k not in seen: + ordered.append(k) + seen.add(k) + missing = [k for k in channels if k not in eligible] + if missing: + _logger.warning( + f"{stream_name}: requested {ch_type} channels not available " + f"(excluded/forcing/constant-in-time or absent), skipped: {missing}" ) - ] - ) + else: + # No explicit selection: fall back to deterministic lexicographic order so the + # layout is still reproducible across datasets. + ordered = sorted(eligible) - return np.array(chs_idx, dtype=np.int64) + return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64) def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]: """ @@ -302,19 +317,19 @@ def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]: if len(geoinfo_channels) == 0: return np.array([], dtype=np.int64) - # Select channels that match the geoinfo list (exact match required) - chs_idx = np.sort( - [ds0.name_to_index[k] for k in ds0.typed_variables.keys() if k in geoinfo_channels] - ) + # Select channels that match the geoinfo list (exact match required), preserving the + # order requested in the config so the geoinfo layout is dataset-independent. + available = set(ds0.typed_variables.keys()) + ordered = [k for k in geoinfo_channels if k in available] - if len(chs_idx) == 0 and len(geoinfo_channels) > 0: + if len(ordered) == 0: stream_name = self.stream_info["name"] _logger.warning( f"No matching geoinfo channels found for {stream_name}. " f"Requested: {geoinfo_channels}" ) - return np.array(chs_idx, dtype=np.int64) + return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64) def _clip_lat(lats: NDArray) -> NDArray[np.float32]: diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index b37352a7e..2067ea882 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -394,20 +394,26 @@ def select( ch_filters: list[str] | None, excl: list[str] | None = None, ) -> tuple[list[str], NDArray]: - if excl and ch_filters: - mask = [ - any(f == c for f in ch_filters) and all(ex not in c for ex in excl) - for c in colnames - ] - elif ch_filters: - mask = [any(f == c for f in ch_filters) for c in colnames] - elif excl: - mask = [all(ex not in c for ex in excl) for c in colnames] + excl = excl or [] + name_to_pos = {c: i for i, c in enumerate(colnames)} + + if ch_filters: + # Respect config order (exact match) so the channel layout is identical across + # datasets that share channels, regardless of each dataset's column order. + seen: set[str] = set() + ordered = [] + for f in ch_filters: + if f in name_to_pos and f not in seen and all(ex not in f for ex in excl): + ordered.append(f) + seen.add(f) else: - assert False, "Cannot use select with both ch_filters and excl as None" + assert excl, "Cannot use select with both ch_filters and excl as None" + # No explicit selection: deterministic lexicographic order of non-excluded columns. + ordered = sorted(c for c in colnames if all(ex not in c for ex in excl)) - selected_cols_idx = cols_idx[np.where(mask)[0]] - selected_colnames = [colnames[i] for i in np.where(mask)[0]] + positions = [name_to_pos[c] for c in ordered] + selected_cols_idx = cols_idx[positions] + selected_colnames = [colnames[i] for i in positions] return selected_colnames, selected_cols_idx @override diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 62b1dcfba..461adcc48 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -113,18 +113,25 @@ def select_channels( """ Allow user to specify which columns they want to access. Get functions only returned for these specified columns. - """ - selected_colnames = [ - c - for c in colnames - if ( - np.array([c_sel in c for c_sel in cols_select]).any() - if cols_select is not None - else True and not np.array([c_nsel in c for c_nsel in cols_exclude]).any() - ) - ] - return selected_colnames + When ``cols_select`` is given, the returned columns follow the order of the select + filters (config order) so the channel layout is identical across datasets that share + channels, regardless of each dataset's column order. Without a selection, columns fall + back to deterministic lexicographic order. + """ + cols_exclude = cols_exclude or [] + + if cols_select is not None: + # Respect config order: group matching columns by the order of the select filters. + # Matching is substring-based (a filter may match several columns). + selected_colnames: list[str] = [] + for c_sel in cols_select: + for c in colnames: + if c_sel in c and c not in selected_colnames: + selected_colnames.append(c) + return selected_colnames + + return sorted(c for c in colnames if not any(c_nsel in c for c_nsel in cols_exclude)) def first_sample_with_data(self) -> int: """ diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index d845d2e7d..b523a375c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -131,6 +131,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.teacher_time_offset = 0 self.batch_size = get_batch_size_from_config(mode_cfg) + self.num_workers = cf.data_loading.num_workers self.shuffle = mode_cfg.shuffle self.len_timedelta = mode_cfg.time_window_len @@ -194,8 +195,13 @@ def check_samples(self, fsm: int): # streamlined calculation of length epoch_len = self.samples_per_mini_epoch - # adjust len to split loading across all workers and ensure it is multiple of batch_size - self.len = ((epoch_len // self.world_size) // self.batch_size) * self.batch_size + # adjust len to split loading across all workers and ensure it is multiple of batch_size; + # also account for num_workers so per-worker slice is a multiple of batch_size, + # preventing the range-loop in __iter__ from yielding extra batches via ceiling division + effective_workers = max(1, self.num_workers) + self.len = ( + (epoch_len // self.world_size) // (self.batch_size * effective_workers) + ) * (self.batch_size * effective_workers) n_duplicates = self.len * self.world_size - available_samples if not self.repeat_data: @@ -269,6 +275,7 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: stream_info[str(self._stage) + "_source_channels"] = ds.source_channels stream_info[str(self._stage) + "_target_channels"] = ds.target_channels + stream_info[str(self._stage) + "_geoinfo_channels"] = ds.geoinfo_channels stream_info["target_channel_weights"] = ( ds.target_channel_weights if ds.target_channel_weights is not None @@ -732,7 +739,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): input_data, source_masks.metadata[sidx], is_student=True, - add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels",[]), + add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels", []), ) sdata = self._build_stream_data( @@ -761,8 +768,10 @@ def _get_batch(self, idx: int, num_forecast_steps: int): # Apply self-flow noise to teacher data (handled by masker) input_data_target = self.masker.apply_noise_to_data( - input_data_target_orig, target_masks.metadata[tidx], is_student=False, - add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels",[]), + input_data_target_orig, + target_masks.metadata[tidx], + is_student=False, + add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels", []), ) sdata = self._build_stream_data( diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7818612be..38fcde666 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -40,7 +40,7 @@ from weathergen.model.layers import MLP, NamedLinear from weathergen.model.utils import get_num_parameters from weathergen.utils.distributed import is_root -from weathergen.utils.utils import get_dtype, is_stream_forcing +from weathergen.utils.utils import get_dtype, is_stream_reconstructed logger = logging.getLogger(__name__) @@ -419,8 +419,9 @@ def create(self) -> "Model": for i_stream, si in enumerate(cf.streams): stream_name = self.stream_names[i_stream] - # skip decoder if channels are empty - if is_stream_forcing(si): + # skip decoder for streams that are not physically reconstructed + # (forcing/input-only, or explicit reconstruct: false -> JEPA-only target) + if not is_stream_reconstructed(si): continue # skip for the moment to ensure target embedding and tte exist (ordering of @@ -515,8 +516,9 @@ def create(self) -> "Model": for i_stream, si in enumerate(cf.streams): stream_name = self.stream_names[i_stream] - # skip decoder if channels are empty - if is_stream_forcing(si): + # skip decoder for streams that are not physically reconstructed + # (forcing/input-only, or explicit reconstruct: false -> JEPA-only target) + if not is_stream_reconstructed(si): continue pred_spatial_shared = si.get("pred_spatial_shared") @@ -779,10 +781,12 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: if self.forecast_engine: tokens = self.forecast_engine(tokens, step, model_params.rope_coords) - # decoder predictions - output = self.predict_decoders(model_params, step, tokens, batch, output) - # latent predictions (raw and with SSL heads) - output = self.predict_latent(model_params, step, tokens, batch, output, intermediates) + if "masking" in self.cf.training_config.training_mode: + # decoder predictions + output = self.predict_decoders(model_params, step, tokens, batch, output) + if "student_teacher" in self.cf.training_config.training_mode: + # latent predictions (raw and with SSL heads) + output = self.predict_latent(model_params, step, tokens, batch, output, intermediates) return output @@ -890,6 +894,12 @@ def predict_decoders( # pair with tokens from assimilation engine to obtain target tokens for stream_name in self.stream_names: + # streams without a physical decoder (forcing, or reconstruct: false JEPA-only + # targets) have no embed_target_coords/target_token_engine. Skip them here even + # though they may still carry (unused) target coords on the student view. + if stream_name not in self.embed_target_coords: + continue + # extract target coords for current stream and fstep and convert to one tensor t_coords = [ batch.samples[i_b].streams_data[stream_name].target_coords[step] diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 6b42fca9a..255903ce1 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -40,6 +40,20 @@ type TrainingMode = str +def _has_trainable_params(module: torch.nn.Module) -> bool: + """True if the module has at least one parameter with requires_grad=True. + + FSDP2 raises "RuntimeError: _chunk_cat expects non-empty tensor" in the + backward reduce-scatter (foreach_reduce) when a fully_shard group contains + only frozen parameters, since there are no gradients to reduce. This happens + during fine-tuning (e.g. forecast fine-tuning freezes the encoder and + latent_heads). Skipping fully_shard for fully-frozen modules leaves their + parameters in the root FSDP group, which still has trainable parameters, so + they remain sharded without triggering the empty-gradient reduce. + """ + return any(p.requires_grad for p in module.parameters()) + + def init_model_and_shard( cf, dataset, @@ -96,36 +110,36 @@ def init_model_and_shard( ) for module in model.encoder.ae_local_engine.ae_local_blocks.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) for module in model.encoder.ae_local_global_engine.ae_adapter.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) for module in model.encoder.ae_global_engine.ae_global_blocks.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) for module in model.forecast_engine.fe_blocks.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): # reshard_after_forward=False keeps FE parameters unsharded # during the multi-step rollout loop. # Needed for pushforward trick. fully_shard(module, reshard_after_forward=False, **fsdp_kwargs) for module in model.latent_heads.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) if model.deep_ssl_fusion is not None: for module in model.deep_ssl_fusion.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) if model.deep_ssl_level_projections is not None: for module in model.deep_ssl_level_projections.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **fsdp_kwargs) full_precision_fsdp_kwargs = { @@ -140,7 +154,7 @@ def init_model_and_shard( } for module in model.target_token_engines.modules(): - if isinstance(module, modules_to_shard): + if isinstance(module, modules_to_shard) and _has_trainable_params(module): fully_shard(module, **full_precision_fsdp_kwargs) if with_ddp and with_fsdp: diff --git a/src/weathergen/model/utils.py b/src/weathergen/model/utils.py index 7dd2060bb..035bd8b31 100644 --- a/src/weathergen/model/utils.py +++ b/src/weathergen/model/utils.py @@ -48,6 +48,7 @@ def apply_fct_to_blocks(model, blocks, fct): name = module.name if hasattr(module, "name") else name # avoid the whole model element which has name '' if (re.fullmatch(blocks, name) is not None) and (name != ""): + # logger.info(f"Freezing weights of {name}") fct(module) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0b3c504db..f5e2e052f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -80,6 +80,13 @@ def __init__(self, train_logging: Config): self.data_loader_validation: torch.utils.data.DataLoader | None = None self.dataset: MultiStreamDataSampler | None = None self.dataset_val: MultiStreamDataSampler | None = None + # extra validation sets, keyed by stage label "val_" + self.extra_val_cfgs: dict[str, Config] = {} + self.datasets_val_extra: dict[str, MultiStreamDataSampler] = {} + self.data_loaders_val_extra: dict[str, torch.utils.data.DataLoader] = {} + self.target_and_aux_calculators_val_extra: dict[str, dict] = {} + self.loss_calculators_val_extra: dict[str, LossCalculator] = {} + self.batch_sizes_val_extra: dict[str, int] = {} self.device: torch.device = None self.ema_model = None self.grad_scaler: torch.amp.GradScaler | None = None @@ -144,6 +151,19 @@ def init(self, cf: Config, devices): self.validation_cfg, cf.get("test_config", {}), cfg_keys_to_filter ) + # extra validation sets, each derived from the validation cfg like test_cfg; + # extra sets must only override dates/shuffle/sample-count/batch-size, not + # stream/channel selection (samplers share cf.streams) + for name, overrides in (cf.get("extra_validation_configs", None) or {}).items(): + if not (overrides or {}).get("enabled", True): + continue # set disabled, e.g. by a train_continue override + stage_label = f"val_{name}" + extra_cfg = get_active_stage_config(self.validation_cfg, overrides, cfg_keys_to_filter) + # extra sets never write sample output files (would collide with primary val output) + extra_cfg.output = {} + self.extra_val_cfgs[stage_label] = extra_cfg + self.batch_sizes_val_extra[stage_label] = get_batch_size_from_config(extra_cfg) + # batch sizes self.batch_size_per_gpu = get_batch_size_from_config(self.training_cfg) self.batch_size_validation_per_gpu = get_batch_size_from_config(self.validation_cfg) @@ -155,6 +175,8 @@ def init(self, cf: Config, devices): strict=True, ): config.validate_forecast_policy_and_steps(mode_cfg.get("forecast", {}), mode) + for stage_label, extra_cfg in self.extra_val_cfgs.items(): + config.validate_forecast_policy_and_steps(extra_cfg.get("forecast", {}), stage_label) self.mixed_precision_dtype = get_dtype(cf.mixed_precision_dtype) @@ -256,11 +278,72 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): logger.info(f"Starting inference with id={self.cf.general.run_id}.") - # inference validation set self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") + def _check_channel_order_consistency( + self, + dataset: MultiStreamDataSampler, + from_run_id: str, + mini_epoch: int | None, + stage: Stage, + ) -> None: + """Guard against silently scrambling the channel<->weight mapping when continuing. + + Compares the source/target/geoinfo channel order resolved for the current data against + the order stored in the checkpoint's config and raises if they differ. Streams (or + channel lists) absent from the checkpoint config cannot be verified and are skipped with + a warning (e.g. geoinfo for checkpoints predating the resolved-geoinfo back-fill). + """ + try: + prev_cf = config.load_run_config(from_run_id, mini_epoch, None) + except FileNotFoundError: + logger.warning( + f"Could not load config for run_id '{from_run_id}' to verify channel order; " + "skipping channel-order consistency check." + ) + return + + prev_streams = {s["name"]: s for s in prev_cf.get("streams", [])} + src_key = f"{stage}_source_channels" + tgt_key = f"{stage}_target_channels" + geo_key = f"{stage}_geoinfo_channels" + + mismatches: list[str] = [] + for name, readers in dataset.streams_datasets.items(): + prev = prev_streams.get(name) + if prev is None: + continue + reader = readers[0] + for key, resolved in ( + (src_key, list(reader.source_channels)), + (tgt_key, list(reader.target_channels)), + (geo_key, list(reader.geoinfo_channels)), + ): + stored = prev.get(key) + if stored is None: + logger.warning( + f"Checkpoint '{from_run_id}' has no '{key}' for stream '{name}'; " + "cannot verify channel order for it." + ) + continue + if list(stored) != resolved: + mismatches.append( + f" [{name}] {key}:\n" + f" checkpoint: {list(stored)}\n" + f" current: {resolved}" + ) + + if mismatches: + details = "\n".join(mismatches) + raise ValueError( + f"Channel order/content differs from the checkpoint being continued " + f"(run_id='{from_run_id}'). Continuing would scramble the learned " + f"channel<->weight mapping. Align the stream configs (channel order matters):\n" + f"{details}" + ) + def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # general initalization self.init(cf, devices) @@ -276,6 +359,11 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) + if run_id_contd is not None: + self._check_channel_order_consistency( + self.dataset, run_id_contd, mini_epoch_contd, TRAIN + ) + loader_params = { "batch_size": None, "batch_sampler": None, @@ -288,6 +376,19 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.dataset_val, **loader_params, sampler=None ) + for stage_label, extra_cfg in self.extra_val_cfgs.items(): + # stage=VAL so masking/loss behave as in validation; the stage label is + # only used for logging + self.datasets_val_extra[stage_label] = MultiStreamDataSampler(cf, extra_cfg, stage=VAL) + # cap workers: each loader spawns its own processes, each re-opening all + # stream readers + extra_loader_params = loader_params | { + "num_workers": min(cf.data_loading.num_workers, 2) + } + self.data_loaders_val_extra[stage_label] = torch.utils.data.DataLoader( + self.datasets_val_extra[stage_label], **extra_loader_params, sampler=None + ) + self.model, self.model_params = init_model_and_shard( cf, self.dataset, @@ -328,9 +429,13 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # get target_aux calculators for different loss terms self.target_and_aux_calculators = self.get_target_aux_calculators(self.training_cfg) self.target_and_aux_calculators_val = self.get_target_aux_calculators(self.validation_cfg) + for stage_label, extra_cfg in self.extra_val_cfgs.items(): + self.target_and_aux_calculators_val_extra[stage_label] = ( + self.get_target_aux_calculators(extra_cfg) + ) # Restore EMA teacher weights when continuing from a checkpoint - if run_id_contd is not None: + if run_id_contd is not None: # and self.cf.general.istep != 0: # To be tested self._load_ema_teacher_state(run_id_contd, mini_epoch_contd) # if with_fsdp then parameter count is unreliable @@ -374,7 +479,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): ) # Restore optimizer momentum buffers when continuing from a checkpoint - if run_id_contd is not None: + if run_id_contd is not None and self.cf.general.istep != 0: self._load_optimizer_state(run_id_contd, mini_epoch_contd) if self.cf.general.istep > 0 and is_root(): @@ -384,6 +489,10 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.loss_calculator = LossCalculator(cf, self.training_cfg, TRAIN, device=self.device) val_cfg = self.validation_cfg self.loss_calculator_val = LossCalculator(cf, val_cfg, VAL, device=self.device) + for stage_label, extra_cfg in self.extra_val_cfgs.items(): + self.loss_calculators_val_extra[stage_label] = LossCalculator( + cf, extra_cfg, VAL, device=self.device + ) # recover mini_epoch when continuing run if self.world_size_original is None: @@ -418,6 +527,18 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): ) self.validate(mini_epoch, self.validation_cfg, self.batch_size_validation_per_gpu) + for stage_label, extra_cfg in self.extra_val_cfgs.items(): + logger.info( + f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: " + f"validate {stage_label}." + ) + self.validate( + mini_epoch, + extra_cfg, + self.batch_sizes_val_extra[stage_label], + stage=stage_label, + ) + logger.info( f"Mini_epoch {mini_epoch} of {self.training_cfg.num_mini_epochs}: save_model." ) @@ -530,6 +651,11 @@ def train(self, mini_epoch): self.optimizer.zero_grad() self.grad_scaler.scale(loss).backward() + #Comment in this code when trying to debug DDP errors, it will find the offending params + # for name, param in self.model.named_parameters(): + # if param.requires_grad and param.grad is None: + # print(f"UNUSED (no grad): {name}") + # gradient clipping self.grad_scaler.unscale_(self.optimizer) total_norm = torch.nn.utils.clip_grad_norm_( @@ -598,23 +724,33 @@ def train(self, mini_epoch): self.dataset.advance() - def validate(self, mini_epoch, mode_cfg, batch_size): + def validate(self, mini_epoch, mode_cfg, batch_size, stage: Stage = VAL): """ - Perform validation / test computation as specified by mode_cfg + Perform validation / test computation as specified by mode_cfg. + + stage selects the dataset/loss objects: VAL for the primary validation set, + "val_" for an extra validation set; it is also the logging label. """ cf = self.cf self.model.eval() - dataset_val_iter = iter(self.data_loader_validation) + if stage == VAL: + dataset, data_loader = self.dataset_val, self.data_loader_validation + target_aux_calcs = self.target_and_aux_calculators_val + else: + dataset = self.datasets_val_extra[stage] + data_loader = self.data_loaders_val_extra[stage] + target_aux_calcs = self.target_and_aux_calculators_val_extra[stage] + loss_calculator = self._loss_calculator_for(stage) + + dataset_val_iter = iter(data_loader) num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp - with tqdm.tqdm( - total=len(self.data_loader_validation), disable=self.cf.with_ddp - ) as pbar: + with tqdm.tqdm(total=len(data_loader), disable=self.cf.with_ddp) as pbar: for bidx, batch in enumerate(dataset_val_iter): if cf.data_loading.get("memory_pinning", False): # pin memory for faster CPU-GPU transfer @@ -640,7 +776,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): + for loss_name, target_aux in target_aux_calcs.items(): target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) targets_and_auxs[loss_name] = target_aux.compute( self.cf.general.istep, @@ -649,7 +785,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): self.model, ) - _ = self.loss_calculator_val.compute_loss( + _ = loss_calculator.compute_loss( preds=preds, targets_and_aux=targets_and_auxs, metadata=extract_batch_metadata(batch), @@ -662,7 +798,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): denormalize_data_fct = ( (lambda x0, x1: x1) if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels + else dataset.denormalize_target_channels ) # write output write_output( @@ -682,11 +818,11 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if (bidx * batch_size) > mode_cfg.samples_per_mini_epoch: break - self._log_terminal(0, mini_epoch, VAL) - self._log(VAL) + self._log_terminal(0, mini_epoch, stage) + self._log(stage) # avoid that there is a systematic bias in the validation subset - self.dataset_val.advance() + dataset.advance() def _get_full_model_state_dict(self): maybe_sharded_sd = ( @@ -892,19 +1028,29 @@ def _load_optimizer_state(self, run_id: str, mini_epoch): total = sum(1 for _ in self.model.parameters()) logger.info(f"Loaded optimizer state for {loaded}/{total} parameters.") + def _loss_calculator_for(self, stage: Stage) -> LossCalculator: + """ + Get the loss calculator for the given stage (TRAIN, VAL or "val_"). + """ + if stage == TRAIN: + return self.loss_calculator + if stage == VAL: + return self.loss_calculator_val + return self.loss_calculators_val_extra[stage] + def _log(self, stage: Stage): """ Logs training or validation metrics. Args: - stage: Stage Is it's VAL, logs are treated as validation logs. - If TRAIN, logs are treated as training logs + stage: Stage If TRAIN, logs are treated as training logs. + Otherwise (VAL or "val_"), as validation logs. Notes: - This method only executes logging on the main process (rank 0). - After logging, historical loss and standard deviation records are cleared. """ - loss_calculator = self.loss_calculator_val if stage == VAL else self.loss_calculator + loss_calculator = self._loss_calculator_for(stage) avg_loss, losses_all, stddev_all = prepare_losses_for_logging( loss_calculator.loss_hist, loss_calculator.losses_unweighted_hist, @@ -915,7 +1061,7 @@ def _log(self, stage: Stage): if is_root(): # plain logger - if stage == VAL: + if stage != TRAIN: self.train_logger.add_logs(stage, samples, losses_all, stddev_all) elif self.cf.general.istep >= 0: @@ -1097,9 +1243,9 @@ def _log_instant_grad_norms(self, stage: Stage): def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): print_freq = self.train_logging.terminal - if bidx % print_freq == 0 and bidx > 0 or stage == VAL: + if bidx % print_freq == 0 and bidx > 0 or stage != TRAIN: # compute from last iteration - loss_calculator = self.loss_calculator_val if stage == VAL else self.loss_calculator + loss_calculator = self._loss_calculator_for(stage) avg_loss, losses_all, _ = prepare_losses_for_logging( loss_calculator.loss_hist, loss_calculator.losses_unweighted_hist, @@ -1107,9 +1253,9 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): ) if is_root(): - if stage == VAL: + if stage != TRAIN: logger.info( - f"""validation ({self.cf.general.run_id}) : {mini_epoch:03d} : + f"""validation {stage} ({self.cf.general.run_id}) : {mini_epoch:03d} : {np.nanmean(avg_loss)}""" ) diff --git a/src/weathergen/train/utils.py b/src/weathergen/train/utils.py index 430870f71..4566b9ad7 100644 --- a/src/weathergen/train/utils.py +++ b/src/weathergen/train/utils.py @@ -9,7 +9,6 @@ import copy import json -from typing import Literal import torch from omegaconf import OmegaConf @@ -17,8 +16,8 @@ from weathergen.common import config from weathergen.common.config import Config, merge_configs -# Run stages -Stage = Literal["train", "val", "test"] +# Run stages: "train", "val", "test", or "val_" for extra validation sets +Stage = str TRAIN: Stage = "train" VAL: Stage = "val" TEST: Stage = "test" diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index ea5a864d5..3b6df27e8 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -419,7 +419,7 @@ def plot_loss_per_stream( linestyle = ":" if "stddev" in err else linestyle alpha = 1.0 if "train" in modes and "val" in modes: - alpha = 0.35 if "train" in mode else alpha + alpha = 0.35 if mode == "train" else alpha for j, run_data in enumerate(runs_data): run_data_mode = run_data.by_mode(mode) @@ -591,8 +591,10 @@ def plot_loss_per_run( linestyle = ":" if "stddev" in err else linestyle alpha = 1.0 if "train" in modes and "val" in modes: - alpha = 0.35 if "train" in mode else alpha + alpha = 0.35 if mode == "train" else alpha run_data_mode = run_data.by_mode(mode) + if run_data_mode.is_empty(): + continue x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0] # find the cols of the requested metric (e.g. mse) for all streams @@ -849,6 +851,9 @@ def plot_train(args=None): for run_id in runs_ids ] + # extra validation sets ("val_") discovered in any run's metrics + extra_modes = sorted({m for rd in runs_data for m in rd.extra}) + # determine which runs are still alive (as a process, though they might hang internally) ret = subprocess.run(["squeue"], capture_output=True) lines = str(ret.stdout).split("\\n") @@ -873,7 +878,7 @@ def plot_train(args=None): # compare different runs plot_loss_per_stream( - ["train", "val"], + ["train", "val", *extra_modes], runs_ids, runs_data, runs_active, @@ -889,7 +894,7 @@ def plot_train(args=None): plot_dir=out_dir, ) plot_loss_per_stream( - ["val"], + ["val", *extra_modes], runs_ids, runs_data, runs_active, @@ -924,7 +929,7 @@ def plot_train(args=None): # plot all cols for all run_ids for run_id, run_data in zip(runs_ids, runs_data, strict=False): plot_loss_per_run( - ["train", "val"], + ["train", "val", *extra_modes], run_id, runs_ids[run_id], run_data, @@ -934,7 +939,7 @@ def plot_train(args=None): legend_outside=args.legend_outside, ) plot_loss_per_run( - ["val"], + ["val", *extra_modes], run_id, runs_ids[run_id], run_data, diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 5f1550e42..46c847e8a 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -14,7 +14,7 @@ import time import traceback from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path import numpy as np @@ -44,6 +44,8 @@ class Metrics: train: pl.DataFrame val: pl.DataFrame system: pl.DataFrame + # extra validation sets, keyed by stage label "val_" + extra: dict[str, pl.DataFrame] = field(default_factory=dict) def by_mode(self, s: str) -> pl.DataFrame: match s: @@ -53,8 +55,11 @@ def by_mode(self, s: str) -> pl.DataFrame: return self.val case "system": return self.system + case _ if s.startswith("val_"): + # empty frame when this run lacks the extra validation set + return self.extra.get(s, pl.DataFrame()) case _: - raise ValueError(f"Unknown mode {s}. Use 'train', 'val' or 'system'.") + raise ValueError(f"Unknown mode {s}. Use 'train', 'val', 'val_' or 'system'.") class TrainLogger: @@ -236,7 +241,16 @@ def read( log_val = np.array([]) metrics_val_df = read_metrics(cf, run_id, "val", cols2, cols2_patterns, result_dir_base) - return Metrics(run_id, "train", log_train_df, metrics_val_df, None) + # extra validation sets: discover "val_" stages present in the metrics file + metrics_path = get_train_metrics_path(base_path=result_dir_base, run_id=run_id) + stages = read_metrics_file(metrics_path)["stage"].unique().to_list() + extra = { + stage: read_metrics(cf, run_id, stage, list(cols2), cols2_patterns, result_dir_base) + for stage in sorted(stages) + if stage.startswith("val_") + } + + return Metrics(run_id, "train", log_train_df, metrics_val_df, None, extra) def read_metrics( diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index 291ab1521..1bda83d9d 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -49,6 +49,23 @@ def is_stream_forcing(stream_cfg: dict, stage: Stage | None = None) -> bool: return is_forcing +def is_stream_reconstructed(stream_cfg: dict, stage: Stage | None = None) -> bool: + """ + Determine if a stream is physically reconstructed, i.e. has a decoder and contributes + to the physical (decoder) reconstruction loss. + + A stream is NOT reconstructed if it is forcing (input-only) or if it explicitly opts + out via ``reconstruct: false``. The latter lets a stream still serve as a + student-teacher (JEPA) target while having no physical decoder, so JEPA can be trained + on all streams while only a subset is reconstructed in physical space. Note that, unlike + forcing streams, ``reconstruct: false`` streams keep a normal (non-empty) target mask, + so the teacher still encodes them. + """ + if is_stream_forcing(stream_cfg, stage): + return False + return stream_cfg.get("reconstruct", True) + + def is_stream_diagnostic(stream_cfg: dict, stage: Stage | None = None) -> bool: """ Determine if stream is diagnostic, i.e. does not contribute to model input