From 4e7ec4774fc7d18a0e0365b8694146eb1e20a1af Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Jun 2026 15:06:03 +0200 Subject: [PATCH] Fixed IO for forcing --- packages/common/src/weathergen/common/io.py | 13 ++++++------- src/weathergen/utils/validation_io.py | 20 +++++++------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 0431b4698..407239dbb 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -182,9 +182,6 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData": assert other.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match" assert other.datetimes.shape[0] == n_datapoints, "number of datapoints do not match" - if n_datapoints == 0: - continue - coords = np.concatenate([coords, other.coords]) geoinfos = np.concatenate([geoinfos, other.geoinfos]) data = np.concatenate([data, other.data]) @@ -361,9 +358,9 @@ def __init__( def _append_dataset(self, dataset: OutputDataset | None, name: str) -> None: if dataset: self.datasets.append(dataset) - else: - msg = f"Missing {name} dataset for item: {self.key.path}" - raise ValueError(msg) + # else: + # msg = f"Missing {name} dataset for item: {self.key.path}" + # raise ValueError(msg) class ZarrIO: @@ -739,11 +736,13 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi def _extract_sources( self, sample: int, stream_idx: int, key: ItemKey, source_interval: TimeRange - ) -> OutputDataset: + ) -> OutputDataset | None: channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] source: IOReaderData = self.sources[sample][stream_idx] + if source is None: + return None assert source.data.shape[1] == len(channels), ( f"Number of source channel names {len(channels)} does not align with source data." diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index b71c3fe5b..47fb4466e 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -57,19 +57,12 @@ def write_output( # handle spoof data: do not write since it might corrupt validation (spoofing invisible # there) if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: - preds = model_output.get_physical_prediction(t_idx, sname) - # handle forcing streams or if sample is empty - if preds is None: - targets = target_aux_out.physical[t_idx][sname]["target"] - # preds are empty so create copy of target and add ensemble dimension - assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." - preds = [target.clone().unsqueeze(0) for target in targets] - preds_shape = preds[0].shape + targets = target_aux_out.physical[t_idx][sname]["target"] # for-loop to make sure we have a consistent number of samples - preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds] - targets_s = [np.zeros((0, preds_shape[2])) for _ in preds] - t_coords_s = [np.zeros((0, 2)) for _ in preds] - t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds] + preds_s = [np.zeros((1, 0, t.shape[1])) for t in targets] + targets_s = [np.zeros((0, t.shape[1])) for t in targets] + t_coords_s = [np.zeros((0, 2)) for t in targets] + t_times_s = [np.array([]).astype("datetime64[ns]") for t in targets] else: preds = model_output.get_physical_prediction(t_idx, sname) @@ -138,7 +131,8 @@ def write_output( output_stream_names = stream_names output_streams = {name: stream_names.index(name) for name in output_stream_names} - _logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}") + if batch_idx == 0: + _logger.info(f"Using output streams: {output_streams} from streams: {stream_names}") target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams] source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams]