From 9af4d0bf4bcca84b13dea2cf52eb23a89148a854 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 9 Jun 2026 16:43:12 +0200 Subject: [PATCH 1/3] Implement first prototype to be tested --- src/weathergen/datasets/data_reader_anemoi.py | 61 ++++++++++------- src/weathergen/datasets/data_reader_fesom.py | 30 +++++---- src/weathergen/datasets/data_reader_obs.py | 29 +++++---- src/weathergen/train/trainer.py | 65 ++++++++++++++++++- 4 files changed, 138 insertions(+), 47 deletions(-) diff --git a/src/weathergen/datasets/data_reader_anemoi.py b/src/weathergen/datasets/data_reader_anemoi.py index 6d4237fd3..4fc3c496d 100644 --- a/src/weathergen/datasets/data_reader_anemoi.py +++ b/src/weathergen/datasets/data_reader_anemoi.py @@ -259,28 +259,43 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6 channels = self.stream_info.get(ch_type) channels_exclude = self.stream_info.get(ch_type + "_exclude", []) + stream_name = self.stream_info["name"] # sanity check - is_empty = len(channels) == 0 if channels is not None else False - if is_empty: - stream_name = self.stream_info["name"] + if channels is not None and len(channels) == 0: _logger.warning(f"No channel for {stream_name} for {ch_type}.") - chs_idx = np.sort( - [ - ds0.name_to_index[k] - for (k, v) in ds0.typed_variables.items() - if ( - not v.is_computed_forcing - and not v.is_constant_in_time - and ( - np.array([f == k for f in channels]).any() if channels is not None else True - ) - and not np.array([f == k for f in channels_exclude]).any() + # Variables eligible for selection: physical fields only (no computed forcings, + # no constants-in-time), and not explicitly excluded. + eligible = { + k + for k, v in ds0.typed_variables.items() + if not v.is_computed_forcing + and not v.is_constant_in_time + and not any(ex == k for ex in channels_exclude) + } + + if channels is not None: + # Respect the order given in the stream config so the channel layout is identical + # across datasets that share channels (e.g. ERA5 vs operational analysis), + # regardless of each dataset's on-disk variable order. + seen: set[str] = set() + ordered = [] + for k in channels: + if k in eligible and k not in seen: + ordered.append(k) + seen.add(k) + missing = [k for k in channels if k not in eligible] + if missing: + _logger.warning( + f"{stream_name}: requested {ch_type} channels not available " + f"(excluded/forcing/constant-in-time or absent), skipped: {missing}" ) - ] - ) + else: + # No explicit selection: fall back to deterministic lexicographic order so the + # layout is still reproducible across datasets. + ordered = sorted(eligible) - return np.array(chs_idx, dtype=np.int64) + return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64) def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]: """ @@ -302,19 +317,19 @@ def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]: if len(geoinfo_channels) == 0: return np.array([], dtype=np.int64) - # Select channels that match the geoinfo list (exact match required) - chs_idx = np.sort( - [ds0.name_to_index[k] for k in ds0.typed_variables.keys() if k in geoinfo_channels] - ) + # Select channels that match the geoinfo list (exact match required), preserving the + # order requested in the config so the geoinfo layout is dataset-independent. + available = set(ds0.typed_variables.keys()) + ordered = [k for k in geoinfo_channels if k in available] - if len(chs_idx) == 0 and len(geoinfo_channels) > 0: + if len(ordered) == 0: stream_name = self.stream_info["name"] _logger.warning( f"No matching geoinfo channels found for {stream_name}. " f"Requested: {geoinfo_channels}" ) - return np.array(chs_idx, dtype=np.int64) + return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64) def _clip_lat(lats: NDArray) -> NDArray[np.float32]: diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index b37352a7e..2067ea882 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -394,20 +394,26 @@ def select( ch_filters: list[str] | None, excl: list[str] | None = None, ) -> tuple[list[str], NDArray]: - if excl and ch_filters: - mask = [ - any(f == c for f in ch_filters) and all(ex not in c for ex in excl) - for c in colnames - ] - elif ch_filters: - mask = [any(f == c for f in ch_filters) for c in colnames] - elif excl: - mask = [all(ex not in c for ex in excl) for c in colnames] + excl = excl or [] + name_to_pos = {c: i for i, c in enumerate(colnames)} + + if ch_filters: + # Respect config order (exact match) so the channel layout is identical across + # datasets that share channels, regardless of each dataset's column order. + seen: set[str] = set() + ordered = [] + for f in ch_filters: + if f in name_to_pos and f not in seen and all(ex not in f for ex in excl): + ordered.append(f) + seen.add(f) else: - assert False, "Cannot use select with both ch_filters and excl as None" + assert excl, "Cannot use select with both ch_filters and excl as None" + # No explicit selection: deterministic lexicographic order of non-excluded columns. + ordered = sorted(c for c in colnames if all(ex not in c for ex in excl)) - selected_cols_idx = cols_idx[np.where(mask)[0]] - selected_colnames = [colnames[i] for i in np.where(mask)[0]] + positions = [name_to_pos[c] for c in ordered] + selected_cols_idx = cols_idx[positions] + selected_colnames = [colnames[i] for i in positions] return selected_colnames, selected_cols_idx @override diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 62b1dcfba..461adcc48 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -113,18 +113,25 @@ def select_channels( """ Allow user to specify which columns they want to access. Get functions only returned for these specified columns. - """ - selected_colnames = [ - c - for c in colnames - if ( - np.array([c_sel in c for c_sel in cols_select]).any() - if cols_select is not None - else True and not np.array([c_nsel in c for c_nsel in cols_exclude]).any() - ) - ] - return selected_colnames + When ``cols_select`` is given, the returned columns follow the order of the select + filters (config order) so the channel layout is identical across datasets that share + channels, regardless of each dataset's column order. Without a selection, columns fall + back to deterministic lexicographic order. + """ + cols_exclude = cols_exclude or [] + + if cols_select is not None: + # Respect config order: group matching columns by the order of the select filters. + # Matching is substring-based (a filter may match several columns). + selected_colnames: list[str] = [] + for c_sel in cols_select: + for c in colnames: + if c_sel in c and c not in selected_colnames: + selected_colnames.append(c) + return selected_colnames + + return sorted(c for c in colnames if not any(c_nsel in c for c_nsel in cols_exclude)) def first_sample_with_data(self) -> int: """ diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0b3c504db..b60bb26f4 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -256,11 +256,69 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd): logger.info(f"Starting inference with id={self.cf.general.run_id}.") - # inference validation set self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) logger.info(f"Finished inference run with id: {cf.general.run_id}") + def _check_channel_order_consistency( + self, + dataset: MultiStreamDataSampler, + from_run_id: str, + mini_epoch: int | None, + stage: Stage, + ) -> None: + """Guard against silently scrambling the channel<->weight mapping when continuing. + + Compares the channel order resolved for the current data against the order stored in + the checkpoint's config and raises if they differ. Streams (or channel lists) absent + from the checkpoint config cannot be verified and are skipped with a warning. + """ + try: + prev_cf = config.load_run_config(from_run_id, mini_epoch, None) + except FileNotFoundError: + logger.warning( + f"Could not load config for run_id '{from_run_id}' to verify channel order; " + "skipping channel-order consistency check." + ) + return + + prev_streams = {s["name"]: s for s in prev_cf.get("streams", [])} + src_key = f"{stage}_source_channels" + tgt_key = f"{stage}_target_channels" + + mismatches: list[str] = [] + for name, readers in dataset.streams_datasets.items(): + prev = prev_streams.get(name) + if prev is None: + continue + reader = readers[0] + for key, resolved in ( + (src_key, list(reader.source_channels)), + (tgt_key, list(reader.target_channels)), + ): + stored = prev.get(key) + if stored is None: + logger.warning( + f"Checkpoint '{from_run_id}' has no '{key}' for stream '{name}'; " + "cannot verify channel order for it." + ) + continue + if list(stored) != resolved: + mismatches.append( + f" [{name}] {key}:\n" + f" checkpoint: {list(stored)}\n" + f" current: {resolved}" + ) + + if mismatches: + details = "\n".join(mismatches) + raise ValueError( + f"Channel order/content differs from the checkpoint being continued " + f"(run_id='{from_run_id}'). Continuing would scramble the learned " + f"channel<->weight mapping. Align the stream configs (channel order matters):\n" + f"{details}" + ) + def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # general initalization self.init(cf, devices) @@ -276,6 +334,11 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) + if run_id_contd is not None: + self._check_channel_order_consistency( + self.dataset, run_id_contd, mini_epoch_contd, TRAIN + ) + loader_params = { "batch_size": None, "batch_sampler": None, From 7e26241caedd0474489a6f1d85e17ed88003cc10 Mon Sep 17 00:00:00 2001 From: Sophie Xhonneux Date: Tue, 9 Jun 2026 17:44:10 +0200 Subject: [PATCH 2/3] Add geoinfo check and stream dir w/ era5 ch order --- config/config_jepa_multi_data_ft_forecast.yml | 2 +- .../analysis.yml | 39 +++++ .../avhrr.yml | 28 ++++ .../era5.yml | 36 +++++ .../geos.yml | 140 ++++++++++++++++++ .../datasets/multi_stream_data_sampler.py | 9 +- src/weathergen/train/trainer.py | 9 +- 7 files changed, 256 insertions(+), 7 deletions(-) create mode 100644 config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml create mode 100644 config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml create mode 100644 config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml create mode 100644 config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml diff --git a/config/config_jepa_multi_data_ft_forecast.yml b/config/config_jepa_multi_data_ft_forecast.yml index f824a9147..2f50937a1 100644 --- a/config/config_jepa_multi_data_ft_forecast.yml +++ b/config/config_jepa_multi_data_ft_forecast.yml @@ -1,4 +1,4 @@ -streams_directory: "./config/streams/jepa_forecast_multi_data_od/" +streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/" general: diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml new file mode 100644 index 000000000..c316696cc --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml @@ -0,0 +1,39 @@ +# ERA5 input stream for temporal JEPA. +# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked +# forcing so the teacher metadata mask remains populated for SSL targets. + +ERA5_in: + type: anemoi + # filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] + stream_id: 0 + source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z'] + loss_weight: 1.0 + forcing: True + location_weight: cosine_latitude + token_size: 8 + tokenize_spacetime: True + max_num_targets: -1 + frequency: 06:00:00 + nominal_time_mapping : + "0" : 5 # 04:30:00 + "6" : 9 # 09:00:00 + "12" : 17 #16:30:00 + "18" : 21 #21:00:00 + embed: + net: transformer + num_tokens: 1 + num_heads: 8 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml new file mode 100644 index 000000000..14f2589e4 --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml @@ -0,0 +1,28 @@ +# Polar-orbiting observation stream for temporal JEPA. + +METOP_ABC_AVHRR_IASI: + type: obs + stream_id: 20 + filenames: ['observations-ea-ofb-0001-2007-2021-metop-a-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2013-2023-metop-b-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2019-2023-metop-c-iasi-radiances-v1.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] + source: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383'] + target: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383'] + forcing: True + loss_weight: 1.0 + token_size: 512 + tokenize_spacetime: False + embed: + net: transformer + num_tokens: 1 + num_heads: 2 + dim_embed: 256 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml new file mode 100644 index 000000000..9e22ee48e --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml @@ -0,0 +1,36 @@ +# ERA5 input stream for temporal JEPA. +# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked +# forcing so the teacher metadata mask remains populated for SSL targets. + +ERA5: + type: anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] + stream_id: 0 + source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] + geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z'] + loss_weight: 1.0 + diagnostic: True + location_weight: cosine_latitude + token_size: 8 + tokenize_spacetime: False + max_num_targets: -1 + frequency: 01:00:00 + embed: + net: transformer + num_tokens: 1 + num_heads: 8 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 256 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + + diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml new file mode 100644 index 000000000..04f9abb41 --- /dev/null +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml @@ -0,0 +1,140 @@ +# Geostationary observation streams for temporal JEPA. +# These mirror era5_georing_avhrr_forecast_random_inputs/geos.yml, with forcing +# removed so teacher masks are available to the SSL loss. + +METEOSAT_SEVIRI_IR: + type: obs + stream_id: 10 + filenames: ['observations-file-2014-2024-seviri-o256-wegen-v3.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +GOES_ABI_IR: + type: obs + stream_id: 11 + filenames: ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +HIMAWARI_AHI_IR: + type: obs + stream_id: 12 + filenames: ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr', 'observations-file-2022-2024-himawari9-IR-o256-v1.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +GOES_ABI_VIS: + type: obs + stream_id: 13 + filenames: ['observations-file-2017-2024-abi-goes16-VIS-o256-v2.zarr'] + geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 + +HIMAWARI_AHI_VIS: + type: obs + stream_id: 14 + filenames: ['observations-file-2015-2022-himawari8-VIS-o256-v1.zarr', 'observations-file-2022-2024-himawari9-VIS-o256-v1.zarr'] + geoinfo_channels: ['zenith', 'noise_time'] + # geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time'] + forcing: True + loss_weight: 1.0 + token_size: 1024 + tokenize_spacetime: False + max_num_targets: 262144 + embed: + net: transformer + num_tokens: 1 + num_heads: 4 + dim_embed: 512 + num_blocks: 2 + embed_target_coords: + net: linear + dim_embed: 512 + target_readout: + num_layers: 2 + num_heads: 4 + pred_head: + ens_size: 1 + num_layers: 1 diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index d845d2e7d..f966dafac 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -269,6 +269,7 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: stream_info[str(self._stage) + "_source_channels"] = ds.source_channels stream_info[str(self._stage) + "_target_channels"] = ds.target_channels + stream_info[str(self._stage) + "_geoinfo_channels"] = ds.geoinfo_channels stream_info["target_channel_weights"] = ( ds.target_channel_weights if ds.target_channel_weights is not None @@ -732,7 +733,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): input_data, source_masks.metadata[sidx], is_student=True, - add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels",[]), + add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels", []), ) sdata = self._build_stream_data( @@ -761,8 +762,10 @@ def _get_batch(self, idx: int, num_forecast_steps: int): # Apply self-flow noise to teacher data (handled by masker) input_data_target = self.masker.apply_noise_to_data( - input_data_target_orig, target_masks.metadata[tidx], is_student=False, - add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels",[]), + input_data_target_orig, + target_masks.metadata[tidx], + is_student=False, + add_geoinfo_noise="noise_time" in stream_info.get("geoinfo_channels", []), ) sdata = self._build_stream_data( diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b60bb26f4..7704f9412 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -269,9 +269,10 @@ def _check_channel_order_consistency( ) -> None: """Guard against silently scrambling the channel<->weight mapping when continuing. - Compares the channel order resolved for the current data against the order stored in - the checkpoint's config and raises if they differ. Streams (or channel lists) absent - from the checkpoint config cannot be verified and are skipped with a warning. + Compares the source/target/geoinfo channel order resolved for the current data against + the order stored in the checkpoint's config and raises if they differ. Streams (or + channel lists) absent from the checkpoint config cannot be verified and are skipped with + a warning (e.g. geoinfo for checkpoints predating the resolved-geoinfo back-fill). """ try: prev_cf = config.load_run_config(from_run_id, mini_epoch, None) @@ -285,6 +286,7 @@ def _check_channel_order_consistency( prev_streams = {s["name"]: s for s in prev_cf.get("streams", [])} src_key = f"{stage}_source_channels" tgt_key = f"{stage}_target_channels" + geo_key = f"{stage}_geoinfo_channels" mismatches: list[str] = [] for name, readers in dataset.streams_datasets.items(): @@ -295,6 +297,7 @@ def _check_channel_order_consistency( for key, resolved in ( (src_key, list(reader.source_channels)), (tgt_key, list(reader.target_channels)), + (geo_key, list(reader.geoinfo_channels)), ): stored = prev.get(key) if stored is None: From 1a0b72137ca82f859791818f7b0a071ea4920d26 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:07:08 +0200 Subject: [PATCH 3/3] Fix missing start date in oper forecast finetune --- config/config_jepa_multi_data_ft_forecast.yml | 8 +++++++- .../streams/jepa_forecast_multi_data_all_years/era5.yml | 2 +- .../jepa_forecast_multi_data_od_ckpt_order/era5.yml | 2 +- src/weathergen/model/utils.py | 1 + src/weathergen/train/trainer.py | 4 ++-- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/config/config_jepa_multi_data_ft_forecast.yml b/config/config_jepa_multi_data_ft_forecast.yml index 2f50937a1..30285197d 100644 --- a/config/config_jepa_multi_data_ft_forecast.yml +++ b/config/config_jepa_multi_data_ft_forecast.yml @@ -1,11 +1,17 @@ streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/" +freeze_modules: "" # "^(?!.*ERA5)(?=.*(?:encoder|latent_pre_norm|latent_heads)).*$" + general: # mutable parameters + istep: 0 rank: ??? world_size: ??? training_config: - num_mini_epochs: 32 + start_date: 2016-01-01T00:00 + end_date: 2022-12-31T00:00 + + num_mini_epochs: 12 samples_per_mini_epoch: 8192 diff --git a/config/streams/jepa_forecast_multi_data_all_years/era5.yml b/config/streams/jepa_forecast_multi_data_all_years/era5.yml index 5f9268163..6fa370ab5 100644 --- a/config/streams/jepa_forecast_multi_data_all_years/era5.yml +++ b/config/streams/jepa_forecast_multi_data_all_years/era5.yml @@ -6,7 +6,7 @@ ERA5: type: anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] # filenames: # ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] - stream_id: 0 + stream_id: 42 source: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] target: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl'] geoinfo_channels: ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time'] diff --git a/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml index 9e22ee48e..492c6296e 100644 --- a/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml +++ b/config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml @@ -6,7 +6,7 @@ ERA5: type: anemoi filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] # filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr'] - stream_id: 0 + stream_id: 42 source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925'] geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z'] diff --git a/src/weathergen/model/utils.py b/src/weathergen/model/utils.py index 7dd2060bb..7181f68fa 100644 --- a/src/weathergen/model/utils.py +++ b/src/weathergen/model/utils.py @@ -48,6 +48,7 @@ def apply_fct_to_blocks(model, blocks, fct): name = module.name if hasattr(module, "name") else name # avoid the whole model element which has name '' if (re.fullmatch(blocks, name) is not None) and (name != ""): + logger.info(f"Freezing weights of {name}") fct(module) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 7704f9412..afefb0aed 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -396,7 +396,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.target_and_aux_calculators_val = self.get_target_aux_calculators(self.validation_cfg) # Restore EMA teacher weights when continuing from a checkpoint - if run_id_contd is not None: + if run_id_contd is not None: # and self.cf.general.istep != 0: # To be tested self._load_ema_teacher_state(run_id_contd, mini_epoch_contd) # if with_fsdp then parameter count is unreliable @@ -440,7 +440,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): ) # Restore optimizer momentum buffers when continuing from a checkpoint - if run_id_contd is not None: + if run_id_contd is not None and self.cf.general.istep != 0: self._load_optimizer_state(run_id_contd, mini_epoch_contd) if self.cf.general.istep > 0 and is_root():