diff --git a/examples/power_forecast/power_forecast_example.py b/examples/power_forecast/power_forecast_example.py index 1330230..b7d7e36 100644 --- a/examples/power_forecast/power_forecast_example.py +++ b/examples/power_forecast/power_forecast_example.py @@ -72,6 +72,19 @@ def main(): end_time=now, ) print(ds3) + print() + + # --- Stitched day-ahead series (by init hour) --- + if zones: + print(f"Stitched day-ahead series for {zones[0]} Solar, runs at 09:00 UTC") + ds4 = pf.get_day_ahead_timeseries( + zone_keys=[zones[0]], + psr_types=["Solar"], + init_hour=9, # e.g. D-1 09:00 + time_zone="UTC", + max_init_times=10, # stitch up to 10 matching days + ) + print(ds4) if __name__ == "__main__": diff --git a/src/jua/power_forecast/power_forecast.py b/src/jua/power_forecast/power_forecast.py index c5611d4..a4b0418 100644 --- a/src/jua/power_forecast/power_forecast.py +++ b/src/jua/power_forecast/power_forecast.py @@ -4,8 +4,9 @@ import re from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo import pandas as pd import xarray as xr @@ -292,6 +293,312 @@ def get_data( except Exception as e: raise RuntimeError(f"Failed to fetch power forecast data: {e}") from e + # ------------------------------------------------------------------ + # Stitched day-ahead time series + # ------------------------------------------------------------------ + def get_day_ahead_timeseries( + self, + *, + zone_keys: list[str], + psr_types: list[str] | None = None, + init_hour: int, + time_zone: str = "UTC", + start_date: datetime | None = None, + end_date: datetime | None = None, + max_init_times: int = 365, + ) -> xr.Dataset: + """Return a continuous day-ahead time series stitched across runs. + + This helper selects the forecast runs whose local init-time hour matches + ``init_hour`` (interpreted in ``time_zone``), takes from each run the + day-ahead window, and concatenates the results into a single continuous + ``time`` axis. + + The day-ahead window is defined by the forecast lead range: + ``[(24 - init_hour), (24 - init_hour) + 24)`` hours from the init time. + For example, for ``init_hour = 9`` the selected window is + ``[15h, 39h]`` from each init (i.e., 00:00..23:00 of D when the run is + at D-1 09:00). + + Two selection modes are supported: + + **Date-range mode** (``start_date`` and/or ``end_date`` given): + The daily ``init_hour`` runs spanning the range are constructed + directly and fetched in a single request. This bypasses the + init-times listing limit, so arbitrarily long histories (e.g. a full + year) can be stitched. ``start_date``/``end_date`` bound the + resulting *valid time* axis. + + **Latest mode** (no dates): + The most recent matching runs are discovered via the init-times + endpoint (bounded by ``max_init_times``). + + Args: + zone_keys: Zone codes to query (e.g. ``["DE"]``). + psr_types: Optional PSR types to include + (e.g. ``["Solar"]``). If ``None``, returns all available types. + init_hour: Local hour-of-day (0..23) of the runs to stitch together. + time_zone: IANA time zone used to interpret ``init_hour`` when + matching runs (default ``"UTC"``). + start_date: Inclusive lower bound on the valid-time axis. Enables + date-range mode. Naive datetimes are interpreted in + ``time_zone``. + end_date: Exclusive upper bound on the valid-time axis. Enables + date-range mode. Naive datetimes are interpreted in + ``time_zone``. + max_init_times: Upper bound on how many matching init times are + requested from the server in latest mode (controls history + depth). + + Returns: + ``xarray.Dataset`` with dims ``(zone_key, psr_type, time)`` and + variable ``value`` (MW). The series is continuous across days. + """ + if not (0 <= init_hour <= 23): + raise ValueError("init_hour must be in the range 0..23") + if not zone_keys or not isinstance(zone_keys, list): + raise ValueError("zone_keys must be a non-empty list of zone codes") + + # Compute lead range in minutes for the day-ahead slice + start_lead_hours = (24 - init_hour) % 24 + end_lead_hours = start_lead_hours + 24 + end_lead_minutes = int(end_lead_hours * 60) + + tz = ZoneInfo(time_zone) + + if start_date is not None or end_date is not None: + df = self._fetch_day_ahead_by_date_range( + zone_keys=zone_keys, + psr_types=psr_types, + init_hour=init_hour, + tz=tz, + time_zone=time_zone, + start_date=start_date, + end_date=end_date, + end_lead_minutes=end_lead_minutes, + ) + else: + df = self._fetch_day_ahead_latest( + zone_keys=zone_keys, + psr_types=psr_types, + init_hour=init_hour, + tz=tz, + time_zone=time_zone, + end_lead_minutes=end_lead_minutes, + max_init_times=max_init_times, + ) + + return self._stitch_day_ahead( + df, + start_lead_hours=start_lead_hours, + end_lead_hours=end_lead_hours, + tz=tz, + start_date=start_date, + end_date=end_date, + ) + + def _fetch_day_ahead_latest( + self, + *, + zone_keys: list[str], + psr_types: list[str] | None, + init_hour: int, + tz: ZoneInfo, + time_zone: str, + end_lead_minutes: int, + max_init_times: int, + ) -> pd.DataFrame: + """Fetch day-ahead data for the most recent matching runs.""" + init_infos = self.get_init_times( + zone_key=zone_keys, psr_type=psr_types, limit=max_init_times + ) + matching_inits: list[str | int | datetime] = [] + for info in init_infos: + it = info.init_time + local_hour = ( + (it if it.tzinfo else it.replace(tzinfo=ZoneInfo("UTC"))) + .astimezone(tz) + .hour + ) + if local_hour == init_hour: + matching_inits.append(it) + + if not matching_inits: + return pd.DataFrame() + + ds = self.get_data( + zone_keys=zone_keys, + psr_types=psr_types, + init_time=matching_inits, + max_prediction_timedelta=end_lead_minutes, + time_zone=time_zone, + ) + if "value" not in ds: + return pd.DataFrame() + return ds.to_dataframe().reset_index() + + def _fetch_day_ahead_by_date_range( + self, + *, + zone_keys: list[str], + psr_types: list[str] | None, + init_hour: int, + tz: ZoneInfo, + time_zone: str, + start_date: datetime | None, + end_date: datetime | None, + end_lead_minutes: int, + ) -> pd.DataFrame: + """Fetch day-ahead data by constructing daily init runs over a range. + + The day-ahead run for valid day ``D`` is issued on ``D - 1`` at + ``init_hour``. We therefore build one init datetime per day from + ``start_date - 1`` through ``end_date`` and fetch them in a single + request. + """ + init_times = self._build_day_ahead_inits( + init_hour=init_hour, + tz=tz, + start_date=start_date, + end_date=end_date, + ) + if not init_times: + return pd.DataFrame() + + ds = self.get_data( + zone_keys=zone_keys, + psr_types=psr_types, + init_time=init_times, + max_prediction_timedelta=end_lead_minutes, + time_zone=time_zone, + ) + if "value" not in ds: + return pd.DataFrame() + return ds.to_dataframe().reset_index() + + @staticmethod + def _build_day_ahead_inits( + *, + init_hour: int, + tz: ZoneInfo, + start_date: datetime | None, + end_date: datetime | None, + ) -> list[str | int | datetime]: + """Construct one ``init_hour`` init datetime per day spanning the range. + + ``start_date``/``end_date`` bound the valid-time axis; the day-ahead run + for valid day ``D`` is issued the previous day. Naive bounds are + interpreted in ``tz``. Returned datetimes are timezone-aware (UTC). + """ + utc = ZoneInfo("UTC") + + def _localize(value: datetime) -> datetime: + return value if value.tzinfo else value.replace(tzinfo=tz) + + if end_date is None: + end_local = datetime.now(utc).astimezone(tz) + else: + end_local = _localize(end_date).astimezone(tz) + + if start_date is None: + # Default to ~1 year of history when only an end is supplied. + start_local = end_local - timedelta(days=365) + else: + start_local = _localize(start_date).astimezone(tz) + + # Runs issued from (start_date - 1 day) cover valid times from start_date + first_init_day = (start_local - timedelta(days=1)).date() + last_init_day = end_local.date() + + inits: list[str | int | datetime] = [] + day = first_init_day + while day <= last_init_day: + local_init = datetime(day.year, day.month, day.day, init_hour, tzinfo=tz) + inits.append(local_init.astimezone(utc)) + day = day + timedelta(days=1) + return inits + + @staticmethod + def _stitch_day_ahead( + df: pd.DataFrame, + *, + start_lead_hours: int, + end_lead_hours: int, + tz: ZoneInfo, + start_date: datetime | None, + end_date: datetime | None, + ) -> xr.Dataset: + """Filter to the day-ahead window, dedupe overlaps, and build a Dataset.""" + if df.empty or "value" not in df.columns: + return xr.Dataset(attrs={"unit": "MW"}) + + df = df.dropna(subset=["value"]).copy() + if df.empty: + return xr.Dataset(attrs={"unit": "MW"}) + + df["lead_hours"] = (df["time"] - df["init_time"]) / pd.Timedelta(hours=1) + mask = (df["lead_hours"] >= start_lead_hours) & ( + df["lead_hours"] < end_lead_hours + ) + df = df.loc[mask].drop(columns=["lead_hours"]) + + if df.empty: + return xr.Dataset(attrs={"unit": "MW"}) + + index_cols = [c for c in ["zone_key", "psr_type", "time"] if c in df.columns] + + # Several matching runs can share the same target hour (e.g. sub-hourly + # runs like 07:00 and 07:30), so their day-ahead windows overlap and + # produce duplicate (zone_key, psr_type, time) rows. Keep the value from + # the most recent init for each valid time so the stitched index stays + # unique and reflects the freshest forecast. + sort_cols = [c for c in index_cols if c != "time"] + ["time", "init_time"] + df = ( + df.sort_values(sort_cols) + .drop_duplicates(subset=index_cols, keep="last") + .drop(columns=["init_time"]) + ) + + # Clip the valid-time axis to the requested bounds (date-range mode). + if start_date is not None or end_date is not None: + df = PowerForecast._clip_time(df, tz, start_date, end_date) + if df.empty: + return xr.Dataset(attrs={"unit": "MW"}) + + df = df.sort_values(index_cols) + stitched = xr.Dataset.from_dataframe(df.set_index(index_cols)) + stitched = stitched.assign_attrs(unit="MW") + return stitched + + @staticmethod + def _clip_time( + df: pd.DataFrame, + tz: ZoneInfo, + start_date: datetime | None, + end_date: datetime | None, + ) -> pd.DataFrame: + """Clip ``df`` to ``[start_date, end_date)`` on the ``time`` column. + + Comparison is done in UTC to avoid tz/naive mismatches regardless of how + the API localized the returned ``time`` values. + """ + times = pd.DatetimeIndex(df["time"]) + if times.tz is None: + times_utc = times.tz_localize(tz).tz_convert("UTC") + else: + times_utc = times.tz_convert("UTC") + + keep = pd.Series(True, index=df.index) + if start_date is not None: + lo = pd.Timestamp(start_date) + lo = lo.tz_localize(tz) if lo.tzinfo is None else lo + keep &= times_utc >= lo + if end_date is not None: + hi = pd.Timestamp(end_date) + hi = hi.tz_localize(tz) if hi.tzinfo is None else hi + keep &= times_utc < hi + return df.loc[keep.values] + # ------------------------------------------------------------------ # Init-time resolution # ------------------------------------------------------------------ diff --git a/tests/power_forecast/test_day_ahead_timeseries.py b/tests/power_forecast/test_day_ahead_timeseries.py new file mode 100644 index 0000000..175b941 --- /dev/null +++ b/tests/power_forecast/test_day_ahead_timeseries.py @@ -0,0 +1,177 @@ +from datetime import datetime, timedelta, timezone + +import pandas as pd +import xarray as xr + +from jua import JuaClient +from jua.power_forecast.power_forecast import InitTimeInfo + + +def _make_ds(zone: str, psr: str, init_times: list[datetime]) -> xr.Dataset: + """Create a simple dataset with 40 hours of horizon per init.""" + rows = [] + for i, it in enumerate(init_times): + for h in range(0, 40): # 0..39h + rows.append( + { + "zone_key": zone, + "psr_type": psr, + "init_time": pd.Timestamp(it), + "time": pd.Timestamp(it + timedelta(hours=h)), + "value": float(i * 1000 + h), + } + ) + df = pd.DataFrame(rows) + return xr.Dataset.from_dataframe( + df.set_index(["zone_key", "psr_type", "init_time", "time"]) + ) + + +def test_get_day_ahead_timeseries_stitches_across_days(monkeypatch): + client = JuaClient() + pf = client.power_forecast + + zone = "GB" + psr = "Solar" + t1 = datetime(2025, 1, 1, 9, 0, tzinfo=timezone.utc) + t2 = datetime(2025, 1, 2, 9, 0, tzinfo=timezone.utc) + init_infos = [ + InitTimeInfo(init_time=t1, max_prediction_timedelta=40 * 60), + InitTimeInfo(init_time=t2, max_prediction_timedelta=40 * 60), + ] + + # Patch network methods + monkeypatch.setattr( + pf, "get_init_times", lambda zone_key=None, psr_type=None, limit=96: init_infos + ) + monkeypatch.setattr(pf, "get_data", lambda **kwargs: _make_ds(zone, psr, [t1, t2])) + + stitched = pf.get_day_ahead_timeseries( + zone_keys=[zone], + psr_types=[psr], + init_hour=9, + time_zone="UTC", + max_init_times=10, + ) + + assert "time" in stitched.dims + # Expect 48 hours (two days) starting at midnight after the first init + assert stitched.sizes["time"] == 48 + first_time = pd.Timestamp(datetime(2025, 1, 2, 0, 0)).tz_localize(None) + last_time = pd.Timestamp(datetime(2025, 1, 3, 23, 0)).tz_localize(None) + assert pd.Timestamp(stitched.time.values[0]) == first_time + assert pd.Timestamp(stitched.time.values[-1]) == last_time + + +def _make_ds_15min(zone: str, psr: str, init_times: list[datetime]) -> xr.Dataset: + """Create a 15-minute-resolution dataset with ~40h of horizon per init.""" + rows = [] + for i, it in enumerate(init_times): + for step in range(0, 40 * 4): # 0..40h at 15-min steps + rows.append( + { + "zone_key": zone, + "psr_type": psr, + "init_time": pd.Timestamp(it), + "time": pd.Timestamp(it + timedelta(minutes=15 * step)), + "value": float(i * 1000 + step), + } + ) + df = pd.DataFrame(rows) + return xr.Dataset.from_dataframe( + df.set_index(["zone_key", "psr_type", "init_time", "time"]) + ) + + +def test_get_day_ahead_timeseries_dedupes_overlapping_runs(monkeypatch): + """Sub-hourly runs sharing a target hour produce overlapping windows. + + The stitched series must stay unique by keeping the most recent init's + value for each valid time instead of raising on a non-unique index. + """ + client = JuaClient() + pf = client.power_forecast + + zone = "GB" + psr = "Solar" + # Two runs in the same hour-of-day (09:00 and 09:30) on a 15-min grid. + t_early = datetime(2025, 1, 1, 9, 0, tzinfo=timezone.utc) + t_late = datetime(2025, 1, 1, 9, 30, tzinfo=timezone.utc) + init_infos = [ + InitTimeInfo(init_time=t_late, max_prediction_timedelta=40 * 60), + InitTimeInfo(init_time=t_early, max_prediction_timedelta=40 * 60), + ] + + monkeypatch.setattr( + pf, + "get_init_times", + lambda zone_key=None, psr_type=None, limit=96: init_infos, + ) + monkeypatch.setattr( + pf, + "get_data", + lambda **kwargs: _make_ds_15min(zone, psr, [t_early, t_late]), + ) + + stitched = pf.get_day_ahead_timeseries( + zone_keys=[zone], + psr_types=[psr], + init_hour=9, + time_zone="UTC", + max_init_times=10, + ) + + times = pd.to_datetime(stitched.time.values) + assert len(times) == len(set(times)), "stitched time index must be unique" + + +def test_get_day_ahead_timeseries_date_range_builds_inits_and_clips(monkeypatch): + """Date-range mode constructs daily inits in a single request and clips + the result to the requested valid-time window.""" + client = JuaClient() + pf = client.power_forecast + + zone, psr = "DE", "Solar" + init_hour = 9 + + calls = {"n_inits": []} + + def fake_get_data(**kwargs): + inits = kwargs["init_time"] + calls["n_inits"].append(len(inits)) + return _make_ds_15min(zone, psr, list(inits)) + + # get_init_times must NOT be used in date-range mode. + def fail_init_times(*a, **k): + raise AssertionError("get_init_times should not be called in date-range mode") + + monkeypatch.setattr(pf, "get_data", fake_get_data) + monkeypatch.setattr(pf, "get_init_times", fail_init_times) + + start = datetime(2025, 6, 1, tzinfo=timezone.utc) + end = datetime(2025, 6, 11, tzinfo=timezone.utc) # 10 valid days + + ds = pf.get_day_ahead_timeseries( + zone_keys=[zone], + psr_types=[psr], + init_hour=init_hour, + time_zone="UTC", + start_date=start, + end_date=end, + ) + + # Single request with daily inits from (start-1) through end (12 days). The + # final init's window falls outside [start, end) and is clipped away. + assert calls["n_inits"] == [12] + + times = pd.to_datetime(ds.time.values) + assert len(times) == len(set(times)), "time index must be unique" + # 10 days at 15-min resolution + assert ds.sizes["time"] == 10 * 96 + lo = pd.Timestamp(start).tz_convert("UTC") + hi = pd.Timestamp(end).tz_convert("UTC") + tmin = pd.Timestamp(times.min()) + tmax = pd.Timestamp(times.max()) + tmin = tmin.tz_localize("UTC") if tmin.tzinfo is None else tmin + tmax = tmax.tz_localize("UTC") if tmax.tzinfo is None else tmax + assert tmin >= lo and tmax < hi