diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 407239dbb..c0eb18c6f 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -491,7 +491,9 @@ def example_key(self) -> ItemKey: try: sample, example_sample = next(self.data_root.groups()) stream, example_stream = next(example_sample.groups()) - fstep = 0 + # Use the lowest available forecast step rather than assuming step 0 + # exists: some runs (e.g. without a source group) start at step 1. + fstep = min(int(s) for s in example_stream.group_keys()) except StopIteration as e: msg = f"Data store at: {self._store_path} is empty." raise FileNotFoundError(msg) from e diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 3a9e7b2e5..506aec78e 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -96,10 +96,10 @@ def _get_file_extension(output_format: str) -> str: return "nc" if output_format == "verif": return "nc" - elif output_format == "quaver": + elif output_format in ("quaver", "evalml"): return "grib" else: raise ValueError( f"Unsupported output format: {output_format}," - "supported formats are ['netcdf', 'verif', 'quaver']" + "supported formats are ['netcdf', 'verif', 'quaver', 'evalml']" ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index eeea8576b..da741df1a 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -252,16 +252,29 @@ def get_source_info(fname_zarr, stream, samples) -> tuple[list[np.datetime64], l source_starts = [] source_ends = [] with zarrio_reader(fname_zarr) as zio: + fstep0 = sorted(int(s) for s in zio.forecast_steps)[0] for sample in tqdm(samples, desc="Getting source info"): group_path = f"{sample}/{stream}/0/source" source_group = zio.data_root.get(group_path) - if source_group is None: - raise FileNotFoundError(f"Zarr group '{group_path}' not found in {fname_zarr}") - - times_arr = np.asarray(source_group["times"]).astype("datetime64[ns]") - source_start = np.min(times_arr) - source_end = np.max(times_arr) + if source_group is not None: + times_arr = np.asarray(source_group["times"]).astype("datetime64[ns]") + source_start = np.min(times_arr) + source_end = np.max(times_arr) + else: + # No step-0 source group: derive the interval from the + # source_interval attribute stored on the prediction/target group. + grp = zio.data_root.get( + f"{sample}/{stream}/{fstep0}/prediction" + ) or zio.data_root.get(f"{sample}/{stream}/{fstep0}/target") + if grp is None: + raise FileNotFoundError( + f"Neither '{group_path}' nor a fstep-{fstep0} prediction/target " + f"group found for sample {sample} in {fname_zarr}" + ) + interval = dict(grp.attrs)["source_interval"] + source_start = np.datetime64(interval["start"]).astype("datetime64[ns]") + source_end = np.datetime64(interval["end"]).astype("datetime64[ns]") _logger.debug(f"Sample {sample}: source_interval=[{source_start} .. {source_end}]") source_starts.append(source_start) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 162377512..c70e49ed8 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -95,16 +95,17 @@ def parse_args(args: list) -> argparse.Namespace: "--format", dest="output_format", type=str, - choices=["netcdf", "verif", "quaver"], + choices=["netcdf", "verif", "quaver", "evalml"], help="Output file format; netcdf (CF-compliant netcdfs), \ - verif (netcdf compatible with MetNor verif tool), quaver (GRIB files for Quaver tool)", + verif (netcdf compatible with MetNor verif tool), quaver (GRIB files for Quaver tool), \ + evalml (GRIB files for the evalML tool)", required=True, ) parser.add_argument( "--stream", type=str, - choices=["N320", "ERA5", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"], + choices=["N320", "ERA5", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI", "ICON"], help="Stream name to retrieve data for", ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py index 1d3f7e26b..425bca306 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py @@ -1,6 +1,7 @@ from omegaconf import OmegaConf from weathergen.evaluate.export.cf_utils import CfParser +from weathergen.evaluate.export.parsers.evalml_parser import EvalmlParser from weathergen.evaluate.export.parsers.netcdf_parser import NetcdfParser from weathergen.evaluate.export.parsers.quaver_parser import QuaverParser from weathergen.evaluate.export.parsers.verif_parser import VerifParser @@ -31,6 +32,7 @@ def get_parser(config: OmegaConf, **kwargs) -> CfParser: _parser_map = { "netcdf": (NetcdfParser, ["grid_type"]), "quaver": (QuaverParser, ["grid_type", "channels", "template"]), + "evalml": (EvalmlParser, ["grid_type", "channels", "template"]), "verif": (VerifParser, ["obs", "method", "verif_template"]), } diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/evalml_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/evalml_parser.py new file mode 100644 index 000000000..de98593f8 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/evalml_parser.py @@ -0,0 +1,210 @@ +# pylint: disable=bad-builtin + +import contextlib +import logging +from pathlib import Path + +import eccodes +import numpy as np +import pandas as pd +import xarray as xr +from omegaconf import OmegaConf + +from weathergen.evaluate.export.cf_utils import CfParser + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +""" +Usage: + +uv run export --run-id z8mi61v4 --stream ICON +--output-dir ./out +--format evalml --type prediction +--quaver-template-folder /path/to/evalml/resources/inference/templates +--quaver-template-grid-type icon-ch1 + +evalML expects one GRIB file per lead time per forecast initialisation: + + {output_dir}/{YYYYmmddHHMM}/{YYYYmmddHHMM}_{step:03d}.grib + +Each file holds every requested variable at that lead time. Total precipitation +is written cumulative-from-start (evalML diffs it back to hourly values). + +This parser mirrors scripts/zarr_to_grib.py: it clones the per-typeOfLevel GRIB +templates shipped with evalML and fills in metadata + values with eccodes, +keeping the data in the native (ICON) grid order. +""" + +# WeatherGenerator surface channel -> GRIB template + shortName + level. +SURFACE_CHANNELS: dict[str, dict] = { + "T_2M": {"template_key": "heightAboveGround", "shortName": "2t", "level": 2}, + "TD_2M": {"template_key": "heightAboveGround", "shortName": "2d", "level": 2}, + "U_10M": {"template_key": "heightAboveGround", "shortName": "10u", "level": 10}, + "V_10M": {"template_key": "heightAboveGround", "shortName": "10v", "level": 10}, + "PMSL": {"template_key": "meanSea", "shortName": "prmsl", "level": 0}, + "PS": {"template_key": "surface", "shortName": "sp", "level": 0}, + "TOT_PREC_1H": {"template_key": "TOT_PREC", "shortName": "tp", "level": 0, "accumulated": True}, + "TOT_PREC": {"template_key": "TOT_PREC", "shortName": "tp", "level": 0, "accumulated": True}, +} + +# Pressure-level channel prefix -> GRIB shortName (names like T_850, QV_500, FI_250). +PLEVEL_CHANNELS: dict[str, str] = {"T": "t", "QV": "q", "U": "u", "V": "v", "FI": "z"} + +# typeOfLevel key -> template filename, formatted with the grid type (e.g. icon-ch1). +TEMPLATE_FILES: dict[str, str] = { + "heightAboveGround": "{grid}-typeOfLevel=heightAboveGround.grib", + "surface": "{grid}-typeOfLevel=surface.grib", + "meanSea": "{grid}-typeOfLevel=meanSea.grib", + "TOT_PREC": "{grid}-shortName=TOT_PREC.grib", + "isobaricInhPa": "{grid}-typeOfLevel=isobaricInhPa.grib", +} + + +class EvalmlParser(CfParser): + """evalML GRIB exporter (one file per lead time, per init time).""" + + def __init__(self, config: OmegaConf, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + if not getattr(self, "quaver_template_folder", None): + raise ValueError("Template folder must be provided for evalml format.") + if not getattr(self, "quaver_template_grid_type", None): + raise ValueError("Template grid type (e.g. icon-ch1) must be provided for evalml.") + if not getattr(self, "channels", None): + raise ValueError("Channels must be provided for evalml format.") + + super().__init__(config, **kwargs) + + self.templates = self._load_templates() + if not self.templates: + raise RuntimeError( + f"No templates loaded from {self.quaver_template_folder} " + f"for grid '{self.quaver_template_grid_type}'." + ) + + def _load_templates(self) -> dict[str, object]: + """Read one GRIB message from each template file; return {template_key: handle}.""" + folder = Path(self.quaver_template_folder) + grid = self.quaver_template_grid_type + templates = {} + for key, pattern in TEMPLATE_FILES.items(): + path = folder / pattern.format(grid=grid) + if not path.exists(): + _logger.warning("Template not found, skipping %s: %s", key, path) + continue + with open(path, "rb") as f: + templates[key] = eccodes.codes_grib_new_from_file(f) + _logger.info("Loaded template: %s (%s)", key, path.name) + return templates + + def process_sample( + self, + fstep_iterator_results: iter, + ref_time: np.datetime64, + source_interval_start: np.datetime64 = None, + source_interval_end: np.datetime64 = None, + ): + """Write one GRIB file per lead time into a per-init-time directory.""" + init_dt = pd.Timestamp(source_interval_end) + stamp = init_dt.strftime("%Y%m%d%H%M") + sample_dir = Path(self.output_dir) / stamp + sample_dir.mkdir(parents=True, exist_ok=True) + + # Collect every hourly sub-step across all fsteps, keyed by valid_time, + # so they can be written (and precipitation accumulated) in time order. + subs: dict[np.datetime64, xr.DataArray] = {} + for result in fstep_iterator_results: + if result is None: + continue + if not isinstance(result, xr.DataArray): + result = result.as_xarray().squeeze() + result = result.sel(channel=self.channels) + for vt in np.unique(result.valid_time.values): + mask = result.valid_time.values == vt + subs[vt] = result.isel(ipoint=mask) + + cumulative: dict[str, np.typing.NDArray] = {} + for vt in sorted(subs): + da_sub = subs[vt] + step_h = int((vt - np.datetime64(source_interval_end)) / np.timedelta64(1, "h")) + out_file = sample_dir / f"{stamp}_{step_h:03d}.grib" + if out_file.exists(): + out_file.unlink() + + n_msgs = 0 + for var in self.channels: + spec = self._channel_spec(var) + if spec is None: + continue + template_key, short_name, level, accumulated = spec + if template_key not in self.templates: + _logger.warning("No template '%s' for %s, skipping", template_key, var) + continue + + values = self.scale_data(da_sub.sel(channel=var), var).values + if accumulated: + cumulative[var] = values if var not in cumulative else cumulative[var] + values + values = cumulative[var] + + self._write_field( + out_file, template_key, values, init_dt, step_h, short_name, level, accumulated + ) + n_msgs += 1 + + _logger.info( + " step %03d valid=%s msgs=%d", step_h, np.datetime_as_string(vt, "h"), n_msgs + ) + + _logger.info(f"Saved sample data to {self.output_format} in {sample_dir}.") + + @staticmethod + def _channel_spec(var: str) -> tuple[str, str, int, bool] | None: + """Return (template_key, shortName, level, accumulated) for a channel, or None.""" + if var in SURFACE_CHANNELS: + cfg = SURFACE_CHANNELS[var] + return ( + cfg["template_key"], + cfg["shortName"], + cfg["level"], + cfg.get("accumulated", False), + ) + parts = var.rsplit("_", 1) + if len(parts) == 2 and parts[0] in PLEVEL_CHANNELS and parts[1].isdigit(): + return "isobaricInhPa", PLEVEL_CHANNELS[parts[0]], int(parts[1]), False + _logger.warning("No evalml mapping for channel '%s', skipping", var) + return None + + def _write_field( + self, + out_path: Path, + template_key: str, + values: np.typing.NDArray, + init_dt: pd.Timestamp, + step_h: int, + short_name: str, + level: int, + accumulated: bool, + ) -> None: + """Clone the template, set metadata + values, and append the message to out_path.""" + msg = eccodes.codes_clone(self.templates[template_key]) + try: + eccodes.codes_set(msg, "shortName", short_name) + eccodes.codes_set(msg, "level", level) + eccodes.codes_set(msg, "dataDate", int(init_dt.strftime("%Y%m%d"))) + eccodes.codes_set(msg, "dataTime", int(init_dt.strftime("%H%M"))) + eccodes.codes_set(msg, "stepUnits", 1) # hours + if accumulated: + eccodes.codes_set(msg, "startStep", 0) + eccodes.codes_set(msg, "endStep", step_h) + eccodes.codes_set_values(msg, np.asarray(values, dtype=float)) + with open(out_path, "ab") as f: + eccodes.codes_write(msg, f) + finally: + eccodes.codes_release(msg) + + def __del__(self): + for msg in getattr(self, "templates", {}).values(): + with contextlib.suppress(Exception): # best-effort cleanup + eccodes.codes_release(msg) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py index 0cd860137..82036021e 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool: True if any metric requires climatology, False otherwise """ metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()] - req_clim = ["acc", "rps", "rpss"] + req_clim = ["acc", "rps", "rpss"] return any(m in req_clim for m in metrics)