Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import multiprocessing
from pathlib import Path

from openstef_beam.benchmarking.baselines import (
from openstef_beam.benchmarking.baselines.openstef4 import (
create_openstef4_preset_backtest_forecaster,
)
from openstef_beam.benchmarking.benchmarks.liander2024 import Liander2024Category, create_liander2024_benchmark_runner
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/liander_2024_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path

from openstef_beam.backtesting.backtest_forecaster import BacktestForecasterConfig
from openstef_beam.benchmarking.baselines import (
from openstef_beam.benchmarking.baselines.openstef4 import (
create_openstef4_preset_backtest_forecaster,
)
from openstef_beam.benchmarking.benchmarks.liander2024 import Liander2024Category, create_liander2024_benchmark_runner
Expand Down
7 changes: 7 additions & 0 deletions packages/openstef-beam/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ dependencies = [
]

optional-dependencies.all = [
"openstef-beam[baselines]",
"s3fs>=2025.5.1",
]
optional-dependencies.baselines = [
"openstef-meta>=4.0.0.dev0,<5",
"openstef-models>=4.0.0.dev0,<5",
]
urls.Documentation = "https://openstef.github.io/openstef/index.html"
urls.Homepage = "https://lfenergy.org/projects/openstef/"
urls.Issues = "https://github.com/OpenSTEF/openstef/issues"
Expand All @@ -48,3 +53,5 @@ packages = [ "src/openstef_beam" ]

[tool.uv.sources]
openstef-core = { workspace = true }
openstef-models = { workspace = true }
openstef-meta = { workspace = true }
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""Benchmarks baselines used by the OpenSTEF Beam benchmarking utilities.

This package exposes baseline forecasters for use in backtesting.
The OpenSTEF v4 baselines require ``openstef-models`` and ``openstef-meta``,
available via the ``baselines`` extra: ``pip install openstef-beam[baselines]``.

Import directly from the submodule::

from openstef_beam.benchmarking.baselines.openstef4 import (
OpenSTEF4BacktestForecaster,
create_openstef4_preset_backtest_forecaster,
)
"""

# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

from openstef_beam.benchmarking.baselines.openstef4 import (
OpenSTEF4BacktestForecaster,
WorkflowCreationContext,
create_openstef4_preset_backtest_forecaster,
)

__all__ = [
"OpenSTEF4BacktestForecaster",
"WorkflowCreationContext",
"create_openstef4_preset_backtest_forecaster",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
#
# SPDX-License-Identifier: MPL-2.0

"""OpenSTEF 4.0 forecaster for backtesting pipelines."""
"""OpenSTEF 4.0 forecaster for backtesting pipelines.

Requires the ``baselines`` extra: ``pip install openstef-beam[baselines]``.
"""

import logging
from collections.abc import Callable
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, cast, override
from typing import TYPE_CHECKING, Any, cast, override

import pandas as pd
from pydantic import Field, PrivateAttr
from pydantic_extra_types.coordinate import Coordinate

Expand All @@ -27,24 +28,20 @@
BenchmarkTarget,
ForecasterFactory,
)
from openstef_core.base_model import BaseConfig, BaseModel
from openstef_core.base_model import BaseModel
from openstef_core.datasets import TimeSeriesDataset
from openstef_core.exceptions import FlatlinerDetectedError, NotFittedError
from openstef_core.exceptions import FlatlinerDetectedError, MissingExtraError, NotFittedError
from openstef_core.types import Q
from openstef_meta.presets import EnsembleForecastingWorkflowConfig, create_ensemble_forecasting_workflow
from openstef_models.presets import ForecastingWorkflowConfig
from openstef_models.presets import ForecastingWorkflowConfig, create_forecasting_workflow
from openstef_models.presets.forecasting_workflow import LocationConfig
from openstef_models.workflows.callbacks.data_save import DataSaveCallback
from openstef_models.workflows.custom_forecasting_workflow import (
CustomForecastingWorkflow,
ForecastingCallback,
)


class WorkflowCreationContext(BaseConfig):
"""Context information for workflow execution within backtesting."""

step_name: str | None = Field(
default=None,
description="Name of the current backtesting step.",
)
if TYPE_CHECKING:
from openstef_meta.presets import EnsembleForecastingWorkflowConfig


class OpenSTEF4BacktestForecaster(BaseModel, BacktestForecasterMixin):
Expand All @@ -57,8 +54,8 @@ class OpenSTEF4BacktestForecaster(BaseModel, BacktestForecasterMixin):
config: BacktestForecasterConfig = Field(
description="Configuration for the backtest forecaster interface",
)
workflow_factory: Callable[[WorkflowCreationContext], CustomForecastingWorkflow] = Field(
description="Factory function that creates a new CustomForecastingWorkflow instance",
workflow_template: CustomForecastingWorkflow = Field(
description="Untrained workflow template; deep-copied for each fit() call",
)
cache_dir: Path = Field(
description="Directory to use for caching model artifacts during backtesting",
Expand All @@ -71,6 +68,10 @@ class OpenSTEF4BacktestForecaster(BaseModel, BacktestForecasterMixin):
default=False,
description="When True, saves base forecaster prediction contributions for ensemble models",
)
extra_callbacks: list[ForecastingCallback] = Field(
default_factory=list[ForecastingCallback],
description="Additional callbacks to inject into workflows created by the factory.",
)

_workflow: CustomForecastingWorkflow | None = PrivateAttr(default=None)
_is_flatliner_detected: bool = PrivateAttr(default=False)
Expand All @@ -80,22 +81,27 @@ class OpenSTEF4BacktestForecaster(BaseModel, BacktestForecasterMixin):
@override
def model_post_init(self, context: Any) -> None:
if self.debug or self.contributions:
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.extra_callbacks.append(
DataSaveCallback(
cache_dir=self.cache_dir,
save_training_data=self.debug,
save_prepared_data=self.debug,
save_predict_data=self.debug,
save_forecast=self.debug,
save_contributions=self.contributions,
)
)

@property
@override
def quantiles(self) -> list[Q]:
# Create a workflow instance if needed to get quantiles
if self._workflow is None:
self._workflow = self.workflow_factory(WorkflowCreationContext())

return self._workflow.model.quantiles
return self.workflow_template.model.quantiles

@override
def fit(self, data: RestrictedHorizonVersionedTimeSeries) -> None:
# Create a new workflow for this training cycle
context = WorkflowCreationContext(step_name=data.horizon.isoformat())
workflow = self.workflow_factory(context)
# Deep-copy the template for a fresh model
workflow = self.workflow_template.with_run_name(data.horizon.isoformat())
workflow.callbacks.extend(self.extra_callbacks)

# Extract the dataset for training
training_data = data.get_window(
Expand All @@ -104,10 +110,6 @@ def fit(self, data: RestrictedHorizonVersionedTimeSeries) -> None:
available_before=data.horizon,
)

if self.debug:
id_str = data.horizon.strftime("%Y%m%d%H%M%S")
training_data.to_parquet(path=self.cache_dir / f"debug_{id_str}_training.parquet")

try:
# Use the workflow's fit method
workflow.fit(data=training_data)
Expand All @@ -119,12 +121,6 @@ def fit(self, data: RestrictedHorizonVersionedTimeSeries) -> None:

self._workflow = workflow

if self.debug:
id_str = data.horizon.strftime("%Y%m%d%H%M%S")
self._workflow.model.prepare_input(training_data).to_parquet(
path=self.cache_dir / f"debug_{id_str}_prepared_training.parquet"
)

@override
def predict(self, data: RestrictedHorizonVersionedTimeSeries) -> TimeSeriesDataset | None:
if self._is_flatliner_detected:
Expand All @@ -150,72 +146,49 @@ def predict(self, data: RestrictedHorizonVersionedTimeSeries) -> TimeSeriesDatas
self._logger.info("Flatliner detected during prediction")
return None

if self.debug:
id_str = data.horizon.strftime("%Y%m%d%H%M%S")
predict_data.to_parquet(path=self.cache_dir / f"debug_{id_str}_predict.parquet")
forecast.to_parquet(path=self.cache_dir / f"debug_{id_str}_forecast.parquet")

if self.contributions:
id_str = data.horizon.strftime("%Y%m%d%H%M%S")
try:
contributions = self._workflow.model.predict_contributions(predict_data, forecast_start=data.horizon)
except NotImplementedError:
pass
else:
df = pd.concat([contributions.data, forecast.data.drop(columns=["load"])], axis=1)
df.to_parquet(path=self.cache_dir / f"contrib_{id_str}_predict.parquet")
return forecast


class OpenSTEF4PresetBacktestForecaster(OpenSTEF4BacktestForecaster):
pass


def _preset_target_forecaster_factory(
base_config: ForecastingWorkflowConfig | EnsembleForecastingWorkflowConfig,
base_config: "ForecastingWorkflowConfig | EnsembleForecastingWorkflowConfig",
backtest_config: BacktestForecasterConfig,
cache_dir: Path,
context: BenchmarkContext,
target: BenchmarkTarget,
) -> OpenSTEF4BacktestForecaster:
from openstef_models.presets import create_forecasting_workflow # noqa: PLC0415
from openstef_models.presets.forecasting_workflow import LocationConfig # noqa: PLC0415

# Factory function that creates a forecaster for a given target.
prefix = context.run_name

def _create_workflow(context: WorkflowCreationContext) -> CustomForecastingWorkflow:
# Create a new workflow instance with fresh model.
location = LocationConfig(
name=target.name,
description=target.description,
coordinate=Coordinate(
latitude=target.latitude,
longitude=target.longitude,
),
)

update = {
"model_id": f"{prefix}_{target.name}",
"location": location,
"run_name": context.step_name,
}
location = LocationConfig(
name=target.name,
description=target.description,
coordinate=Coordinate(
latitude=target.latitude,
longitude=target.longitude,
),
)

if isinstance(base_config, EnsembleForecastingWorkflowConfig):
return create_ensemble_forecasting_workflow(config=base_config.model_copy(update=update))
update: dict[str, Any] = {
"model_id": f"{context.run_name}_{target.name}",
"location": location,
}

return create_forecasting_workflow(config=base_config.model_copy(update=update))
if base_config.kind == "ensemble":
try:
from openstef_meta.presets import create_ensemble_forecasting_workflow # noqa: PLC0415
except ImportError as e:
raise MissingExtraError("openstef-meta") from e
workflow = create_ensemble_forecasting_workflow(config=base_config.model_copy(update=update))
else:
workflow = create_forecasting_workflow(config=base_config.model_copy(update=update))

return OpenSTEF4BacktestForecaster(
config=backtest_config,
workflow_factory=_create_workflow,
workflow_template=workflow,
debug=False,
cache_dir=cache_dir / f"{context.run_name}_{target.name}",
)


def create_openstef4_preset_backtest_forecaster(
workflow_config: ForecastingWorkflowConfig | EnsembleForecastingWorkflowConfig,
workflow_config: "ForecastingWorkflowConfig | EnsembleForecastingWorkflowConfig",
backtest_config: BacktestForecasterConfig | None = None,
cache_dir: Path = Path("cache"),
) -> ForecasterFactory[BenchmarkTarget]:
Expand Down Expand Up @@ -258,6 +231,5 @@ def create_openstef4_preset_backtest_forecaster(

__all__ = [
"OpenSTEF4BacktestForecaster",
"WorkflowCreationContext",
"create_openstef4_preset_backtest_forecaster",
]
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ def get_metric(self, quantile: QuantileOrGlobal, metric_name: str) -> FloatOrNan
"""
return self.metrics.get(quantile, {}).get(metric_name)

def to_flat_dict(self, prefix: str = "") -> dict[str, float]:
"""Flatten metrics into a single dict suitable for logging (e.g. MLflow).

Each key is ``{prefix}{quantile}_{metric_name}``.

Args:
prefix: String prepended to every key.

Returns:
Flat mapping of metric names to values.
"""
return {
f"{prefix}{quantile}_{metric_name}": value
for quantile, metrics_dict in self.metrics.items()
for metric_name, value in metrics_dict.items()
}


def merge_quantile_metrics(metrics_list: list[QuantileMetricsDict]) -> QuantileMetricsDict:
"""Merge multiple quantile metrics dictionaries into a single one.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0
Loading