Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ dev = [
]

[tool.uv.sources]
dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.22.0/dp_sdk-0.22.0-py3-none-any.whl" }
dp-sdk = { url = "https://github.com/openclimatefix/data-platform/releases/download/v0.23.0/dp_sdk-0.23.0-py3-none-any.whl" }


[project.urls]
Expand Down
2 changes: 2 additions & 0 deletions src/quartz_api/internal/backends/dataplatform/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ async def get_predicted_generation_snapshot(
capacity_kilowatts=v.effective_capacity_watts / 1000,
forecaster_name=forecaster.forecaster_name,
forecaster_version=forecaster.forecaster_version,
created_timestamp=v.created_timestamp_utc,
init_timestamp=v.initialization_timestamp_utc,
)
for v in resp.values
]
Expand Down
2 changes: 1 addition & 1 deletion src/quartz_api/internal/service/uk_national/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
log = logging.getLogger(__name__)

cache_dependent_scopes = ["read:intraday"]
legacy_query_params = ["compact", "historic"]
legacy_query_params = ["historic"]


async def key_builder(
Expand Down
4 changes: 2 additions & 2 deletions src/quartz_api/internal/service/uk_national/endpoint_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class Forecast(EnhancedBaseModel):

location: Location = Field(..., description="The location object for this forecaster")
model: MLModel = Field(..., description="The name of the model that made this forecast")
forecast_creation_time: dt.datetime = Field(
...,
forecast_creation_time: dt.datetime | None = Field(
None,
description="The time when the forecaster was made",
)
historic: bool = Field(
Expand Down
89 changes: 72 additions & 17 deletions src/quartz_api/internal/service/uk_national/gsp_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from quartz_api.internal import models
from quartz_api.internal.middleware.auth import AuthDependency
from quartz_api.internal.middleware.ratelimit import limiter
from quartz_api.internal.service.uk_national.metadata import format_metadata

from .cache import key_builder
from .endpoint_types import (
Forecast,
ForecastValue,
GSPYield,
GSPYieldGroupByDatetime,
Location,
MLModel,
OneDatetimeManyForecastValuesMW,
convert_list_of_gsp_ids,
)
Expand Down Expand Up @@ -207,7 +211,7 @@ async def get_truths_for_a_specific_gsp(

@router.get(
"/forecast/all/",
response_model=list[OneDatetimeManyForecastValuesMW],
response_model=list[OneDatetimeManyForecastValuesMW | Forecast],
include_in_schema=False,
)
@limiter.limit("3600/hour")
Expand All @@ -227,8 +231,9 @@ async def get_all_available_forecasts(
],
creation_utc_limit: models.UTCDatetime | None = None,
gsp_ids: str | None = None,
compact: bool = False,

) -> list[OneDatetimeManyForecastValuesMW]:
) -> list[OneDatetimeManyForecastValuesMW | Forecast]:
"""### Get all forecasts for all GSPs.

The return object contains a forecast object with system details and
Expand Down Expand Up @@ -290,28 +295,78 @@ async def get_all_available_forecasts(
results: list[list[models.PredictedGenerationValue] | Exception] = await asyncio.gather(
*tasks, return_exceptions=True,
)

# reorganize results by timestamp
grouped_data: dict[dt.datetime, dict[int, float]] = defaultdict(dict)
gsp_ids = list(gsp_uuid_id_map.values())

# We can zip these because the tasks will return in the same order as they were created
for snapshot in results:
for predicted_generation_value in snapshot:
if compact:
# We can zip these because the tasks will return in the same order as they were created
for snapshot in results:
for predicted_generation_value in snapshot:

gsp_id = gsp_uuid_id_map[predicted_generation_value.location_uuid]
grouped_data[predicted_generation_value.valid_timestamp][gsp_id] \
= round(predicted_generation_value.power_kilowatts / 1000.0,4)
gsp_id = gsp_uuid_id_map[predicted_generation_value.location_uuid]
grouped_data[predicted_generation_value.valid_timestamp][gsp_id] \
= round(predicted_generation_value.power_kilowatts / 1000.0, 4)

out: list[OneDatetimeManyForecastValuesMW] = [
OneDatetimeManyForecastValuesMW(
datetime_utc=ts,
forecast_values=dict(sorted(gsp_dict.items())),
)
for ts, gsp_dict in grouped_data.items()
]
out: list[OneDatetimeManyForecastValuesMW] = [
OneDatetimeManyForecastValuesMW(
datetime_utc=ts,
forecast_values=dict(sorted(gsp_dict.items())),
)
for ts, gsp_dict in grouped_data.items()
]

return out
return out
else:
# Lets format like a list of Forecasts objects

# 1. lets split the results up into groups of gsps
forecast_values_by_gsp_id: dict[int, list[ForecastValue]] = {}
forecasts_by_gsp_id: dict[int, Forecast] = {}
for snapshot in results:
for predicted_generation_value in snapshot:
gsp_id = gsp_uuid_id_map[predicted_generation_value.location_uuid]
forecast_value = ForecastValue(
expected_power_generation_megawatts
=round(predicted_generation_value.power_kilowatts / 1000, 4),
target_time=predicted_generation_value.valid_timestamp,
)
forecast_values_by_gsp_id.setdefault(gsp_id, []).append(forecast_value)

if gsp_id not in forecasts_by_gsp_id:

version = predicted_generation_value.metadata.get("app_version",
predicted_generation_value.forecaster_version)
input_data = format_metadata(predicted_generation_value.metadata)
gsp = next(g for g in gsps if int(g.metadata["gsp_id"]) == gsp_id)
forecast_creation_time = predicted_generation_value.created_timestamp

forecasts_by_gsp_id[gsp_id] = Forecast(
location=Location.from_location(gsp),
model=MLModel(
name=predicted_generation_value.forecaster_name,
version=version,
),
forecast_creation_time=forecast_creation_time,
initialization_datetime_utc=predicted_generation_value.init_timestamp,
# we will add to this later
forecast_values=[],
input_data_last_updated=input_data,
)


forecasts: list[Forecast] = []
gsp_ids = sorted(gsp_uuid_id_map.values())
for gsp_id in gsp_ids:

gsp_forecast = forecasts_by_gsp_id[gsp_id]
forecast_values = forecast_values_by_gsp_id[gsp_id]

gsp_forecast.forecast_values = forecast_values

forecasts.append(gsp_forecast)

return forecasts


@router.get(
Expand Down
17 changes: 17 additions & 0 deletions src/quartz_api/internal/service/uk_national/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Function to format metadata."""
import datetime as dt

from quartz_api.internal.service.uk_national.endpoint_types import InputDataLastUpdated


def format_metadata(metadata: dict) -> InputDataLastUpdated:
"""Format metadata dictionary into InputDataLastUpdated object."""
old = dt.datetime(1970, 1, 1, tzinfo=dt.UTC)
gsp = metadata.get("gsp_last_updated", old)
satellite = metadata.get("satellite_last_updated", old)

# the nwp keys could be nwp_ukv_last_updated, nwp_ecwmwf_last_updated, or nwp_last_updated
nwp = old
for nwp_key in [k for k in metadata if "nwp" in k]:
nwp = max([metadata.get(nwp_key, old)])
return InputDataLastUpdated(gsp=gsp, nwp=nwp, pv=old, satellite=satellite)
15 changes: 2 additions & 13 deletions src/quartz_api/internal/service/uk_national/national_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from quartz_api.internal import models
from quartz_api.internal.middleware.auth import AuthDependency
from quartz_api.internal.middleware.ratelimit import limiter
from quartz_api.internal.service.uk_national.metadata import format_metadata

from .cache import key_builder
from .endpoint_types import (
InputDataLastUpdated,
Location,
MLModel,
ModelName,
Expand Down Expand Up @@ -163,7 +163,7 @@ async def get_national_forecast(
)
@limiter.limit("3600/hour")
@limiter.limit("10/second")
@cache(key_builder=key_builder)
# @cache(key_builder=key_builder)
async def get_national_pvlive(
request: Request, # noqa: ARG001
db: models.StorageClientDependency,
Expand Down Expand Up @@ -219,14 +219,3 @@ async def get_national_pvlive(
return out


def format_metadata(metadata: dict) -> InputDataLastUpdated:
"""Format metadata dictionary into InputDataLastUpdated object."""
old = dt.datetime(1970, 1, 1, tzinfo=dt.UTC)
gsp = metadata.get("gsp_last_updated", old)
satellite = metadata.get("satellite_last_updated", old)

# the nwp keys could be nwp_ukv_last_updated, nwp_ecwmwf_last_updated, or nwp_last_updated
nwp = old
for nwp_key in [k for k in metadata if "nwp" in k]:
nwp = max([metadata.get(nwp_key, old)])
return InputDataLastUpdated(gsp=gsp, nwp=nwp, pv=old, satellite=satellite)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" Test for format metadata"""
import datetime as dt

from .national_router import format_metadata
from .metadata import format_metadata


def test_format_metadata():
Expand Down
48 changes: 46 additions & 2 deletions src/quartz_api/tests/integration/uk_national/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ async def test_gsp_forecast_all(
) -> None:
"""Test a sample endpoint for UK National forecast data."""

response = await api_client_uk_national.get("/v0/solar/GB/gsp/forecast/all/")
response = await api_client_uk_national.get("/v0/solar/GB/gsp/forecast/all/?compact=true")

assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
Expand All @@ -215,7 +216,9 @@ async def test_gsp_forecast_all_gsp_ids(
) -> None:
"""Test a sample endpoint for UK National forecast data."""

response = await api_client_uk_national.get("/v0/solar/GB/gsp/forecast/all/?gsp_ids=1,2,3")
url = "/v0/solar/GB/gsp/forecast/all/?gsp_ids=1,2,3&compact=true"
response = await api_client_uk_national.get(url)

assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
Expand All @@ -225,6 +228,47 @@ async def test_gsp_forecast_all_gsp_ids(
assert len(data[0]["forecastValues"]) == 3


# 4.3.2 Check GSP forecast route, compact=false
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_forecast_compact_false(
api_client_uk_national,
gsp_locations, # noqa arg001
make_forecasters, # noqa arg001
make_gsp_forecast_values, # noqa arg001
) -> None:
"""Test a sample endpoint for UK National forecast data."""

response = await api_client_uk_national.get("/v0/solar/GB/gsp/forecast/all/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 10
assert "location" in data[0]
assert "model" in data[0]
assert "forecastValues" in data[0]
assert len(data[0]["forecastValues"]) == 1

# 4.3.3 Check GSP forecast route, compact=false, and restrict gsps
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_forecast_compact_false_gsp_ids(
api_client_uk_national,
gsp_locations, # noqa arg001
make_forecasters, # noqa arg001
make_gsp_forecast_values, # noqa arg001
) -> None:
"""Test a sample endpoint for UK National forecast data."""

response = await api_client_uk_national.get("/v0/solar/GB/gsp/forecast/all/?gsp_ids=1,2,3")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 3
assert "location" in data[0]
assert "model" in data[0]
assert "forecastValues" in data[0]
assert len(data[0]["forecastValues"]) == 10


# 4.4 Check GSP pvlive route
@pytest.mark.asyncio(loop_scope="session")
async def test_gsp_pvlive_all(
Expand Down
Loading
Loading