Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions config/config_jepa_multi_data_ft_forecast.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
streams_directory: "./config/streams/jepa_forecast_multi_data_od/"
streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/"

freeze_modules: "" # "^(?!.*ERA5)(?=.*(?:encoder|latent_pre_norm|latent_heads)).*$"

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

training_config:
num_mini_epochs: 32
start_date: 2016-01-01T00:00
end_date: 2022-12-31T00:00

num_mini_epochs: 12
samples_per_mini_epoch: 8192
2 changes: 1 addition & 1 deletion config/streams/jepa_forecast_multi_data_all_years/era5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ERA5:
type: anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
# filenames: # ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr']
stream_id: 0
stream_id: 42
source: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl']
target: ['q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 'q_1000', 't_50', 't_100', 't_150', 't_200', 't_250', 't_300', 't_400', 't_500', 't_600', 't_700', 't_850', 't_925', 't_1000', 'u_50', 'u_100', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'u_1000', 'v_50', 'v_100', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'v_1000', 'z_50', 'z_100', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925', 'z_1000', '10u', '10v', '2d', '2t', 'msl']
geoinfo_channels: ['z', 'lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time']
Expand Down
39 changes: 39 additions & 0 deletions config/streams/jepa_forecast_multi_data_od_ckpt_order/analysis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# ERA5 input stream for temporal JEPA.
# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked
# forcing so the teacher metadata mask remains populated for SSL targets.

ERA5_in:
type: anemoi
# filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr']
stream_id: 0
source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925']
target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925']
geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z']
loss_weight: 1.0
forcing: True
location_weight: cosine_latitude
token_size: 8
tokenize_spacetime: True
max_num_targets: -1
frequency: 06:00:00
nominal_time_mapping :
"0" : 5 # 04:30:00
"6" : 9 # 09:00:00
"12" : 17 #16:30:00
"18" : 21 #21:00:00
embed:
net: transformer
num_tokens: 1
num_heads: 8
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 256
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1
28 changes: 28 additions & 0 deletions config/streams/jepa_forecast_multi_data_od_ckpt_order/avhrr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Polar-orbiting observation stream for temporal JEPA.

METOP_ABC_AVHRR_IASI:
type: obs
stream_id: 20
filenames: ['observations-ea-ofb-0001-2007-2021-metop-a-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2013-2023-metop-b-iasi-radiances-v1.zarr', 'observations-ea-ofb-0001-2019-2023-metop-c-iasi-radiances-v1.zarr']
geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'noise_time']
source: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383']
target: ['obsvalue_avhrr_mean_vis_0', 'obsvalue_rawbt_16', 'obsvalue_rawbt_63', 'obsvalue_rawbt_138', 'obsvalue_rawbt_170', 'obsvalue_rawbt_185', 'obsvalue_rawbt_224', 'obsvalue_rawbt_249', 'obsvalue_rawbt_271', 'obsvalue_rawbt_445', 'obsvalue_rawbt_756', 'obsvalue_rawbt_867', 'obsvalue_rawbt_921', 'obsvalue_rawbt_2907', 'obsvalue_rawbt_2991', 'obsvalue_rawbt_3093', 'obsvalue_rawbt_3160', 'obsvalue_rawbt_5383']
forcing: True
loss_weight: 1.0
token_size: 512
tokenize_spacetime: False
embed:
net: transformer
num_tokens: 1
num_heads: 2
dim_embed: 256
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 256
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1
36 changes: 36 additions & 0 deletions config/streams/jepa_forecast_multi_data_od_ckpt_order/era5.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# ERA5 input stream for temporal JEPA.
# This mirrors era5_georing_avhrr_forecast_random_inputs/era5.yml, but is not marked
# forcing so the teacher metadata mask remains populated for SSL targets.

