From 544a5927516bd31612c156585dfa8008b935c516 Mon Sep 17 00:00:00 2001 From: Tim Hunter Date: Wed, 17 Jun 2026 12:09:22 +0200 Subject: [PATCH] adding toy datasets --- config/streams/era5_toy/era5.yml | 39 ++ config/toy_era5.yml | 205 ++++++++++ config/toy_era5_private.yml | 2 + .../common/src/weathergen/common/config.py | 2 + scripts/era5_o96_2020_1pct.yaml | 22 + src/weathergen/model/attention.py | 385 ++++++++++++++---- src/weathergen/run_train.py | 25 +- src/weathergen/train/trainer.py | 2 +- 8 files changed, 586 insertions(+), 96 deletions(-) create mode 100644 config/streams/era5_toy/era5.yml create mode 100644 config/toy_era5.yml create mode 100644 config/toy_era5_private.yml create mode 100644 scripts/era5_o96_2020_1pct.yaml diff --git a/config/streams/era5_toy/era5.yml b/config/streams/era5_toy/era5.yml new file mode 100644 index 000000000..67566b84c --- /dev/null +++ b/config/streams/era5_toy/era5.yml @@ -0,0 +1,39 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5_toy: + type: anemoi + stream_id: 0 + filenames: ["era5-o96-2020-1pct-6h-v1.zarr"] + loss_weight: 1. + source_exclude: ["w_", "skt", "sp", "tcw", "cp", "tp"] + target_exclude: ["w_", "skt", "sp", "tcw", "cp", "tp"] + source: ["t_850", "z_850"] + target: ["t_850"] + diagnostic: False + masking_rate: 0.6 + masking_rate_none: 0.05 + token_size: 32 + tokenize_spacetime: True + max_num_targets: -1 + embed: + net: transformer + num_tokens: 1 + num_heads: 2 + dim_embed: 16 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 16 + target_readout: + num_layers: 2 + num_heads: 2 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/toy_era5.yml b/config/toy_era5.yml new file mode 100644 index 000000000..ebbd5f7cf --- /dev/null +++ b/config/toy_era5.yml @@ -0,0 +1,205 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 512 #1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 512 #1024 #2048 +ae_global_num_blocks: 2 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 1 +num_register_tokens: 7 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 2 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 4 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" + +norm_type: "LayerNorm" + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +################ + +streams_directory: "./config/streams/era5_toy" +model_path: "./models" +results_path: "./results" + +general: + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading: + num_workers: 2 + rng_seed: ??? + repeat_data_in_mini_epoch: False + + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. + # If this happens, you can disable the flag, but performance will drop on GH200. + memory_pinning: True + +# config for training +training_config: + training_mode: ["masking"] + + num_mini_epochs: 1 + samples_per_mini_epoch: 64 + shuffle: True + + start_date: 2020-01-01T00:00 + end_date: 2020-01-02T06:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction: 1 + + learning_rate_scheduling: + lr_start: 1e-6 + lr_max: 0.00005 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 4 + num_steps_cooldown: 4 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw: + # parameters are scaled by number of DDP workers + beta1: 0.975 + beta2: 0.9875 + eps: 2e-08 + + losses: { "physical": { type: LossPhysical, loss_fcts: { "mse": {} } } } + + model_input: { "forecasting": { masking_strategy: "forecast" } } + + forecast: + time_step: 06:00:00 + num_steps: 1 + policy: "fixed" + offset: 1 + +# validation config; full validation config is merge of training and validation config +validation_config: + samples_per_mini_epoch: 32 + shuffle: False + + start_date: 2020-01-01T00:00 + end_date: 2020-01-02T06:00 + + output: + streams: ["ERA5_toy"] + + validate_with_ema: + enabled: True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + +test_config: + start_date: 2021-10-10T00:00 + end_date: 2022-10-11T00:00 + output: + num_samples: 10 + +# TODO: read latent from here +inference_config: + output: + streams: ["ERA5_toy"] diff --git a/config/toy_era5_private.yml b/config/toy_era5_private.yml new file mode 100644 index 000000000..9cf7e82c7 --- /dev/null +++ b/config/toy_era5_private.yml @@ -0,0 +1,2 @@ +path_shared_working_dir: "." +data_path_anemoi: "./datasets/" diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 7737504cf..a7e1975e2 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -551,6 +551,7 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: elif "WEATHERGEN_PRIVATE_CONF" in os.environ: private_home = Path(os.environ["WEATHERGEN_PRIVATE_CONF"]) + print(f"Loading private config from WEATHERGEN_PRIVATE_CONF:{private_home}.") _logger.info(f"Loading private config from WEATHERGEN_PRIVATE_CONF:{private_home}.") elif env_script_path.is_file(): @@ -587,6 +588,7 @@ def _load_private_conf(private_home: Path | None = None) -> DictConfig: "WEATHERGEN_PRIVATE_CONF or provide a path." ) private_cf = OmegaConf.load(private_home) + print(f"private_cf: {private_cf}") if "secrets" in private_cf: del private_cf["secrets"] diff --git a/scripts/era5_o96_2020_1pct.yaml b/scripts/era5_o96_2020_1pct.yaml new file mode 100644 index 000000000..b830de0e3 --- /dev/null +++ b/scripts/era5_o96_2020_1pct.yaml @@ -0,0 +1,22 @@ +description: | + Approximately 1% time slice of ECMWF Anemoi ERA5 O96 dataset. + +name: era5-o96-2020-1pct-6h-v1 +licence: CC-BY-4.0 +attribution: ECMWF + +dates: + start: "2020-01-01T00:00:00" + end: "2020-06-02T06:00:00" + frequency: 6h + +input: + anemoi-dataset: + dataset: "https://data.ecmwf.int/anemoi-datasets/era5-o96-1979-2023-6h-v8.zarr" + +output: + layout: gridded + dtype: float32 + +build: + group_by: 10 diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index bf97479e6..964b71b27 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -10,9 +10,20 @@ from functools import partial import torch -from flash_attn import flash_attn_func, flash_attn_varlen_func from torch.nn.attention.flex_attention import create_block_mask, flex_attention +try: + from flash_attn import ( # pyright: ignore[reportMissingImports] + flash_attn_func, + flash_attn_varlen_func, + ) +except ImportError: + flash_attn_func = None + flash_attn_varlen_func = None + FLASH_ATTN_AVAILABLE = False +else: + FLASH_ATTN_AVAILABLE = True + from weathergen.model.norms import AdaLayerNorm, RMSNorm from weathergen.model.positional_encoding import rotary_pos_emb_2d @@ -24,6 +35,135 @@ """ +def _maybe_to_flash_dtype(tensor, attention_dtype): + return tensor.to(attention_dtype) if FLASH_ATTN_AVAILABLE else tensor + + +def _match_attention_dtypes(qs, ks, vs): + if qs.dtype == ks.dtype == vs.dtype: + return qs, ks, vs + + dtype = torch.promote_types(torch.promote_types(qs.dtype, ks.dtype), vs.dtype) + return qs.to(dtype), ks.to(dtype), vs.to(dtype) + + +def _attention_output(output): + return output[0] if isinstance(output, tuple) else output + + +def _softcap_score_mod(softcap): + if softcap <= 0.0: + return None + + def score_mod(score, batch, head, q_idx, kv_idx): + return softcap * torch.tanh(score / softcap) + + return score_mod + + +def _normalise_varlen_lens(lens, total_tokens, name): + if lens is None: + raise ValueError(f"{name} must be provided for variable-length attention") + + lens = lens.to(dtype=torch.long) + if lens.numel() > 0 and int(lens[0].detach().cpu().item()) == 0: + lens_without_pad = lens[1:] + if int(lens_without_pad.sum().detach().cpu().item()) == total_tokens: + return lens_without_pad + + if int(lens.sum().detach().cpu().item()) != total_tokens: + raise ValueError( + f"{name} sums to {int(lens.sum().detach().cpu().item())}, " + f"but expected {total_tokens} tokens" + ) + return lens + + +def _pad_packed_sequence(x, lens): + lens_list = [int(length) for length in lens.detach().cpu().tolist()] + if len(lens_list) == 0: + return x.new_empty((0, x.shape[1], 0, x.shape[2])), lens_list + + max_len = max(lens_list) + if max_len == 0: + return x.new_empty((len(lens_list), x.shape[1], 0, x.shape[2])), lens_list + + chunks = list(torch.split(x, lens_list, dim=0)) + return torch.nn.utils.rnn.pad_sequence(chunks, batch_first=True).transpose(1, 2), lens_list + + +def _unpad_packed_sequence(x, lens_list): + if len(lens_list) == 0: + return x.new_empty((0, x.shape[1], x.shape[-1])) + + return torch.cat([x[i, :, :length].transpose(0, 1) for i, length in enumerate(lens_list)]) + + +def _dense_attention(qs, ks, vs, dropout_rate=0.0, softcap=0.0): + qs, ks, vs = _match_attention_dtypes(qs, ks, vs) + score_mod = _softcap_score_mod(softcap) + if score_mod is not None: + return _attention_output(flex_attention(qs, ks, vs, score_mod=score_mod)) + + return torch.nn.functional.scaled_dot_product_attention( + qs, + ks, + vs, + dropout_p=dropout_rate, + ) + + +def _varlen_attention(qs, ks, vs, q_lens, kv_lens, dropout_rate=0.0, softcap=0.0): + q_lens = _normalise_varlen_lens(q_lens, qs.shape[0], "q_lens") + kv_lens = _normalise_varlen_lens(kv_lens, ks.shape[0], "kv_lens") + + if qs.shape[0] == 0: + return qs.new_empty(qs.shape) + + qs, ks, vs = _match_attention_dtypes(qs, ks, vs) + qs, q_lens_list = _pad_packed_sequence(qs, q_lens) + ks, kv_lens_list = _pad_packed_sequence(ks, kv_lens) + vs, _ = _pad_packed_sequence(vs, kv_lens) + + max_q_len = qs.shape[-2] + max_kv_len = ks.shape[-2] + q_lens = q_lens.to(qs.device) + kv_lens = kv_lens.to(qs.device) + + score_mod = _softcap_score_mod(softcap) + if score_mod is not None: + + def mask_mod(batch, head, q_idx, kv_idx): + return (q_idx < q_lens[batch]) & (kv_idx < kv_lens[batch]) + + block_mask = create_block_mask( + mask_mod, + B=qs.shape[0], + H=None, + Q_LEN=max_q_len, + KV_LEN=max_kv_len, + device=str(qs.device), + ) + outs = _attention_output( + flex_attention(qs, ks, vs, score_mod=score_mod, block_mask=block_mask) + ) + else: + q_idx = torch.arange(max_q_len, device=qs.device) + kv_idx = torch.arange(max_kv_len, device=qs.device) + attn_mask = (q_idx[None, :, None] < q_lens[:, None, None]) & ( + kv_idx[None, None, :] < kv_lens[:, None, None] + ) + outs = torch.nn.functional.scaled_dot_product_attention( + qs, + ks, + vs, + attn_mask=attn_mask[:, None, :, :], + dropout_p=dropout_rate, + ) + + return _unpad_packed_sequence(outs, q_lens_list) + + class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( self, @@ -82,18 +222,15 @@ def __init__( self.dtype = attention_dtype - assert with_flash, "Only flash attention supported at the moment" - def forward(self, x, x_lens, ada_ln_aux=None, coords=None): - if self.with_residual: - x_in = x + x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [x.shape[0], self.num_heads, x.shape[-1] // self.num_heads] - qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype) - ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) + qs = _maybe_to_flash_dtype(self.lnorm_q(self.proj_heads_q(x).reshape(s)), self.dtype) + ks = _maybe_to_flash_dtype(self.lnorm_k(self.proj_heads_k(x).reshape(s)), self.dtype) vs = self.proj_heads_v(x).reshape(s) if self.with_2d_rope: @@ -104,19 +241,32 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 - cum_x_lens = torch.cumsum(x_lens, 0, dtype=torch.int32) - # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt) - outs = flash_attn_varlen_func( - qs, - ks, - vs, - cum_x_lens, - cum_x_lens, - x_lens.max(), - x_lens.max(), - softcap=self.softcap, - dropout_p=dropout_rate, - ) + if FLASH_ATTN_AVAILABLE: + assert flash_attn_varlen_func is not None + cum_x_lens = torch.cumsum(x_lens, 0, dtype=torch.int32) + # ordering of tensors (seq, heads, embed) differs from torch's + # flash attention implementation. + outs = flash_attn_varlen_func( + qs, + ks, + vs, + cum_x_lens, + cum_x_lens, + x_lens.max(), + x_lens.max(), + softcap=self.softcap, + dropout_p=dropout_rate, + ) + else: + outs = _varlen_attention( + qs, + ks, + vs, + x_lens, + x_lens, + dropout_rate=dropout_rate, + softcap=self.softcap, + ) out = self.proj_out(outs.flatten(-2, -1)) @@ -176,29 +326,40 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype - assert with_flash, "Only flash attention supported at the moment" + def att(qs, ks, vs): + qs, ks, vs = _match_attention_dtypes(qs, ks, vs) - def att(qs, ks, vs, x_mask): - def sparsity_mask(score, b, h, q_idx, kv_idx): + def sparsity_mask(batch, head, q_idx, kv_idx): return (q_idx // 16) == (kv_idx % 16) - return flex_attention(qs, ks, vs, score_mod=sparsity_mask) + block_mask = create_block_mask( + sparsity_mask, + B=qs.shape[0], + H=None, + Q_LEN=qs.shape[-2], + KV_LEN=ks.shape[-2], + device=str(qs.device), + ) + return _attention_output(flex_attention(qs, ks, vs, block_mask=block_mask)) - self.compiled_flex_attention = torch.compile(att, dynamic=False) + self.compiled_flex_attention = att def forward(self, x, x_lens=None): - if self.with_residual: - x_in = x + x_in = x x = self.lnorm(x) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [x.shape[0], 1, self.num_heads, -1] - qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype).permute([1, 2, 0, 3]) - ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([1, 2, 0, 3]) + qs = _maybe_to_flash_dtype( + self.lnorm_q(self.proj_heads_q(x).reshape(s)), self.dtype + ).permute([1, 2, 0, 3]) + ks = _maybe_to_flash_dtype( + self.lnorm_k(self.proj_heads_k(x).reshape(s)), self.dtype + ).permute([1, 2, 0, 3]) vs = self.proj_heads_v(x).reshape(s).permute([1, 2, 0, 3]) - outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze() + outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze(0) out = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: @@ -265,27 +426,26 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype - assert with_flash, "Only flash attention supported." + self.block_factor = block_factor + self.qkv_len = qkv_len - # define block mask def mask_block_local(batch, head, idx_q, idx_kv): - return (idx_q // block_factor) == (idx_kv // block_factor) + return (idx_q // self.block_factor) == (idx_kv // self.block_factor) - self.block_mask = create_block_mask( - mask_block_local, B=None, H=None, Q_LEN=qkv_len, KV_LEN=qkv_len - ) - # compile for efficiency - self.flex_attention = torch.compile(flex_attention, dynamic=False) + self.mask_block_local = mask_block_local def forward(self, x, coords=None, ada_ln_aux=None): - if self.with_residual: - x_in = x + x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] - qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) - ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) + qs = _maybe_to_flash_dtype( + self.lnorm_q(self.proj_heads_q(x).reshape(s)), self.dtype + ).permute([0, 2, 1, 3]) + ks = _maybe_to_flash_dtype( + self.lnorm_k(self.proj_heads_k(x).reshape(s)), self.dtype + ).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) if self.with_2d_rope: @@ -293,7 +453,24 @@ def forward(self, x, coords=None, ada_ln_aux=None): raise ValueError("coords must be provided when with_2d_rope=True") qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) - outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) + qs, ks, vs = _match_attention_dtypes(qs, ks, vs) + block_mask = create_block_mask( + self.mask_block_local, + B=qs.shape[0], + H=None, + Q_LEN=qs.shape[-2], + KV_LEN=ks.shape[-2], + device=str(qs.device), + ) + outs = _attention_output( + flex_attention( + qs, + ks, + vs, + score_mod=_softcap_score_mod(self.softcap), + block_mask=block_mask, + ) + ).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) if self.with_residual: @@ -364,26 +541,30 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype - assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): - if self.with_residual: - x_q_in = x_q + x_q_in = x_q x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux) x_kv = self.lnorm_in_kv(x_kv) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [x_q.shape[0], self.num_heads, self.dim_head_proj] - qs = self.lnorm_q(self.proj_heads_q(x_q).reshape(s)).to(self.dtype) + qs = _maybe_to_flash_dtype(self.lnorm_q(self.proj_heads_q(x_q).reshape(s)), self.dtype) s = [x_kv.shape[0], self.num_heads, self.dim_head_proj] - ks = self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)).to(self.dtype) + ks = _maybe_to_flash_dtype(self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)), self.dtype) vs = self.proj_heads_v(x_kv).reshape(s) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 - if x_kv_lens is not None: + if x_q_lens is None or x_kv_lens is None: + raise ValueError( + "x_q_lens and x_kv_lens must be provided for variable-length attention" + ) + + if FLASH_ATTN_AVAILABLE: + assert flash_attn_varlen_func is not None cum_x_q_lens = torch.cumsum(x_q_lens, 0, dtype=torch.int32) cum_x_kv_lens = torch.cumsum(x_kv_lens, 0, dtype=torch.int32) outs = flash_attn_varlen_func( @@ -398,7 +579,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): dropout_p=dropout_rate, ) else: - assert False + outs = _varlen_attention( + qs, + ks, + vs, + x_q_lens, + x_kv_lens, + dropout_rate=dropout_rate, + softcap=self.softcap, + ) outs = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: @@ -477,11 +666,9 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype - assert with_flash, "Only flash attention supported at the moment" def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): - if self.with_residual: - x_q_in = x_q + x_q_in = x_q x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux) x_kv = self.lnorm_in_kv(x_kv) @@ -489,32 +676,52 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): # ensure these are 4D tensors as required for flash attention s = [x_q.shape[0], self.num_heads, self.dim_head_proj] qs = [ - self.lnorm_q(head_proj(x_q_i).reshape(s)).to(self.dtype) + _maybe_to_flash_dtype(self.lnorm_q(head_proj(x_q_i).reshape(s)), self.dtype) for head_proj, x_q_i in zip(self.proj_heads_q, x_q.transpose(1, 0), strict=False) ] s = [x_kv.shape[0], self.num_heads, self.dim_head_proj] - ks = self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)).to(self.dtype) + ks = _maybe_to_flash_dtype(self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)), self.dtype) vs = self.proj_heads_v(x_kv).reshape(s) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 - cum_x_q_lens = torch.cumsum(x_q_lens, 0, dtype=torch.int32) - cum_x_kv_lens = torch.cumsum(x_kv_lens, 0, dtype=torch.int32) - outs = [] - for _i, qs_i in enumerate(qs): - outs += [ - flash_attn_varlen_func( + if x_q_lens is None or x_kv_lens is None: + raise ValueError( + "x_q_lens and x_kv_lens must be provided for variable-length attention" + ) + + if FLASH_ATTN_AVAILABLE: + assert flash_attn_varlen_func is not None + cum_x_q_lens = torch.cumsum(x_q_lens, 0, dtype=torch.int32) + cum_x_kv_lens = torch.cumsum(x_kv_lens, 0, dtype=torch.int32) + outs = [] + for _i, qs_i in enumerate(qs): + outs += [ + flash_attn_varlen_func( + qs_i, + ks, + vs, + cum_x_q_lens, + cum_x_kv_lens, + x_q_lens.max(), + x_kv_lens.max(), + softcap=self.softcap, + dropout_p=dropout_rate, + ) + ] + else: + outs = [ + _varlen_attention( qs_i, ks, vs, - cum_x_q_lens, - cum_x_kv_lens, - x_q_lens.max(), - x_kv_lens.max(), + x_q_lens, + x_kv_lens, + dropout_rate=dropout_rate, softcap=self.softcap, - dropout_p=dropout_rate, ) + for qs_i in qs ] outs = self.proj_out(torch.stack(outs).transpose(1, 0).flatten(-2, -1)) @@ -581,23 +788,18 @@ def __init__( self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype - if with_flash: - self.att = torch.nn.functional.scaled_dot_product_attention - else: - self.att = self.attention - self.softmax = torch.nn.Softmax(dim=-1) + self.att = torch.nn.functional.scaled_dot_product_attention def forward(self, x, coords=None, ada_ln_aux=None): - if self.with_residual: - x_in = x + x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1] - qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype) - ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) - vs = self.proj_heads_v(x).reshape(s).to(self.dtype) + qs = _maybe_to_flash_dtype(self.lnorm_q(self.proj_heads_q(x).reshape(s)), self.dtype) + ks = _maybe_to_flash_dtype(self.lnorm_k(self.proj_heads_k(x).reshape(s)), self.dtype) + vs = _maybe_to_flash_dtype(self.proj_heads_v(x).reshape(s), self.dtype) if self.with_2d_rope: if coords is None: @@ -607,8 +809,19 @@ def forward(self, x, coords=None, ada_ln_aux=None): # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 - # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt) - outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) + if FLASH_ATTN_AVAILABLE: + assert flash_attn_func is not None + # ordering of tensors (seq, heads, embed) differs from torch's + # flash attention implementation. + outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=dropout_rate) + else: + outs = _dense_attention( + qs.transpose(-3, -2), + ks.transpose(-3, -2), + vs.transpose(-3, -2), + dropout_rate=dropout_rate, + softcap=self.softcap, + ).transpose(-3, -2) out = self.proj_out(outs.flatten(-2, -1)) if self.with_residual: @@ -678,21 +891,23 @@ def __init__( ######################################### def forward(self, x_q, x_kv): - if self.with_residual: - x_q_in = x_q + x_q_in = x_q x_q, x_kv = self.lnorm_in_q(x_q), self.lnorm_in_kv(x_kv) # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [x_q.shape[0], -1, self.num_heads, self.dim_head_proj] - qs = self.lnorm_q(self.proj_heads_q(x_q).reshape(s)).to(self.dtype).transpose(-3, -2) + qs = _maybe_to_flash_dtype( + self.lnorm_q(self.proj_heads_q(x_q).reshape(s)), self.dtype + ).transpose(-3, -2) s = [x_kv.shape[0], -1, self.num_heads, self.dim_head_proj] - ks = self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)).to(self.dtype).transpose(-3, -2) - vs = self.proj_heads_v(x_kv).reshape(s).transpose(-3, -2) + ks = _maybe_to_flash_dtype( + self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)), self.dtype + ).transpose(-3, -2) + vs = _maybe_to_flash_dtype(self.proj_heads_v(x_kv).reshape(s), self.dtype).transpose(-3, -2) # correct ordering of tensors with seq dimension second but last is critical - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): - outs = self.att(qs, ks, vs).transpose(2, 1) + outs = _dense_attention(qs, ks, vs).transpose(2, 1) outs = self.dropout(self.proj_out(outs.flatten(-2, -1))) if self.with_residual: diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 7995b5864..12a032f6d 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -137,6 +137,8 @@ def run_continue(args): mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) + if not isinstance(devices, list): + devices = [devices] cf = Trainer.init_ddp(cf) init_loggers(cf.general.run_id) @@ -151,8 +153,8 @@ def run_continue(args): except Exception: extype, value, tb = sys.exc_info() traceback.print_exc() - if cf.world_size == 1: - pdb.post_mortem(tb) + # if cf.world_size == 1: + # pdb.post_mortem(tb) def run_train(args): @@ -172,6 +174,8 @@ def run_train(args): cf.data_loading.rng_seed = int(time.time()) mp_method = cf.general.get("multiprocessing_method", "fork") devices = Trainer.init_torch(multiprocessing_method=mp_method) + if not isinstance(devices, list): + devices = [devices] cf = Trainer.init_ddp(cf) # this line should probably come after the processes have been sorted out else we get lots @@ -186,14 +190,15 @@ def run_train(args): assert cf.with_mixed_precision trainer = Trainer(cf.train_logging) - - try: - trainer.run(cf, devices) - except Exception: - extype, value, tb = sys.exc_info() - traceback.print_exc() - if cf.world_size == 1: - pdb.post_mortem(tb) + trainer.run(cf, devices) + + # try: + # trainer.run(cf, devices) + # except Exception: + # extype, value, tb = sys.exc_info() + # traceback.print_exc() + # if cf.world_size == 1: + # pdb.post_mortem(tb) if __name__ == "__main__": diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f396e611c..92843f227 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -796,7 +796,7 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): if is_root(): if stage == VAL: logger.info( - f"""validation ({self.cf.general.run_id}) : {mini_epoch:03d} : + f"""validation ({self.cf.general.run_id}) : {mini_epoch:03d} : {np.nanmean(avg_loss)}""" )