diff --git a/config/config_cerra_crps_decode.yml b/config/config_cerra_crps_decode.yml new file mode 100644 index 000000000..ae4f24258 --- /dev/null +++ b/config/config_cerra_crps_decode.yml @@ -0,0 +1,248 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ===================================================================================== +# CERRA CRPS decoding from the frozen pretrained encoder `yqo6nvmy-pretrain`. +# +# Goal: attach a CRPS decoder to the pretrained encoder and probe how well CERRA can be +# predicted at the *same* (forecast) step as the input window. +# +# - Encoder loaded from yqo6nvmy-pretrain and FROZEN (linear-probe style). +# - Inputs = the pretrain_multi_data_all_years streams, now all marked `forcing`. +# - CERRA = `diagnostic` output stream (contributes no input tokens), trained with CRPS. +# - Reconstruction / masking training mode; student-teacher (SSL) loss disabled. +# +# IMPORTANT: every architecture parameter below is copied verbatim from the +# yqo6nvmy-pretrain checkpoint config so that load_chkpt (strict=False) can match the +# encoder/forecast-engine weights. Do NOT change these unless you also retrain the +# encoder — a shape mismatch will crash load_state_dict. +# ===================================================================================== + + +decoder_type: PerceiverIOCoordConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 64 + +fe_num_blocks: 0 +fe_layer_norm_after_blocks: [] +fe_impute_latent_noise_std: 0.0 +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: True + +mlp_type: swiglu +qk_norm_type: RMSNorm +use_xsa: True +with_step_conditioning: True +noise_pre_predictor_std: 0 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +# The CRPS decoder decodes all M ensemble members in a SINGLE varlen call +# (model.py predict_decoders): each member's B*H query groups pair with its own KV groups, +# so the decoder params are used exactly once per forward — no per-member reuse. Because +# there is no reuse, per-block activation checkpointing is DDP-safe (it is the combination +# reuse + checkpointing that trips the reducer's "marked ready twice"), and it keeps peak +# memory low. The single call launches M*B*H attention groups, which must stay under the +# CUDA grid-dim cap (65535) — true for M<=4 at healpix-5 with B=1, hence members capped at 4. +# static_graph is NOT usable here: predict_decoders has data-dependent branches (P==0 +# skip, NaN guard) so the participating-param set varies per iteration. +# (config_forecasting_crps.yml uses FSDP instead, which has none of these constraints.) +with_fsdp: False +decoder_checkpointing: True # per-block checkpoint on — safe: members batched, no reuse +ddp_find_unused_parameters: False # every trainable param must participate in the loss — + # freeze leftover modules rather than masking them +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +# ---- CRPS decoder (the merged work) ---------------------------------------------- +# Ensemble members are produced by perturbing the latent tokens before decoding. +# kernel_crps requires >1 member. pred_head.ens_size stays 1 in the stream config. +latent_perturbation_num_members: 4 # capped at 4 so M× un-checkpointed decoder activations fit +latent_perturbation_sigma_init: 0.05 # ASSUMPTION: initial latent noise std +latent_perturbation_sigma_learnable: True + +# ASSUMPTION: freeze the whole encoder (probe the frozen representation). Set to "" to +# fine-tune end-to-end, or narrow the regex to unfreeze parts. +freeze_modules: ".*encoder.*|.*deep_ssl_fusion.*" # |^(?!.*CERRA).*$" +# freeze_modules: "^(encoder|forecast_engine|latent_heads|latent_pre_norm|deep_ssl_fusion|deep_ssl_level_projections|embed_target_coords_(?!CERRA$).+|TargetPredictionEngine_(?!CERRA$).+|TargetPredictionEngineClassic_(?!CERRA$).+|BilinearDecoder_(?!CERRA$).+|EnsPredictionHead_(?!CERRA$).+)$" + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/pretrain_multi_data_all_years_cerra/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + with_ddp: True + + # zarr-v3's asyncio event loop is not fork-safe; forked DataLoader workers abort + # during CERRA validation. "spawn" gives each worker a fresh interpreter + event loop. + multiprocessing_method: "spawn" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + memory_pinning: True + # keep spawned workers alive across mini-epochs / validation calls to avoid paying + # the (expensive) spawn startup + reader re-open every time. + persistent_workers: True + + +# config for training +training_config: + + # reconstruction / masking mode (NOT student_teacher, NOT latent_loss) + training_mode: ["masking"] + + # ASSUMPTION: training budget. Tune to taste; encoder is frozen so only the decoder + # (+ latent perturbation sigma) is optimised. + num_mini_epochs: 32 + samples_per_mini_epoch: 4096 + shuffle: True + + # ASSUMPTION: CERRA is available 1985-2023; reuse the era5_cerra train/val split. + start_date: 1985-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + beta1 : 0.98125 + beta2 : 0.9875 + eps : 2e-08 + + # CERRA decoded with the CRPS (kernel_crps) loss; SSL student-teacher disabled. + losses : { + "student-teacher": { + enabled: False, + type: Disabled, + }, + "physical": { + type: LossPhysical, + # "mse": null removes the mse loss inherited from default_config so only CRPS trains + loss_fcts: { "mse": null, "kernel_crps": { }, }, + }, + } + + # NOTE: when launched via `train_continue --from-run-id yqo6nvmy-pretrain`, the SSL + # pretrain config is the base, so its `random_easy` model_input / `random_easy_target` + # target_input are inherited and MUST be disabled here. Otherwise model_input and + # target_input end up with mismatched lengths and mask building asserts. (Harmless on + # the plain `train` path, where these keys are absent.) + model_input: { + "random_easy" : { + enabled: False, + num_samples: 0, + }, + "forecasting" : { + masking_strategy: "forecast", + }, + } + + target_input: { + "random_easy_target" : { + enabled: False, + }, + } + + # ASSUMPTION: "same forecast step" => offset 0 + num_steps 0, i.e. decode CERRA at the + # input window (auto-encoder/diagnostic), matching the pretrain encoder (num_steps=0). + forecast : + time_step: 06:00:00 + offset: 0 + num_steps: 0 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # no SSL here, so no EMA teacher needed for validation + validate_with_ema: + enabled : False + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + output : { + num_samples: 0, + normalized_samples: False, + streams: null, + } + + validate_before_training: False + + +# Tags for experiment tracking +wgtags: + org: null + issue: null + exp: "cerra_crps_decode" + grid: null diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 2f91f8097..31becbe33 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -66,7 +66,6 @@ forecast_att_dense_rate: 1.0 healpix_level: 5 -<<<<<<< HEAD # Generalized RoPE selector. rope_mode: none # one of: none, 2d, spherical # Optional spherical harmonic band for spherical RoPE. If null, the model picks one @@ -74,9 +73,6 @@ rope_mode: none # one of: none, 2d, spherical rope_spherical_band: null mlp_type: swiglu use_xsa: True -======= -rope_2D: False ->>>>>>> origin/develop with_mixed_precision: True with_flash_attention: True diff --git a/config/config_forecasting_crps.yml b/config/config_forecasting_crps.yml new file mode 100644 index 000000000..ee8d59104 --- /dev/null +++ b/config/config_forecasting_crps.yml @@ -0,0 +1,256 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +# Latent Gaussian perturbation for probabilistic forecasting (CRPS training) +# Set num_members > 1 to enable; use kernel_crps loss. +latent_perturbation_num_members: 5 # 0 / 1 = disabled; >1 = number of ensemble members +latent_perturbation_sigma_init: 0.05 # initial noise standard deviation +latent_perturbation_sigma_learnable: True # true = nn.Parameter; false = fixed buffer + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg_forecasting/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + with_ddp: False, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 100 + samples_per_mini_epoch: 1 + shuffle: False + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: {"kernel_crps": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 8 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : False + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_forecasting_crps_debug.yml b/config/config_forecasting_crps_debug.yml new file mode 100644 index 000000000..83939f4fc --- /dev/null +++ b/config/config_forecasting_crps_debug.yml @@ -0,0 +1,259 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +# Latent Gaussian perturbation for probabilistic forecasting (CRPS training) +# Set num_members > 1 to enable; use kernel_crps loss. +latent_perturbation_num_members: 5 # 0 / 1 = disabled; >1 = number of ensemble members +latent_perturbation_sigma_init: 0.1 # initial noise standard deviation +latent_perturbation_sigma_learnable: False # true = nn.Parameter; false = fixed buffer +latent_perturbation_sigma_lr: 0.0 #1.0e-2 # separate LR for log_sigma (~200x main peak LR); 0.0 = disabled +latent_perturbation_sigma_override: null + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/era5_1deg/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + with_ddp: True + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 1000 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 2048 + num_steps_cooldown: 6144 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.9604 # == 0.85 on 1 node x 4 gpus (0.85^(1/4)) + beta2 : 0.9747 # == 0.90 on 1 node x 4 gpus (0.90^(1/4)) + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: {"mse": null, "kernel_crps": {}, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : False + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: "crps_perturbation_v1" + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + # Override per-run: m_s_l + grid: null diff --git a/config/config_forecasting_era5_cerra.yml b/config/config_forecasting_era5_cerra.yml new file mode 100644 index 000000000..fd3f1fa9e --- /dev/null +++ b/config/config_forecasting_era5_cerra.yml @@ -0,0 +1,250 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +#load_chkpt: {'run_id': 'qgnvcrky', 'mini_epoch': -1} + +norm_type: "LayerNorm" + +##################################### + +#streams_directory: "./config/streams/era5_cerra/" #"./config/streams/era5_cerra/" set to "./config/streams/era5_1deg/" for era5 training only +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + with_ddp: True + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 6 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1985-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: "knmi-mf" + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: "" + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_jepa_multi_data.yml b/config/config_jepa_multi_data.yml index c14f957d9..8790e80cc 100644 --- a/config/config_jepa_multi_data.yml +++ b/config/config_jepa_multi_data.yml @@ -69,7 +69,7 @@ healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) # When False: uses traditional pe_global positional encoding -rope_2D: True +rope_2D: True with_mixed_precision: True with_flash_attention: True diff --git a/config/default_config.yml b/config/default_config.yml index 67bcc3e76..655d0720b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -89,6 +89,12 @@ latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True +# Latent Gaussian perturbation for probabilistic forecasting (CRPS training) +# Set num_members > 1 to enable; use kernel_crps loss. +latent_perturbation_num_members: 0 # 0 / 1 = disabled; >1 = number of ensemble members +latent_perturbation_sigma_init: 0.01 # initial noise standard deviation +latent_perturbation_sigma_learnable: true # true = nn.Parameter; false = fixed buffer + freeze_modules: "" load_chkpt: {} diff --git a/config/evaluate/config_crps_debug.yml b/config/evaluate/config_crps_debug.yml new file mode 100644 index 000000000..a23e15ed8 --- /dev/null +++ b/config/evaluate/config_crps_debug.yml @@ -0,0 +1,24 @@ +evaluation: + metrics: ["rmse", "bias", "crps", "spread", "ssr", "rank_histogram"] + regions: ["global"] + summary_plots: true + plot_ensemble: "members" # visualise individual members + summary_dir: "./plots/crps_epoch2/" + +default_streams: + ERA5: + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850"] # a few key variables + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" # required for all probabilistic metrics + plotting: + sample: [0] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + plot_maps: true + plot_histograms: true + plot_animations: true + +run_ids: + zt0vcpho: + label: "CRPS_s0.05_learnable_5members" \ No newline at end of file diff --git a/config/evaluate/config_crps_spread.yml b/config/evaluate/config_crps_spread.yml new file mode 100644 index 000000000..f2e61227d --- /dev/null +++ b/config/evaluate/config_crps_spread.yml @@ -0,0 +1,50 @@ +evaluation: + metrics: ["rmse", "spread"] + regions: ["global"] + summary_plots: true + plot_score_maps: true + plot_ensemble: "members" # visualise individual members + summary_dir: "./plots/crps_spread_v3/" + +default_streams: + ERA5: + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850"] # a few key variables + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" # required for all probabilistic metrics + plotting: + sample: [0,1,2,3] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + ensemble: "std" #supported: "all", "mean", "std", [0,1,2] + plot_maps: true + plot_histograms: true + plot_animations: true + +run_ids: + xf74qbhg: + label: "CRPS_s0.05_learnable_5members" + bg93vcya: + label: "CRPS_s0.01_learnable_5members" + b7p4dbak: + label: "CRPS_s0.1_learnable_5members" + l6ji4cvl: + label: "CRPS_s0.1_learnable_5members_10_more_epochs" + xhl8j3wa: + label: "CRPS_s0.05_learnable_2members" + g2amz7ou: + label: "CRPS_s0.05_learnable_10members" + a9wa1rvq: + label: "CRPS_s0.01_fixed_5members" + ne54k2fa: + label: "CRPS_s0.05_fixed_5members" + g9z7j8ei: + label: "CRPS_s0.1_fixed_5members" + ra7qm4bc: + label: "CRPS_s0.2_fixed_5members" + ha23u84m: + label: "CRPS_s0.5_fixed_5members" + si91kfmq: + label: "CRPS_s0.05_learnable_5members_sigma+lr_1e-2" + o5scxlnr: + label: "CRPS_s0.5_learnable_5members_sigma+lr_1e-2" \ No newline at end of file diff --git a/config/evaluate/eval_cerra_europe_config.yml b/config/evaluate/eval_cerra_europe_config.yml new file mode 100644 index 000000000..466d3738d --- /dev/null +++ b/config/evaluate/eval_cerra_europe_config.yml @@ -0,0 +1,47 @@ + +global_plotting_options: + regions: ["europe"] + image_format : "png" + dpi_val : 300 + fps: 2 + ERA5: + marker_size: 2 + scale_marker_size: 1 + marker: "o" + +evaluation: + metrics : ["mae"] + regions: ["europe"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots_zfhwspzc_17_4gpu" + plot_ensemble: false #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + num_processes: 0 #options: int, "auto", 0 means no parallelism (default) + +default_streams: + CERRA: + channels: ["2t", "10si", "10wdir", "r_850", "t_850", "u_850", "v_850", "z_500"] #, "blah"] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" + # plotting: + # sample: [0] + # forecast_step: "all" + # ensemble: "mean" + # plot_maps: true + # plot_bias: true + # plot_target: true + # plot_histograms: true + # plot_animations: true + +run_ids: + xxx: + label: "xxx" diff --git a/config/inference/inference_era5_1deg.yml b/config/inference/inference_era5_1deg.yml new file mode 100644 index 000000000..641130fe2 --- /dev/null +++ b/config/inference/inference_era5_1deg.yml @@ -0,0 +1,8 @@ +test_config: + start_date: 2023-10-01T00:00 + output: + num_samples: 50 + samples_per_mini_epoch: 50 + forecast: + num_steps: 2 +streams_directory: ./config/inference/streams/era5_1deg/ \ No newline at end of file diff --git a/config/inference/inference_era5_o96_cerra.yml b/config/inference/inference_era5_o96_cerra.yml new file mode 100644 index 000000000..6a54dc658 --- /dev/null +++ b/config/inference/inference_era5_o96_cerra.yml @@ -0,0 +1,8 @@ +test_config: + start_date: 2023-10-01T00:00 + output: + num_samples: 16 + samples_per_mini_epoch: 16 + forecast: + num_steps: 9 +streams_directory: ./config/inference/streams/era5_o96_cerra/ \ No newline at end of file diff --git a/config/inference/streams/era5_1deg/era5.yml b/config/inference/streams/era5_1deg/era5.yml new file mode 100644 index 000000000..39f655024 --- /dev/null +++ b/config/inference/streams/era5_1deg/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + frequency : 06:00:00 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/inference/streams/era5_o96_cerra/cerra.yml b/config/inference/streams/era5_o96_cerra/cerra.yml new file mode 100644 index 000000000..ccaf25139 --- /dev/null +++ b/config/inference/streams/era5_o96_cerra/cerra.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +CERRA : + type : anemoi + filenames : ['cerra-rr-an-oper-se-al-ec-mars-5p5km-1985-2023-3h-v2.zarr'] + frequency : 6h + stream_id : 1 + source_exclude : ['skt','tciwv','tp','al','rsn','sde','sf'] + target_exclude : ['tciwv','tp','al','rsn','sde','sf'] + geoinfo_channels : ['orog', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 512 + tokenize_spacetime : True + #max_num_targets: 570000 + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/inference/streams/era5_o96_cerra/era5.yml b/config/inference/streams/era5_o96_cerra/era5.yml new file mode 100644 index 000000000..c6317a2d6 --- /dev/null +++ b/config/inference/streams/era5_o96_cerra/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + geoinfo_channels : ['z', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + #max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 6d6737d44..9e14050e9 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,7 +9,7 @@ ERA5 : type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] stream_id : 0 source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] @@ -18,7 +18,7 @@ ERA5 : location_weight : cosine_latitude token_size : 8 tokenize_spacetime : True - max_num_targets: -1 + max_num_targets: 20000 frequency : 06:00:00 embed : net : transformer diff --git a/config/streams/era5_cerra/cerra.yml b/config/streams/era5_cerra/cerra.yml new file mode 100644 index 000000000..f5b27dd63 --- /dev/null +++ b/config/streams/era5_cerra/cerra.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +CERRA : + type : anemoi + filenames : ['cerra-rr-an-oper-se-al-ec-mars-5p5km-1985-2023-3h-v2.zarr'] + frequency : 6h + stream_id : 1 + source_exclude : ['skt','tciwv','tp','al','rsn','sde','sf'] + target_exclude : ['tciwv','tp','al','rsn','sde','sf'] + geoinfo_channels : ['orog', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 64 + tokenize_spacetime : True + max_num_targets: 570000 + #max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_cerra/era5.yml b/config/streams/era5_cerra/era5.yml new file mode 100644 index 000000000..3f990eb23 --- /dev/null +++ b/config/streams/era5_cerra/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + geoinfo_channels : ['z', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: 20000 + #max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_crps/era5.yml b/config/streams/era5_crps/era5.yml new file mode 100644 index 000000000..31895423b --- /dev/null +++ b/config/streams/era5_crps/era5.yml @@ -0,0 +1,38 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: 20000 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/era5_n320_cerra/cerra.yml b/config/streams/era5_n320_cerra/cerra.yml new file mode 100644 index 000000000..a0e888ea7 --- /dev/null +++ b/config/streams/era5_n320_cerra/cerra.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +CERRA : + type : anemoi + filenames : ['cerra-rr-an-oper-se-al-ec-mars-5p5km-1985-2023-3h-v2.zarr'] + frequency : 6h + stream_id : 1 + source_exclude : ['skt','tciwv','tp','al','rsn','sde','sf'] + target_exclude : ['tciwv','tp','al','rsn','sde','sf'] + geoinfo_channels : ['orog', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 64 + tokenize_spacetime : True + max_num_targets: 228000 + #max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/era5_n320_cerra/era5.yml b/config/streams/era5_n320_cerra/era5.yml new file mode 100644 index 000000000..580173b04 --- /dev/null +++ b/config/streams/era5_n320_cerra/era5.yml @@ -0,0 +1,40 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-n320-1979-2023-6h-v8.zarr'] + stream_id : 0 + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + geoinfo_channels : ['orog', 'lsm'] + loss_weight : 1. + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: 108000 + #max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/evaluation_slurm.sh b/evaluation_slurm.sh new file mode 100755 index 000000000..263324806 --- /dev/null +++ b/evaluation_slurm.sh @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --job-name=wg-eval +#SBATCH --nodes=1 +#SBATCH --mem=368G +#SBATCH --cpus-per-task=8 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=0 +#SBATCH --time=11:59:59 +#SBATCH --output=logs/%x.%j.out +#SBATCH --error=logs/%x.%j.err +#SBATCH --switches=1 + +module load gcc/12.2.0 +module load cuda/12.2 + +if [ $# -lt 3 ]; then + echo "Usage: $0