diff --git a/examples/weather/forecast_solar_runs_comparison.py b/examples/weather/forecast_solar_runs_comparison.py new file mode 100644 index 0000000..b9ac0a5 --- /dev/null +++ b/examples/weather/forecast_solar_runs_comparison.py @@ -0,0 +1,165 @@ +"""Compare recent solar-radiation forecast runs at a point, valid-time aligned. + +Plots 1h surface downwelling shortwave flux (SSRD) at Zurich for the most recent +runs of Helios and ICON-EU. Every run is drawn against valid time, colored by +model with a light->dark gradient over init time (older runs lighter). + +A black "ground truth" line is built from each Helios run's T+1h value (valid at +init + 1h), stitched across runs to approximate the analysis. + +This uses cheap single-point queries, so it runs against the live latest +forecasts without needing a raised credit limit. +""" + +import logging +from datetime import timedelta + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.lines import Line2D + +from jua import JuaClient +from jua.types.geo import LatLon +from jua.weather import Models, Variables + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +ZURICH = LatLon(lat=47.3769, lon=8.5417, label="Zurich") +VARIABLE = Variables.SURFACE_DOWNWELLING_SHORTWAVE_FLUX_SUM_1H + +# Only keep runs initialized within this many hours of the latest available run. +WINDOW_HOURS = 6 + +# (label, matplotlib colormap) per model. +MODEL_CONFIG = { + Models.EPT2_HELIOS: ("Helios", "Oranges"), + Models.ICON_EU: ("ICON-EU", "Purples"), +} + + +def _naive_utc(ts) -> pd.Timestamp: + """Normalize a timestamp to tz-naive UTC.""" + t = pd.Timestamp(ts) + if t.tzinfo is not None: + t = t.tz_convert("UTC").tz_localize(None) + return t + + +def available_init_times(model_obj) -> list[pd.Timestamp]: + """Sorted (ascending) tz-naive init times available for a model.""" + available = model_obj.get_available_forecasts(limit=50) + return sorted(_naive_utc(f.init_time) for f in available.forecasts) + + +def ground_truth(model_obj, init_times, day_start, day_end): + """Stitch each run's T+1h value into a pseudo ground-truth series.""" + inits = [t for t in init_times if day_start <= t + timedelta(hours=1) <= day_end] + if not inits: + return None, None + + forecast = model_obj.get_forecasts( + init_time=[t.to_pydatetime() for t in inits], + points=ZURICH, + variables=[VARIABLE], + prediction_timedelta=[np.timedelta64(60, "m")], + stream=False, + ) + da = forecast[VARIABLE].squeeze() + valid_times = pd.to_datetime(da["init_time"].values) + pd.Timedelta(minutes=60) + values = np.asarray(da.values).ravel() + order = np.argsort(valid_times) + return valid_times[order], values[order] + + +def main(): + client = JuaClient() + + helios = client.weather.get_model(Models.EPT2_HELIOS) + helios_inits = available_init_times(helios) + if not helios_inits: + logger.warning("No Helios runs available; nothing to plot.") + return + + # Frame the chart on the day of the latest Helios run. + latest_init = helios_inits[-1] + day_start = latest_init.normalize() + day_end = day_start + timedelta(days=1) + + fig, ax = plt.subplots(figsize=(14, 7)) + legend_handles = [] + + for model_enum, (label, cmap_name) in MODEL_CONFIG.items(): + model_obj = client.weather.get_model(model_enum) + init_times = available_init_times(model_obj) + runs = [ + t for t in init_times if t >= init_times[-1] - timedelta(hours=WINDOW_HOURS) + ] + if not runs: + logger.warning(f"No runs found for {label}") + continue + + cmap = plt.get_cmap(cmap_name) + n = len(runs) + logger.info(f"{label}: {n} runs from {runs[0]} to {runs[-1]}") + + for i, init_time in enumerate(runs): + max_lead = int((day_end - init_time).total_seconds() / 3600) + 1 + if max_lead <= 0: + continue + + forecast = model_obj.get_forecasts( + init_time=init_time.to_pydatetime(), + points=ZURICH, + variables=[VARIABLE], + max_lead_time=min(48, max_lead), + stream=False, + ) + da = forecast[VARIABLE].to_absolute_time().squeeze() + times = pd.to_datetime(da["time"].values) + values = np.asarray(da.values).ravel() + + mask = times <= day_end + times, values = times[mask], values[mask] + if len(times) == 0: + continue + + # Older runs lighter, newest darkest. + shade = 0.35 + 0.6 * (i / max(n - 1, 1)) + ax.plot(times, values, color=cmap(shade), linewidth=1.6, alpha=0.9) + + legend_handles.append( + Line2D([0], [0], color=cmap(0.85), linewidth=2.5, label=label) + ) + + # Pseudo ground truth: Helios T+1h stitched across runs. + gt_times, gt_values = ground_truth(helios, helios_inits, day_start, day_end) + if gt_times is not None: + ax.plot(gt_times, gt_values, color="black", linewidth=2.8, zorder=10) + legend_handles.append( + Line2D( + [0], + [0], + color="black", + linewidth=2.8, + label="Ground truth (Helios T+1h)", + ) + ) + + ax.set_xlim(day_start, day_end) + ax.set_xlabel("Valid time (UTC)") + ax.set_ylabel(VARIABLE.display_name_with_unit) + ax.set_title( + f"SSRD 1h at {ZURICH.label} — recent runs (last {WINDOW_HOURS}h), " + "gradient = init time (light=older)" + ) + ax.legend(handles=legend_handles) + ax.grid(True, alpha=0.3) + fig.autofmt_xdate() + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/src/jua/weather/_model_meta.py b/src/jua/weather/_model_meta.py index 434497e..353df12 100644 --- a/src/jua/weather/_model_meta.py +++ b/src/jua/weather/_model_meta.py @@ -9,17 +9,19 @@ class TemporalResolution: """Internal class to store model temporal resolution Used for models with variable temporal resolution, such as EPT2. + Resolutions are expressed in hours and may be fractional (e.g. ``0.5`` + for a 30-minute cadence such as EPT2 Helios). Attributes: - default: The default temporal resolution for the model. - segments: The resolution of the model for prediction_timedelta ranges. + base: The default temporal resolution for the model, in hours. + special: The resolution of the model for prediction_timedelta ranges. Defined as `(resolution, from_hour, to_hour)`, where the model has a prediction every `resolution` hours when the prediction_timedelta is in the interval [`from_hour`, `to_hour`]. """ - base: int - special: tuple[tuple[int, int, int], ...] = tuple() + base: float + special: tuple[tuple[float, int, int], ...] = tuple() def __post_init__(self) -> None: """Checks that the special cases make sense""" @@ -45,6 +47,9 @@ def __post_init__(self) -> None: def num_prediction_timedeltas(self, from_hour: int, to_hour: int) -> int: """Determines the number of `prediction_timedeltas` in an interval. + Iterates internally in minutes so that sub-hourly resolutions + (e.g. a 30-minute cadence) are counted correctly. + Attributes: from_hour: The start hour for the interval to_hour: The end hour for the interval @@ -56,13 +61,15 @@ def num_prediction_timedeltas(self, from_hour: int, to_hour: int) -> int: ) num_timedeltas = 0 - for h in range(from_hour, to_hour + 1): + for minute in range(from_hour * 60, to_hour * 60 + 1): + hour = minute / 60 resolution = self.base for s_res, s_start, s_end in self.special: - if s_start <= h <= s_end: + if s_start <= hour <= s_end: resolution = s_res break - if h % resolution == 0: + resolution_minutes = round(resolution * 60) + if minute % resolution_minutes == 0: num_timedeltas += 1 return num_timedeltas @@ -153,6 +160,18 @@ class ModelMetaInfo: full_forecasted_hours=480, temporal_resolution=TemporalResolution(base=6, special=((1, 0, 10 * 24),)), ) +_MODEL_META_INFO[Models.EPT2_HELIOS] = ModelMetaInfo( + has_grid_access=True, + full_forecasted_hours=48, + forecasts_per_day=48, + temporal_resolution=TemporalResolution(base=0.5), +) +_MODEL_META_INFO[Models.EPT2_EUROPA] = ModelMetaInfo( + has_grid_access=True, + full_forecasted_hours=48, + forecasts_per_day=24, + temporal_resolution=TemporalResolution(base=1), +) _MODEL_META_INFO[Models.AIFS] = ModelMetaInfo( has_grid_access=True, forecast_name_mapping="aifs", diff --git a/src/jua/weather/models.py b/src/jua/weather/models.py index 97cc15f..c0086a9 100644 --- a/src/jua/weather/models.py +++ b/src/jua/weather/models.py @@ -27,6 +27,8 @@ class Models(str, Enum): EPT2_HRRR = "ept2_hrrr" EPT2_RR = "ept2_rr" EPT2_REASONING = "ept2_reasoning" + EPT2_HELIOS = "ept2_1_helios" + EPT2_EUROPA = "ept2_1_europa" AIFS = "aifs" AURORA = "aurora" ECMWF_IFS_SINGLE = "ecmwf_ifs_single" diff --git a/src/jua/weather/variables.py b/src/jua/weather/variables.py index 94b419a..e2982cf 100644 --- a/src/jua/weather/variables.py +++ b/src/jua/weather/variables.py @@ -200,10 +200,18 @@ class Variables(Enum): "surface_downwelling_shortwave_flux_sum_1h", "J / m^2", "ssrd", None ) + SURFACE_DOWNWELLING_SHORTWAVE_FLUX_SUM_30MIN = Variable( + "surface_downwelling_shortwave_flux_sum_30min", "J / m^2", None, None + ) + SURFACE_DIRECT_DOWNWELLING_SHORTWAVE_FLUX_SUM_1H = Variable( "surface_direct_downwelling_shortwave_flux_sum_1h", "J / m^2", "fdir", None ) + SURFACE_DIRECT_DOWNWELLING_SHORTWAVE_FLUX_SUM_30MIN = Variable( + "surface_direct_downwelling_shortwave_flux_sum_30min", "J / m^2", None, None + ) + SURFACE_NET_DOWNWARD_SHORTWAVE_FLUX_SUM_1H = Variable( "surface_net_downward_shortwave_flux_sum_1h", "J / m^2", "ssr", None ) diff --git a/tests/functional/test_forecasts.py b/tests/functional/test_forecasts.py index 8c3031e..57cb0dd 100644 --- a/tests/functional/test_forecasts.py +++ b/tests/functional/test_forecasts.py @@ -38,8 +38,14 @@ Models.ICON_EU: datetime(2026, 2, 9, 0, 0, 0), } -ALL_MODELS = list(Models) -INTERNAL_MODELS = [m for m in Models if get_model_meta_info(m).has_grid_access] +SOLAR_ONLY_MODELS = {Models.EPT2_HELIOS} + +ALL_MODELS = [m for m in Models if m not in SOLAR_ONLY_MODELS] +INTERNAL_MODELS = [ + m + for m in Models + if get_model_meta_info(m).has_grid_access and m not in SOLAR_ONLY_MODELS +] def get_forecast_date(model: Models) -> datetime: diff --git a/tests/weather/test_temporal_resolution.py b/tests/weather/test_temporal_resolution.py index f473c06..38e54f6 100644 --- a/tests/weather/test_temporal_resolution.py +++ b/tests/weather/test_temporal_resolution.py @@ -13,11 +13,13 @@ ((1, 0, 24), (3, 24, 48)), [((0, 48), 33), ((0, 72), 37)], ), + # Sub-hourly (30min) resolution: 2 steps per hour + (0.5, tuple(), [((0, 1), 3), ((0, 2), 5), ((0, 48), 97)]), ], ) def test_temporal_resolution( - base: int, - special: tuple[tuple[int, int, int]], + base: float, + special: tuple[tuple[float, int, int]], test_cases: list[tuple[int, int], int], ) -> None: tr = TemporalResolution(base=base, special=special)