Skip to content
9 changes: 8 additions & 1 deletion config/evaluate/config_zarr2cf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ coordinates:
forecast_step: forecast_period
forecast_reference_time: forecast_reference_time
ncells: ncells
mem: mem
pl:
pressure_level: pressure
valid_time: valid_time
Expand All @@ -114,6 +115,7 @@ coordinates:
forecast_step: forecast_period
forecast_reference_time: forecast_reference_time
ncells: ncells
mem: mem

dimensions:
valid_time:
Expand Down Expand Up @@ -141,4 +143,9 @@ dimensions:
std_unit: hours
ncells:
wg: ncells
std: ncells
std: ncells
mem:
wg: mem
std: realization
long: realization
std_unit: 1
39 changes: 22 additions & 17 deletions packages/evaluate/src/weathergen/evaluate/export/export_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,23 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]:
npoints = data_arr.shape[0]

# Handle optional ensemble dimension: squeeze it out if present.
if data_arr.ndim == 3 and data_arr.shape[2] == 1:
data_arr = data_arr[:, :, 0]

da_result = xr.DataArray(
data_arr,
dims=["ipoint", "channel"],
coords={
"ipoint": np.arange(npoints),
"channel": channels,
"forecast_step": fstep,
"valid_time": ("ipoint", times_arr),
"lat": ("ipoint", coords_arr[:, 0]),
"lon": ("ipoint", coords_arr[:, 1]),
},
)
if data_arr.ndim == 3:
if data_arr.shape[2] == 1:
data_arr = data_arr[:, :, 0]
data_dims = ["ipoint", "channel"]
else:
data_dims = ["ipoint", "channel", "mem"]

data_coords = {
"ipoint": np.arange(npoints),
"channel": channels,
"forecast_step": fstep,
"valid_time": ("ipoint", times_arr),
"lat": ("ipoint", coords_arr[:, 0]),
"lon": ("ipoint", coords_arr[:, 1]),
}

da_result = xr.DataArray(data_arr, dims=data_dims, coords=data_coords)

return (sample, fstep, da_result)

Expand Down Expand Up @@ -312,13 +314,16 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None:
fsteps = get_fsteps(fsteps, fname_zarr)
samples = get_samples(samples, fname_zarr)
streams = get_streams(stream, fname_zarr)
_logger.info(f"Streams to process: {streams}")
for stream in streams:
grid_type = get_grid_type(data_type, stream, fname_zarr)
channels = get_channels(channels, stream, fname_zarr)
stream_channels = get_channels(channels, stream, fname_zarr)
source_starts, source_ends = get_source_info(fname_zarr, stream, samples)
# prevent overwritting of channels between streams
kwargs["grid_type"] = grid_type
kwargs["channels"] = channels
kwargs["channels"] = stream_channels
kwargs["data_type"] = data_type
kwargs["stream"] = stream

parser = CfParserFactory.get_parser(config=config, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def parse_args(args: list) -> argparse.Namespace:
parser.add_argument(
"--stream",
type=str,
choices=["ERA5", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"],
choices=["ERA5", "ERA5pl", "ERA5ml", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"],
help="Stream name to retrieve data for",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def process_sample(
da_fs = []

for result in fstep_iterator_results:
if result is None:
if result is None or (isinstance(result, xr.DataArray) and result.size == 0):
continue

# result is already a materialized xarray DataArray (built in the worker).
Expand Down Expand Up @@ -139,19 +139,25 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset:
# Original logic
var_dict = find_pl(data.channel.values)
data_vars = {}

#order of appending upoints should be ipoint, pressure_level, mem (if mem exists)
for new_var, pls in var_dict.items():
data_dims = ["ipoint"]
if pls[0] is not None:
data_dims.append("pressure_level")
if "mem" in data.dims:
data_dims.append("mem")
old_vars = [f"{new_var}_{p}" for p in pls]
data_vars[new_var] = xr.DataArray(
data.sel(channel=old_vars).values,
dims=["ipoint", "pressure_level"],
dims=data_dims,
coords={"pressure_level": pls},
)
else:
if "mem" in data.dims:
data_dims.append("mem")
data_vars[new_var] = xr.DataArray(
data.sel(channel=new_var).values,
dims=["ipoint"],
dims=data_dims,
)

reshaped_dataset = xr.Dataset(data_vars)
Expand Down Expand Up @@ -471,7 +477,6 @@ def _build_coordinate_mapping(
f"Coordinate '{coord}' will be skipped for "
f"variable '{var_cfg.get('var', 'unknown')}'."
)

return coords

def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def process_sample(
ref_time: np.datetime64,
source_interval_start: np.datetime64 = None,
source_interval_end: np.datetime64 = None,
**kwargs
):
"""
Process results from get_data_worker: reshape, concatenate, add metadata, and save.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,26 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset:
data_vars = {}

for new_var, pls in var_dict.items():
data_dims = ["ipoint"]
if "mem" in data.dims:
data_dims.append("mem")
if pls[0] is not None:
old_vars = [f"{new_var}_{p}" for p in pls]
data_vars[new_var] = xr.DataArray(
data.sel(channel=old_vars).values,
dims=["ipoint", "pressure_level"],
dims=[*data_dims, "pressure_level"],
coords={"pressure_level": pls},
)
else:
data_vars[new_var] = xr.DataArray(
data.sel(channel=new_var).values,
dims=["ipoint"],
dims=data_dims,
)

reshaped_dataset = xr.Dataset(data_vars)
reshaped_dataset = reshaped_dataset.assign_coords(
ipoint=data.coords["ipoint"],
)

# order using pressure_level coord
if "pressure_level" in reshaped_dataset.coords:
reshaped_dataset = reshaped_dataset.sortby("pressure_level")
Expand Down Expand Up @@ -544,7 +546,8 @@ def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset:

coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs)

wg_unit = mapped_units.get(self.stream, "DEFAULT")
wg_unit = mapped_units.get(self.stream, mapped_units.get("DEFAULT", None))
print(wg_unit)
verif_unit = mapped_info.get("verif_unit", None)
if wg_unit != verif_unit:
# perform unit conversion
Expand Down
14 changes: 5 additions & 9 deletions packages/evaluate/src/weathergen/evaluate/export/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,9 @@ def gaussian_regular_da(self, data: xr.DataArray) -> xr.DataArray:
pos = dims.index("ncells")
dims[pos : pos + 1] = ["latitude", "longitude"]
dims = tuple(dims)
ordered_dims = (
["valid_time", "pressure", "latitude", "longitude"]
if len(dims) == 4
else ["valid_time", "latitude", "longitude"]
)
permutation_indices = [dims.index(o_dim) for o_dim in ordered_dims]
regridded_values = np.transpose(regridded_values, axes=permutation_indices)

regrid_data = xr.DataArray(
data=regridded_values, dims=ordered_dims, coords=new_coords, attrs=attrs, name=data.name
data=regridded_values, dims=dims, coords=new_coords, attrs=attrs, name=data.name
)

return regrid_data
Expand Down Expand Up @@ -549,7 +543,9 @@ def regrid_ds(
regrid_vars[var] = self.regrid_da(ds[var])
regrid_ds = xr.Dataset(regrid_vars)
regrid_ds = self.add_attrs(regrid_ds)

regrid_ds = regrid_ds.transpose(
"valid_time", "pressure", "latitude", "longitude", "mem", ..., missing_dims="ignore"
)
return regrid_ds

def regrid_da(self, da: xr.DataArray) -> xr.DataArray:
Expand Down
Loading