diff --git a/src/dataplatform/forecast/data.py b/src/dataplatform/forecast/data.py index d61a032..4799044 100644 --- a/src/dataplatform/forecast/data.py +++ b/src/dataplatform/forecast/data.py @@ -3,6 +3,7 @@ import time from datetime import datetime, timedelta +from grpc_requests import Client import betterproto import pandas as pd from aiocache import Cache, cached @@ -14,6 +15,7 @@ async def get_forecast_data( client: dp.DataPlatformDataServiceStub, + sync_client: Client, location: dp.ListLocationsResponseLocationSummary, start_date: datetime, end_date: datetime, @@ -24,7 +26,7 @@ async def get_forecast_data( for forecaster in selected_forecasters: forecaster_data_df = await get_forecast_data_one_forecaster( - client, + sync_client, location, start_date, end_date, @@ -63,7 +65,7 @@ async def get_forecast_data( @cached(ttl=cache_seconds, cache=Cache.MEMORY, key_builder=key_builder_remove_client) async def get_forecast_data_one_forecaster( - client: dp, + sync_client: Client, location: dp.ListLocationsResponseLocationSummary, start_date: datetime, end_date: datetime, @@ -78,20 +80,21 @@ async def get_forecast_data_one_forecaster( temp_end_date = min(temp_start_date + timedelta(days=30), end_date) # fetch data - stream_forecast_data_request = dp.StreamForecastDataRequest( - location_uuid=location.location_uuid, - energy_source=dp.EnergySource.SOLAR, - time_window=dp.TimeWindow( - start_timestamp_utc=temp_start_date, - end_timestamp_utc=temp_end_date, - ), - forecasters=[selected_forecaster], - ) + stream_forecast_data_request = { + "location_uuid": location.location_uuid, + "energy_source": dp.EnergySource.SOLAR.value, + "time_window": { + "start_timestamp_utc": temp_start_date.isoformat(), + "end_timestamp_utc": temp_end_date.isoformat(), + }, + "forecasters": [selected_forecaster.to_dict()], + } + + svc = sync_client.service("ocf.dp.DataPlatformDataService") + forecasts = [] - async for chunk in client.stream_forecast_data(stream_forecast_data_request): - forecasts.append( - chunk.to_dict(include_default_values=True, casing=betterproto.Casing.SNAKE), - ) + for chunk in svc.StreamForecastData(stream_forecast_data_request): + forecasts.append(chunk) if len(forecasts) > 0: all_data_list_dict.extend(forecasts) @@ -199,6 +202,7 @@ async def get_all_observations( async def get_all_data( client: dp.DataPlatformDataServiceStub, + sync_client: Client, selected_location: dp.ListLocationsResponseLocationSummary, start_date: datetime, end_date: datetime, @@ -219,6 +223,7 @@ async def get_all_data( time_start = time.time() all_forecast_data_df = await get_forecast_data( client, + sync_client, selected_location, start_date, end_date, diff --git a/src/dataplatform/forecast/main.py b/src/dataplatform/forecast/main.py index 9840c0c..e491a8c 100644 --- a/src/dataplatform/forecast/main.py +++ b/src/dataplatform/forecast/main.py @@ -7,6 +7,7 @@ import streamlit as st from dp_sdk.ocf import dp from grpclib.client import Channel +from grpc_requests import Client as GRPC_Client from dataplatform.forecast.constant import metrics, observer_names from dataplatform.forecast.data import align_t0, get_all_data @@ -34,6 +35,9 @@ async def async_dp_forecast_page() -> None: async with Channel(host=data_platform_host, port=data_platform_port) as channel: client = dp.DataPlatformDataServiceStub(channel) + # this is used to streamline requests + sync_client = GRPC_Client.get_by_endpoint(f"{data_platform_host}:{data_platform_port}") + setup_page_dict = await setup_page(client) selected_location = setup_page_dict["selected_location"] start_date = setup_page_dict["start_date"] @@ -51,6 +55,7 @@ async def async_dp_forecast_page() -> None: ### 1. Get all the data ### all_data_dict = await get_all_data( client=client, + sync_client=sync_client, start_date=start_date, end_date=end_date, selected_forecasters=selected_forecasters,