ERA5:
type: anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2024-1h-v3-with-era51.zarr'] #['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
# filenames: ['aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr']
stream_id: 42
source: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925']
target: ['10u', '10v', '2d', '2t', 'msl', 'q_1000', 'q_150', 'q_200', 'q_250', 'q_300', 'q_400', 'q_500', 'q_600', 'q_700', 'q_850', 'q_925', 't_100', 't_1000', 't_150', 't_200', 't_250', 't_300', 't_400', 't_50', 't_500', 't_600', 't_700', 't_850', 't_925', 'u_100', 'u_1000', 'u_150', 'u_200', 'u_250', 'u_300', 'u_400', 'u_50', 'u_500', 'u_600', 'u_700', 'u_850', 'u_925', 'v_100', 'v_1000', 'v_150', 'v_200', 'v_250', 'v_300', 'v_400', 'v_50', 'v_500', 'v_600', 'v_700', 'v_850', 'v_925', 'z_100', 'z_1000', 'z_150', 'z_200', 'z_250', 'z_300', 'z_400', 'z_50', 'z_500', 'z_600', 'z_700', 'z_850', 'z_925']
geoinfo_channels: ['cos_julian_day', 'cos_local_time', 'insolation', 'lsm', 'noise_time', 'sdor', 'sin_julian_day', 'sin_local_time', 'slor', 'z']
loss_weight: 1.0
diagnostic: True
location_weight: cosine_latitude
token_size: 8
tokenize_spacetime: False
max_num_targets: -1
frequency: 01:00:00
embed:
net: transformer
num_tokens: 1
num_heads: 8
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 256
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1


140 changes: 140 additions & 0 deletions config/streams/jepa_forecast_multi_data_od_ckpt_order/geos.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Geostationary observation streams for temporal JEPA.
# These mirror era5_georing_avhrr_forecast_random_inputs/geos.yml, with forcing
# removed so teacher masks are available to the SSL loss.

METEOSAT_SEVIRI_IR:
type: obs
stream_id: 10
filenames: ['observations-file-2014-2024-seviri-o256-wegen-v3.zarr']
geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time']
source: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
target: ['obsvalue_rawbt_065', 'obsvalue_rawbt_086', 'obsvalue_rawbt_160', 'obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
forcing: True
loss_weight: 1.0
token_size: 1024
tokenize_spacetime: False
max_num_targets: 262144
embed:
net: transformer
num_tokens: 1
num_heads: 4
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 512
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1

GOES_ABI_IR:
type: obs
stream_id: 11
filenames: ['observations-file-2017-2024-abi-goes16-IR-o256-v2.zarr']
geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time']
source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
forcing: True
loss_weight: 1.0
token_size: 1024
tokenize_spacetime: False
max_num_targets: 262144
embed:
net: transformer
num_tokens: 1
num_heads: 4
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 512
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1

HIMAWARI_AHI_IR:
type: obs
stream_id: 12
filenames: ['observations-file-2015-2022-himawari8-IR-o256-v1.zarr', 'observations-file-2022-2024-himawari9-IR-o256-v1.zarr']
geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time']
source: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
target: ['obsvalue_rawbt_380', 'obsvalue_rawbt_620', 'obsvalue_rawbt_730', 'obsvalue_rawbt_850', 'obsvalue_rawbt_960', 'obsvalue_rawbt_105', 'obsvalue_rawbt_120', 'obsvalue_rawbt_133']
forcing: True
loss_weight: 1.0
token_size: 1024
tokenize_spacetime: False
max_num_targets: 262144
embed:
net: transformer
num_tokens: 1
num_heads: 4
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 512
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1

GOES_ABI_VIS:
type: obs
stream_id: 13
filenames: ['observations-file-2017-2024-abi-goes16-VIS-o256-v2.zarr']
geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time']
forcing: True
loss_weight: 1.0
token_size: 1024
tokenize_spacetime: False
max_num_targets: 262144
embed:
net: transformer
num_tokens: 1
num_heads: 4
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 512
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1

HIMAWARI_AHI_VIS:
type: obs
stream_id: 14
filenames: ['observations-file-2015-2022-himawari8-VIS-o256-v1.zarr', 'observations-file-2022-2024-himawari9-VIS-o256-v1.zarr']
geoinfo_channels: ['zenith', 'noise_time']
# geoinfo_channels: ['cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day', 'zenith', 'cos_sza', 'noise_time']
forcing: True
loss_weight: 1.0
token_size: 1024
tokenize_spacetime: False
max_num_targets: 262144
embed:
net: transformer
num_tokens: 1
num_heads: 4
dim_embed: 512
num_blocks: 2
embed_target_coords:
net: linear
dim_embed: 512
target_readout:
num_layers: 2
num_heads: 4
pred_head:
ens_size: 1
num_layers: 1
61 changes: 38 additions & 23 deletions src/weathergen/datasets/data_reader_anemoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,28 +259,43 @@ def select_channels(self, ds0: anemoi_datasets, ch_type: str) -> NDArray[np.int6

channels = self.stream_info.get(ch_type)
channels_exclude = self.stream_info.get(ch_type + "_exclude", [])
stream_name = self.stream_info["name"]
# sanity check
is_empty = len(channels) == 0 if channels is not None else False
if is_empty:
stream_name = self.stream_info["name"]
if channels is not None and len(channels) == 0:
_logger.warning(f"No channel for {stream_name} for {ch_type}.")

chs_idx = np.sort(
[
ds0.name_to_index[k]
for (k, v) in ds0.typed_variables.items()
if (
not v.is_computed_forcing
and not v.is_constant_in_time
and (
np.array([f == k for f in channels]).any() if channels is not None else True
)
and not np.array([f == k for f in channels_exclude]).any()
# Variables eligible for selection: physical fields only (no computed forcings,
# no constants-in-time), and not explicitly excluded.
eligible = {
k
for k, v in ds0.typed_variables.items()
if not v.is_computed_forcing
and not v.is_constant_in_time
and not any(ex == k for ex in channels_exclude)
}

if channels is not None:
# Respect the order given in the stream config so the channel layout is identical
# across datasets that share channels (e.g. ERA5 vs operational analysis),
# regardless of each dataset's on-disk variable order.
seen: set[str] = set()
ordered = []
for k in channels:
if k in eligible and k not in seen:
ordered.append(k)
seen.add(k)
missing = [k for k in channels if k not in eligible]
if missing:
_logger.warning(
f"{stream_name}: requested {ch_type} channels not available "
f"(excluded/forcing/constant-in-time or absent), skipped: {missing}"
)
]
)
else:
# No explicit selection: fall back to deterministic lexicographic order so the
# layout is still reproducible across datasets.
ordered = sorted(eligible)

return np.array(chs_idx, dtype=np.int64)
return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64)

def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]:
"""
Expand All @@ -302,19 +317,19 @@ def select_geoinfo_channels(self, ds0: anemoi_datasets) -> NDArray[np.int64]:
if len(geoinfo_channels) == 0:
return np.array([], dtype=np.int64)

# Select channels that match the geoinfo list (exact match required)
chs_idx = np.sort(
[ds0.name_to_index[k] for k in ds0.typed_variables.keys() if k in geoinfo_channels]
)
# Select channels that match the geoinfo list (exact match required), preserving the
# order requested in the config so the geoinfo layout is dataset-independent.
available = set(ds0.typed_variables.keys())
ordered = [k for k in geoinfo_channels if k in available]

if len(chs_idx) == 0 and len(geoinfo_channels) > 0:
if len(ordered) == 0:
stream_name = self.stream_info["name"]
_logger.warning(
f"No matching geoinfo channels found for {stream_name}. "
f"Requested: {geoinfo_channels}"
)

return np.array(chs_idx, dtype=np.int64)
return np.array([ds0.name_to_index[k] for k in ordered], dtype=np.int64)


def _clip_lat(lats: NDArray) -> NDArray[np.float32]:
Expand Down
Loading
Loading