diff --git a/config/evaluate/config_zarr2cf.yaml b/config/evaluate/config_zarr2cf.yaml index 75677858b..7d0e7163a 100644 --- a/config/evaluate/config_zarr2cf.yaml +++ b/config/evaluate/config_zarr2cf.yaml @@ -106,6 +106,7 @@ coordinates: forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells + mem: mem pl: pressure_level: pressure valid_time: valid_time @@ -114,6 +115,7 @@ coordinates: forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells + mem: mem dimensions: valid_time: @@ -141,4 +143,9 @@ dimensions: std_unit: hours ncells: wg: ncells - std: ncells \ No newline at end of file + std: ncells + mem: + wg: mem + std: realization + long: realization + std_unit: 1 \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index eeea8576b..35dc9135b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -62,21 +62,23 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: npoints = data_arr.shape[0] # Handle optional ensemble dimension: squeeze it out if present. - if data_arr.ndim == 3 and data_arr.shape[2] == 1: - data_arr = data_arr[:, :, 0] - - da_result = xr.DataArray( - data_arr, - dims=["ipoint", "channel"], - coords={ - "ipoint": np.arange(npoints), - "channel": channels, - "forecast_step": fstep, - "valid_time": ("ipoint", times_arr), - "lat": ("ipoint", coords_arr[:, 0]), - "lon": ("ipoint", coords_arr[:, 1]), - }, - ) + if data_arr.ndim == 3: + if data_arr.shape[2] == 1: + data_arr = data_arr[:, :, 0] + data_dims = ["ipoint", "channel"] + else: + data_dims = ["ipoint", "channel", "mem"] + + data_coords = { + "ipoint": np.arange(npoints), + "channel": channels, + "forecast_step": fstep, + "valid_time": ("ipoint", times_arr), + "lat": ("ipoint", coords_arr[:, 0]), + "lon": ("ipoint", coords_arr[:, 1]), + } + + da_result = xr.DataArray(data_arr, dims=data_dims, coords=data_coords) return (sample, fstep, da_result) @@ -312,13 +314,16 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: fsteps = get_fsteps(fsteps, fname_zarr) samples = get_samples(samples, fname_zarr) streams = get_streams(stream, fname_zarr) + _logger.info(f"Streams to process: {streams}") for stream in streams: grid_type = get_grid_type(data_type, stream, fname_zarr) - channels = get_channels(channels, stream, fname_zarr) + stream_channels = get_channels(channels, stream, fname_zarr) source_starts, source_ends = get_source_info(fname_zarr, stream, samples) + # prevent overwritting of channels between streams kwargs["grid_type"] = grid_type - kwargs["channels"] = channels + kwargs["channels"] = stream_channels kwargs["data_type"] = data_type + kwargs["stream"] = stream parser = CfParserFactory.get_parser(config=config, **kwargs) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 22a674a86..0409a461d 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -104,7 +104,7 @@ def parse_args(args: list) -> argparse.Namespace: parser.add_argument( "--stream", type=str, - choices=["ERA5", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"], + choices=["ERA5", "ERA5pl", "ERA5ml", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"], help="Stream name to retrieve data for", ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index fc30b05f3..e0b6bbe90 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -70,7 +70,7 @@ def process_sample( da_fs = [] for result in fstep_iterator_results: - if result is None: + if result is None or (isinstance(result, xr.DataArray) and result.size == 0): continue # result is already a materialized xarray DataArray (built in the worker). @@ -139,19 +139,25 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: # Original logic var_dict = find_pl(data.channel.values) data_vars = {} - + #order of appending upoints should be ipoint, pressure_level, mem (if mem exists) for new_var, pls in var_dict.items(): + data_dims = ["ipoint"] if pls[0] is not None: + data_dims.append("pressure_level") + if "mem" in data.dims: + data_dims.append("mem") old_vars = [f"{new_var}_{p}" for p in pls] data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, - dims=["ipoint", "pressure_level"], + dims=data_dims, coords={"pressure_level": pls}, ) else: + if "mem" in data.dims: + data_dims.append("mem") data_vars[new_var] = xr.DataArray( data.sel(channel=new_var).values, - dims=["ipoint"], + dims=data_dims, ) reshaped_dataset = xr.Dataset(data_vars) @@ -471,7 +477,6 @@ def _build_coordinate_mapping( f"Coordinate '{coord}' will be skipped for " f"variable '{var_cfg.get('var', 'unknown')}'." ) - return coords def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py index a19077ccb..f68fa7c78 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py @@ -74,6 +74,7 @@ def process_sample( ref_time: np.datetime64, source_interval_start: np.datetime64 = None, source_interval_end: np.datetime64 = None, + **kwargs ): """ Process results from get_data_worker: reshape, concatenate, add metadata, and save. diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 3b33786a4..956ef671c 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -205,24 +205,26 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: data_vars = {} for new_var, pls in var_dict.items(): + data_dims = ["ipoint"] + if "mem" in data.dims: + data_dims.append("mem") if pls[0] is not None: old_vars = [f"{new_var}_{p}" for p in pls] data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, - dims=["ipoint", "pressure_level"], + dims=[*data_dims, "pressure_level"], coords={"pressure_level": pls}, ) else: data_vars[new_var] = xr.DataArray( data.sel(channel=new_var).values, - dims=["ipoint"], + dims=data_dims, ) reshaped_dataset = xr.Dataset(data_vars) reshaped_dataset = reshaped_dataset.assign_coords( ipoint=data.coords["ipoint"], ) - # order using pressure_level coord if "pressure_level" in reshaped_dataset.coords: reshaped_dataset = reshaped_dataset.sortby("pressure_level") @@ -544,7 +546,8 @@ def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) - wg_unit = mapped_units.get(self.stream, "DEFAULT") + wg_unit = mapped_units.get(self.stream, mapped_units.get("DEFAULT", None)) + print(wg_unit) verif_unit = mapped_info.get("verif_unit", None) if wg_unit != verif_unit: # perform unit conversion diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index d0d0bea6e..9ec334214 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -267,15 +267,9 @@ def gaussian_regular_da(self, data: xr.DataArray) -> xr.DataArray: pos = dims.index("ncells") dims[pos : pos + 1] = ["latitude", "longitude"] dims = tuple(dims) - ordered_dims = ( - ["valid_time", "pressure", "latitude", "longitude"] - if len(dims) == 4 - else ["valid_time", "latitude", "longitude"] - ) - permutation_indices = [dims.index(o_dim) for o_dim in ordered_dims] - regridded_values = np.transpose(regridded_values, axes=permutation_indices) + regrid_data = xr.DataArray( - data=regridded_values, dims=ordered_dims, coords=new_coords, attrs=attrs, name=data.name + data=regridded_values, dims=dims, coords=new_coords, attrs=attrs, name=data.name ) return regrid_data @@ -549,7 +543,9 @@ def regrid_ds( regrid_vars[var] = self.regrid_da(ds[var]) regrid_ds = xr.Dataset(regrid_vars) regrid_ds = self.add_attrs(regrid_ds) - + regrid_ds = regrid_ds.transpose( + "valid_time", "pressure", "latitude", "longitude", "mem", ..., missing_dims="ignore" + ) return regrid_ds def regrid_da(self, da: xr.DataArray) -> xr.DataArray: