Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion config/evaluate/config_zarr2cf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ coordinates:
forecast_step: forecast_period
forecast_reference_time: forecast_reference_time
ncells: ncells
mem: mem
pl:
pressure_level: pressure
valid_time: valid_time
Expand All @@ -114,6 +115,7 @@ coordinates:
forecast_step: forecast_period
forecast_reference_time: forecast_reference_time
ncells: ncells
mem: mem

dimensions:
valid_time:
Expand Down Expand Up @@ -141,4 +143,9 @@ dimensions:
std_unit: hours
ncells:
wg: ncells
std: ncells
std: ncells
mem:
wg: mem
std: realization
long: realization
std_unit: 1
32 changes: 17 additions & 15 deletions packages/evaluate/src/weathergen/evaluate/export/export_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,23 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]:
npoints = data_arr.shape[0]

# Handle optional ensemble dimension: squeeze it out if present.
if data_arr.ndim == 3 and data_arr.shape[2] == 1:
data_arr = data_arr[:, :, 0]

da_result = xr.DataArray(
data_arr,
dims=["ipoint", "channel"],
coords={
"ipoint": np.arange(npoints),
"channel": channels,
"forecast_step": fstep,
"valid_time": ("ipoint", times_arr),
"lat": ("ipoint", coords_arr[:, 0]),
"lon": ("ipoint", coords_arr[:, 1]),
},
)
if data_arr.ndim == 3:
if data_arr.shape[2] == 1:
data_arr = data_arr[:, :, 0]
data_dims = ["ipoint", "channel"]
else:
data_dims = ["ipoint", "channel", "mem"]

data_coords = {
"ipoint": np.arange(npoints),
"channel": channels,
"forecast_step": fstep,
"valid_time": ("ipoint", times_arr),
"lat": ("ipoint", coords_arr[:, 0]),
"lon": ("ipoint", coords_arr[:, 1]),
}

da_result = xr.DataArray(data_arr, dims=data_dims, coords=data_coords)

return (sample, fstep, da_result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,20 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset:
data_vars = {}

for new_var, pls in var_dict.items():
data_dims = ["ipoint"]
if "mem" in data.dims:
data_dims.append("mem")
if pls[0] is not None:
old_vars = [f"{new_var}_{p}" for p in pls]
data_vars[new_var] = xr.DataArray(
data.sel(channel=old_vars).values,
dims=["ipoint", "pressure_level"],
dims=[*data_dims, "pressure_level"],
coords={"pressure_level": pls},
)
else:
data_vars[new_var] = xr.DataArray(
data.sel(channel=new_var).values,
dims=["ipoint"],
dims=data_dims,
)

reshaped_dataset = xr.Dataset(data_vars)
Expand Down Expand Up @@ -471,7 +474,6 @@ def _build_coordinate_mapping(
f"Coordinate '{coord}' will be skipped for "
f"variable '{var_cfg.get('var', 'unknown')}'."
)

return coords

def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset:
Expand Down
14 changes: 5 additions & 9 deletions packages/evaluate/src/weathergen/evaluate/export/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,9 @@ def gaussian_regular_da(self, data: xr.DataArray) -> xr.DataArray:
pos = dims.index("ncells")
dims[pos : pos + 1] = ["latitude", "longitude"]
dims = tuple(dims)
ordered_dims = (
["valid_time", "pressure", "latitude", "longitude"]
if len(dims) == 4
else ["valid_time", "latitude", "longitude"]
)
permutation_indices = [dims.index(o_dim) for o_dim in ordered_dims]
regridded_values = np.transpose(regridded_values, axes=permutation_indices)

regrid_data = xr.DataArray(
data=regridded_values, dims=ordered_dims, coords=new_coords, attrs=attrs, name=data.name
data=regridded_values, dims=dims, coords=new_coords, attrs=attrs, name=data.name
)

return regrid_data
Expand Down Expand Up @@ -549,7 +543,9 @@ def regrid_ds(
regrid_vars[var] = self.regrid_da(ds[var])
regrid_ds = xr.Dataset(regrid_vars)
regrid_ds = self.add_attrs(regrid_ds)

regrid_ds = regrid_ds.transpose(
"valid_time", "pressure", "latitude", "longitude", "mem", ..., missing_dims="ignore"
)
return regrid_ds

def regrid_da(self, da: xr.DataArray) -> xr.DataArray:
Expand Down
Loading