From dbcefb7e126c9c5eefaee5b1659668258240459b Mon Sep 17 00:00:00 2001 From: lessig2 Date: Tue, 16 Jun 2026 20:15:06 +0200 Subject: [PATCH] This pulls in the latest data_readers (and related code) to allow training with operan offset and IASI-PCs --- .../data_reader_anemoi_operan.py | 186 ++++++++++++++++++ .../src/weathergen/readers_extra/registry.py | 8 + src/weathergen/datasets/data_reader_anemoi.py | 36 +++- src/weathergen/datasets/data_reader_obs.py | 54 +++-- .../datasets/multi_stream_data_sampler.py | 1 + 5 files changed, 259 insertions(+), 26 deletions(-) create mode 100644 packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py new file mode 100644 index 000000000..91033b5a2 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_anemoi_operan.py @@ -0,0 +1,186 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +from anemoi.datasets.data import MissingDateError + +from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi +from weathergen.datasets.data_reader_base import ( + ReaderData, + TimeWindowHandler, + TIndex, +) +from weathergen.train.utils import Stage + +_logger = logging.getLogger(__name__) + + +def dt2cal(dt): + """ + Convert array of datetime64 to a calendar array of year, month, day, hour, + minute, seconds, microsecond with these quantites indexed on the last axis. + + Parameters + ---------- + dt : datetime64 array (...) + numpy.ndarray of datetimes of arbitrary shape + + Returns + ------- + cal : uint32 array (..., 7) + calendar array with last axis representing year, month, day, hour, + minute, second, microsecond + """ + + # allocate output + out = np.empty(dt.shape + (7,), dtype="u4") + # decompose calendar floors + year, month, day, hour, min, sec = [dt.astype(f"M8[{x}]") for x in "YMDhms"] + out[..., 0] = year + 1970 # Gregorian Year + out[..., 1] = (month - year) + 1 # month + out[..., 2] = (day - month) + 1 # dat + out[..., 3] = (dt - day).astype("m8[h]") # hour + out[..., 4] = (dt - hour).astype("m8[m]") # minute + out[..., 5] = (dt - min).astype("m8[s]") # second + out[..., 6] = (dt - sec).astype("m8[us]") # microsecond + return out + + +class DataReaderAnemoiOperan(DataReaderAnemoi): + "Wrapper for Anemoi datasets" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + stage: Stage, + ) -> None: + """ + Construct data reader for anemoi dataset + + Parameters + ---------- + filename : + filename (and path) of dataset + stream_info : + information about stream + + Returns + ------- + None + """ + + super().__init__(tw_handler, filename, stream_info, stage) + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for window (for either source or target, through public interface) + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + """ + + t_idxs, dtr = self._get_dataset_idxs(idx) + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # get additional timestep to ensure we have one valid timestep + t_idxs = np.insert(t_idxs, 0, t_idxs[0] - 1) + + didx_start = t_idxs[0] + didx_end = t_idxs[-1] + 1 + datetimes = self.ds.dates[didx_start:didx_end] + datetimes_split = dt2cal(datetimes) + + # compute corrected datetimes that account for actual availability + nts = self.stream_info["nominal_time_mapping"] + deltas = [int(nts[str(hour)]) - int(hour) for hour in datetimes_split[:, 3]] + datetimes_offset = [ + dt + np.timedelta64(delta, "h") for dt, delta in zip(datetimes, deltas, strict=False) + ] + + # use latest available sample that is valid w.r.t the input data window + datetimes_mask = [dt < dtr.end for dt in datetimes_offset] + if np.array(datetimes_mask).sum() == 0: + t_idxs = [] + else: + t_idxs = [t_idxs[datetimes_mask][-1].item()] + + # _get from DataReaderAnemoi + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + assert t_idxs[0] >= 0, "index must be non-negative" + didx_start = t_idxs[0] + # End is inclusive + didx_end = t_idxs[-1] + 1 + + # extract number of time steps and collapse ensemble dimension + # ds is a wrapper around zarr with get_coordinate_selection not being exposed since + # subsetting is pushed to the ctor via frequency argument; this also ensures that no sub- + # sampling is required here + try: + data = self.ds[didx_start:didx_end][:, :, 0].astype(np.float32) + except MissingDateError as e: + _logger.debug(f"Date not present in anemoi dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # coords-first representation and collapse multiple steps + data = data.transpose([0, 2, 1]).reshape((data.shape[0] * data.shape[2], -1)) + + # extract geoinfo channels (can be time-varying, so read from dataset) + geoinfos = data[:, list(self.geoinfo_idx)] + # extract channels + data = data[:, list(channels_idx)] + + # construct lat/lon coords + latlon = np.concatenate( + [ + np.expand_dims(self.latitudes, 0), + np.expand_dims(self.longitudes, 0), + ], + axis=0, + ).transpose() + # repeat latlon len(t_idxs) times + coords = np.vstack((latlon,) * len(t_idxs)) + + # date time matching #data points of data + # Assuming a fixed frequency for the dataset + datetimes = np.repeat(self.ds.dates[didx_start:didx_end], len(data) // len(t_idxs)) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + # check_reader_data(rd, dtr) + + return rd diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 7ea7c7d59..b13ee4655 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -24,5 +24,13 @@ def get_extra_reader(stream_type: str) -> object | None: from weathergen.readers_extra.data_reader_mesh import DataReaderMesh return DataReaderMesh + case "anemoi_operan": + from weathergen.readers_extra.data_reader_anemoi_operan import DataReaderAnemoiOperan + + return DataReaderAnemoiOperan + case "fesom": + from weathergen.readers_extra.data_reader_fesom import DataReaderFesom + + return DataReaderFesom case _: return None diff --git a/src/weathergen/datasets/data_reader_anemoi.py b/src/weathergen/datasets/data_reader_anemoi.py index ea0493149..f16559fe1 100644 --- a/src/weathergen/datasets/data_reader_anemoi.py +++ b/src/weathergen/datasets/data_reader_anemoi.py @@ -25,6 +25,7 @@ TIndex, check_reader_data, ) +from weathergen.train.utils import Stage _logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ def __init__( tw_handler: TimeWindowHandler, filename: Path, stream_info: dict, + stage: Stage, ) -> None: """ Construct data reader for anemoi dataset @@ -104,19 +106,35 @@ def __init__( self.longitudes = _clip_lon(ds.longitudes) # select/filter requested source channels - self.source_idx = self.select_channels(ds0, "source") - self.source_channels = [ds.variables[i] for i in self.source_idx] + if stream_info.get(str(stage) + "_source_channels") is None: + self.source_idx = self.select_channels(ds, "source") + self.source_channels = [ds.variables[i] for i in self.source_idx] + else: + self.source_channels = stream_info.get(str(stage) + "_source_channels") + self.source_idx = [ds.variables.index(ch) for ch in self.source_channels] # select/filter requested target channels - self.target_idx = self.select_channels(ds0, "target") - self.target_channels = [ds.variables[i] for i in self.target_idx] + if stream_info.get(str(stage) + "_target_channels") is None: + self.target_idx = self.select_channels(ds, "target") + self.target_channels = [ds.variables[i] for i in self.target_idx] + else: + self.target_channels = stream_info.get(str(stage) + "_target_channels") + self.target_idx = [ds.variables.index(ch) for ch in self.target_channels] # get target channel weights from stream config - self.target_channel_weights = self.parse_target_channel_weights() + if stream_info.get("target_channel_weights") is None: + self.target_channel_weights = self.parse_target_channel_weights() + else: + self.target_channel_weights = stream_info.get("target_channel_weights") # select/filter requested geoinfo channels (can be any variable, not just constant-in-time) - self.geoinfo_idx = self.select_geoinfo_channels(ds0) - self.geoinfo_channels = [ds.variables[i] for i in self.geoinfo_idx] + if stream_info.get("geoinfo_channels") is None: + self.geoinfo_idx = self.select_geoinfo_channels(ds) + self.geoinfo_channels = [ds.variables[i] for i in self.geoinfo_idx] + else: + self.geoinfo_channels = stream_info.get("geoinfo_channels") + self.geoinfo_idx = [ds.variables.index(ch) for ch in self.geoinfo_channels] + # set geoinfo normalization statistics if len(self.geoinfo_idx) > 0: self.mean_geoinfo = ds.statistics["mean"][self.geoinfo_idx] @@ -253,9 +271,9 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6 not v.is_computed_forcing and not v.is_constant_in_time and ( - np.array([f in k for f in channels]).any() if channels is not None else True + np.array([f == k for f in channels]).any() if channels is not None else True ) - and not np.array([f in k for f in channels_exclude]).any() + and not np.array([f == k for f in channels_exclude]).any() ) ] ) diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index 963b910d7..d5547ebfb 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -21,12 +21,15 @@ TimeWindowHandler, check_reader_data, ) +from weathergen.train.utils import Stage _logger = logging.getLogger(__name__) class DataReaderObs(DataReaderBase): - def __init__(self, tw_handler: TimeWindowHandler, filename: Path, stream_info: dict) -> None: + def __init__( + self, tw_handler: TimeWindowHandler, filename: Path, stream_info: dict, stage: Stage + ) -> None: super().__init__(tw_handler, stream_info) self.filename = filename @@ -52,29 +55,41 @@ def __init__(self, tw_handler: TimeWindowHandler, filename: Path, stream_info: d # determine source / target channels and corresponding idx using include and exclude lists - s_chs = stream_info.get("source") - s_chs_exclude = stream_info.get("source_exclude", []) - - t_chs = stream_info.get("target") - t_chs_exclude = stream_info.get("target_exclude", []) - - # source_n_empty = len(s_chs) > 0 if s_chs is not None else True - # assert source_n_empty, "source is empty; at least one channels must be present." - # target_n_empty = len(t_chs) > 0 if t_chs is not None else True - # assert target_n_empty, "target is empty; at least one channels must be present." - - self.source_channels = self.select_channels(data_colnames, s_chs, s_chs_exclude) + if stream_info.get(str(stage) + "_source_channels") is None: + s_chs = stream_info.get("source") + s_chs_exclude = stream_info.get("source_exclude", []) + self.source_channels = self.select_channels(data_colnames, s_chs, s_chs_exclude) + else: + self.source_channels = stream_info.get(str(stage) + "_source_channels") self.source_idx = [self.colnames.index(c) for c in self.source_channels] self.source_idx = np.array(self.source_idx, dtype=np.int64) - self.target_channels = self.select_channels(data_colnames, t_chs, t_chs_exclude) + if stream_info.get(str(stage) + "_target_channels") is None: + t_chs = stream_info.get("target") + t_chs_exclude = stream_info.get("target_exclude", []) + self.target_channels = self.select_channels(data_colnames, t_chs, t_chs_exclude) + else: + self.target_channels = stream_info.get(str(stage) + "_target_channels") self.target_idx = [self.colnames.index(c) for c in self.target_channels] self.target_idx = np.array(self.target_idx, dtype=np.int64) # determine idx for coords and geoinfos self.coords_idx = [self.colnames.index("lat"), self.colnames.index("lon")] - self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) - self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx] + + # geoinfo channels + sname = stream_info["name"] + if stream_info.get("geoinfo_channels") is not None: + self.geoinfo_idx, self.geoinfo_channels = [], [] + for c in stream_info.get("geoinfo_channels"): + if c not in self.colnames: + _logger.warning(f"{sname} : geoinfo {c} specified in config but not present.") + else: + self.geoinfo_idx.append(self.colnames.index(c)) + self.geoinfo_channels.append(c) + else: + self.geoinfo_idx = list(range(self.coords_idx[-1] + 1, data_idx[0])) + self.geoinfo_channels = [self.colnames[i] for i in self.geoinfo_idx] + _logger.info(f"{stream_info['name']} geoinfos : {self.geoinfo_channels}") # load additional properties (mean, var) self._load_properties() @@ -185,7 +200,7 @@ def _setup_sample_index(self) -> None: self.indices_start = np.append( self.indices_start, np.ones( - (diff_in_hours_end - self.hrly_index.shape[0] - 1) // step_hrs, dtype=int + (diff_in_hours_end - (self.hrly_index.shape[0] - 1)) // step_hrs, dtype=int ) * self.indices_start[-1], ) @@ -235,6 +250,11 @@ def _get(self, idx: int, channels_idx: list[int]) -> ReaderData: num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) ) + if idx >= len(self.indices_start) or idx >= len(self.indices_end): + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + start_row = self.indices_start[idx] end_row = self.indices_end[idx] diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index f637eb082..b98131f80 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -220,6 +220,7 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: kwargs = { "tw_handler": self.time_window_handler, "stream_info": stream_info, + "stage": self._stage, } dataset: type[AnyDataReader] | None = None match stream_info["type"]: