diff --git a/config/evaluate/config_zarr2cf.yaml b/config/evaluate/config_zarr2cf.yaml index 75677858b..7d0e7163a 100644 --- a/config/evaluate/config_zarr2cf.yaml +++ b/config/evaluate/config_zarr2cf.yaml @@ -106,6 +106,7 @@ coordinates: forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells + mem: mem pl: pressure_level: pressure valid_time: valid_time @@ -114,6 +115,7 @@ coordinates: forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells + mem: mem dimensions: valid_time: @@ -141,4 +143,9 @@ dimensions: std_unit: hours ncells: wg: ncells - std: ncells \ No newline at end of file + std: ncells + mem: + wg: mem + std: realization + long: realization + std_unit: 1 \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index eeea8576b..fb691db93 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -62,21 +62,23 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: npoints = data_arr.shape[0] # Handle optional ensemble dimension: squeeze it out if present. - if data_arr.ndim == 3 and data_arr.shape[2] == 1: - data_arr = data_arr[:, :, 0] - - da_result = xr.DataArray( - data_arr, - dims=["ipoint", "channel"], - coords={ - "ipoint": np.arange(npoints), - "channel": channels, - "forecast_step": fstep, - "valid_time": ("ipoint", times_arr), - "lat": ("ipoint", coords_arr[:, 0]), - "lon": ("ipoint", coords_arr[:, 1]), - }, - ) + if data_arr.ndim == 3: + if data_arr.shape[2] == 1: + data_arr = data_arr[:, :, 0] + data_dims = ["ipoint", "channel"] + else: + data_dims = ["ipoint", "channel", "mem"] + + data_coords = { + "ipoint": np.arange(npoints), + "channel": channels, + "forecast_step": fstep, + "valid_time": ("ipoint", times_arr), + "lat": ("ipoint", coords_arr[:, 0]), + "lon": ("ipoint", coords_arr[:, 1]), + } + + da_result = xr.DataArray(data_arr, dims=data_dims, coords=data_coords) return (sample, fstep, da_result) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index fc30b05f3..8767d6df3 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -141,17 +141,20 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: data_vars = {} for new_var, pls in var_dict.items(): + data_dims = ["ipoint"] + if "mem" in data.dims: + data_dims.append("mem") if pls[0] is not None: old_vars = [f"{new_var}_{p}" for p in pls] data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, - dims=["ipoint", "pressure_level"], + dims=[*data_dims, "pressure_level"], coords={"pressure_level": pls}, ) else: data_vars[new_var] = xr.DataArray( data.sel(channel=new_var).values, - dims=["ipoint"], + dims=data_dims, ) reshaped_dataset = xr.Dataset(data_vars) @@ -471,7 +474,6 @@ def _build_coordinate_mapping( f"Coordinate '{coord}' will be skipped for " f"variable '{var_cfg.get('var', 'unknown')}'." ) - return coords def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index d0d0bea6e..9ec334214 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -267,15 +267,9 @@ def gaussian_regular_da(self, data: xr.DataArray) -> xr.DataArray: pos = dims.index("ncells") dims[pos : pos + 1] = ["latitude", "longitude"] dims = tuple(dims) - ordered_dims = ( - ["valid_time", "pressure", "latitude", "longitude"] - if len(dims) == 4 - else ["valid_time", "latitude", "longitude"] - ) - permutation_indices = [dims.index(o_dim) for o_dim in ordered_dims] - regridded_values = np.transpose(regridded_values, axes=permutation_indices) + regrid_data = xr.DataArray( - data=regridded_values, dims=ordered_dims, coords=new_coords, attrs=attrs, name=data.name + data=regridded_values, dims=dims, coords=new_coords, attrs=attrs, name=data.name ) return regrid_data @@ -549,7 +543,9 @@ def regrid_ds( regrid_vars[var] = self.regrid_da(ds[var]) regrid_ds = xr.Dataset(regrid_vars) regrid_ds = self.add_attrs(regrid_ds) - + regrid_ds = regrid_ds.transpose( + "valid_time", "pressure", "latitude", "longitude", "mem", ..., missing_dims="ignore" + ) return regrid_ds def regrid_da(self, da: xr.DataArray) -> xr.DataArray: