diff --git a/config/config_forecasting.yml b/config/config_forecasting.yml index 14a9667a9..eb8b0f59d 100644 --- a/config/config_forecasting.yml +++ b/config/config_forecasting.yml @@ -66,11 +66,7 @@ forecast_att_dense_rate: 1.0 healpix_level: 5 -# Generalized RoPE selector. -rope_mode: none # one of: none, 2d, spherical -# Optional spherical harmonic band for spherical RoPE. If null, the model picks one -# conservative shared band that fits all spherical-RoPE attention modules. -rope_spherical_band: null +rope_2D: False with_mixed_precision: True with_flash_attention: True diff --git a/config/config_forecasting_eerie.yml b/config/config_forecasting_eerie.yml index 75215ee0c..53466ad7d 100644 --- a/config/config_forecasting_eerie.yml +++ b/config/config_forecasting_eerie.yml @@ -66,6 +66,8 @@ forecast_att_dense_rate: 1.0 healpix_level: 5 +rope_2D: False + with_mixed_precision: True with_flash_attention: True compile_model: False diff --git a/config/config_jepa.yml b/config/config_jepa.yml index 0e2ebe8f2..a4f7c8aed 100644 --- a/config/config_jepa.yml +++ b/config/config_jepa.yml @@ -66,12 +66,10 @@ forecast_att_dense_rate: 1.0 with_step_conditioning: True # False healpix_level: 5 - -# Generalized RoPE selector. -rope_mode: none # one of: none, 2d, spherical -# Optional spherical harmonic band for spherical RoPE. If null, the model picks one -# conservative shared band that fits all spherical-RoPE attention modules. -rope_spherical_band: null +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False with_mixed_precision: True with_flash_attention: True diff --git a/config/config_operan_georing_avhrr_forecasting_lowres.yml b/config/config_operan_georing_avhrr_forecasting_lowres.yml new file mode 100644 index 000000000..538ecc6ce --- /dev/null +++ b/config/config_operan_georing_avhrr_forecasting_lowres.yml @@ -0,0 +1,270 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 4 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True +ae_local_max_tokens_per_cell: 128 + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 0 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 32 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +# streams_directory: "./config/streams/era5_georing_avhrr/" +streams_directory: "./config/streams/operan_georing_avhrr_synop_lowres/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 500 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 56 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 2016-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 01:00:00 + time_window_len: 06:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "source_masking" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 06:00:00 + offset: 1 + num_steps: 2 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +test_config: + + samples_per_mini_epoch: 128 + shuffle: False + + start_date: 2023-06-01T00:00 + end_date: 2023-08-31T00:00 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 128, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/default_config.yml b/config/default_config.yml index 44887d581..39abe739b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -67,11 +67,10 @@ num_register_tokens: 0 healpix_level: 5 -# Generalized RoPE selector. -rope_mode: none # one of: none, 2d, spherical -# Optional spherical harmonic band for spherical RoPE. If null, the model picks one -# conservative shared band that fits all spherical-RoPE attention modules. -rope_spherical_band: null +# Use 2D RoPE instead of traditional global positional encoding +# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) +# When False: uses traditional pe_global positional encoding +rope_2D: False with_mixed_precision: True with_flash_attention: True diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index a7415052b..79ee4967b 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -5,6 +5,8 @@ # regions: ["europe", "global"] # Have regions here, if you want for them to apply to all streams (map generation) # image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. # animation_format: "gif" #options: "mp4", "gif" +# log_colorbar: true + # dpi_val : 300 # fps: 2 # n_bins: 50 #number of bins for histograms. @@ -36,7 +38,8 @@ evaluation: heat_maps : false summary_dir: "./plots/" plot_ensemble: "members" #supported: false, "std", "minmax", "members" - plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_animations: false #plot animations of score maps across forecast steps. it slows down score computation print_summary: false #print out score values on screen. it can be verbose log_scale: false add_grid: false @@ -99,7 +102,7 @@ run_ids : color: "magenta" #optional: if not specified, the color is automatically assigned by the plotting module results_base_dir : "./results/" epoch: 1 #optional: if not specified epoch 0 (in inference it is always 0) is used - rank: 2 #optional: if not specified rank 0 is used + rank: "all" #optional: int, "all", or list of ints. Default: "all". Use "all" for multi-rank inference. streams: ERA5: channels: ["2t", "10u", "10v"] diff --git a/config/evaluate/eval_config_default.yml b/config/evaluate/eval_config_default.yml index 03293c296..928f4a723 100644 --- a/config/evaluate/eval_config_default.yml +++ b/config/evaluate/eval_config_default.yml @@ -31,7 +31,8 @@ evaluation: heat_maps : false summary_dir: "./plots/" plot_ensemble: "members" #supported: false, "std", "minmax", "members" - plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_animations: false #plot animations of score maps across forecast steps. it slows down score computation print_summary: false #print out score values on screen. it can be verbose log_scale: false add_grid: false diff --git a/config/streams/operan_georing_avhrr_synop_lowres/avhrr.yml b/config/streams/operan_georing_avhrr_synop_lowres/avhrr.yml new file mode 100644 index 000000000..7bba57e91 --- /dev/null +++ b/config/streams/operan_georing_avhrr_synop_lowres/avhrr.yml @@ -0,0 +1,37 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +METOP_ABC_AVHRR_IASI : + type : obs + stream_id : 20 + filenames : ['observations-ea-ofb-0001-2007-2021-metop-a-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2013-2023-metop-b-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2019-2023-metop-c-iasi-radiances-v1.zarr'] + geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1.0 + masking_override : + model_input : + masking_strategy_config : + rate: 1.0 + token_size : 512 # 256 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + diff --git a/config/streams/operan_georing_avhrr_synop_lowres/era5.yml b/config/streams/operan_georing_avhrr_synop_lowres/era5.yml new file mode 100644 index 000000000..8739bf0cd --- /dev/null +++ b/config/streams/operan_georing_avhrr_synop_lowres/era5.yml @@ -0,0 +1,48 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5_in : + type : anemoi_operan + # filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] + filenames : ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] + stream_id : 0 + source: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] + target: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] + geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1. + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy_config : + rate: 0.1 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + forcing: True + frequency : 06:00:00 + nominal_time_mapping : + "0" : 5 # 04:30:00 + "6" : 9 # 09:00:00 + "12" : 17 #16:30:00 + "18" : 21 #21:00:00 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/operan_georing_avhrr_synop_lowres/era5_out.yml b/config/streams/operan_georing_avhrr_synop_lowres/era5_out.yml new file mode 100644 index 000000000..9a30b289c --- /dev/null +++ b/config/streams/operan_georing_avhrr_synop_lowres/era5_out.yml @@ -0,0 +1,41 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] + stream_id : 1 + source_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'skt', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + target_exclude : ['z', 'w_10', 'w_50', 'w_100', 'w_150', 'w_200', 'w_250', 'w_300', 'w_400', 'w_500', 'w_600', 'w_700', 'w_850', 'w_925', 'w_1000', 'slor', 'sdor', 'tcw', 'cp', 'tp', 'q_50', 'q_100'] + geoinfo_channels : ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1. + location_weight : cosine_latitude + masking_override : + target_input : + masking_strategy_config : + rate: 1.0 + token_size : 8 + tokenize_spacetime : False + max_num_targets: -1 + diagnostic: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/operan_georing_avhrr_synop_lowres/geos.yml b/config/streams/operan_georing_avhrr_synop_lowres/geos.yml new file mode 100644 index 000000000..87c504574 --- /dev/null +++ b/config/streams/operan_georing_avhrr_synop_lowres/geos.yml @@ -0,0 +1,174 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +METEOSAT_SEVIRI_IR : + type : obs + stream_id : 10 + filenames : ['observations-file-2014-2024-seviri-o256-wegen-v3.zarr'] + geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza'] + loss_weight : 1.0 + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 1.0 + token_size : 1024 + tokenize_spacetime : False + max_num_targets: 131072 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +GOES_ABI_IR : + type : obs + stream_id : 11 + filenames : ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] + geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day','zenith', 'cos_sza'] + loss_weight : 1.0 + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 1.0 + token_size : 1024 + tokenize_spacetime : False + max_num_targets: 131072 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +HIMAWARI_AHI_IR : + type : obs + stream_id : 12 + filenames : ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr', 'observations-file-2022-2024-himawari9-IR-o256-v1.zarr'] + geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza'] + loss_weight : 1.0 + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 1.0 + token_size : 1024 + tokenize_spacetime : False + max_num_targets: 131072 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +GOES_ABI_VIS : + type : obs + stream_id : 13 + filenames : ['observations-file-2017-2024-abi-goes16-VIS-o256-v2.zarr'] + # geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza'] + geoinfo_channels : ['zenith'] + loss_weight : 1.0 + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 1.0 + token_size : 1024 + tokenize_spacetime : False + max_num_targets: 131072 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 + + +HIMAWARI_AHI_VIS : + type : obs + stream_id : 14 + filenames : ['observations-file-2015-2022-himawari8-VIS-o256-v1.zarr', 'observations-file-2022-2024-himawari9-VIS-o256-v1.zarr'] + # geoinfo_channels : ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza'] + geoinfo_channels : ['zenith'] + loss_weight : 1.0 + location_weight : cosine_latitude + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 1.0 + token_size : 1024 + tokenize_spacetime : False + max_num_targets: 131072 + forcing: True + embed : + net : transformer + num_tokens : 1 + num_heads : 4 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/operan_georing_avhrr_synop_lowres/synop.yml b/config/streams/operan_georing_avhrr_synop_lowres/synop.yml new file mode 100644 index 000000000..df796579b --- /dev/null +++ b/config/streams/operan_georing_avhrr_synop_lowres/synop.yml @@ -0,0 +1,43 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +SurfaceCombined : + type : obs + stream_id : 30 + filenames : ['observations-ea-ofb-0001-1979-2025-combined-surface-v5.zarr'] + # filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr'] + # filenames : ['observations-ea-ofb-0001-1979-2022-combined-surface-v3-fixed-land-spatial80-lsm09-min10.zarr', 'observations-ea-ofb-0001-2023-combined-surface-v3-fixed-land-heldout20-lsm09-min10.zarr'] + source: ['obsvalue_tsts_0', 'obsvalue_t2m_0', 'obsvalue_td2m_0', 'obsvalue_u10m_0', 'obsvalue_v10m_0', 'obsvalue_pmsl_0', 'obsvalue_ps_0'] + target: ['obsvalue_tsts_0', 'obsvalue_t2m_0', 'obsvalue_td2m_0', 'obsvalue_u10m_0', 'obsvalue_v10m_0', 'obsvalue_pmsl_0', 'obsvalue_ps_0'] + geoinfo_channels : ['stalt', 'lsm', 'cos_sza', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + loss_weight : 1.0 + token_size : 64 + masking_override : + model_input : + masking_strategy: "random" + masking_strategy_config : + rate: 0.1 + tokenize_spacetime : False + max_num_targets: -1 + # diagnostic: True + embed : + net : transformer + num_tokens : 1 + num_heads : 2 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index b791ec13c..789ae018a 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -185,7 +185,7 @@ def format_cf(config: Config) -> str: for key, value in clean_cf.items(): match key: case "streams": - for rt in value: + for rt in value.values(): for k, v in rt.items(): whitespace = "" if k == "reportypes" else " " stream.write(f"{whitespace}{k} : {v}") @@ -312,6 +312,7 @@ def _apply_fixes(config: Config) -> Config: """ config = _check_time_interpolation(config) config = _check_datasets(config) + config = _check_streams(config) return config @@ -366,6 +367,18 @@ def _convert_interpolation(cfg, key): return config +def _check_streams(config: Config) -> Config: + """Convert streams stored as list to dict/DictConfig.""" + config = config.copy() + stream_conf = config.get("streams") + assert stream_conf + if isinstance(stream_conf, list | ListConfig): + stream_conf = OmegaConf.create({conf["name"]: conf for conf in stream_conf}) + + config["streams"] = stream_conf + return config + + def merge_configs(base_config: Config, update_config: Config): """ Merge two configs using OmegaConf's default strategy @@ -597,7 +610,7 @@ def _load_base_conf(base: Path | Config | None) -> Config: return conf -def load_streams(streams_directory: Path) -> list[Config]: +def load_streams(streams_directory: Path) -> Config: """Load all stream configurations from a directory.""" # TODO: might want to put this into config later instead of hardcoding it here... streams_history = { @@ -636,10 +649,7 @@ def load_streams(streams_directory: Path) -> list[Config]: try: config = OmegaConf.load(config_file) for stream_name, stream_config in config.items(): - # Stream config schema is {stream_name: stream_config} - # where stream_config itself is a dict containing the actual options. - # stream_name needs to be added to this dict since only stream_config - # will be further processed. + # include key in value to have bidirectional key <-> value mapping stream_config.name = stream_name if stream_name in streams: msg = f"Duplicate stream name found: {stream_name}." @@ -664,7 +674,7 @@ def load_streams(streams_directory: Path) -> list[Config]: if stream.get("frequency", None) is not None: stream = _patch_time("frequency", stream, _TIMEDELTA_TYPE_NAME) - return list(streams.values()) + return OmegaConf.create(streams) def get_path_run(config: Config) -> Path: diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 5d493370c..407239dbb 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -358,9 +358,9 @@ def __init__( def _append_dataset(self, dataset: OutputDataset | None, name: str) -> None: if dataset: self.datasets.append(dataset) - else: - msg = f"Missing {name} dataset for item: {self.key.path}" - raise ValueError(msg) + # else: + # msg = f"Missing {name} dataset for item: {self.key.path}" + # raise ValueError(msg) class ZarrIO: @@ -736,11 +736,13 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi def _extract_sources( self, sample: int, stream_idx: int, key: ItemKey, source_interval: TimeRange - ) -> OutputDataset: + ) -> OutputDataset | None: channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] source: IOReaderData = self.sources[sample][stream_idx] + if source is None: + return None assert source.data.shape[1] == len(channels), ( f"Number of source channel names {len(channels)} does not align with source data." diff --git a/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py b/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py index 40a317024..22a0320e9 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py @@ -77,6 +77,7 @@ class IOState: lon: NDArray n_workers: int backend: str = "loky" + rank: str = "0000" # --------------------------------------------------------------------------- @@ -247,6 +248,7 @@ def _build_io_state( ensemble: list[str], n_io_workers: int, ens_select: EnsembleSelect, + rank: str = "", ) -> IOState: """Resolve all I/O parameters that are shared between the two impl paths.""" zarr_path = str(fname_zarr) @@ -283,6 +285,7 @@ def _build_io_state( lat=lat, lon=lon, n_workers=n_io_workers, + rank=rank, ) @@ -459,7 +462,8 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: ``n_samples × 1 × n_ipoints × n_channels × 4 bytes``. """ _logger.info( - f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " + f"Loading {len(state.samples)} samples × " f"{len(state.fsteps)} fsteps via zarr I/O " f"(workers={state.n_workers}, backend={state.backend})..." ) @@ -472,7 +476,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: for fi, fs in enumerate(state.fsteps): _logger.info( - f"RUN {state.run_id} - {state.stream}: " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " f"Reading fstep {fs} ({fi + 1}/{len(state.fsteps)})..." ) @@ -487,7 +491,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: is_gridded=state.is_gridded, n_workers=n_workers, backend=state.backend, - label=f"RUN {state.run_id} - {state.stream} fstep {fs}", + label=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} fstep {fs}", ) # If _parallel_read fell back to sequential, honour that for the rest if fell_back: @@ -525,7 +529,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: get_reusable_executor().shutdown(wait=True) _logger.info( - f"RUN {state.run_id} - {state.stream}: I/O complete. " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: I/O complete. " f"{len(da_tars_dict)} forecast entries loaded." ) return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict) @@ -546,7 +550,8 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: """ n_total = len(state.samples) * len(state.fsteps) _logger.info( - f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " + f"Loading {len(state.samples)} samples × " f"{len(state.fsteps)} fsteps = {n_total} items via ZipStore-parallel zarr I/O " f"(workers={state.n_workers}, backend={state.backend})..." ) @@ -569,7 +574,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: calls, n_workers=state.n_workers, backend=state.backend, - desc=f"RUN {state.run_id} - {state.stream} (ZipStore)", + desc=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} (ZipStore)", verbose=5, ) @@ -633,7 +638,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: get_reusable_executor().shutdown(wait=True) _logger.info( - f"RUN {state.run_id} - {state.stream}: ZipStore-parallel I/O complete. " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: ZipStore-parallel I/O complete. " f"{len(da_tars_dict)} forecast entries loaded." ) return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 56b852d3a..fc089a910 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -131,14 +131,16 @@ def get_climatology_filename(self, stream: str) -> str | None: ) return None - clim_fn = next( - ( - item.get("filenames") - for item in self.inference_cfg.get("streams", []) - if item.get("name") == stream - ), - None, - ) + streams = self.inference_cfg.get("streams", {}) + if isinstance(streams, list | oc.ListConfig): + streams = {s["name"]: s for s in streams} + streams = oc.OmegaConf.create(streams) + + try: + clim_fn = streams[stream].get("filenames") + except KeyError: + clim_fn = None + if isinstance(clim_fn, oc.ListConfig) and len(clim_fn) == 1: climatology_partial_filename = clim_fn[0] else: @@ -299,9 +301,15 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): ------------ The parameter value if found, otherwise the default. """ - for stream in self.inference_cfg.get("streams", []): - if stream.get("name") == stream_name: - return stream.get(key, default) + + streams = self.inference_cfg.get("streams", {}) + if isinstance(streams, list | oc.ListConfig): + for stream in streams: + if stream.get("name") == stream_name: + return stream.get(key, default) + else: + return streams.get(stream_name, {}).get(key, default) + return default @@ -373,37 +381,163 @@ def get_recomputable_metrics(self, metrics): class WeatherGenZarrReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" + """Data reader class for WeatherGenerator model outputs stored in Zarr format. + + Supports multi-rank inference outputs where each rank file contains a disjoint + subset of forecast initializations with overlapping local sample indices. + """ super().__init__(eval_cfg, run_id, private_paths) zarr_ext = self.inference_cfg.get("zarr_store", "zarr") - # For backwards compatibility, assume zarr store is local (.zarr format). + self.zarr_ext = zarr_ext - fname_zarr = self.results_dir.joinpath( - f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.{zarr_ext}" - ) - - assert fname_zarr.exists(), f"Zarr file {fname_zarr} does not exist." + # Discover rank files: support rank="all", rank=[0,1,2], or rank=0 (int) + self.rank_files: list[Path] = self._discover_rank_files() - assert (zarr_ext == "zarr" and fname_zarr.is_dir()) or ( - zarr_ext == "zip" and fname_zarr.is_file() - ), ( - f"Zarr file {fname_zarr} has unexpected format. ({zarr_ext}). " - f"Expected directory for 'zarr' or file for 'zip'." - ) - self.fname_zarr = fname_zarr + # Validate metadata consistency across all ranks (fail-fast) + self._validated_metadata: dict = self._validate_rank_metadata() # Metadata caches — populated lazily on first access self._cached_samples: set[int] | None = None - self._cached_fsteps: set[int] | None = None - self._cached_streams: set[str] | None = None self._cached_ensemble: dict[str, list[str]] = {} self._cached_is_gridded: dict[str, bool] = {} + self._rank_sample_map: dict[Path, tuple[list[int], int]] | None = None # Raw I/O worker config (direct zarr access) self._max_workers: int | None = eval_cfg.get("max_workers") self._num_io_workers: int = get_num_workers(max_workers=self._max_workers) + def _discover_rank_files(self) -> list[Path]: + """Discover zarr rank files based on the ``rank`` config parameter. + + Supports: + - ``rank: 0`` (int) — single specific rank (backward compatible) + - ``rank: "all"`` — glob all matching rank files + - ``rank: [0, 1, 2]`` — specific list of ranks + """ + rank_cfg = self.eval_cfg.get("rank", self.rank) + + if isinstance(rank_cfg, int): + # Single rank (backward compatible) + fname = self.results_dir / ( + f"validation_chkpt{self.mini_epoch:05d}_rank{rank_cfg:04d}.{self.zarr_ext}" + ) + if not fname.exists(): + raise FileNotFoundError(f"Zarr file {fname} does not exist.") + return self._validate_rank_files([fname]) + + elif rank_cfg == "all": + pattern = f"validation_chkpt{self.mini_epoch:05d}_rank*.{self.zarr_ext}" + files = sorted(self.results_dir.glob(pattern)) + if not files: + raise FileNotFoundError(f"No zarr files matching {pattern} in {self.results_dir}") + _logger.info(f"Discovered {len(files)} rank file(s) for run {self.run_id}.") + return self._validate_rank_files(files) + + elif isinstance(rank_cfg, list | tuple): + files = [] + for r in rank_cfg: + fname = self.results_dir / ( + f"validation_chkpt{self.mini_epoch:05d}_rank{int(r):04d}.{self.zarr_ext}" + ) + if not fname.exists(): + raise FileNotFoundError(f"Zarr file {fname} does not exist.") + files.append(fname) + return self._validate_rank_files(sorted(files)) + + else: + raise ValueError( + f"Invalid rank config: {rank_cfg!r}. Use an int, 'all', or a list of ints." + ) + + def _validate_rank_files(self, files: list[Path]) -> list[Path]: + """Validate that rank files have the expected format.""" + for f in files: + is_valid = (self.zarr_ext == "zarr" and f.is_dir()) or ( + self.zarr_ext == "zip" and f.is_file() + ) + if not is_valid: + raise FileNotFoundError( + f"Zarr file {f} has unexpected format ({self.zarr_ext}). " + f"Expected directory for 'zarr' or file for 'zip'." + ) + return files + + def _validate_rank_metadata(self) -> dict: + """Validate that all rank files share identical metadata (streams, fsteps). + + Returns a dict with the validated common metadata. + Raises ValueError if any inconsistency is detected. + """ + reference_streams: set[str] | None = None + reference_fsteps: set[int] | None = None + + for rank_file in self.rank_files: + with zarrio_reader(rank_file) as zio: + streams = set(zio.streams) + fsteps = set(int(f) for f in zio.forecast_steps) + + if reference_streams is None: + reference_streams = streams + reference_fsteps = fsteps + else: + if streams != reference_streams: + raise ValueError( + f"Stream mismatch: {rank_file.name} has {streams}, " + f"expected {reference_streams}" + ) + if fsteps != reference_fsteps: + raise ValueError( + f"Forecast step mismatch: {rank_file.name} has {fsteps}, " + f"expected {reference_fsteps}" + ) + + return { + "streams": reference_streams or set(), + "forecast_steps": reference_fsteps or set(), + } + + def _open_any_rank_for_metadata(self): + """Open a rank file for metadata queries. Tries each rank until one succeeds. + + Returns a context-manager (zarrio_reader) that the caller must use in a + ``with`` statement or close manually. + """ + for rank_file in self.rank_files: + try: + return zarrio_reader(rank_file) + except Exception: + _logger.warning(f"Failed to open {rank_file.name} for metadata, trying next...") + raise RuntimeError("No rank files could be opened for metadata queries.") + + def _get_rank_sample_map(self) -> dict[Path, tuple[list[int], int]]: + """Build and cache mapping of rank_file → (local_samples, global_offset). + + Since all ranks use local indices (0, 1, ...), we assign global offsets: + rank0: offset=0, rank1: offset=len(rank0_samples), etc. + """ + if self._rank_sample_map is None: + self._rank_sample_map = {} + offset = 0 + for zarr_file in self.rank_files: + with zarrio_reader(zarr_file) as zio: + local = sorted(int(s) for s in zio.samples) + self._rank_sample_map[zarr_file] = (local, offset) + offset += len(local) + return self._rank_sample_map + + def _merge_fsteps(self, all_das: dict, global_sample_coords) -> dict: + """Merge lists of DataArrays for each forecast step across ranks. + Concatenates along the sample dimension and re-indexes to global samples. + """ + merged = {} + for fstep, das in all_das.items(): + combined = xr.concat(das, dim="sample") if len(das) > 1 else das[0] + merged[fstep] = combined.assign_coords( + sample=global_sample_coords[: len(combined.sample)] + ) + return merged + def get_data( self, stream: str, @@ -414,6 +548,9 @@ def get_data( ) -> ReaderOutput: """Load prediction and target data via direct zarr array access. + When multiple rank files are present, loads from each rank sequentially + and concatenates along the sample dimension with re-indexed global samples. + Parameters ---------- stream : str @@ -428,76 +565,126 @@ def get_data( """ resolved_ensemble = to_list(ensemble or self.get_ensemble(stream)) ens_select = EnsembleSelect.from_names(resolved_ensemble, self.get_ensemble(stream)) - state = _build_io_state( - self.run_id, - self.fname_zarr, - stream, - self.get_stream(stream), - self.get_channels(stream), - self.is_gridded_data(stream), - sorted(int(f) for f in (fsteps or self.get_forecast_steps())), - sorted(int(s) for s in (samples or self.get_samples())), - to_list(channels or self.get_stream(stream).get("channels", self.get_channels(stream))), - resolved_ensemble, - self._num_io_workers, - ens_select, + resolved_fsteps = sorted(int(f) for f in (fsteps or self.get_forecast_steps())) + resolved_channels = to_list( + channels or self.get_stream(stream).get("channels", self.get_channels(stream)) ) - get_data = get_data_zipstore if state.is_zip else get_data_dirstore - return get_data(state) - def get_stream(self, stream: str): - """ - returns the dictionary associated to a particular stream. - Returns an empty dictionary if the stream does not exist in the Zarr file. + rank_sample_map = self._get_rank_sample_map() - Parameters - ---------- - stream: - the stream name + # Determine which ranks to load based on requested global samples + requested_globals = set(int(s) for s in (samples or self.get_samples())) - Returns - ------- - The config dictionary associated to that stream - """ - if self._cached_streams is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_streams = set(zio.streams) + all_targets: dict[int, list[xr.DataArray]] = {} + all_predictions: dict[int, list[xr.DataArray]] = {} + ranks_loaded = 0 + + for rank_file in self.rank_files: + local_samples, global_offset = rank_sample_map[rank_file] + + # Check if any of this rank's global samples are requested + rank_globals = set(range(global_offset, global_offset + len(local_samples))) + if not rank_globals & requested_globals: + continue + + # Map requested global indices back to local indices for this rank + rank_local_to_load = [ + local_samples[g - global_offset] for g in sorted(rank_globals & requested_globals) + ] + + _logger.info( + f"RUN {self.run_id} [rank {rank_file.stem.split('rank')[-1]}]: " + f"Loading {len(rank_local_to_load)} samples" + ) + _logger.debug( + f"RUN {self.run_id} [rank {rank_file.stem.split('rank')[-1]}]: " + f"local indices {rank_local_to_load}, " + f"global samples {sorted(rank_globals & requested_globals)}" + ) + + state = _build_io_state( + self.run_id, + rank_file, + stream, + self.get_stream(stream), + self.get_channels(stream), + self.is_gridded_data(stream), + resolved_fsteps, + rank_local_to_load, + resolved_channels, + resolved_ensemble, + self._num_io_workers, + ens_select, + rank=rank_file.stem.split("rank")[-1], + ) + get_data_fn = get_data_zipstore if state.is_zip else get_data_dirstore + result = get_data_fn(state) + + for fstep, da in result.target.items(): + all_targets.setdefault(fstep, []).append(da) + for fstep, da in result.prediction.items(): + all_predictions.setdefault(fstep, []).append(da) + ranks_loaded += 1 - if stream in self._cached_streams: + # Concatenate across ranks along sample dimension and re-index + global_sample_coords = np.array(sorted(requested_globals)) + + merged_targets = self._merge_fsteps(all_targets, global_sample_coords) + merged_predictions = self._merge_fsteps(all_predictions, global_sample_coords) + + ranks_skipped = len(self.rank_files) - ranks_loaded + _logger.info( + f"RUN {self.run_id}: Multi-rank load complete. " + f"{len(global_sample_coords)} samples × {len(merged_targets)} fsteps " + f"from {ranks_loaded}/{len(self.rank_files)} ranks " + f"({ranks_skipped} skipped)." + ) + return ReaderOutput(target=merged_targets, prediction=merged_predictions) + + def get_stream(self, stream: str): + """Return the config dictionary for a particular stream. + + Returns an empty dictionary if the stream does not exist in the Zarr files. + """ + if stream in self._validated_metadata["streams"]: return self.eval_cfg.streams.get(stream, {}) return {} def get_samples(self) -> set[int]: - """Get the set of sample indices from the Zarr file.""" + """Get global sample indices across all rank files. + + Assigns contiguous global indices: rank0 gets 0..N0-1, rank1 gets N0..N0+N1-1, etc. + """ if self._cached_samples is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_samples = set(int(s) for s in zio.samples) + rank_sample_map = self._get_rank_sample_map() + all_samples: set[int] = set() + for local_samples, offset in rank_sample_map.values(): + all_samples.update(range(offset, offset + len(local_samples))) + self._cached_samples = all_samples return self._cached_samples def get_forecast_steps(self) -> set[int]: - """Get the set of forecast steps from the Zarr file.""" - if self._cached_fsteps is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_fsteps = set(int(f) for f in zio.forecast_steps) - return self._cached_fsteps + """Get the set of forecast steps (validated across all ranks at init).""" + return self._validated_metadata["forecast_steps"] def get_forecast_substep_valid_times(self, stream: str) -> set[str]: - """Get the set of forecast times from the Zarr file.""" + """Get the set of forecast times from a rank file.""" if not self.is_gridded_data(stream): _logger.warning(f"Stream {stream} is not gridded. Forecast times cannot be retrieved.") return set() - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) unique_lead = np.unique(dummy.valid_time.data) return set(str(lt) for lt in unique_lead) def get_ensemble(self, stream: str | None = None) -> list[str]: - """Get the list of ensemble member names for a given stream from the config. + """Get the list of ensemble member names for a given stream. + Parameters ---------- stream : - The name of the stream to get channels for. + The name of the stream to get ensemble members for. Returns ------- @@ -506,18 +693,18 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: _logger.debug(f"Getting ensembles for stream {stream}...") if stream not in self._cached_ensemble: - # TODO: improve this to get ensemble from io class - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) self._cached_ensemble[stream] = list(dummy.prediction.as_xarray().coords["ens"].values) return self._cached_ensemble[stream] def is_gridded_data(self, stream: str) -> bool: - """Check if the latitude and longitude coordinates are regularly spaced for a given stream. + """Check if lat/lon coordinates are regularly spaced for a given stream. + Parameters ---------- stream : - The name of the stream to get channels for. + The name of the stream to check. Returns ------- @@ -531,8 +718,8 @@ def _compute_is_gridded(self, stream: str) -> bool: """is_gridded_data logic, called once per stream and cached.""" _logger.debug(f"Checking regular spacing for stream {stream}...") - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) sample_idx = zio.samples[1] if len(zio.samples) > 1 else zio.samples[0] fstep_idx = ( diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index f129c34eb..e776ec3b2 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -25,6 +25,7 @@ from weathergen.evaluate.io.io_reader import Reader, ReaderOutput from weathergen.evaluate.plotting.bar_plots import BarPlots from weathergen.evaluate.plotting.line_plots import LinePlots +from weathergen.evaluate.plotting.plot_orchestration_utils import _compute_ranges, _compute_scores from weathergen.evaluate.plotting.plot_utils import ( bar_plot_metric_region, heat_maps_metric_region, @@ -36,11 +37,8 @@ from weathergen.evaluate.plotting.plotter import Plotter from weathergen.evaluate.plotting.quantile_plots import QuantilePlots from weathergen.evaluate.plotting.score_cards import ScoreCards -from weathergen.evaluate.scores.score import VerifiedData, get_score -from weathergen.evaluate.scores.score_orchestration import get_next_fstep_data from weathergen.evaluate.utils.array_utils import bias_ranges, common_ranges from weathergen.evaluate.utils.clim_utils import get_climatology -from weathergen.evaluate.utils.regions import RegionBoundingBox _logger = logging.getLogger(__name__) @@ -49,12 +47,14 @@ # --------------------------------------------------------------------------- -def plot_score_maps_per_stream( +def run_score_map_pipeline( reader: Reader, stream: str, regions: list[str], metrics_dict: dict, output_data: "ReaderOutput | None" = None, + global_plotting_options: dict | None = None, + plot_score_animations: bool = False, ) -> None: """Plot spatial score maps for all regions and forecast steps. @@ -70,6 +70,10 @@ def plot_score_maps_per_stream( Dictionary mapping region names to metric dicts. output_data : ReaderOutput | None Pre-loaded data; when provided ``reader.get_data()`` is skipped. + global_plotting_options : dict | None + Global plotting options. These can be passed to the plotter and can be used to set options. + plot_score_animations : bool + Whether to build animations of score maps across forecast steps. """ if not reader.is_gridded_data(stream): _logger.debug(f"RUN {reader.run_id} - {stream}: Skipping score maps (non-gridded data).") @@ -105,38 +109,42 @@ def plot_score_maps_per_stream( max_workers=reader.eval_cfg.get("max_workers", None), ) - cfg = reader.global_plotting_options + cfg = global_plotting_options plotter_cfg = { "image_format": cfg.get("image_format", "png"), "dpi_val": cfg.get("dpi_val", 300), "fig_size": cfg.get("fig_size", None), + "animation_format": cfg.get("animation_format", "gif"), + "fps": cfg.get("fps", 2), + "log_colorbar": cfg.get("log_colorbar", False), } output_basedir = str(reader.runplot_dir) run_id = reader.run_id + _computed, raw_results = _compute_scores( + regions, + metrics_dict, + fsteps, + da_preds, + da_tars, + aligned_clim_data, + n_workers=n_plot_workers, + ) + + score_ranges_dict = _compute_ranges(raw_results) + fstep_tasks: list[dict] = [] for region in regions: - bbox = RegionBoundingBox.from_region_name(region) - metrics = metrics_dict[region] for fstep in fsteps: - tars_fs = da_tars[fstep] - preds_fs = da_preds[fstep] - preds_next, tars_next = get_next_fstep_data(fstep, da_preds, da_tars, fsteps) - climatology = aligned_clim_data[fstep] if aligned_clim_data else None - tars_r, preds_r, tars_next_r, preds_next_r = [ - bbox.apply_mask(x) if x is not None else None - for x in (tars_fs, preds_fs, tars_next, preds_next) - ] - score_data = VerifiedData(preds_r, tars_r, preds_next_r, tars_next_r, climatology) fstep_tasks.append( { "plotter_cfg": plotter_cfg, + "score_ranges_dict": score_ranges_dict, "output_basedir": output_basedir, "map_dir": str(map_dir), "stream": stream, "region": region, - "score_data": score_data, - "metrics": dict(metrics), + "computed": _computed[(region, fstep)], "fstep": fstep, "run_id": run_id, } @@ -151,28 +159,36 @@ def plot_score_maps_per_stream( calls = [delayed(_plot_score_maps_per_stream)(**t) for t in fstep_tasks] dispatch_parallel(calls, n_workers=n_plot_workers, backend="loky", desc=f"Score maps {stream}") + # Derive variables and ens_values from the computed results. + if plot_score_animations: + _dispatch_score_map_animations( + map_dir=map_dir, + plotter_cfg=plotter_cfg, + run_id=run_id, + stream=stream, + metrics=list(dict.fromkeys(m for metrics in metrics_dict.values() for m in metrics)), + regions=regions, + variables=channels, + ens_values=list(ensemble) if ensemble else [None], + fsteps=fsteps, + n_workers=n_plot_workers, + ) + def _plot_score_maps_per_stream( plotter_cfg: dict, + score_ranges_dict: dict, output_basedir: str, map_dir: str, stream: str, region: str, - score_data: "VerifiedData", - metrics: dict[str, object], + computed: tuple[list, xr.DataArray, list[str]], fstep: int, run_id: str = "", ) -> None: """Plot 2D score maps for all metrics/channels for one (region, fstep).""" - preds = score_data.prediction - - metric_names = list(metrics.keys()) - metric_params = list(metrics.values()) - score_results: list[xr.DataArray | None] = [ - get_score(score_data, m, agg_dims="sample", parameters=p) - for m, p in zip(metric_names, metric_params, strict=False) - ] + score_results, preds, metric_names = computed valid = [(m, r) for m, r in zip(metric_names, score_results, strict=False) if r is not None] if not valid: return @@ -198,9 +214,7 @@ def _plot_score_maps_per_stream( plot_tasks: list[dict] = [] for metric in plot_metrics.coords["metric"].values: for ens_val in ens_values: - tag = f"score_maps_{metric}_fstep_{fstep}" + ( - f"_ens_{ens_val}" if ens_val is not None else "" - ) + tag = "score_maps" + (f"_ens_{ens_val}" if ens_val is not None else "") + f"_{metric}" for channel in plot_metrics.coords["channel"].values: sel = {"metric": metric, "channel": channel} if ens_val is not None: @@ -209,15 +223,18 @@ def _plot_score_maps_per_stream( title = f"{metric} - {channel}: fstep {fstep}" + ( f", ens {ens_val}" if ens_val is not None else "" ) + scores_cfg = score_ranges_dict.get(metric, {}).get(region, {}).get(channel, {}) plot_tasks.append( { "plotter_cfg": plotter_cfg, + "scores_cfg": scores_cfg, "output_basedir": output_basedir, "stream": stream, "data": data, "map_dir": str(map_dir), "channel": str(channel), "region": region, + "fstep": fstep, "tag": tag, "title": title, } @@ -229,19 +246,24 @@ def _plot_score_maps_per_stream( def _scatter_plot_single( plotter_cfg: dict, + scores_cfg: dict, output_basedir: str, stream: str, data: xr.DataArray, map_dir: str, channel: str, region: str, + fstep: int, tag: str, title: str, ) -> None: """Plot a single score-map scatter plot (picklable for loky workers).""" matplotlib.use("Agg") plotter = Plotter(plotter_cfg, Path(output_basedir), stream) - plotter.scatter_plot(data, Path(map_dir), channel, region, tag=tag, title=title) + plotter.update_data_selection({"sample": None, "stream": stream, "forecast_step": fstep}) + plotter.scatter_plot( + data, Path(map_dir), channel, region, tag=tag, map_kwargs=scores_cfg, title=title + ) # --------------------------------------------------------------------------- @@ -256,34 +278,40 @@ def _build_single_animation( stream: str, region: str | None, var: str, - sa: object, + sample: object, fsteps: list, image_format: str, animation_format: str, duration_ms: int, prefix: str = "map", ) -> list[str]: - """Build one GIF for a single (region, sample, variable) combination. + """Build one animation for a single (region, sample/ens, variable) combination. All work is I/O + Pillow — no matplotlib state involved. - Returns the list of source frame paths that were assembled into the GIF - (empty list if no frames were found). + The function scans ``output_dir`` for per-sample map/histogram frames whose filenames follow: + + {prefix}_{run_id}_{tag}_{sample}_{valid_time}_{stream}_{region}_{var}_{fstep:03d} + + When ``score_animation=True`` filenames are constructed deterministically because + the fstep is embedded in the tag (``score_maps_{metric}_fstep_{N}``) rather + than being a zero-padded suffix. Pass ``tag="score_maps_{metric}"`` and + ``sample`` as the ensemble value (or ``None`` for no ensemble). + + Returns the list of source frame paths assembled into the animation, or an + empty list when no (or fewer than two for score maps) frames were found. """ - # Both map and histogram filenames follow the same pattern: - # {prefix}_{run_id}_{tag}_{sample}_{valid_time}_{stream}_{region}_{var}_{fstep:03d} - # For all_samples histograms, valid_time is omitted. - # We match files by checking a fixed prefix and suffix, allowing any - # valid_time (or none) in between — no glob wildcards needed. + if not output_dir.is_dir(): + return [] + region_part = region if region else "" - head = "_".join(filter(None, [prefix, run_id, tag, str(sa)])) + if sample is not None: + head = "_".join(filter(None, [prefix, run_id, tag, str(sample)])) + else: + head = "_".join(filter(None, [prefix, run_id, tag])) tail = "_".join(filter(None, [stream, region_part, var])) suffix = f".{image_format}" fstep_strs = {str(f).zfill(3) for f in fsteps} - - if not output_dir.is_dir(): - return [] - image_paths = sorted( str(f) for f in output_dir.iterdir() @@ -292,11 +320,12 @@ def _build_single_animation( and f"_{tail}_" in f.name and f.stem.rsplit("_", 1)[-1] in fstep_strs ) - if not image_paths: return [] - - anim_parts = ["animation", run_id, tag, str(sa), stream] + if sample is not None: + anim_parts = ["animation", run_id, tag, str(sample), stream] + else: + anim_parts = ["animation", run_id, tag, stream] if region: anim_parts.append(region) anim_parts.append(var) @@ -364,7 +393,7 @@ def _dispatch_animations( "stream": plotter.stream, "region": region, "var": var, - "sa": sa, + "sample": sample, "fsteps": list(fsteps), "image_format": plotter.image_format, "animation_format": plotter.animation_format, @@ -373,7 +402,7 @@ def _dispatch_animations( } for prefix, output_dir in prefixes for region in plotter.regions - for sa in samples + for sample in samples for var in variables ] @@ -390,6 +419,55 @@ def _dispatch_animations( return [p for r in results if r for p in r] +def _dispatch_score_map_animations( + map_dir: Path, + plotter_cfg: dict, + run_id: str, + stream: str, + metrics: list[str], + regions: list[str], + variables: list[str], + ens_values: list, + fsteps: list, + n_workers: int | None = None, +) -> list[str]: + """Build score-map animations in parallel for all (metric, region, variable[, ens]) combos. + + Returns the paths of all source frames assembled into animations. + """ + duration_ms = int(1000 / plotter_cfg["fps"]) if plotter_cfg["fps"] > 0 else 400 + + tasks = [ + dict( + output_dir=map_dir, + run_id=run_id, + tag="score_maps" + (f"_ens_{ens_val}" if ens_val is not None else "") + f"_{metric}", + stream=stream, + region=region, + var=var, + sample=None, + fsteps=list(fsteps), + image_format=plotter_cfg["image_format"], + animation_format=plotter_cfg["animation_format"], + duration_ms=duration_ms, + score_animation=True, + ) + for metric in metrics + for region in regions + for var in variables + for ens_val in ens_values + ] + + calls = [delayed(_build_single_animation)(**t) for t in tasks] + results = dispatch_parallel( + calls, + n_workers=n_workers, + backend="loky", + desc=f"Score map animations {stream}", + ) + return [p for r in results if r for p in r] + + # --------------------------------------------------------------------------- # Per-sample map / histogram plots # --------------------------------------------------------------------------- diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration_utils.py new file mode 100644 index 000000000..5323e38fa --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration_utils.py @@ -0,0 +1,124 @@ +import numpy as np +import xarray as xr +from joblib import delayed + +from weathergen.evaluate.io.data.io_orchestration import dispatch_parallel +from weathergen.evaluate.scores.score import VerifiedData, get_score +from weathergen.evaluate.scores.score_orchestration import get_next_fstep_data +from weathergen.evaluate.utils.regions import RegionBoundingBox + + +def _compute_scores_for_fstep( + region: str, + fstep: int, + metric_names: list[str], + metric_params: list, + score_data: VerifiedData, + preds_r: xr.DataArray, +) -> tuple[str, int, list, xr.DataArray, list[str]]: + """Compute scores for a single (region, fstep) pair (parallelisable worker). + + Returns ``(region, fstep, score_results, preds_r, metric_names)``. + """ + score_results: list[xr.DataArray | None] = [ + get_score(score_data, m, agg_dims="sample", parameters=p) + for m, p in zip(metric_names, metric_params, strict=False) + ] + return region, fstep, score_results, preds_r, metric_names + + +def _compute_scores( + regions: list[str], + metrics_dict: dict, + fsteps: list, + da_preds: dict, + da_tars: dict, + aligned_clim_data: dict | None, + n_workers: int | None = None, +) -> tuple[dict, dict]: + """Compute scores for all (region, fstep) pairs. Score computation is parallelised across + (region, fstep) pairs. + + Returns + ------- + computed : dict[tuple, tuple] + ``{(region, fstep): (score_results, preds_r, metric_names)}`` + raw_results : list[tuple] + List of raw results from parallel score computation, each item is a tuple of + ``(region, fstep, score_results, preds_r, metric_names)``. + """ + # Build one task per (region, fstep) with pre-applied region masking. + tasks = [] + for region in regions: + bbox = RegionBoundingBox.from_region_name(region) + metrics = metrics_dict[region] + metric_names = list(metrics.keys()) + metric_params = list(metrics.values()) + for fstep in fsteps: + tars_fs = da_tars[fstep] + preds_fs = da_preds[fstep] + preds_next, tars_next = get_next_fstep_data(fstep, da_preds, da_tars, fsteps) + climatology = aligned_clim_data[fstep] if aligned_clim_data else None + tars_r, preds_r, tars_next_r, preds_next_r = [ + bbox.apply_mask(x) if x is not None else None + for x in (tars_fs, preds_fs, tars_next, preds_next) + ] + tasks.append( + dict( + region=region, + fstep=fstep, + metric_names=metric_names, + metric_params=metric_params, + score_data=VerifiedData( + preds_r, tars_r, preds_next_r, tars_next_r, climatology + ), + preds_r=preds_r, + ) + ) + + # Compute scores in parallel across (region, fstep) pairs. + calls = [delayed(_compute_scores_for_fstep)(**t) for t in tasks] + raw_results = dispatch_parallel( + calls, n_workers=n_workers, backend="loky", desc="Score computation" + ) + + # Accumulate per-channel colour ranges from the completed results. + computed: dict[tuple, tuple] = {} + for region, fstep, score_results, preds_r, metric_names in raw_results: + computed[(region, fstep)] = (score_results, preds_r, metric_names) + + return computed, raw_results + + +def _compute_ranges( + raw_results: list[tuple[str, int, list, xr.DataArray, list[str]]], +) -> dict: + """Compute colour ranges for each metric/region/channel from the raw score results. + + Returns + ------- + score_ranges_dict : dict + ``{metric: {region: {channel: {'vmin': float, 'vmax': float}}}}`` + """ + # Accumulate per-channel colour ranges from the completed results. + + score_ranges_dict: dict = {} + for region, _, score_results, _, metric_names in raw_results: + for metric, result in zip(metric_names, score_results, strict=False): + if result is None: + continue + score_ranges_dict.setdefault(metric, {}).setdefault(region, {}) + for ch in result.coords["channel"].values: + vals = result.sel(channel=ch).values.flatten() + vals = vals[~np.isnan(vals)] + if vals.size == 0: + continue + ch_key = str(ch) + vmin, vmax = float(vals.min()), float(vals.max()) + prev = score_ranges_dict[metric][region].get(ch_key) + score_ranges_dict[metric][region][ch_key] = { + "vmin": min(prev["vmin"], vmin) if prev else vmin, + "vmax": max(prev["vmax"], vmax) if prev else vmax, + } + + return score_ranges_dict diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 97e0840a6..34c7a8cd3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -137,6 +137,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.dpi_val = plotter_cfg.get("dpi_val") self.fig_size = plotter_cfg.get("fig_size") self.fps = plotter_cfg.get("fps") + self.log_colorbar = plotter_cfg.get("log_colorbar", False) self.regions = plotter_cfg.get("regions") self.log_x = plotter_cfg.get("log_x", False) self.log_y = plotter_cfg.get("log_y", False) @@ -626,6 +627,7 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: "vmax": kw.pop("vmax", None), "cmap": plt.get_cmap(kw.pop("colormap", "coolwarm")), "use_datashader": kw.pop("use_datashader", False), + "levels": kw.pop("levels", None), # HEALPix grid "add_healpix_grid": kw.pop("add_healpix_grid", False), "healpix_nside": kw.pop("healpix_nside", 4), @@ -635,16 +637,6 @@ def _parse_map_kwargs(map_kwargs: dict | None, stream: str | None) -> dict: "healpix_linestyle": kw.pop("healpix_linestyle", "-"), } - # Colour normalisation - if isinstance(kw.get("levels", False), oc.listconfig.ListConfig): - parsed["norm"] = mpl.colors.BoundaryNorm( - kw.pop("levels", None), parsed["cmap"].N, extend="both" - ) - else: - parsed["norm"] = mpl.colors.Normalize( - vmin=parsed["vmin"], vmax=parsed["vmax"], clip=False - ) - parsed["extra"] = kw # remaining kwargs forwarded to scatter return parsed @@ -888,7 +880,16 @@ def scatter_plot( if figsize is None and data.size >= 200_000: figsize = (15, 7) - proj = ccrs.Robinson() if regionname == "global" else ccrs.PlateCarree() + proj = ccrs.PlateCarree() + if regionname: + try: + # This uses the method already available in RegionBoundingBox + bbox = RegionBoundingBox.from_region_name(regionname) + proj = bbox.projection + except ValueError: + # If regionname isn't in the library, fall back to PlateCarree + _logger.warning(f"Region '{regionname}' not found in library, using PlateCarree.") + proj = ccrs.PlateCarree() fig = plt.figure(figsize=figsize, dpi=self.dpi_val) ax = fig.add_subplot(1, 1, 1, projection=proj) try: @@ -907,10 +908,18 @@ def scatter_plot( opts["vmin"] = float(p_lo) if opts["vmax"] is None: opts["vmax"] = float(p_hi) - # Rebuild norm with the robust limits - opts["norm"] = mpl.colors.Normalize( - vmin=opts["vmin"], vmax=opts["vmax"], clip=False + + if isinstance(opts["levels"], oc.listconfig.ListConfig): + opts["norm"] = mpl.colors.BoundaryNorm(opts["levels"], opts["cmap"].N, extend="both") + elif self.log_colorbar and opts["vmin"] is not None and opts["vmin"] > 0: + opts["norm"] = mpl.colors.LogNorm(vmin=opts["vmin"], vmax=opts["vmax"]) + else: + if self.log_colorbar: + _logger.warning( + "log_colorbar=True but vmin=%.3g <= 0; falling back to linear norm.", + opts["vmin"], ) + opts["norm"] = mpl.colors.Normalize(vmin=opts["vmin"], vmax=opts["vmax"], clip=False) if regionname == "global": ax.set_global() diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 17825e12f..2964af569 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -16,7 +16,6 @@ from collections import defaultdict from pathlib import Path -# Third-party import mlflow from mlflow.client import MlflowClient from omegaconf import DictConfig, OmegaConf, open_dict @@ -34,8 +33,8 @@ ) from weathergen.evaluate.plotting.plot_orchestration import ( plot_data, - plot_score_maps_per_stream, plot_summary, + run_score_map_pipeline, ) from weathergen.evaluate.plotting.plot_utils import collect_channels from weathergen.evaluate.scores.score_orchestration import ( @@ -187,6 +186,7 @@ def _process_stream( regions: list[str], metrics: dict[str, object], plot_score_maps: bool, + plot_score_animations: bool, ) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]: """ Worker function for a single stream of a single run. @@ -210,6 +210,8 @@ def _process_stream( Dict of metrics to be processed and their parameters. plot_score_maps: Bool to define if the score maps need to be plotted or not. + plot_score_animations: + Bool to define if the score animations need to be plotted or not. """ type_ = run.get("type", "zarr") reader = get_reader(type_, run, run_id, private_paths, regions, metrics) @@ -270,12 +272,14 @@ def _process_stream( scores_dict = merge(stream_loaded_scores, stream_computed_scores) if score_maps: - plot_score_maps_per_stream( + run_score_map_pipeline( reader, stream, regions_to_compute, metrics_to_compute, output_data=output_data, + global_plotting_options=global_plotting_opts, + plot_score_animations=plot_score_animations, ) return run_id, stream, scores_dict @@ -299,6 +303,7 @@ def evaluate_from_config(cfg: dict, mlflow_client: MlflowClient | None) -> None: summary_dir = Path(cfg.evaluation.get("summary_dir", _DEFAULT_PLOT_DIR)) metrics = cfg.evaluation.metrics plot_score_maps = cfg.evaluation.get("plot_score_maps", False) + plot_score_animations = cfg.evaluation.get("plot_score_animations", False) global_plotting_opts = cfg.get("global_plotting_options", {}) default_streams = cfg.get("default_streams", {}) max_workers = cfg.get("max_workers") # global hard cap for parallel workers @@ -330,6 +335,7 @@ def evaluate_from_config(cfg: dict, mlflow_client: MlflowClient | None) -> None: "regions": regions, "metrics": metrics, "plot_score_maps": plot_score_maps, + "plot_score_animations": plot_score_animations, } ) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py index 6b84a79d1..68b01318a 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py @@ -15,6 +15,8 @@ from scipy.spatial import cKDTree from tqdm import tqdm +from weathergen.evaluate.utils.derived_channels import scale_z_channels + _logger = logging.getLogger(__name__) @@ -245,6 +247,7 @@ def get_climatology(reader, da_tars, stream: str) -> dict | None: if clim_data_path is not None: clim_data = xr.open_dataset(clim_data_path) _logger.info("Aligning climatological data with target structure...") - return align_clim_data(da_tars, clim_data) + aligned = align_clim_data(da_tars, clim_data) + return {fstep: scale_z_channels(da, stream) for fstep, da in aligned.items()} return None diff --git a/packages/evaluate/src/weathergen/evaluate/utils/regions.py b/packages/evaluate/src/weathergen/evaluate/utils/regions.py index b2893c314..de0606500 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/regions.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/regions.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import ClassVar +import cartopy.crs as ccrs import xarray as xr _logger = logging.getLogger(__name__) @@ -22,13 +23,23 @@ class RegionLibrary: Predefined bounding boxes for known regions. """ - REGIONS: ClassVar[dict[str, tuple[float, float, float, float]]] = { - "global": (-90.0, 90.0, -180.0, 180.0), - "nhem": (0.0, 90.0, -180.0, 180.0), - "shem": (-90.0, 0.0, -180.0, 180.0), - "tropics": (-30.0, 30.0, -180.0, 180.0), - "belgium": (49, 52, 2, 7), - "europe": (35, 70, -10, 40), + REGIONS: ClassVar[dict[str, tuple[float, float, float, float, ccrs.Projection]]] = { + "global": (-90.0, 90.0, -180.0, 180.0, ccrs.Robinson()), + "nhem": (0.0, 90.0, -180.0, 180.0, ccrs.PlateCarree()), + "shem": (-90.0, 0.0, -180.0, 180.0, ccrs.PlateCarree()), + "tropics": (-30.0, 30.0, -180.0, 180.0, ccrs.PlateCarree()), + "belgium": (49, 52, 2, 7, ccrs.PlateCarree()), + "europe": (35, 70, -10, 40, ccrs.PlateCarree()), + "arctic": ( + 50.0, + 90.0, + -180.0, + 180.0, + ccrs.Stereographic(central_longitude=0, central_latitude=90), + ), + "uwc-west": (39.0, 63.0, -26.0, 41.0, ccrs.PlateCarree()), + "arome": (37.0, 56.0, -12.0, 16.0, ccrs.PlateCarree()), + "icon": (42.0, 51.0, -1.0, 18.0, ccrs.PlateCarree()), } @@ -38,6 +49,7 @@ class RegionBoundingBox: lat_max: float lon_min: float lon_max: float + projection: ccrs.Projection def __post_init__(self): """Validate the bounding box coordinates.""" diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py new file mode 100644 index 000000000..91033b5a2 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py @@ -0,0 +1,186 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +from anemoi.datasets.data import MissingDateError + +from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi +from weathergen.datasets.data_reader_base import ( + ReaderData, + TimeWindowHandler, + TIndex, +) +from weathergen.train.utils import Stage + +_logger = logging.getLogger(__name__) + + +def dt2cal(dt): + """ + Convert array of datetime64 to a calendar array of year, month, day, hour, + minute, seconds, microsecond with these quantites indexed on the last axis. + + Parameters + ---------- + dt : datetime64 array (...) + numpy.ndarray of datetimes of arbitrary shape + + Returns + ------- + cal : uint32 array (..., 7) + calendar array with last axis representing year, month, day, hour, + minute, second, microsecond + """ + + # allocate output + out = np.empty(dt.shape + (7,), dtype="u4") + # decompose calendar floors + year, month, day, hour, min, sec = [dt.astype(f"M8[{x}]") for x in "YMDhms"] + out[..., 0] = year + 1970 # Gregorian Year + out[..., 1] = (month - year) + 1 # month + out[..., 2] = (day - month) + 1 # dat + out[..., 3] = (dt - day).astype("m8[h]") # hour + out[..., 4] = (dt - hour).astype("m8[m]") # minute + out[..., 5] = (dt - min).astype("m8[s]") # second + out[..., 6] = (dt - sec).astype("m8[us]") # microsecond + return out + + +class DataReaderAnemoiOperan(DataReaderAnemoi): + "Wrapper for Anemoi datasets" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + stage: Stage, + ) -> None: + """ + Construct data reader for anemoi dataset + + Parameters + ---------- + filename : + filename (and path) of dataset + stream_info : + information about stream + + Returns + ------- + None + """ + + super().__init__(tw_handler, filename, stream_info, stage) + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for window (for either source or target, through public interface) + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + """ + + t_idxs, dtr = self._get_dataset_idxs(idx) + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # get additional timestep to ensure we have one valid timestep + t_idxs = np.insert(t_idxs, 0, t_idxs[0] - 1) + + didx_start = t_idxs[0] + didx_end = t_idxs[-1] + 1 + datetimes = self.ds.dates[didx_start:didx_end] + datetimes_split = dt2cal(datetimes) + + # compute corrected datetimes that account for actual availability + nts = self.stream_info["nominal_time_mapping"] + deltas = [int(nts[str(hour)]) - int(hour) for hour in datetimes_split[:, 3]] + datetimes_offset = [ + dt + np.timedelta64(delta, "h") for dt, delta in zip(datetimes, deltas, strict=False) + ] + + # use latest available sample that is valid w.r.t the input data window + datetimes_mask = [dt < dtr.end for dt in datetimes_offset] + if np.array(datetimes_mask).sum() == 0: + t_idxs = [] + else: + t_idxs = [t_idxs[datetimes_mask][-1].item()] + + # _get from DataReaderAnemoi + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + assert t_idxs[0] >= 0, "index must be non-negative" + didx_start = t_idxs[0] + # End is inclusive + didx_end = t_idxs[-1] + 1 + + # extract number of time steps and collapse ensemble dimension + # ds is a wrapper around zarr with get_coordinate_selection not being exposed since + # subsetting is pushed to the ctor via frequency argument; this also ensures that no sub- + # sampling is required here + try: + data = self.ds[didx_start:didx_end][:, :, 0].astype(np.float32) + except MissingDateError as e: + _logger.debug(f"Date not present in anemoi dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # coords-first representation and collapse multiple steps + data = data.transpose([0, 2, 1]).reshape((data.shape[0] * data.shape[2], -1)) + + # extract geoinfo channels (can be time-varying, so read from dataset) + geoinfos = data[:, list(self.geoinfo_idx)] + # extract channels + data = data[:, list(channels_idx)] + + # construct lat/lon coords + latlon = np.concatenate( + [ + np.expand_dims(self.latitudes, 0), + np.expand_dims(self.longitudes, 0), + ], + axis=0, + ).transpose() + # repeat latlon len(t_idxs) times + coords = np.vstack((latlon,) * len(t_idxs)) + + # date time matching #data points of data + # Assuming a fixed frequency for the dataset + datetimes = np.repeat(self.ds.dates[didx_start:didx_end], len(data) // len(t_idxs)) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + # check_reader_data(rd, dtr) + + return rd diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 7ea7c7d59..303e87ab9 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -24,5 +24,9 @@ def get_extra_reader(stream_type: str) -> object | None: from weathergen.readers_extra.data_reader_mesh import DataReaderMesh return DataReaderMesh + case "anemoi_operan": + from weathergen.readers_extra.data_reader_anemoi_operan import DataReaderAnemoiOperan + + return DataReaderAnemoiOperan case _: return None diff --git a/pyproject.toml b/pyproject.toml index 0b5c022cd..00103cb8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ requires-python = ">=3.12,<3.13" dependencies = [ 'numpy~=2.2', 'astropy_healpix~=1.1.2', - 'healpy>=1.19,<2', 'zarr~=3.1.3', 'pandas~=2.2', 'tqdm', @@ -274,3 +273,4 @@ members = [ # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. ] + diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 6c4c0f913..21641b71d 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -53,12 +53,12 @@ def pin_memory(self): return self - def __init__(self, streams: dict) -> None: + def __init__(self, stream_names: list[str]) -> None: self.meta_info = {} self.streams_data = {} - for stream_info in streams: - self.streams_data[stream_info["name"]] = None + for stream_name in stream_names: + self.streams_data[stream_name] = None def to_device(self, device) -> None: for key in self.meta_info.keys(): @@ -146,8 +146,10 @@ class BatchSamples: output_idxs: list[int] device: str | None - def __init__(self, streams: dict, num_samples: int, output_steps, output_idxs) -> None: - self.samples = [Sample(streams) for _ in range(num_samples)] + def __init__( + self, stream_names: list[str], num_samples: int, output_steps, output_idxs + ) -> None: + self.samples = [Sample(stream_names) for _ in range(num_samples)] self.tokens_lens = None self.output_steps = output_steps self.output_idxs = output_idxs @@ -275,7 +277,7 @@ class ModelBatch: def __init__( self, - streams: dict, + stream_names: list[str], num_source_samples: int, num_target_samples: int, output_offset, @@ -289,10 +291,10 @@ def __init__( self.output_idxs = list(range(output_offset, output_steps)) self.source_samples = BatchSamples( - streams, num_source_samples, output_steps, self.output_idxs + stream_names, num_source_samples, output_steps, self.output_idxs ) self.target_samples = BatchSamples( - streams, num_target_samples, output_steps, self.output_idxs + stream_names, num_target_samples, output_steps, self.output_idxs ) self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 896638979..62b1dcfba 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -75,8 +75,21 @@ def __init__( # determine idx for coords and geoinfos self.coords_idx = [self.colnames.index("lat"), self.colnames.index("lon")] - self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) - self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx] + + # geoinfo channels + sname = stream_info["name"] + if stream_info.get("geoinfo_channels") is not None: + self.geoinfo_idx, self.geoinfo_channels = [], [] + for c in stream_info.get("geoinfo_channels"): + if c not in self.colnames: + _logger.warning(f"{sname} : geoinfo {c} specified in config but not present.") + else: + self.geoinfo_idx.append(self.colnames.index(c)) + self.geoinfo_channels.append(c) + else: + self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) + self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx] + _logger.info(f"{stream_info['name']} geoinfos : {self.geoinfo_channels}") # load additional properties (mean, var) self._load_properties() @@ -237,6 +250,11 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData: num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) ) + if idx >= len(self.indices_start) or idx >= len(self.indices_end): + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + start_row = self.indices_start[idx] end_row = self.indices_end[idx] diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 240e53763..329762410 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -8,6 +8,7 @@ import torch from numpy.typing import NDArray +from weathergen.common.config import Config from weathergen.datasets.batch import SampleMetaData from weathergen.train.utils import Stage from weathergen.utils.utils import is_stream_diagnostic, is_stream_forcing @@ -196,13 +197,12 @@ def merge_masking_config(self, mode_cfg, override): return stream_cfg - def build_effective_masking_cfgs(self, streams, mode_cfg): + def build_effective_masking_cfgs(self, streams: Config, mode_cfg): """Build effective masking configs for all streams.""" cfgs = {} - for stream_info in streams: - name = stream_info["name"] + for stream_name, stream_info in streams.items(): override = stream_info.get("masking_override", {}) - cfgs[name] = self.merge_masking_config(mode_cfg, override) + cfgs[stream_name] = self.merge_masking_config(mode_cfg, override) return cfgs diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 8122cd084..2bebef363 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import logging import pathlib from collections.abc import Sequence @@ -85,6 +86,12 @@ def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IORe return IOReaderData.combine(rdatas) +@dataclasses.dataclass +class _Stream: + info: Config + readers: list[DataReaderBase] + + class MultiStreamDataSampler(torch.utils.data.IterableDataset): def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): super(MultiStreamDataSampler, self).__init__() @@ -94,7 +101,6 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.mini_epoch = 0 self.mask_value = 0.0 - self.streams = cf.streams self.rank = cf.rank self.world_size = cf.world_size self.repeat_data = cf.data_loading.get("repeat_data_in_mini_epoch", False) @@ -102,7 +108,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): # initialise healpic self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level - self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg) + self.masker = Masker(cf.healpix_level, stage, cf.streams, self.mode_cfg) self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) forecast_cfg = FORECAST_DEFAULTS | OmegaConf.to_object(mode_cfg.get("forecast", {})) @@ -177,6 +183,20 @@ def check_samples(self, fsm: int): # streamlined calculation of length epoch_len = self.samples_per_mini_epoch + + # ensure epoch_len is large enough to produce at least one batch per rank + min_samples = self.world_size * self.batch_size + if epoch_len < min_samples: + logger.warning( + f"samples_per_mini_epoch={epoch_len} is too small for " + f"world_size={self.world_size} and batch_size={self.batch_size}. " + f"samples_per_mini_epoch has to be equal to or larger than" + f"world_size*batch_size to ensure that each rank can produce at least one sample. " + f"Automatically increasing to {min_samples}." + ) + epoch_len = min_samples + self.samples_per_mini_epoch = min_samples + # adjust len to split loading across all workers and ensure it is multiple of batch_size self.len = ((epoch_len // self.world_size) // self.batch_size) * self.batch_size @@ -192,13 +212,12 @@ def _calc_baseperms(self, fsm: int) -> np.typing.NDArray: return np.arange(perms_len) - def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: + def _init_stream_datasets(self, cf) -> dict[StreamName, _Stream]: """Load dataset readers for all streams from config.""" - streams_datasets: dict[StreamName, list[AnyDataReader]] = {} - - for _, stream_info in enumerate(cf.streams): + streams_datasets: dict[StreamName, _Stream] = {} + for stream_name, stream_info in cf.streams.items(): # list of sources for current stream - streams_datasets[stream_info["name"]] = [] + streams_datasets[stream_name] = _Stream(stream_info, []) kwargs = { "tw_handler": self.time_window_handler, @@ -217,7 +236,7 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: dataset = get_extra_reader(type_name) if dataset is None: msg = f"Unsupported stream type {stream_info['type']}" - f"for stream name '{stream_info['name']}'." + f"for stream name '{stream_name}'." raise ValueError(msg) for fname in stream_info["filenames"]: @@ -232,7 +251,7 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: if not any(filename.exists() for filename in filenames): # see above msg = ( f"Did not find input data for {stream_info['type']} " - f"stream '{stream_info['name']}': {filenames}." + f"stream '{stream_name}': {filenames}." ) raise FileNotFoundError(msg) @@ -244,11 +263,11 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: if is_root(): logger.info( f"Opening dataset with type: {ds_type}" - + f" from stream config {stream_info['name']}.", + + f" from stream config {stream_name}.", ) ds = dataset(filename=filename, **kwargs) - streams_datasets[stream_info["name"]] += [ds] + streams_datasets[stream_name].readers += [ds] stream_info[str(self._stage) + "_source_channels"] = ds.source_channels stream_info[str(self._stage) + "_target_channels"] = ds.target_channels @@ -333,35 +352,35 @@ def advance(self): def get_sources_size(self): return [ 0 - if ds[0].get_source_num_channels() == 0 - else ds[0].get_source_num_channels() - + ds[0].get_geoinfo_size() - + ds[0].get_coords_size() + if ds.readers[0].get_source_num_channels() == 0 + else ds.readers[0].get_source_num_channels() + + ds.readers[0].get_geoinfo_size() + + ds.readers[0].get_coords_size() + self.tokenizer.get_size_time_embedding() - for _, ds in self.streams_datasets.items() + for ds in self.streams_datasets.values() ] def get_sources_num_channels(self): - return [ds[0].get_source_num_channels() for _, ds in self.streams_datasets.items()] + return [ds.readers[0].get_source_num_channels() for ds in self.streams_datasets.values()] def get_targets_num_channels(self): - return [ds[0].get_target_num_channels() for _, ds in self.streams_datasets.items()] + return [ds.readers[0].get_target_num_channels() for ds in self.streams_datasets.values()] def get_targets_coords_size(self): # TODO: avoid hard coding magic values # +6 at the end for stream_id and time encoding return [ - (ds[0].get_geoinfo_size() + (5 * (3 * 5)) + 3 * 8) + 6 - for _, ds in self.streams_datasets.items() + (ds.readers[0].get_geoinfo_size() + (5 * (3 * 5)) + 3 * 8) + 6 + for ds in self.streams_datasets.values() ] def denormalize_source_channels(self, stream_name, data) -> torch.Tensor: # [0]: with multiple ds per stream we use the first one - return self.streams_datasets[stream_name][0].denormalize_source_channels(data) + return self.streams_datasets[stream_name].readers[0].denormalize_source_channels(data) def denormalize_target_channels(self, stream_name, data) -> torch.Tensor: # [0]: with multiple ds per stream we use the first one - return self.streams_datasets[stream_name][0].denormalize_target_channels(data) + return self.streams_datasets[stream_name].readers[0].denormalize_target_channels(data) def _build_stream_data_input( self, @@ -413,8 +432,9 @@ def _build_stream_data_input( mask, ) - # collect data for stream - stream_data.add_source(step, rdata, source_cells_lens, source_cells) + stream_data.add_source( + self._stage, step, rdata, source_cells_lens, source_cells, rdata.is_spoof + ) return stream_data @@ -455,7 +475,7 @@ def _build_stream_data_output( (time_win_target.start, time_win_target.end), target_mask, ) - stream_data.add_target_coords(timestep_idx, tc, tc_l, rdata.is_spoof) + stream_data.add_target_coords(self._stage, timestep_idx, tc, tc_l, rdata.is_spoof) if "target_values" in mode: (tt_cells, tt_t, tt_c, idxs_inv) = self.tokenizer.get_target_values( @@ -465,8 +485,9 @@ def _build_stream_data_output( (time_win_target.start, time_win_target.end), target_mask, ) + stream_data.add_target_values( - timestep_idx, tt_cells, tt_c, tt_t, idxs_inv, rdata.is_spoof + self._stage, timestep_idx, tt_cells, tt_c, tt_t, idxs_inv, rdata.is_spoof ) return stream_data @@ -594,16 +615,17 @@ def _get_source_target_masks(self, training_mode): Generate source and target masks for all streams. """ masks = {} - for stream_info in self.streams: + for stream_name, stream_data in self.streams_datasets.items(): + stream_info = stream_data.info # Build source and target sample masks - masks[stream_info["name"]] = self.tokenizer.build_samples_for_stream( + masks[stream_name] = self.tokenizer.build_samples_for_stream( training_mode, self.num_healpix_cells, stream_info, ) # identical for all streams - num_target_samples = len(masks[stream_info["name"]][0]) - num_source_samples = len(masks[stream_info["name"]][1]) + num_target_samples = len(masks[stream_name][0]) + num_source_samples = len(masks[stream_name][1]) return masks, num_source_samples, num_target_samples @@ -617,11 +639,12 @@ def _preprocess_model_batch( """ Perform necessary pre-processing of model batch """ + stream_names = list(self.streams_datasets.keys()) batch.source_samples.tokens_lens = get_tokens_lens( - self.streams, batch.source_samples, source_input_steps + stream_names, batch.source_samples, source_input_steps ) batch.target_samples.tokens_lens = get_tokens_lens( - self.streams, batch.target_samples, target_input_steps + stream_names, batch.target_samples, target_input_steps ) return batch @@ -652,7 +675,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): num_output_steps = self._get_output_length(num_forecast_steps) batch = ModelBatch( - self.streams, + list(self.streams_datasets.keys()), num_source_samples, num_target_samples, self.output_offset, @@ -660,9 +683,8 @@ def _get_batch(self, idx: int, num_forecast_steps: int): ) # for all streams - for stream_info, (stream_name, stream_ds) in zip( - self.streams, self.streams_datasets.items(), strict=True - ): + for stream_name, stream_data in self.streams_datasets.items(): + stream_info, stream_ds = stream_data.info, stream_data.readers (target_masks, source_masks, source_to_target) = masks_streams[stream_name] # max number of input steps diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index e993a0f03..c77e01282 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -12,6 +12,7 @@ import torch from weathergen.common.io import IOReaderData +from weathergen.train.utils import TRAIN, Stage def _pin_tensor(tensor: torch.Tensor) -> torch.Tensor: @@ -126,13 +127,6 @@ def pin_memory(self): self.source_tokens_cells = _pin_tensor_list(self.source_tokens_cells) self.source_tokens_lens = _pin_tensor_list(self.source_tokens_lens) self.source_idxs_embed = _pin_tensor_list(self.source_idxs_embed) - self.source_idxs_embed_pe = _pin_tensor_list(self.source_idxs_embed_pe) - - # Pin source_raw (list of IOReaderData objects) - if hasattr(self, "source_raw"): - for raw_data in self.source_raw: - if raw_data is not None and hasattr(raw_data, "pin_memory"): - raw_data.pin_memory() return self @@ -163,14 +157,17 @@ def to_device(self, device: str) -> None: self.source_tokens_lens = [s.to(dv, non_blocking=True) for s in self.source_tokens_lens] self.source_idxs_embed = [s.to(dv, non_blocking=True) for s in self.source_idxs_embed] - self.source_idxs_embed_pe = [ - s.to(dv, non_blocking=True) for s in self.source_idxs_embed_pe - ] return self def add_source( - self, step: int, ss_raw: IOReaderData, ss_lens: torch.Tensor, ss_cells: list + self, + stage: Stage, + step: int, + ss_raw: IOReaderData, + ss_lens: torch.Tensor, + ss_cells: list, + is_spoof: bool, ) -> None: """ Add data for source for one input. @@ -189,6 +186,10 @@ def add_source( assert step < self.input_steps + if stage == TRAIN: + del ss_raw + ss_raw = None + self.source_raw[step] = ss_raw self.source_tokens_lens[step] = ss_lens self.source_tokens_cells[step] = torch.stack(ss_cells) @@ -196,10 +197,11 @@ def add_source( idx = torch.isnan(self.source_tokens_cells[step]) self.source_tokens_cells[step][idx] = self.mask_value - self.source_is_spoof[step] = ss_raw.is_spoof + self.source_is_spoof[step] = is_spoof def add_target( self, + stage: Stage, fstep: int, targets: list, target_coords: torch.Tensor, @@ -245,6 +247,7 @@ def add_target( def add_target_values( self, + stage: Stage, fstep: int, targets: list, target_coords_raw: torch.Tensor, @@ -278,6 +281,10 @@ def add_target_values( None """ + if stage == TRAIN: + del idxs_inv + idxs_inv = None + self.target_tokens[fstep] = targets self.target_times_raw[fstep] = times_raw self.target_coords_raw[fstep] = target_coords_raw @@ -287,6 +294,7 @@ def add_target_values( def add_target_coords( self, + stage: Stage, fstep: int, target_coords: torch.Tensor, target_coords_per_cell: torch.Tensor, diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index b392f4ae0..1bd1d1722 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -35,13 +35,17 @@ def encode_times_source(times, time_win) -> torch.tensor: dt = pd.to_datetime(times) dt_win = pd.to_datetime(time_win) dt_delta = dt - dt_win[0] + year = np.atleast_1d(dt.year) + dayofyear = np.atleast_1d(dt.dayofyear) + minutes = np.atleast_1d(dt.hour * 60 + dt.minute) + delta_seconds = np.atleast_1d(dt_delta.seconds) time_tensor = torch.cat( ( - torch.tensor(dt.year, dtype=fp32).unsqueeze(1), - torch.tensor(dt.dayofyear, dtype=fp32).unsqueeze(1), - torch.tensor(dt.hour * 60 + dt.minute, dtype=fp32).unsqueeze(1), - torch.tensor(dt_delta.seconds, dtype=fp32).unsqueeze(1), - torch.tensor(dt_delta.seconds, dtype=fp32).unsqueeze(1), + torch.tensor(year, dtype=fp32).unsqueeze(1), + torch.tensor(dayofyear, dtype=fp32).unsqueeze(1), + torch.tensor(minutes, dtype=fp32).unsqueeze(1), + torch.tensor(delta_seconds, dtype=fp32).unsqueeze(1), + torch.tensor(delta_seconds, dtype=fp32).unsqueeze(1), ), 1, ) @@ -65,7 +69,9 @@ def encode_times_target(times, time_win) -> torch.tensor: dt = pd.to_datetime(times) dt_win = pd.to_datetime(time_win) # for target only provide local time - dt_delta = torch.tensor((dt - dt_win[0]).seconds, dtype=torch.float32).unsqueeze(1) + dt_delta = torch.tensor(np.atleast_1d((dt - dt_win[0]).seconds), dtype=torch.float32).unsqueeze( + 1 + ) time_tensor = torch.cat( ( dt_delta, @@ -349,7 +355,7 @@ def return_empty(rdata, idxs_cells_lens): idxs_data = torch.cat(idxs_data) # apply mask - datetimes = rdata.datetimes[idxs_data] + datetimes = np.atleast_1d(rdata.datetimes[idxs_data]) datetimes_enc = enc_time(datetimes, time_win) geoinfos = rdata.geoinfos[idxs_data] coords = rdata.coords[idxs_data] diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 27b1af64a..90cb75ce4 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -257,7 +257,9 @@ def add_local_vert_coords_ctrs2(verts_local, tcs_lens, a, zi, geoinfo_offset): return a -def get_tokens_lens(streams: dict, batch_data: BatchSamples, input_steps: int) -> torch.Tensor: +def get_tokens_lens( + streams_names: list[str], batch_data: BatchSamples, input_steps: int +) -> torch.Tensor: """ Extract tokens_lens for (num_steps, num_samples, num_streams) """ @@ -268,8 +270,8 @@ def get_tokens_lens(streams: dict, batch_data: BatchSamples, input_steps: int) - [ torch.stack( [ - sample.streams_data[stream_info["name"]].source_tokens_lens[i] - for stream_info in streams + sample.streams_data[stream_name].source_tokens_lens[i] + for stream_name in streams_names ] ) for sample in batch_data.samples diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index b18791aa5..bf97479e6 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,13 +14,13 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm -from weathergen.model.positional_encoding import apply_rope +from weathergen.model.positional_encoding import rotary_pos_emb_2d """ Attention blocks used by WeatherGenerator. -Some blocks optionally apply RoPE-like positional modulation. When enabled, the caller must -provide per-token coordinates aligned with the token order (lat, lon in radians). +Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D +coordinates aligned with the token order (lat, lon in radians). """ @@ -40,7 +40,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - rope_mode="none", + with_2d_rope=False, ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -49,10 +49,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual - self.rope_mode = rope_mode - self.rope_post_mod_qk_lnorm = rope_mode == "spherical" - if self.rope_post_mod_qk_lnorm: - assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -82,9 +79,6 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) - post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity - self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) - self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype @@ -102,12 +96,10 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) - if self.rope_post_mod_qk_lnorm: - qs = self.post_rope_lnorm_q(qs).to(self.dtype) - ks = self.post_rope_lnorm_k(ks).to(self.dtype) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 @@ -233,7 +225,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - rope_mode="none", + with_2d_rope=False, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -241,10 +233,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual - self.rope_mode = rope_mode - self.rope_post_mod_qk_lnorm = rope_mode == "spherical" - if self.rope_post_mod_qk_lnorm: - assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -274,9 +263,6 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) - post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity - self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) - self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype assert with_flash, "Only flash attention supported." @@ -302,12 +288,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) - if self.rope_post_mod_qk_lnorm: - qs = self.post_rope_lnorm_q(qs).to(self.dtype) - ks = self.post_rope_lnorm_k(ks).to(self.dtype) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) @@ -556,7 +540,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - rope_mode="none", + with_2d_rope=False, ): super(MultiSelfAttentionHead, self).__init__() @@ -565,10 +549,7 @@ def __init__( self.softcap = softcap self.dropout_rate = dropout_rate self.with_residual = with_residual - self.rope_mode = rope_mode - self.rope_post_mod_qk_lnorm = rope_mode == "spherical" - if self.rope_post_mod_qk_lnorm: - assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" + self.with_2d_rope = with_2d_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -598,9 +579,6 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) - post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity - self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) - self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype if with_flash: @@ -621,12 +599,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 2 - ) - if self.rope_post_mod_qk_lnorm: - qs = self.post_rope_lnorm_q(qs).to(self.dtype) - ks = self.post_rope_lnorm_k(ks).to(self.dtype) + if self.with_2d_rope: + if coords is None: + raise ValueError("coords must be provided when with_2d_rope=True") + qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index a05e25ca9..f8b4facc9 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -201,12 +201,14 @@ def __init__( self.block = nn.ModuleList() + target_readout_num_heads = next(self.cf.streams.values())["target_readout"]["num_heads"] + # Multi-Cross Attention Head self.block.append( MultiCrossAttentionHeadVarlen( dim_in, self.cf.ae_global_dim_embed, - self.cf.streams[0]["target_readout"]["num_heads"], + target_readout_num_heads, dim_head_proj=self.tr_dim_head_proj, with_residual=True, with_qk_lnorm=True, @@ -226,7 +228,7 @@ def __init__( self.block.append( MultiSelfAttentionHeadVarlen( dim_in, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, dropout_rate=0.1, # Assuming dropout_rate is 0.1 with_qk_lnorm=True, with_flash=self.cf.with_flash_attention, diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 72340d142..5dea1bdae 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -58,7 +58,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # embedding engine # determine stream names once so downstream components use consistent keys - self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] + self.stream_names = list(cf.streams.keys()) # separate embedding networks for differnt observation types self.embed_engine = EmbeddingEngine(cf, self.sources_size) @@ -133,11 +133,7 @@ def forward(self, model_params, batch): tokens_global = checkpoint( self.ae_global_engine, tokens_global, - coords=( - model_params.rope_spherical_coeffs.unbind(dim=-1) - if model_params.rope_spherical_coeffs is not None - else model_params.rope_coords - ), + coords=model_params.rope_coords, use_reentrant=False, ) @@ -225,8 +221,6 @@ def aggregation_engine_unmasked( tokens_global_register_class, tokens_lens, rope_cell_coords=None, - rope_cell_coeffs=None, - rope_extra_coeffs=None, ): """ Aggregation engine on the global latents of unmasked cells @@ -257,19 +251,8 @@ def aggregation_engine_unmasked( ) # Build packed coords matching the interleaved token order - num_extra = self.num_class_tokens + self.num_register_tokens - if rope_cell_coeffs is not None: - extra_real, extra_imag = rope_extra_coeffs.unbind(dim=-1) - cell_real, cell_imag = rope_cell_coeffs.unbind(dim=-1) - packed_real = [] - packed_imag = [] - for mask_b in cell_mask.flatten(0, 1): - packed_real.append(extra_real) - packed_imag.append(extra_imag) - packed_real.append(cell_real[mask_b]) - packed_imag.append(cell_imag[mask_b]) - packed_coords = (torch.cat(packed_real, dim=0), torch.cat(packed_imag, dim=0)) - elif rope_cell_coords is not None: + if rope_cell_coords is not None: + num_extra = self.num_class_tokens + self.num_register_tokens zero_coords = torch.zeros( num_extra, 2, device=rope_cell_coords.device, dtype=rope_cell_coords.dtype ) @@ -333,8 +316,6 @@ def assimilate_local( tokens_global_register_class, batch.tokens_lens, rope_cell_coords=model_params.rope_cell_coords, - rope_cell_coeffs=model_params.rope_spherical_cell_coeffs, - rope_extra_coeffs=model_params.rope_spherical_extra_coeffs, ) # final processing diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index e5af71d2e..9c1a1e3a9 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -29,7 +29,6 @@ StreamEmbedTransformer, ) from weathergen.model.layers import MLP -from weathergen.model.positional_encoding import get_rope_mode from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -49,9 +48,9 @@ def __init__(self, cf: Config, sources_size) -> None: self.dtype = get_dtype(self.cf.mixed_precision_dtype) self.sources_size = sources_size # KCT:iss130, what is this? self.embeds = torch.nn.ModuleDict() - self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] + self.streams = cf.streams - for i, (si, stream_name) in enumerate(zip(self.cf.streams, self.stream_names, strict=True)): + for i, (stream_name, si) in enumerate(self.streams.items()): if si.get("diagnostic", False) or self.sources_size[i] == 0: self.embeds[stream_name] = torch.nn.Identity() continue @@ -90,7 +89,7 @@ def forward(self, batch, pe_embed): # iterate over all streams x_embeds = [] - for stream_name in self.stream_names: + for stream_name in self.streams.keys(): # collect all source tokens from all input_steps and all samples in the batch sdata = [] for istep in range(num_steps_input): @@ -391,7 +390,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: super(QueryAggregationEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells - rope_mode = get_rope_mode(self.cf) self.ae_aggregation_blocks = torch.nn.ModuleList() @@ -412,7 +410,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - rope_mode=rope_mode, + with_2d_rope=self.cf.get("rope_2D", False), ) ) else: @@ -468,7 +466,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: super(GlobalAssimilationEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells - rope_mode = get_rope_mode(self.cf) self.ae_global_blocks = torch.nn.ModuleList() @@ -489,7 +486,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - rope_mode=rope_mode, + with_2d_rope=self.cf.get("rope_2D", False), ) ) else: @@ -506,7 +503,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - rope_mode=rope_mode, + with_2d_rope=self.cf.get("rope_2D", False), ) ) # MLP block @@ -557,7 +554,6 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = super(ForecastingEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells - rope_mode = get_rope_mode(self.cf) self.fe_blocks = torch.nn.ModuleList() global_rate = int(1 / self.cf.forecast_att_dense_rate) @@ -577,7 +573,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - rope_mode=rope_mode, + with_2d_rope=self.cf.get("rope_2D", False), ) ) else: @@ -595,7 +591,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - rope_mode=rope_mode, + with_2d_rope=self.cf.get("rope_2D", False), ) ) # Add MLP block @@ -869,6 +865,7 @@ def __init__( self.pos_embed = nn.Parameter(torch.zeros(1, 9, self.cf.ae_global_dim_embed)) dim_aux = self.cf.ae_global_dim_embed + target_readout_num_heads = next(self.cf.streams.values())["target_readout"]["num_heads"] for ith, dim in enumerate(self.dims_embed[:-1]): if self.cf.decoder_type == "PerceiverIO": # a single cross attention layer as per https://arxiv.org/pdf/2107.14795 @@ -877,7 +874,7 @@ def __init__( dim_q=dim, dim_kv=dim_aux, dim_aux=dim_aux, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, with_self_attn=False, with_adanorm=False, with_mlp=False, @@ -889,7 +886,7 @@ def __init__( SelfAttentionBlock( dim=dim, dim_aux=dim_aux, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, attention_kwargs=attention_kwargs, with_adanorm=True, dropout_rate=0.1, @@ -901,7 +898,7 @@ def __init__( dim_q=dim, dim_kv=self.cf.ae_global_dim_embed, dim_aux=dim_aux, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, with_self_attn=True, with_adanorm=False, with_mlp=True, @@ -915,7 +912,7 @@ def __init__( dim_q=dim, dim_kv=dim_aux, dim_aux=dim_aux, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, with_self_attn=True, with_adanorm=True, with_mlp=True, @@ -931,7 +928,7 @@ def __init__( dim_out=self.dims_embed[ith + 1], dim_kv=dim_aux, dim_aux=self.dim_coord_in, - num_heads=self.cf.streams[0]["target_readout"]["num_heads"], + num_heads=target_readout_num_heads, attention_kwargs=attention_kwargs, tr_dim_head_proj=tr_dim_head_proj, tr_mlp_hidden_factor=tr_mlp_hidden_factor, diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 244eac8f9..d8a30a722 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -11,6 +11,7 @@ import logging import math +import typing import warnings import astropy_healpix as hp @@ -37,11 +38,6 @@ TargetPredictionEngineClassic, ) from weathergen.model.layers import MLP, NamedLinear -from weathergen.model.positional_encoding import ( - build_spherical_rope_coeff_tensors, - get_rope_mode, - get_rope_spherical_band, -) from weathergen.model.utils import get_num_parameters from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype, is_stream_forcing @@ -114,8 +110,8 @@ def __init__(self, cf) -> None: self.pe_global = torch.nn.Parameter(pe, requires_grad=False) # RoPE coordinates - self.rope_mode = get_rope_mode(cf, logger) - if self.rope_mode != "none": + self.rope_2D = cf.get("rope_2D", False) + if self.rope_2D: self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens total_tokens = ( self.num_healpix_cells + self.num_extra_tokens @@ -137,31 +133,9 @@ def __init__(self, cf) -> None: dtype=self.dtype, ), ) - if self.rope_mode == "spherical": - rope_spherical_band = get_rope_spherical_band(cf) - num_modes = 2 * int(rope_spherical_band) + 1 - self.register_buffer( - "rope_spherical_coeffs", - torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), - ) - self.register_buffer( - "rope_spherical_cell_coeffs", - torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), - ) - self.register_buffer( - "rope_spherical_extra_coeffs", - torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), - ) - else: - self.rope_spherical_coeffs = None - self.rope_spherical_cell_coeffs = None - self.rope_spherical_extra_coeffs = None else: self.rope_coords = None self.rope_cell_coords = None - self.rope_spherical_coeffs = None - self.rope_spherical_cell_coeffs = None - self.rope_spherical_extra_coeffs = None # HEALPix neighbours hlc = self.healpix_level @@ -226,9 +200,12 @@ def reset_parameters(self, cf: Config) -> "ModelParams": dim_embed = cf.ae_global_dim_embed - if self.rope_mode != "none": + if self.rope_2D: + # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. + # Shape: (num_healpix_cells, ae_local_num_queries, 2) verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + # Per-cell coords for QueryAggregationEngine (no query expansion) self.rope_cell_coords.data.copy_(coords) coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) coords_flat = coords.flatten(0, 1).unsqueeze(0) @@ -236,36 +213,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.rope_coords.data.fill_(0.0) self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) - if self.rope_mode == "spherical": - band = int(get_rope_spherical_band(cf)) - ( - (cell_real, cell_imag), - (extra_real, extra_imag), - (packed_extra_real, packed_extra_imag), - (packed_real, packed_imag), - ) = build_spherical_rope_coeff_tensors( - nside=2**self.healpix_level, - band=band, - num_local_queries=cf.ae_local_num_queries, - num_extra_tokens=self.num_extra_tokens, - device=self.rope_spherical_coeffs.device, - dtype=self.rope_spherical_coeffs.dtype, - ) - self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) - self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) - self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) - self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) - - self.rope_spherical_coeffs.data.fill_(0.0) - self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) - self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) - self.rope_spherical_coeffs.data[ - :, offset : offset + packed_real.shape[1], :, 0 - ].copy_(packed_real) - self.rope_spherical_coeffs.data[ - :, offset : offset + packed_imag.shape[1], :, 1 - ].copy_(packed_imag) - # pe_global: always initialized. RoPE handles relative position in Q/K, but pe_global # provides per-cell token identity which is critical for masked cells that have no # content from local assimilation. Without it, masked cells are identical and the @@ -380,7 +327,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.forecast_engine: ForecastingEngine | IdentityEngine | None = None self.pred_heads = None self.q_cells: torch.Tensor | None = None - self.stream_names: list[str] = None + self.streams: dict[str, typing.Any] = cf.streams self.target_token_engines = None assert cf.get("forecast", {}).get("att_dense_rate", 1.0) == 1.0, ( @@ -443,11 +390,6 @@ def create(self) -> "Model": self.pred_heads = torch.nn.ModuleDict() # determine stream names once so downstream components use consistent keys - self.stream_names = [str(stream_cfg["name"]) for stream_cfg in cf.streams] - - for i_stream, _ in enumerate(cf.streams): - stream_name = self.stream_names[i_stream] - loss_terms = [ v.type for _, v in cf.training_config.losses.items() if v.get("enabled", True) ] @@ -457,9 +399,7 @@ def create(self) -> "Model": ] if "LossPhysical" in loss_terms: - for i_stream, si in enumerate(cf.streams): - stream_name = self.stream_names[i_stream] - + for i_stream, (stream_name, si) in enumerate(self.streams.items()): # skip decoder if channels are empty if is_stream_forcing(si): continue @@ -550,16 +490,14 @@ def create(self) -> "Model": ) # iterate again to setup shared spatial pred heads if specified in config - for i_stream, si in enumerate(cf.streams): - stream_name = self.stream_names[i_stream] - + for i_stream, (stream_name, si) in enumerate(self.streams.items()): # skip decoder if channels are empty if is_stream_forcing(si): continue pred_spatial_shared = si.get("pred_spatial_shared") if pred_spatial_shared is not None: - if pred_spatial_shared not in self.stream_names: + if pred_spatial_shared not in self.streams.keys(): msg = f"Stream {stream_name} has pred_spatial_shared={pred_spatial_shared}" msg += " but no stream with that name found." raise ValueError(msg) @@ -578,11 +516,8 @@ def create(self) -> "Model": pred_spatial_shared ] - idx_shared_s = [ - i for i, so in enumerate(cf.streams) if so["name"] == pred_spatial_shared - ] - assert (len(idx_shared_s)) == 1 - si_other = cf.streams[idx_shared_s[0]] + assert pred_spatial_shared in self.streams.keys() + si_other = self.streams[pred_spatial_shared] dims_embed = [ si_other["embed_target_coords"]["dim_embed"] for _ in range(num_layers + 1) ] @@ -657,9 +592,9 @@ def _reset_params(module): def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" - cf = self.cf num_params_embed = [ - get_num_parameters(self.encoder.embed_engine.embeds[name]) for name in self.stream_names + get_num_parameters(self.encoder.embed_engine.embeds[name]) + for name in self.streams.keys() ] num_params_total = get_num_parameters(self) num_params_ae_local = get_num_parameters(self.encoder.ae_local_engine.ae_local_blocks) @@ -682,17 +617,17 @@ def print_num_parameters(self) -> None: mdict = self.embed_target_coords num_params_embed_tcs = [ get_num_parameters(mdict[name]) if mdict and name in mdict else 0 - for name in self.stream_names + for name in self.streams.keys() ] mdict = self.target_token_engines num_params_tte = [ get_num_parameters(mdict[name]) if mdict and name in mdict else 0 - for name in self.stream_names + for name in self.streams.keys() ] mdict = self.pred_heads num_params_preds = [ get_num_parameters(mdict[name]) if mdict and name in mdict else 0 - for name in self.stream_names + for name in self.streams.keys() ] print("-----------------") @@ -701,7 +636,7 @@ def print_num_parameters(self) -> None: print(" Embedding networks:") [ print(" {} : {:,}".format(si["name"], np)) - for si, np in zip(cf.streams, num_params_embed, strict=False) + for si, np in zip(self.streams.values(), num_params_embed, strict=False) ] print(f" Local assimilation engine: {num_params_ae_local:,}") print(f" Local-global adapter: {num_params_ae_adapter:,}") @@ -712,16 +647,14 @@ def print_num_parameters(self) -> None: print(f" Forecast engine: {num_params_fe:,}") print(" coordinate embedding, prediction networks and prediction heads:") zps = zip( - cf.streams, + self.streams.keys(), num_params_embed_tcs, num_params_tte, num_params_preds, strict=False, ) - [ - print(" {} : {:,} / {:,} / {:,}".format(si["name"], np0, np1, np2)) - for si, np0, np1, np2 in zps - ] + for stream_name, np0, np1, np2 in zps: + print(f" {stream_name} : {np0:,} / {np1:,} / {np2:,}") print("-----------------") def tokens_to_latent_state(self, tokens_post_norm, tokens) -> LatentState: @@ -757,12 +690,6 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: # collapse along input step dimension tokens = tokens.reshape(shape).sum(axis=1) - rope_data = ( - model_params.rope_spherical_coeffs.unbind(dim=-1) - if model_params.rope_spherical_coeffs is not None - else model_params.rope_coords - ) - # Allow for pushforward trick p_fwd = self.cf.training_config.get("forecast", {}).get("pushforward", False) # roll-out in latent space, iterate and generate output over requested output steps @@ -770,10 +697,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - tokens = self.forecast_engine(tokens, step, coords=rope_data) + tokens = self.forecast_engine(tokens, step, model_params.rope_coords) continue - tokens = self.forecast_engine(tokens, step, coords=rope_data) + tokens = self.forecast_engine(tokens, step, model_params.rope_coords) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) @@ -847,7 +774,7 @@ def predict_decoders( tokens_nbors_lens[0] = 0 # pair with tokens from assimilation engine to obtain target tokens - for stream_name in self.stream_names: + for stream_name in self.streams.keys(): # extract target coords for current stream and fstep and convert to one tensor t_coords = [ batch.samples[i_b].streams_data[stream_name].target_coords[step] diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index ad1c54bee..411942a9f 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -7,18 +7,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging import math -from functools import lru_cache -import healpy as hp import numpy as np -import numpy.typing as npt import torch -# Suppress verbose healpy transform messages during spherical RoPE coefficient precomputation. -logging.getLogger("healpy").setLevel(logging.WARNING) - #################################################################################################### def positional_encoding_harmonic(x): @@ -180,202 +173,3 @@ def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) - -# Spherical RoPE -def _max_supported_spherical_band(dim_embed: int, num_heads: int) -> int: - head_dim = dim_embed // num_heads - max_complex = (head_dim - (head_dim % 2)) // 2 - return max(0, (max_complex - 1) // 2) - - -def get_rope_mode(cf, logger=None) -> str: - """Resolve RoPE mode, including temporary backwards compatibility for rope_2D.""" - - rope_mode = cf.get("rope_mode", "none") or "none" - rope_2d = cf.get("rope_2D", None) - if rope_2d is not None: - if logger is not None: - logger.warning( - "Config key 'rope_2D' is deprecated and will be removed. Use 'rope_mode' " - "with one of: none, 2d, spherical." - ) - if rope_mode == "none": - rope_mode = "2d" if rope_2d else "none" - return rope_mode - - -def get_rope_spherical_band(cf) -> int: - """Resolve spherical band index, supporting explicit config or automatic selection.""" - - rope_spherical_band = cf.get("rope_spherical_band", None) - if rope_spherical_band is not None: - return int(rope_spherical_band) - - candidates = [ - _max_supported_spherical_band(cf.ae_global_dim_embed, cf.ae_aggregation_num_heads), - _max_supported_spherical_band(cf.ae_global_dim_embed, cf.ae_global_num_heads), - ] - if cf.get("fe_num_blocks", 0) > 0: - candidates.append(_max_supported_spherical_band(cf.ae_global_dim_embed, cf.fe_num_heads)) - return min(candidates) - - -def apply_rope(qs, ks, coords, rope_mode, unsqueeze_dim): - rope_mode = rope_mode or "none" - if rope_mode == "none": - return qs, ks - if coords is None: - raise ValueError(f"coords must be provided when rope_mode={rope_mode}") - if rope_mode == "2d": - return rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=unsqueeze_dim) - if rope_mode == "spherical": - return rotary_pos_emb_spherical(qs, ks, coords, unsqueeze_dim=unsqueeze_dim) - raise ValueError(f"Unsupported rope_mode={rope_mode}") - - -def rotary_pos_emb_spherical( - q: torch.Tensor, - k: torch.Tensor, - coeffs: tuple[torch.Tensor, torch.Tensor], - unsqueeze_dim: int = 1, -): - """Apply spherical-harmonic RoPE-style modulation to q/k using precomputed coefficients. - - Both q and k are multiplied by Y_lm(omega) at their respective positions. Under the real-pair - representation of complex modes, the attention dot product is equivalent to - Re[sum_m Y_lm(omega_r) Y_lm*(omega_s) q_m k_m*]. - """ - - coeff_real, coeff_imag = coeffs - return ( - _apply_complex_modulation(q, coeff_real, coeff_imag, unsqueeze_dim), - _apply_complex_modulation(k, coeff_real, coeff_imag, unsqueeze_dim), - ) - - -def _apply_complex_modulation( - x: torch.Tensor, - coeff_real: torch.Tensor, - coeff_imag: torch.Tensor, - unsqueeze_dim: int, -) -> torch.Tensor: - coeff_real = coeff_real.unsqueeze(unsqueeze_dim).to(dtype=x.dtype) - coeff_imag = coeff_imag.unsqueeze(unsqueeze_dim).to(dtype=x.dtype) - num_complex = coeff_real.shape[-1] - max_complex = (x.shape[-1] - (x.shape[-1] % 2)) // 2 - if num_complex > max_complex: - raise ValueError( - f"Spherical RoPE requires {num_complex} complex modes but the head only supports " - f"{max_complex}. Reduce rope_spherical_band or increase the head dimension." - ) - num_rotary_dims = 2 * num_complex - if num_rotary_dims == 0: - return x - - x_rot = x[..., :num_rotary_dims].reshape(*x.shape[:-1], num_complex, 2) - x_real = x_rot[..., 0] - x_imag = x_rot[..., 1] - out_real = (x_real * coeff_real) - (x_imag * coeff_imag) - out_imag = (x_real * coeff_imag) + (x_imag * coeff_real) - out = torch.stack((out_real, out_imag), dim=-1).flatten(-2, -1) - if num_rotary_dims < x.shape[-1]: - out = torch.cat((out, x[..., num_rotary_dims:]), dim=-1) - return out - - -def build_spherical_rope_coeff_tensors( - nside: int, - band: int, - num_local_queries: int, - num_extra_tokens: int, - device=None, - dtype=torch.float32, -) -> tuple[ - tuple[torch.Tensor, torch.Tensor], - tuple[torch.Tensor, torch.Tensor], - tuple[torch.Tensor, torch.Tensor], - tuple[torch.Tensor, torch.Tensor], -]: - """Build spherical-RoPE coefficient tensors for cell-level, extra tokens, and packed tokens.""" - - real_maps, imag_maps = _healpy_band_maps(nside, band) - cell_real = torch.as_tensor(real_maps, device=device, dtype=dtype) - cell_imag = torch.as_tensor(imag_maps, device=device, dtype=dtype) - - extra_real = torch.ones( - num_extra_tokens, cell_real.shape[-1], device=cell_real.device, dtype=cell_real.dtype - ) - extra_imag = torch.zeros_like(extra_real) - packed_extra_real = ( - extra_real.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) - ) - packed_extra_imag = ( - extra_imag.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) - ) - - packed_real = cell_real.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) - packed_imag = cell_imag.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) - - return ( - (cell_real, cell_imag), - (extra_real, extra_imag), - (packed_extra_real, packed_extra_imag), - (packed_real, packed_imag), - ) - - - -@lru_cache(maxsize=32) -def _healpy_band_maps( - nside: int, band: int -) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: - """Precompute one spherical-harmonic band on the HEALPix grid using healpy. - - The returned columns store the complex coefficients Y_lm(omega) for fixed l=band and - m=-l,...,+l. These are the position factors used in spherical RoPE: - - q_m^omega = Y_lm(omega) q_m, k_m^omega = Y_lm(omega) k_m. - - The following attention dot product then implicitly forms - Y_lm(omega_r) Y_lm*(omega_s), matching the spherical harmonics addition-theorem - structure. - """ - - num_pixels = hp.nside2npix(nside) - real_maps = np.zeros((num_pixels, 2 * band + 1), dtype=np.float64) - imag_maps = np.zeros((num_pixels, 2 * band + 1), dtype=np.float64) - alm_size = hp.sphtfunc.Alm.getsize(band, band) - - for m in range(0, band + 1): - # healpy stores alm only for m >= 0 and alm2map reconstructs a real field. Setting - # a_lm=1 gives 2 Re[Y_lm] for m>0, while a_lm=i gives -2 Im[Y_lm]. We combine these - # two real maps below to recover the complex coefficient Y_lm itself. - alm_real = np.zeros(alm_size, dtype=np.complex128) - alm_real[hp.sphtfunc.Alm.getidx(band, band, m)] = 1.0 - real_map = hp.alm2map(alm_real, nside=nside, lmax=band, mmax=band, pol=False) - real_map = hp.reorder(real_map, r2n=True) - - if m == 0: - # Y_l0 is real, and healpy returns it directly because there is no -m counterpart - # to merge into the real map. - real_maps[:, band] = real_map - continue - - alm_imag = np.zeros(alm_size, dtype=np.complex128) - alm_imag[hp.sphtfunc.Alm.getidx(band, band, m)] = 1.0j - imag_map = hp.alm2map(alm_imag, nside=nside, lmax=band, mmax=band, pol=False) - imag_map = hp.reorder(imag_map, r2n=True) - - pos_idx = band + m - neg_idx = band - m - sign = -1.0 if m % 2 else 1.0 - - # Columns are ordered as m=-l,...,+l, hence band+m for +m and band-m for -m. - # The negative-order mode follows the standard convention - # Y_l,-m = (-1)^m Y_lm*. - real_maps[:, pos_idx] = real_map / 2.0 - imag_maps[:, pos_idx] = -imag_map / 2.0 - real_maps[:, neg_idx] = sign * real_map / 2.0 - imag_maps[:, neg_idx] = sign * imag_map / 2.0 - - return real_maps, imag_maps diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 8f878f900..6f63468c3 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -210,8 +210,7 @@ def compute_loss(self, preds: dict, targets: dict, metadata) -> LossValues: source2target_idxs, output_info, target2source_idxs, target_info = metadata # TODO: iterate over batch dimension - for stream_info in self.cf.streams: - stream_name = stream_info["name"] + for stream_name, stream_info in self.cf.streams.items(): # TODO: avoid this target_channels = ( stream_info.val_target_channels @@ -366,6 +365,7 @@ def _nested_dict(): for ch_n, output_step_dict in ch_dict.items(): if ch_n != "avg": for _, v in output_step_dict.items(): + v = 0.0 if type(v) is float and np.isnan(v) else v reordered_losses[stream_name][loss_fct_name]["avg"] += v count += 1 reordered_losses[stream_name][loss_fct_name]["avg"] /= count diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f8cb9cafe..47bc74214 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -816,8 +816,9 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): for key, value in losses_all.items(): if key.endswith("avg"): + val = np.nan if np.isnan(value).all() else f"{np.nanmean(value):0.4E}" logger.info( - f"{key} : {np.nanmean(value):0.4E} \t", + f"{key} : {val} \t", ) logger.info("\n") diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index b4a5f1279..15ec21167 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -27,15 +27,16 @@ DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") MAX_FILENAME_LEN = 255 -LEGEND_FONT_SIZE = "x-small" _LEGEND_MAX_LABEL_LEN = 80 PLOT_DPI_VALUE = 150 def _add_legend( labels, + outside: bool, + font_size: str, + num_columns: int, ax=None, - legend_outside: bool = False, loc=None, bbox_to_anchor=None, **kwargs, @@ -49,25 +50,24 @@ def _add_legend( if ax is None: ax = plt.gca() + # avoid excessively long labels truncated = [ la if len(la) <= _LEGEND_MAX_LABEL_LEN else la[: _LEGEND_MAX_LABEL_LEN - 1] + "\u2026" for la in labels ] - n = len(truncated) - ncol = 1 if n <= 3 else (2 if n <= 8 else 3) if loc is None: - loc = "upper center" if legend_outside else "best" - if bbox_to_anchor is None and legend_outside: + loc = "upper center" if outside else "best" + if bbox_to_anchor is None and outside: bbox_to_anchor = (0.5, -0.13) legend_kwargs = { "loc": loc, - "ncol": ncol, - "fontsize": LEGEND_FONT_SIZE, + "ncol": num_columns, + "fontsize": font_size, "framealpha": 0.9, "edgecolor": "0.8", - "borderaxespad": 0.0, + "borderaxespad": 0.2, **kwargs, } if bbox_to_anchor is not None: @@ -222,7 +222,10 @@ def get_stream_names(run_id: str, model_path: Path | None = "./model"): """ # return col names from training (should be identical to validation) cf = config.load_run_config(run_id, None, model_path=model_path) - return [si["name"].replace(",", "").replace("/", "_").replace(" ", "_") for si in cf.streams] + return [ + stream_name.replace(",", "").replace("/", "_").replace(" ", "_") + for stream_name in cf.streams.keys() + ] #################################################################################################### @@ -233,6 +236,8 @@ def plot_lr( plot_dir: Path, x_axis: str = "samples", legend_outside: bool = False, + legend_font_size: str = "x-small", + legend_num_columns: int = 3, ): """ Plot learning rate curves of training runs. @@ -286,7 +291,12 @@ def plot_lr( plt.ylabel("lr") plt.xlabel(x_axis) plt.tight_layout() - _add_legend(legend_str, legend_outside=legend_outside) + _add_legend( + legend_str, + outside=legend_outside, + font_size=legend_font_size, + num_columns=legend_num_columns, + ) rstr = "".join([f"{r}_" for r in runs_ids]) if len(rstr) + 6 > MAX_FILENAME_LEN: @@ -307,6 +317,8 @@ def plot_loss_avg( stage=TRAIN, x_scale_log=False, legend_outside: bool = False, + legend_font_size: str = "x-small", + legend_num_columns: int = 3, ): prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] + ["r", "g", "b", "k", "y", "m"] @@ -341,7 +353,12 @@ def plot_loss_avg( plt.ylabel("loss") plt.xlabel("step") plt.tight_layout() - _add_legend(legend_str, legend_outside=legend_outside) + _add_legend( + legend_str, + outside=legend_outside, + font_size=legend_font_size, + num_columns=legend_num_columns, + ) rstr = "".join([f"{r}_" for r in runs_ids]) if len(rstr) + len(f"{str(stage)}_avg.png") > MAX_FILENAME_LEN: @@ -370,6 +387,8 @@ def plot_loss_per_stream( y_lim: list[float] | None = None, x_scale_log: bool = False, legend_outside: bool = False, + legend_font_size: str = "x-small", + legend_num_columns: int = 3, ): """ Plot each stream in stream_names (using matching to data columns) for all run_ids @@ -411,6 +430,7 @@ def plot_loss_per_stream( legend_strs = [] min_val = np.finfo(np.float32).max max_val = 0.0 + title_col = None for mode in modes: legend_strs += [[]] linestyle = "-" if mode == "train" else ("--x" if len(modes) > 1 else "-x") @@ -433,6 +453,7 @@ def plot_loss_per_stream( if len(col_split) < 4: if stream_name in col: data_cols += [col] + title_col = col if title_col is None else title_col elif len(col_split) == 4: if ( col_split[1].lower() == stream_name.lower() @@ -440,6 +461,7 @@ def plot_loss_per_stream( and col_split[3] == channel ): data_cols += [col] + title_col = col if title_col is None else title_col elif len(col_split) == 5: if ( col_split[1].lower() == stream_name.lower() @@ -448,6 +470,7 @@ def plot_loss_per_stream( and int(col_split[4]) in forecast_steps ): data_cols += [col] + title_col = col if title_col is None else title_col for col in data_cols: x_vals = np.array(run_data_mode[x_col]) @@ -467,8 +490,6 @@ def plot_loss_per_stream( + run_data.run_id + " : " + runs_ids[run_data.run_id][1] - + ": " - + col ] # skip all-nan slices @@ -501,11 +522,19 @@ def plot_loss_per_stream( if x_lim is not None: plt.xlim(x_lim) - plt.title(stream_name + ": " + channel + " (" + ", ".join(modes) + ")") + # if len(title_col) == 0 : + # import code; code.interact( local=locals()) + title_loss = ".".join(title_col.split(".")[:-1]) + plt.title(title_loss + " (" + ", ".join(modes) + ")") plt.ylabel(err) plt.xlabel(x_axis if x_type == "step" else "rel. time [h]") plt.tight_layout() - _add_legend(legend_str, legend_outside=legend_outside) + _add_legend( + legend_str, + outside=legend_outside, + font_size=legend_font_size, + num_columns=legend_num_columns, + ) # construct file name run_ids_str = "".join([f"{r}_" for r in runs_ids]) @@ -544,6 +573,8 @@ def plot_loss_per_run( x_axis: str = "samples", x_scale_log: bool = False, legend_outside: bool = False, + legend_font_size: str = "x-small", + legend_num_columns: int = 3, ): """ Plot all stream_names (using matching to data columns) for given run_id @@ -594,17 +625,16 @@ def plot_loss_per_run( x_col = [c for _, c in enumerate(run_data_mode.columns) if x_axis in c][0] # find the cols of the requested metric (e.g. mse) for all streams - data_cols = [c for _, c in enumerate(run_data_mode.columns) if err in c] data_cols = [] for col in run_data_mode.columns: col_split = col.split(".") - if len(col_split) < 4: - continue - if col_split[2].lower() == err.lower() and col_split[3] == channels: + if ( + len(col_split) >= 4 + and col_split[2].lower() == err.lower() + and col_split[3] in channels + ): data_cols += [col] - data_cols = list(data_cols) - for _, col in enumerate(data_cols): for j, stream_name in enumerate(stream_names): if stream_name.lower() in col.lower(): @@ -638,7 +668,12 @@ def plot_loss_per_run( plt.ylabel("loss") plt.xlabel("samples") plt.tight_layout() - _add_legend(legend_str, legend_outside=legend_outside) + _add_legend( + legend_str, + outside=legend_outside, + font_size=legend_font_size, + num_columns=legend_num_columns, + ) sstr = "".join( [f"{r}_".replace(",", "").replace("/", "_").replace(" ", "_") for r in legend_str] @@ -775,6 +810,27 @@ def plot_train(args=None): action="store_true", help="Use log scale for the x-axis (produces log-log plots)", ) + parser.add_argument( + "--legend-font-size", + dest="legend_font_size", + default="x-small", + type=str, + help="Font size for the legend", + ) + parser.add_argument( + "--legend-num-columns", + dest="legend_num_columns", + default=3, + type=int, + help="Number of columns for the legend", + ) + parser.add_argument( + "--with-losses-per-run", + dest="with_losses_per_run", + default=False, + action="store_true", + help="Plot losses per run across channels and streams", + ) run_id_group = parser.add_mutually_exclusive_group() run_id_group.add_argument( @@ -833,15 +889,13 @@ def plot_train(args=None): from_run_id=run_id, mini_epoch=None, ) - for stream_info in cf.streams: - streams += [stream_info["name"]] + streams += list(cf.streams.keys()) # ensure items are unique streams = list(set(streams)) # remove "all" key that is a special flag and not an actual stream name streams.remove("all") # read logged data - runs_data = [ TrainLogger.read(run_id, model_path=model_base_dir, cols_patterns=streams) for run_id in runs_ids @@ -857,7 +911,15 @@ def plot_train(args=None): x_scale_log = args.log_x # plot learning rate - plot_lr(runs_ids, runs_data, runs_active, plot_dir=out_dir, legend_outside=args.legend_outside) + plot_lr( + runs_ids, + runs_data, + runs_active, + plot_dir=out_dir, + legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, + ) # plot average loss plot_loss_avg( @@ -867,6 +929,8 @@ def plot_train(args=None): runs_active, stage=TRAIN, legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, ) # compare different runs @@ -884,6 +948,8 @@ def plot_train(args=None): x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, plot_dir=out_dir, ) plot_loss_per_stream( @@ -900,6 +966,8 @@ def plot_train(args=None): x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, plot_dir=out_dir, ) plot_loss_per_stream( @@ -916,13 +984,28 @@ def plot_train(args=None): x_lim=args.per_stream_x_lim, y_lim=args.per_stream_y_lim, legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, plot_dir=out_dir, ) # plot all cols for all run_ids - for run_id, run_data in zip(runs_ids, runs_data, strict=False): + if args.with_losses_per_run: + for run_id, run_data in zip(runs_ids, runs_data, strict=False): + plot_loss_per_run( + ["train", "val"], + run_id, + runs_ids[run_id], + run_data, + get_stream_names(run_id, model_path=model_base_dir), # limit to available streams + channels=args.channels, + plot_dir=out_dir, + legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, + ) plot_loss_per_run( - ["train", "val"], + ["val"], run_id, runs_ids[run_id], run_data, @@ -930,17 +1013,9 @@ def plot_train(args=None): channels=args.channels, plot_dir=out_dir, legend_outside=args.legend_outside, + legend_font_size=args.legend_font_size, + legend_num_columns=args.legend_num_columns, ) - plot_loss_per_run( - ["val"], - run_id, - runs_ids[run_id], - run_data, - get_stream_names(run_id, model_path=model_base_dir), # limit to available streams - channels=args.channels, - plot_dir=out_dir, - legend_outside=args.legend_outside, - ) if __name__ == "__main__": diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 5f1550e42..53e4e551f 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -12,7 +12,6 @@ import logging import math import time -import traceback from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -109,15 +108,18 @@ def add_logs( metrics: dict[str, float] = dict(num_samples=samples) if stage == "train": - metrics["loss_avg_mean"] = np.nanmean(avg_loss) + val = np.nan if np.isnan(avg_loss).all() else np.nanmean(avg_loss) + metrics["loss_avg_mean"] = val metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) for key, value in losses_all.items(): - metrics[key] = np.nanmean(value) + val = np.nan if np.isnan(value).all() else np.nanmean(value) + metrics[key] = val for key, value in stddev_all.items(): - metrics[key] = np.nanmean(value) + val = np.nan if np.isnan(value).all() else np.nanmean(value) + metrics[key] = val self.log_metrics(stage, metrics) @@ -143,100 +145,20 @@ def read( run_id = cf.general.run_id result_dir_base = config.get_path_run(cf) - result_dir = result_dir_base / run_id - fname_log_train = result_dir / f"{run_id}_train_log.txt" - fname_log_val = result_dir / f"{run_id}_val_log.txt" - - # training # define cols for training - cols_train = ["dtime", "samples", "mse", "lr"] cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] cols1_patterns = ["loss_avg"] + cols_patterns - # read training log data - try: - with open(fname_log_train, "rb") as f: - log_train = np.loadtxt(f, delimiter=",") - log_train = log_train.reshape((log_train.shape[0] // len(cols_train), len(cols_train))) - except ( - TypeError, - AttributeError, - IndexError, - ZeroDivisionError, - ValueError, - ) as e: - _logger.warning( - ( - f"Warning: no training data loaded for run_id={run_id}", - "Data loading or reshaping failed — " - "possible format, dimension, or logic issue.", - f"Due to specific error: {e}", - ) - ) - except (FileNotFoundError, PermissionError, OSError) as e: - _logger.error( - ( - f"Error: no training data loaded for run_id={run_id}", - "File system error occurred while handling the log file.", - f"Due to specific error: {e}", - ) - ) - except Exception: - _logger.error( - ( - f"Error: no training data loaded for run_id={run_id}", - f"Due to exception with trace:\n{traceback.format_exc()}", - ) - ) - log_train = np.array([]) - - log_train_df = read_metrics(cf, run_id, "train", cols1, cols1_patterns, result_dir_base) + metrics_train = read_metrics(cf, run_id, "train", cols1, cols1_patterns, result_dir_base) # define cols for validation - cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] cols2_patterns = ["loss_avg"] + cols_patterns - # read validation log data - try: - with open(fname_log_val, "rb") as f: - log_val = np.loadtxt(f, delimiter=",") - log_val = log_val.reshape((log_val.shape[0] // len(cols_val), len(cols_val))) - except ( - TypeError, - AttributeError, - IndexError, - ZeroDivisionError, - ValueError, - ) as e: - _logger.warning( - ( - f"Warning: no validation data loaded for run_id={run_id}", - "Data loading or reshaping failed — " - "possible format, dimension, or logic issue.", - f"Due to specific error: {e}", - ) - ) - except (FileNotFoundError, PermissionError, OSError) as e: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - "File system error occurred while handling the log file.", - f"Due to specific error: {e}", - ) - ) - except Exception: - _logger.error( - ( - f"Error: no validation data loaded for run_id={run_id}", - f"Due to exception with trace:\n{traceback.format_exc()}", - ) - ) - log_val = np.array([]) - metrics_val_df = read_metrics(cf, run_id, "val", cols2, cols2_patterns, result_dir_base) + metrics_val = read_metrics(cf, run_id, "val", cols2, cols2_patterns, result_dir_base) - return Metrics(run_id, "train", log_train_df, metrics_val_df, None) + return Metrics(run_id, "train", metrics_train, metrics_val, None) def read_metrics( diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index d21938dd5..b54bf5c73 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -51,19 +51,16 @@ def write_output( targets_coords_all += [[]] targets_times_all += [[]] targets_lens += [[]] - for stream_info in cf.streams: - sname = stream_info["name"] - + for sname in cf.streams.keys(): # handle spoof data: do not write since it might corrupt validation (spoofing invisible # there) if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: - preds = model_output.get_physical_prediction(t_idx, sname) - preds_shape = preds[0].shape + targets = target_aux_out.physical[t_idx][sname]["target"] # for-loop to make sure we have a consistent number of samples - preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds] - targets_s = [np.zeros((0, preds_shape[2])) for _ in preds] - t_coords_s = [np.zeros((0, 2)) for _ in preds] - t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds] + preds_s = [np.zeros((1, 0, t.shape[1])) for t in targets] + targets_s = [np.zeros((0, t.shape[1])) for t in targets] + t_coords_s = [np.zeros((0, 2)) for t in targets] + t_times_s = [np.array([]).astype("datetime64[ns]") for t in targets] else: preds = model_output.get_physical_prediction(t_idx, sname) @@ -125,7 +122,8 @@ def write_output( # more prep work # output stream names to be written, use specified ones or all if nothing specified - stream_names = [stream.name for stream in cf.streams] + stream_names = list(cf.streams.keys()) + stream_infos = list(cf.streams.values()) if val_cfg.get("output").get("streams") is not None: output_stream_names = val_cfg.output.streams else: @@ -134,10 +132,10 @@ def write_output( output_streams = {name: stream_names.index(name) for name in output_stream_names} _logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}") - target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams] - source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams] + target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in stream_infos] + source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in stream_infos] - geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels + geoinfo_channels = [[] for _ in stream_infos] # TODO obtain channels # calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * batch_size diff --git a/tests/test_config.py b/tests/test_config.py index e04390341..74986a22b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,7 +9,7 @@ TEST_RUN_ID = "test123" SECRET_COMPONENT = "53CR3T" DUMMY_PRIVATE_CONF = { - "data_paths": ["/path/to/anmoi/data", "/path/to/observation/data"] + "data_paths": ["/path/to/anmoi/data", "/path/to/observation/data"], "secrets": { "my_big_secret": { "my_secret_id": f"{SECRET_COMPONENT}01234", @@ -276,12 +276,12 @@ def test_print_cf_no_secrets(config_fresh): @pytest.mark.parametrize("rel_path,cf", VALID_STREAMS) def test_load_streams(streams_dir, rel_path, cf): - expected = get_expected_config(*[*cf.items()][0]) + expected = get_expected_config(*next(cf.items())) write_stream_file(streams_dir / rel_path, OmegaConf.to_yaml(cf)) streams = config.load_streams(streams_dir) - assert all(is_equal(stream, expected) for stream in streams) + assert all(is_equal(stream, expected) for stream in streams.values()) @pytest.mark.parametrize("rel_path,cf", EXCLUDED_STREAMS) @@ -290,14 +290,14 @@ def test_load_streams_exclude_files(streams_dir, rel_path, cf): streams = config.load_streams(streams_dir) - assert streams == [] + assert streams == {} def test_load_empty_stream(streams_dir): write_stream_file(streams_dir / "empty.yml", "") streams = config.load_streams(streams_dir) - assert streams == [] + assert streams == {} def test_load_malformed_stream(streams_dir): @@ -323,9 +323,7 @@ def test_load_multiple_streams_content(streams_dir, rel_path, cf): streams = config.load_streams(streams_dir) - assert all( - is_equal(stream, stream_e) for stream, stream_e in zip(streams, expected, strict=True) - ) + assert all(is_equal(stream, expected) for stream, expected in zip(streams.values(), expected)) def test_load_duplicate_streams(streams_dir):