From e027182361a2f7175bcd4130a0efe1b7ebf31228 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Tue, 27 May 2025 11:39:26 +0200 Subject: [PATCH 1/8] Added method get_available_data_range to EfasOperationalDownloader --- .pre-commit-config.yaml | 1 + pyproject.toml | 6 ++ src/ogs_riverger/efas/download_tools.py | 108 ++++++++++++++++++++++-- tests/test_efas.py | 13 +++ 4 files changed, 121 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f112291..825d20c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,7 @@ repos: rev: '7.2.0' hooks: - id: flake8 + additional_dependencies: [Flake8-pyproject] - repo: https://github.com/pycqa/doc8 rev: 'v1.1.2' hooks: diff --git a/pyproject.toml b/pyproject.toml index e6d1887..b1bac24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,12 @@ markers = [ riverger = "ogs_riverger.__main__:main" +[tool.flake8] +max-line-length = 79 +max-complexity = 14 +ignore = ["E203"] + + [build-system] requires = ["poetry-core>=2.0.0,<3.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/src/ogs_riverger/efas/download_tools.py b/src/ogs_riverger/efas/download_tools.py index c6edaef..65852e6 100644 --- a/src/ogs_riverger/efas/download_tools.py +++ b/src/ogs_riverger/efas/download_tools.py @@ -1,6 +1,7 @@ import asyncio import inspect import logging +import re from collections.abc import Iterable from collections.abc import Mapping from datetime import datetime @@ -272,7 +273,7 @@ def download_from_cdsapi( """Downloads data from the CDSAPI EFAS archives (forecast or historical). Due to limitations in the `cdsapi` interface, this function may - need to download multiple files. For example, if the `start_date` + need to download multiple files. For example, if the `start_date` and `end_date` span different months, a separate file for each month must be downloaded. @@ -536,6 +537,53 @@ class EfasOperationalDownloader: sizes. """ + FILE_NAME_MASK = re.compile( + r"^(?P[a-z0-9]+)\.fc\.dis_(?P\d{10})\.grb$" + ) + + @staticmethod + def _generate_file_name(version: str, file_date: datetime) -> str: + """ + Generates the expected file name for a EFAS file based on the given + file version and date of production. + + Args: + version: A string representing the version identifier of the file. + file_date: A `datetime` object representing the date and time + information of when the file has been produced. + + Returns: + A string containing the generated file name in the format + "{version}.fc.dis_{YYYYMMDDHH}.grb". + """ + time_str = file_date.strftime("%Y%m%d%H") + return f"{version}.fc.dis_{time_str}.grb" + + @staticmethod + def _check_file_name(file_name: str) -> dict[str : str | datetime] | None: + """ + Checks if a file name is conformal to the typical file name of the EFAS + operative files. If this is the case, it returns a dictionary + containing the name and the date of production. Otherwise, it returns + `None` + + Args: + file_name: the name of the file + + Returns: + A dictionary with the version and the production date of the + file if the name of the file is conformal to the typical file name + of the EFAS operative service. Otherwise, it returns `None`. + """ + file_match = EfasOperationalDownloader.FILE_NAME_MASK.match(file_name) + if file_match is None: + return None + output = { + "version": file_match.group("version"), + "date": datetime.strptime(file_match.group("date"), "%Y%m%d%H"), + } + return output + def __init__( self, data_dir: Path, @@ -561,6 +609,8 @@ def __init__( for efas_file in efas_dir.iterdir(): if not efas_file.is_file(): continue + if self._check_file_name(efas_file.name) is None: + continue file_stat = efas_file.stat() file_size = file_stat.st_size efas_cache[Path(efas_file)] = file_size @@ -762,6 +812,7 @@ async def _single_download( """ logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") logger.info("Downloading file %s", remote_file_path) + for attempt in range(1, retries + 1): try: async with self._semaphore: @@ -796,6 +847,48 @@ async def _single_download( ) return output_path + async def get_available_data_range(self) -> tuple[datetime, datetime]: + """ + Retrieve the range of dates available on the remote server. + + This method connects to the remote server, fetches the list of + available files, and determines the minimum and maximum + dates present in the valid filenames. Files with non-conforming + filenames are discarded, and if no valid file is present, an exception + is raised. + + Returns: + A tuple containing the earliest and latest dates found in the + valid filenames from the remote server + + Raises: + RuntimeError: If no valid files are available on the remote + server + """ + logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") + + available_files = await self._get_remote_server_available_files() + logger.debug( + "%s files found on the remote server", len(available_files) + ) + + dates = [] + for file_name in available_files: + file_name_check = self._check_file_name(file_name) + if file_name_check is None: + logger.debug( + "File %s has been discarded because its name it not" + "conformal", + file_name, + ) + continue + dates.append(file_name_check["date"]) + if len(dates) == 0: + raise RuntimeError( + "No files available on the remote EFAS FTP server" + ) + return min(dates), max(dates) + async def download_efas_operational_data( self, start_time: datetime, @@ -848,9 +941,10 @@ async def download_efas_operational_data( returned_files = set() files_to_be_downloaded = set() for time_step in time_steps: - time_str = time_step.strftime("%Y%m%d%H") for version in self._versions: - expected_file_name = f"{version}.fc.dis_{time_str}.grb" + expected_file_name = self._generate_file_name( + version, time_step + ) logger.debug( "Checking if a file named %s exists", expected_file_name ) @@ -865,8 +959,8 @@ async def download_efas_operational_data( if cache_position is None and version in self._fallback: fallback_version = self._fallback[version] - fallback_file_name = ( - f"{fallback_version}.fc.dis_{time_str}.grb" + fallback_file_name = self._generate_file_name( + fallback_version, time_step ) logger.debug( "Trying the fallback version %s: checking for file %s", @@ -895,14 +989,14 @@ async def download_efas_operational_data( "time-step %s", expected_file_name, fallback_file_name, - time_str, + time_step.strftime("%Y/%m/%d-%H:%M:%s"), ) elif cache_position is None: logger.debug( "There is not fallback version for %s; no file will " "be downloaded for time-step %s", version, - time_str, + time_step.strftime("%Y/%m/%d-%H:%M:%S"), ) else: returned_files.add(cache_position) diff --git a/tests/test_efas.py b/tests/test_efas.py index 0cbc6f3..58e9323 100644 --- a/tests/test_efas.py +++ b/tests/test_efas.py @@ -692,6 +692,19 @@ def failing_download(remote_path): assert log_exception.call_count == 1 +@pytest.mark.external_resources +async def test_efas_operational_get_available_data_range(settings, tmp_path): + downloader = await EfasOperationalDownloader.create( + Path(tmp_path), + efas_user=settings.EFAS_FTP_USER, + efas_password=settings.EFAS_FTP_PASSWORD, + versions=("eud",), + fallback_versions={"eud": "dwd"}, + ) + data_min, data_max = await downloader.get_available_data_range() + assert data_max - data_min < timedelta(days=90) + + @pytest.mark.external_resources async def test_efas_operational_download_and_read( settings, config_example, efas_domain_file, tmp_path From 29fd8f0100b9f63e88b6746ae5e58ba755360731 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Tue, 27 May 2025 11:45:25 +0200 Subject: [PATCH 2/8] fixing small bug in efas_manager.py --- src/ogs_riverger/efas/efas_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/ogs_riverger/efas/efas_manager.py b/src/ogs_riverger/efas/efas_manager.py index f475db9..938c83b 100644 --- a/src/ogs_riverger/efas/efas_manager.py +++ b/src/ogs_riverger/efas/efas_manager.py @@ -353,11 +353,7 @@ def _read_single_efas_file(dataset: xr.Dataset) -> xr.Dataset: ) return _read_efas_historical_file(dataset) - if ( - "time" in dataset.coords - and "time" in dataset.coords - and "valid_time" in dataset.coords - ): + if "time" in dataset.coords and "valid_time" in dataset.coords: logger.debug( 'The file has a coordinate named "time"; usually this means that ' "is a GRIB historical file" From 50f88c924337d43a081fb770ebf9abb7b7deb010 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Tue, 27 May 2025 13:01:46 +0200 Subject: [PATCH 3/8] EfasOperationalDownloader saves the file as .temp during download --- src/ogs_riverger/efas/download_tools.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/ogs_riverger/efas/download_tools.py b/src/ogs_riverger/efas/download_tools.py index 65852e6..de9eef8 100644 --- a/src/ogs_riverger/efas/download_tools.py +++ b/src/ogs_riverger/efas/download_tools.py @@ -2,6 +2,7 @@ import inspect import logging import re +import shutil from collections.abc import Iterable from collections.abc import Mapping from datetime import datetime @@ -813,6 +814,9 @@ async def _single_download( logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") logger.info("Downloading file %s", remote_file_path) + temp_path = self.data_dir / f"{remote_file_path.name}.tmp" + output_path = self.data_dir / remote_file_path.name + for attempt in range(1, retries + 1): try: async with self._semaphore: @@ -824,7 +828,9 @@ async def _single_download( socket_timeout=30, connection_timeout=30, ) as client: - await client.download(remote_file_path, self.data_dir) + await client.download( + remote_file_path, temp_path, write_into=True + ) break except aioftp.errors.StatusCodeError as e: if "425" in str(e) and attempt != retries: @@ -838,8 +844,8 @@ async def _single_download( await asyncio.sleep(attempt * 4) continue raise + shutil.move(temp_path, output_path) - output_path = self.data_dir / remote_file_path.name logger.info( "Download of file %s into %s completed", remote_file_path, From 4a932db2974b30c43b7d7f45104f69620b3b9db3 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Wed, 4 Jun 2025 16:45:07 +0200 Subject: [PATCH 4/8] Added EfasArchiveDownloader class --- src/ogs_riverger/efas/download_tools.py | 290 +++++++++++++++++------ src/ogs_riverger/utils/datetime_utils.py | 81 +++++++ tests/test_efas.py | 122 ++++++++-- 3 files changed, 411 insertions(+), 82 deletions(-) create mode 100644 src/ogs_riverger/utils/datetime_utils.py diff --git a/src/ogs_riverger/efas/download_tools.py b/src/ogs_riverger/efas/download_tools.py index de9eef8..85140a4 100644 --- a/src/ogs_riverger/efas/download_tools.py +++ b/src/ogs_riverger/efas/download_tools.py @@ -3,12 +3,15 @@ import logging import re import shutil +from collections.abc import Container from collections.abc import Iterable from collections.abc import Mapping from datetime import datetime from datetime import time from datetime import timedelta +from datetime import timezone from enum import Enum +from itertools import product as cart_prod from os import PathLike from pathlib import Path from typing import Annotated @@ -23,6 +26,7 @@ from pydantic import SecretStr from ogs_riverger.utils.area_selection import AreaSelection +from ogs_riverger.utils.datetime_utils import check_all_timezone_awareness EFAS_FTP_URL = "aux.ecmwf.int" @@ -270,6 +274,7 @@ def download_from_cdsapi( file_data_format: EfasCEMSDataFormat = EfasCEMSDataFormat.NETCDF, output_file_mask: str = "efas_{SERVICE}_{DATE}.{FORMAT}{IS_ZIP}", date_format: str = "{YEAR}_{MONTH:02d}_{START_DAY:02d}-{END_DAY:02d}", + skip: Container[str] | None = None, ) -> tuple[Path, ...]: """Downloads data from the CDSAPI EFAS archives (forecast or historical). @@ -311,8 +316,16 @@ def download_from_cdsapi( - {START_DAY} is the start day of the file - {END_DAY} is the end day of the file + skip: list of file names that should not be downloaded. If one of the + files that must be downloaded by this function has a name + contained inside the skip object, this function will skip the + download of that file. For example, you may want to skip the + download of a file that has already been downloaded. + Returns: - A tuple with the paths of all the files that have been downloaded + A tuple with the paths of all the files that have been downloaded. If + a file has been skipped because its name was in the `skip` list, its + name will not be returned by this function. Raises: ValueError: the output directory does not exist or is not a directory. @@ -346,76 +359,82 @@ def download_from_cdsapi( one_day = timedelta(days=1) retrieve_args = [] - for year in years: - for month in range(1, 13): - start_month_date = datetime(year, month, 1) - if month != 12: - end_month_date = datetime(year, month + 1, 1) - one_day - else: - end_month_date = datetime(year + 1, 1, 1) - one_day - - if start_month_date > end_date: - logger.debug( - "Month %s of year %s will not be downloaded because it is " - "after the end_date (%s)", - month, - year, - end_date, - ) - continue - if end_month_date < start_date: - logger.debug( - "Month %s of year %s will not be downloaded because it is " - "before the start_date (%s)", - month, - year, - start_date, - ) - continue + for year, month in cart_prod(years, range(1, 13)): + start_month_date = datetime(year, month, 1) + if month != 12: + end_month_date = datetime(year, month + 1, 1) - one_day + else: + end_month_date = datetime(year + 1, 1, 1) - one_day + if start_month_date > end_date: logger.debug( - "Preparing request for month %s of year %s", month, year + "Month %s of year %s will not be downloaded because it is " + "after the end_date (%s)", + month, + year, + end_date, ) - - start_month_date = max(start_date, start_month_date) - end_month_date = min(end_date, end_month_date) - - start_day = start_month_date.day - end_day = end_month_date.day - is_zip_str = "" - if download_format is EfasCEMSDownloadFormat.ZIP: - is_zip_str = ".zip" - file_date_str = date_format.format( - YEAR=year, - MONTH=month, - START_DAY=start_day, - END_DAY=end_day, - ) - output_file_name = output_file_mask.format( - SERVICE=service.value, - DATE=file_date_str, - FORMAT=file_data_format.value, - IS_ZIP=is_zip_str, + continue + if end_month_date < start_date: + logger.debug( + "Month %s of year %s will not be downloaded because it is " + "before the start_date (%s)", + month, + year, + start_date, ) - output_file_path = output_dir / output_file_name - - days = tuple(range(start_day, end_day + 1)) - - request = request_class( - year=(year,), - month=(month,), - day=days, - area=area, - data_format=file_data_format, - download_format=download_format, + continue + + logger.debug("Preparing request for month %s of year %s", month, year) + + start_month_date = max(start_date, start_month_date) + end_month_date = min(end_date, end_month_date) + + start_day = start_month_date.day + end_day = end_month_date.day + is_zip_str = "" + if download_format is EfasCEMSDownloadFormat.ZIP: + is_zip_str = ".zip" + file_date_str = date_format.format( + YEAR=year, + MONTH=month, + START_DAY=start_day, + END_DAY=end_day, + ) + output_file_name = output_file_mask.format( + SERVICE=service.value, + DATE=file_date_str, + FORMAT=file_data_format.value, + IS_ZIP=is_zip_str, + ) + if skip is not None and output_file_name in skip: + logger.debug( + "Skipping download of file %s because it was in the skip list", + output_file_name, ) + continue + output_file_path = output_dir / output_file_name + + days = tuple(range(start_day, end_day + 1)) + + request = request_class( + year=(year,), + month=(month,), + day=days, + area=area, + data_format=file_data_format, + download_format=download_format, + ) - retrieve_args.append( - (f"efas-{service.value}", request, output_file_path) - ) + retrieve_args.append( + (f"efas-{service.value}", request, output_file_path) + ) def download_file(retrieve_arg): _request_name, _request, _output_file_path = retrieve_arg + temp_file_path = ( + _output_file_path.parent / f"{_output_file_path.name}.temp" + ) logger.debug( "Downloading file %s using the following request: %s", _output_file_path, @@ -424,10 +443,11 @@ def download_file(retrieve_arg): cdsapi_client.retrieve( _request_name, _request.dump(), - _output_file_path, + temp_file_path, ) + shutil.move(temp_file_path, _output_file_path) logger.debug("File %s has been downloaded", _output_file_path) - return output_file_path + return _output_file_path saved_files = map(download_file, retrieve_args) @@ -851,6 +871,19 @@ async def _single_download( remote_file_path, output_path, ) + + # Add the file into the cache + try: + output_path_async = anyio.Path(output_path) + file_stat = await output_path_async.stat() + file_size = file_stat.st_size + self._cache[output_path] = file_size + except Exception: + logger.exception( + "Failed to add file %s into the cache", output_path + ) + raise + return output_path async def get_available_data_range(self) -> tuple[datetime, datetime]: @@ -899,7 +932,7 @@ async def download_efas_operational_data( self, start_time: datetime, end_time: datetime, - ) -> list[Path]: + ) -> tuple[Path, ...]: """Downloads all files from the FTP server for a specified time interval. @@ -1030,4 +1063,129 @@ async def download_efas_operational_data( else: returned_files.add(file_local_download) - return sorted(list(returned_files)) + return tuple(sorted(list(returned_files))) + + +class EfasArchiveDownloader: + FILE_NAME_MASK = re.compile( + r"^efas_(?P[a-zA-Z0-9]+)_(?P\d{4}_\d{2}_\d{2}-\d{2})" + r"\.netcdf\.zip$" + ) + OUTPUT_FILE_MASK: str = "efas_{SERVICE}_{DATE}.{FORMAT}{IS_ZIP}" + FILE_DATA_FORMAT: str = "{YEAR}_{MONTH:02d}_{START_DAY:02d}-{END_DAY:02d}" + + def __init__( + self, + data_dir: Path, + service: EfasDataSource = EfasDataSource.HISTORICAL, + area: AreaSelection | None = None, + cache: Iterable[Path] | None = None, + cdsapi_client: cdsapi.Client | None = None, + ): + self.data_dir = data_dir + self.service = service + self.area = area + + if cdsapi_client is None: + self._cdsapi_client = get_cdsapi_client() + else: + self._cdsapi_client = cdsapi_client + + if cache is None: + cache = self.data_dir.iterdir() + + self._cache: set[Path] = set() + for file_path in cache: + if not self.FILE_NAME_MASK.match(file_path.name): + continue + self._cache.add(file_path) + + @staticmethod + def _get_file_start_end_dates(file_name: str) -> tuple[datetime, datetime]: + """Extracts the start and end dates from a file name.""" + f_match = EfasArchiveDownloader.FILE_NAME_MASK.match(file_name) + + f_date = f_match.group("DATE") + f_date_start, f_date_end_day = f_date.split("-") + f_date_start = datetime.strptime(f_date_start, "%Y_%m_%d") + f_date_end_day = int(f_date_end_day) + + f_date_end = datetime( + year=f_date_start.year, + month=f_date_start.month, + day=f_date_end_day, + ) + return f_date_start, f_date_end + + def download_efas_archived_files( + self, start_time: datetime, end_time: datetime + ) -> tuple[Path, ...]: + logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") + + timezone_aware = check_all_timezone_awareness(start_time, end_time) + logger.debug( + "The datetime objects are timezone%s", + "-aware" if timezone_aware else " naive", + ) + + already_downloaded: set[Path] = set() + for f in self._cache: + logger.debug( + "Checking if file %s saved in cache is useful for the request", + f, + ) + f_match = self.FILE_NAME_MASK.match(f.name) + if f_match is None: + logger.warning( + "There is a file into this %s cache that is not " + "conformal to the expected file mask: %s", + self.__class__.__name__, + f, + ) + continue + + file_start_date, file_end_date = self._get_file_start_end_dates( + f.name + ) + + if timezone_aware: + file_start_date = file_start_date.replace(tzinfo=timezone.utc) + file_end_date = file_end_date.replace(tzinfo=timezone.utc) + + if file_start_date > end_time: + logger.debug( + "File %s is not useful because it is too recent", f.name + ) + continue + if file_end_date < start_time: + logger.debug( + "File %s is not useful because it is too old", f.name + ) + continue + + logger.debug("File %s will be inserted inside the cache") + already_downloaded.add(f) + + logger.debug( + "%s files will be provided by the cache", len(already_downloaded) + ) + + downloaded_files = download_from_cdsapi( + self.service, + start_time, + end_time, + self.data_dir, + area=self.area, + cdsapi_client=self._cdsapi_client, + download_format=EfasCEMSDownloadFormat.ZIP, + file_data_format=EfasCEMSDataFormat.NETCDF, + output_file_mask=self.OUTPUT_FILE_MASK, + date_format=self.FILE_DATA_FORMAT, + skip=set(f.name for f in already_downloaded), + ) + + for d_file in downloaded_files: + self._cache.add(d_file) + + useful_files = downloaded_files + tuple(already_downloaded) + return tuple(sorted(useful_files)) diff --git a/src/ogs_riverger/utils/datetime_utils.py b/src/ogs_riverger/utils/datetime_utils.py new file mode 100644 index 0000000..586d53d --- /dev/null +++ b/src/ogs_riverger/utils/datetime_utils.py @@ -0,0 +1,81 @@ +from datetime import datetime + + +class MixedTimezoneAwareness(ValueError): + """ + Represents an error raised when there is a mismatch or inconsistency in + timezone awareness between two or more datetime objects. + + This error is typically used to identify and handle cases where operations + on datetime objects require consistent timezone-awareness, but the + provided datetime objects have mixed timezone awareness. + + For instance, this might occur when attempting to compare or perform an + operation between a timezone-aware datetime object and a naive datetime + object. + + This exception inherits from the built-in ValueError to signify that it is + raised due to invalid or inconsistent values related to timezone awareness. + """ + + pass + + +def is_timezone_aware(dt: datetime) -> bool: + """ + Determine whether a given datetime object is timezone-aware or not. + + This function checks the `tzinfo` attribute and the `utcoffset` method of + the provided datetime object to determine whether it is time-aware or not. + A timezone-aware datetime object contains timezone information that allows + it to handle different time zones accurately. + + Args: + dt: A datetime object to be checked for timezone awareness. + If the `tzinfo` attribute is `None` or the result of `utcoffset` + is `None`, the datetime is considered naive (not time-aware). + Otherwise, it is considered time-aware. + + Returns: + A boolean indicating whether the provided datetime object is + timezone-aware (`True`) or timezone-naive (`False`). + """ + naive = dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None + return not naive + + +def check_all_timezone_awareness(*datetime_objects: datetime) -> bool: + """ + Determines if all provided datetime objects have consistent timezone + awareness (either all timezone-aware or all timezone-naive). If no + datetime objects are provided, it assumes consistency and returns True. + + Args: + *datetime_objects (datetime): Variable-length argument list of + datetime objects to be checked for timezone awareness + consistency. + + Returns: + bool: True if all datetime objects have consistent timezone + awareness or no datetime objects are provided, False otherwise. + + Raises: + MixedTimezoneAwareness: If the datetime objects have mixed + timezone awareness (some are timezone-aware and some are + timezone-naive). + """ + if len(datetime_objects) == 0: + return True + + timezone_aware = is_timezone_aware(datetime_objects[0]) + + for dt in datetime_objects: + current_timezone_aware = is_timezone_aware(dt) + if current_timezone_aware != timezone_aware: + raise MixedTimezoneAwareness( + "The provided datetime objects have mixed timezone " + "awareness. All datetime objects must either be " + "time-aware or time-naive." + ) + + return timezone_aware diff --git a/tests/test_efas.py b/tests/test_efas.py index 58e9323..bf53c0e 100644 --- a/tests/test_efas.py +++ b/tests/test_efas.py @@ -5,17 +5,20 @@ from pathlib import Path from types import MappingProxyType from unittest.mock import AsyncMock +from unittest.mock import Mock from unittest.mock import patch from zipfile import ZipFile import numpy as np import pytest import xarray as xr +from cdsapi import Client from pydantic import SecretStr from ogs_riverger.efas.download_tools import AreaSelection from ogs_riverger.efas.download_tools import download_from_cdsapi from ogs_riverger.efas.download_tools import download_yearly_data_from_cdsapi +from ogs_riverger.efas.download_tools import EfasArchiveDownloader from ogs_riverger.efas.download_tools import EfasCEMSDataFormat from ogs_riverger.efas.download_tools import EfasCEMSDownloadFormat from ogs_riverger.efas.download_tools import EfasDataSource @@ -27,20 +30,21 @@ from ogs_riverger.efas.efas_manager import read_efas_data_files -class CDSClientMock: - """ - A mock that simulates the interface of `cdsapi.Client`, storing the - requests that would have been made to the API. - """ +def create_cds_client_mock(): + """Creates a Mock that simulates the interface of `cdsapi.Client`.""" + mock = Mock(spec=Client) + mock.requests = [] - def __init__(self): - self._requests = [] + def retrieve_side_effect(name: str, request, output_file): + output_file.touch() + mock.requests.append(request) - def retrieve(self, name: str, request, output_file) -> None: - self._requests.append(request) + def get_requests(): + return tuple(mock.requests) - def get_requests(self) -> tuple[dict, ...]: - return tuple(self._requests) + mock.retrieve.side_effect = retrieve_side_effect + mock.get_requests = get_requests + return mock @pytest.fixture @@ -188,7 +192,7 @@ def test_download_forecast_data_single_month(year, month, n_days, tmp_path): else: end_date = datetime(year, month + 1, 1) - timedelta(days=1) - client = CDSClientMock() + client = create_cds_client_mock() area = AreaSelection(north=1, west=2, south=3, east=4) # noinspection PyTypeChecker @@ -228,10 +232,9 @@ def test_download_forecast_data_two_months(tmp_path): start_date = datetime(2024, 3, 12) end_date = datetime(2024, 4, 18) - client = CDSClientMock() + client = create_cds_client_mock() area = AreaSelection(north=1, west=2, south=3, east=4) - # noinspection PyTypeChecker download_from_cdsapi( EfasDataSource.FORECAST, start_date, @@ -259,10 +262,9 @@ def test_download_yearly_data(year, tmp_path): WHEN the function is executed, THEN the function downloads the data using the correct cdsapi request. """ - client = CDSClientMock() + client = create_cds_client_mock() area = AreaSelection(north=1, west=2, south=3, east=4) - # noinspection PyTypeChecker download_yearly_data_from_cdsapi( EfasDataSource.HISTORICAL, year=year, @@ -808,3 +810,91 @@ def test_efas_from_cdsapi_download_and_read( assert "time" in efas_dataset.coords assert "dis06" in efas_dataset.variables + + +def test_efas_archive_downloader(tmp_path): + client = create_cds_client_mock() + area = AreaSelection(north=1, west=2, south=3, east=4) + + downloader = EfasArchiveDownloader( + data_dir=tmp_path, + service=EfasDataSource.HISTORICAL, + area=area, + cdsapi_client=client, + ) + + start_time = datetime(2023, 10, 18) + end_time = datetime(2024, 11, 14) + n_files = 14 + + downloaded_files = downloader.download_efas_archived_files( + start_time=start_time, end_time=end_time + ) + + assert len(downloaded_files) == n_files + + +def test_efas_archive_downloader_uses_cache(tmp_path): + client = create_cds_client_mock() + + downloader = EfasArchiveDownloader( + data_dir=tmp_path, + service=EfasDataSource.HISTORICAL, + cdsapi_client=client, + ) + + start_time = datetime(2023, 10, 18) + end_time = datetime(2024, 2, 8) + + downloader.download_efas_archived_files( + start_time=start_time, end_time=end_time + ) + # We performed 5 downloads, one for each month + assert client.retrieve.call_count == 5 + + new_end_time = datetime(2024, 3, 1) + downloaded_files = downloader.download_efas_archived_files( + start_time=start_time, end_time=new_end_time + ) + + # 6 files, from October 2023 to March 2024 + assert len(downloaded_files) == 7 + + # Besides the 5 files from the previous call, we download again february, + # and we download the first day of March + assert client.retrieve.call_count == 7 + + +def test_efas_archive_downloader_generates_the_cache_when_built(tmp_path): + client1 = create_cds_client_mock() + + downloader1 = EfasArchiveDownloader( + data_dir=tmp_path, + service=EfasDataSource.HISTORICAL, + cdsapi_client=client1, + ) + + start_time = datetime(2023, 10, 18) + end_time = datetime(2024, 2, 8) + + downloader1.download_efas_archived_files( + start_time=start_time, end_time=end_time + ) + + client2 = create_cds_client_mock() + downloader2 = EfasArchiveDownloader( + data_dir=tmp_path, + service=EfasDataSource.HISTORICAL, + cdsapi_client=client2, + ) + + new_end_time = datetime(2024, 3, 1) + downloaded_files = downloader2.download_efas_archived_files( + start_time=start_time, end_time=new_end_time + ) + + # 6 files, from October 2023 to March 2024 + assert len(downloaded_files) == 7 + + # We download again february, and we download the first day of March + assert client2.retrieve.call_count == 2 From d6cc72d8484ca86272951ae586f63bc12ae4d54f Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Thu, 5 Jun 2025 10:55:59 +0200 Subject: [PATCH 5/8] Restructuring code so that it is easier to parallelize --- src/ogs_riverger/efas/efas_manager.py | 504 +++++++++++++------------- 1 file changed, 259 insertions(+), 245 deletions(-) diff --git a/src/ogs_riverger/efas/efas_manager.py b/src/ogs_riverger/efas/efas_manager.py index 938c83b..ad92ac4 100644 --- a/src/ogs_riverger/efas/efas_manager.py +++ b/src/ogs_riverger/efas/efas_manager.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from contextlib import ExitStack from datetime import datetime +from functools import partial from os import PathLike from pathlib import Path from tempfile import TemporaryDirectory @@ -209,8 +210,7 @@ def read_efas_data_files( compressed inside some zip archives. The function reads the different kinds of files, merges them, - applies the runoff factor described in the CSV configuration file (if - not specified otherwise), and returns an xArray dataset. + and returns an xArray dataset. Args: input_files: An iterable of paths to the input files @@ -227,87 +227,256 @@ def read_efas_data_files( It also contains a one-dimensional variable named "computation_date" that contains the date of the computation. """ - # This function is just a wrapper that decompresses the (optionally) - # zipped files and calls _read_unzipped_efas_files logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") - # Is there any zip file? - zip_files = len([f for f in input_files if f.suffix.lower() == ".zip"]) > 0 - if not zip_files: - logger.debug( - "No zip files found in the list of files that must be read; " - "they will be read as they are: %s", - zip_files, + logger.debug("Reading the EFAS domain file") + efas_domain = xr.load_dataset(efas_domain_file) + + # Create a temporary directory to store the unzipped files + with TemporaryDirectory() as t: + logger.debug("Creating temporary directory %s", t) + + elaborate_file = partial( + _read_single_efas_file, + config_content=config_content, + efas_domain=efas_domain, + temp_dir=Path(t), ) - return _read_unzipped_efas_files( - input_files, - config_content, - efas_domain_file, + datasets = map(elaborate_file, input_files) + + logger.debug("Concatenating the datasets...") + concat_dataset = xr.concat(datasets, dim="time") + logger.debug("All EFAS files have been read") + logger.debug("Directory %s has been deleted", t) + + # Inside the EFAS files, we could have multiple forecasts for the same + # time of the model (because they have been produced on different runs). + # Here we select only the most reliable forecast (the one that is closest + # to the computation date) + logger.debug("Getting the best forecast (if needed)") + _, slicing_indices = _get_best_unique_element( + concat_dataset.time.values, concat_dataset.computation_date.values + ) + + efas_data = concat_dataset.isel(time=slicing_indices) + + return efas_data + + +def _read_raw_content_of_efas_file( + file_path: Path, + config_content: Iterable[RiverConfigElement], + efas_domain: xr.Dataset, + temp_dir: Path, +) -> xr.Dataset: + """ + Reads the raw content of an EFAS file. + + This function processes a given EFAS file, either zipped or unzipped, + and extracts its content into a usable xarray.Dataset format. + If the input file is a zip file, it ensures that it contains exactly one + file, extracts it into a temporary directory, and processes the extracted + file's content. If the file is not zipped, it directly processes the + content as it is. + + This function returns the raw content of the EFAS file in xarray.Dataset; + the only operation performed on the raw content is the selection of the + rivers' data (i.e., we select only the cells where there is a mouth of a + river). + + Args: + file_path: Path to the EFAS file to be read. It can be either a zipped + file containing one item or an unzipped data file. + config_content: Iterable containing the EFAS file configuration + elements. Each element describes the river to be considered and + the coordinates of its mouth. + efas_domain: The dataset stored into the file containing + the coordinates of the efas domain. + temp_dir: Path to the directory where temporary files or + decompressed directories will be created during the execution. + + Returns: + Parsed data from the EFAS file in xarray.Dataset format. + """ + logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") + + if not file_path.suffix.lower() == ".zip": + logger.debug("File %s is not zipped; it will be read as is", file_path) + return _read_raw_content_of_unzipped_efas_file( + file_path, config_content, efas_domain ) + decompressed_dir = temp_dir / (file_path.stem + "___" + uuid1().hex) + decompressed_dir.mkdir(parents=False, exist_ok=False) logger.debug( - "Some EFAS files are zipped and must be decompressed before being " - "read." + "File %s will be unzipped zipped into directory %s", + file_path, + decompressed_dir, + ) + with ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(decompressed_dir) + if len(zip_ref.namelist()) > 1: + raise ValueError( + f"The zip file {file_path} contains more than one file; " + "only one file is expected" + ) + if len(zip_ref.namelist()) == 0: + raise ValueError( + f"The zip file {file_path} does not contain any file" + ) + data_file_name = zip_ref.namelist()[0] + zip_ref.extract(data_file_name, decompressed_dir) + decompressed_data_file = decompressed_dir / data_file_name + logger.debug( + "File %s that was stored inside %s has been decompressed into %s", + data_file_name, + file_path, + decompressed_data_file, + ) + return _read_raw_content_of_unzipped_efas_file( + decompressed_data_file, config_content, efas_domain ) - uncompressed_files = [] - with TemporaryDirectory() as t: - t = Path(t) - logger.debug("Temporary directory %s will be used", t) - for f in input_files: - if f.suffix.lower() != ".zip": - logger.debug( - "File %s will be read as is (already uncompressed)", f - ) - uncompressed_files.append(f) - continue - - decompressed_dir = t / (f.stem + "___" + uuid1().hex) - decompressed_dir.mkdir(parents=False, exist_ok=False) - logger.debug("Decompressing file %s into %s", f, decompressed_dir) - with ZipFile(f, "r") as zip_ref: - zip_ref.extractall(decompressed_dir) - if len(zip_ref.namelist()) > 1: - raise ValueError( - f"The zip file {f} contains more than one file; " - "only one file is expected" - ) - if len(zip_ref.namelist()) == 0: - raise ValueError( - f"The zip file {f} does not contain any file" - ) - data_file_name = zip_ref.namelist()[0] - zip_ref.extract(data_file_name, decompressed_dir) - decompressed_data_file = decompressed_dir / data_file_name - logger.debug( - "File %s has been decompressed into %s", - f, - decompressed_data_file, - ) - uncompressed_files.append(decompressed_data_file) - output_data = _read_unzipped_efas_files( - uncompressed_files, - config_content, - efas_domain_file, - ) +def _read_raw_content_of_unzipped_efas_file( + file_path: Path, + config_content: Iterable[RiverConfigElement], + efas_domain: xr.Dataset, +) -> xr.Dataset: + """ + Reads and processes the raw content of an unzipped EFAS file to extract + specific river data based on given configurations and EFAS domain. - logger.debug("Directory %s has been deleted", t) - return output_data + The function opens the specified file (which may be in GRIB or NetCDF + format) using Xarray and extracts river data by indexing latitude and + longitude positions. It validates that river positions are within the + domain of the file and reads and loads the corresponding data from the + dataset. + + Args: + file_path: The file path to the unzipped EFAS data file. + config_content: An iterable containing river configuration elements, + filtered for EFAS source type. + efas_domain: The dataset representing the EFAS domain, containing + latitude and longitude values. + + Returns: + The dataset containing the river data extracted and indexed based + on the configuration and EFAS domain. + + Raises: + InvalidEfasFile: If the domain of the file does not contain certain + rivers specified in the configuration. + """ + logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") + + domain_latitudes = efas_domain.latitude.values + domain_longitudes = efas_domain.longitude.values + + river_configs = tuple( + c for c in config_content if c.data_source.type == "EFAS" + ) + n_rivers = len(river_configs) + # Get the ids, latitude and longitude of the rivers + river_ids = np.empty((n_rivers,), dtype=int) + river_names = [] + lon_indices_array = np.empty_like(river_ids) + lat_indices_array = np.empty_like(river_ids) + for i, river in enumerate(river_configs): + river_ids[i] = river.id + river_names.append(river.name) + lon_indices_array[i] = river.data_source.longitude_index + lat_indices_array[i] = river.data_source.latitude_index -def _read_single_efas_file(dataset: xr.Dataset) -> xr.Dataset: + # Transform the arrays with the indices of the latitude and the longitude + # into two DataArrays + lon_indices = xr.DataArray( + lon_indices_array, + dims=["id"], + coords={"id": river_ids}, + ) + lat_indices = xr.DataArray( + lat_indices_array, + dims=["id"], + coords={"id": river_ids}, + ) + + # This file could be a grib or a NetCDF; luckily, xarray supports both + logger.debug('Opening file "%s"', file_path) + + # We use Dask here (chunks={}) because it is way more efficient than + # standard xarray when executing the isel method. + # By setting "decode_timedelta=True" we ensure that the values of the + # step variable are decoded as timedelta64 objects (and we also + # silence a warning) + try: + with xr.open_dataset( + file_path, chunks={}, decode_timedelta=True + ) as single_ds: + dataset_latitudes = single_ds.latitude.values + dataset_longitudes = single_ds.longitude.values + + i_lat1, i_lat2 = _find_slice(domain_latitudes, dataset_latitudes) + i_lon1, i_lon2 = _find_slice(domain_longitudes, dataset_longitudes) + + # Check that the position of the rivers is coherent with the + # domain of the file we have downloaded + outside_lat = np.logical_or( + lat_indices < i_lat1, lat_indices >= i_lat2 + ) + outside_lon = np.logical_or( + lon_indices < i_lon1, lon_indices >= i_lon2 + ) + if np.any(outside_lat) or np.any(outside_lon): + river_outside_index = np.nonzero( + (outside_lon | outside_lat).values + )[0][0] + river_name = river_names[river_outside_index] + river_latitude = lat_indices.values[river_outside_index] + river_longitude = lon_indices.values[river_outside_index] + + lat_sorted = np.sort(dataset_latitudes) + lon_sorted = np.sort(dataset_longitudes) + raise InvalidEfasFile( + f'The domain of the file "{file_path}" (latitudes ' + f"from {lat_sorted[0]:.3f} to {lat_sorted[-1]:.3f} " + f"and longitudes from {lon_sorted[0]:.3f} to " + f"{lon_sorted[-1]:.3f}) does not contain the river " + f'"{river_name}", whose mouth has latitude ' + f"{domain_latitudes[river_latitude]:.3f} and " + f"longitude {domain_longitudes[river_longitude]:.3f}." + ) + except Exception as e: + raise type(e)(f"Error while trying to read file {file_path}") from e + + logger.debug("Retrieving rivers' data from the map") + file_content = single_ds.isel( + longitude=lon_indices - i_lon1, + latitude=lat_indices - i_lat1, + ).load() + return file_content + + +def _read_single_efas_file( + file_path: Path, + config_content: Iterable[RiverConfigElement], + efas_domain: xr.Dataset, + temp_dir: Path, +) -> xr.Dataset: """Determines the type of EFAS file and processes it accordingly. - This function examines the structure and contents of the input dataset to + This function reads the EFAS file content and examines its structure to identify its type. Based on the file characteristics, it calls the corresponding processing function, such as `_read_efas_operative_grib_file` for operative GRIB files or `_read_efas_historical_file` for historical NetCDF files. Args: - dataset (xr.Dataset): An xarray dataset containing the EFAS data to be - analyzed. + file_path: Path to the EFAS file to be processed. + config_content: Configuration elements for each river to be considered. + efas_domain: Dataset containing the EFAS domain coordinates. + temp_dir: Directory for temporary file operations. Returns: xr.Dataset: The processed EFAS data in a standard format. @@ -318,51 +487,64 @@ def _read_single_efas_file(dataset: xr.Dataset) -> xr.Dataset: """ logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") - if "id" not in dataset.dims: + raw_content = _read_raw_content_of_efas_file( + file_path=file_path, + config_content=config_content, + efas_domain=efas_domain, + temp_dir=temp_dir, + ) + + # I don't know what the "surface" coordinate is, but it is unnecessary + if "surface" in raw_content: + logger.debug('Removing coordinate "surface"') + del raw_content["surface"] + + if "id" not in raw_content.dims: raise InvalidEfasFile( f'The current file has no "id" dimension; the current ' - f"dimensions are: {list(dataset.dims)}" + f"dimensions are: {list(raw_content.dims)}" ) - if len(dataset.dims) > 2: + if len(raw_content.dims) > 2: # Files with more than two dimensions usually are forecast files, where # we have two dimensions for the time; one is the time when the # computation has been performed, the other one is the time of the # model - if set(dataset.dims) == {"id", "step", "time"}: - return _read_efas_forecast_file(dataset) + if set(raw_content.dims) == {"id", "step", "time"}: + return _read_efas_forecast_file(raw_content) raise InvalidEfasFile( - f"The current file has unexpected dimensions: {list(dataset.dims)}" + "The current file has unexpected dimensions: " + f"{list(raw_content.dims)}" ) - if "time" in dataset.coords and len(dataset["time"].dims) == 0: + if "time" in raw_content.coords and len(raw_content["time"].dims) == 0: logger.debug( 'The file has a coordinate "time" that contains only one element' ) - if "step" in dataset.dims: + if "step" in raw_content.dims: logger.debug( 'The second dimension of the file (beside "id") is "step"; ' "This file will be considered an operative grib file" ) - return _read_efas_operative_grib_file(dataset) + return _read_efas_operative_grib_file(raw_content) - if "valid_time" in dataset.coords and "valid_time" in dataset.dims: + if "valid_time" in raw_content.coords and "valid_time" in raw_content.dims: logger.debug( 'The file has a coordinate "valid_time" that has the role of the ' "time; it is a netcdf historical file" ) - return _read_efas_historical_file(dataset) + return _read_efas_historical_file(raw_content) - if "time" in dataset.coords and "valid_time" in dataset.coords: + if "time" in raw_content.coords and "valid_time" in raw_content.coords: logger.debug( 'The file has a coordinate named "time"; usually this means that ' "is a GRIB historical file" ) - return _read_efas_historical_file(dataset) + return _read_efas_historical_file(raw_content) raise InvalidEfasFile( f"The EFAS file has a format that this software does not recognize:\n" - f"{dataset}" + f"{raw_content}" ) @@ -535,174 +717,6 @@ def _read_efas_historical_file(dataset): ) -def _read_unzipped_efas_files( - input_files: Iterable[Path], - config_content: Iterable[RiverConfigElement], - efas_domain_file: PathLike, -) -> xr.Dataset: - """Reads and processes unzipped EFAS (European Flood Awareness System) - files to extract data based on specified river locations, reshaping the - data into a consumable format. - - This function takes EFAS files, configuration content, and the EFAS domain - file as input. - For each file provided, it extracts data based on configured river indices, - validates that data aligns with the specified domain, and reshapes the - output. The resulting datasets are concatenated along the time dimension, - retaining the best forecast data, and returned as an xarray.Dataset. - - Args: - input_files: A collection of file paths indicating the - locations of the unzipped EFAS files to process. - config_content: An iterable that generates the configuration for - each EFAS river, including indices of river locations. - efas_domain_file: A path to the EFAS domain file used for latitude and - longitude referencing. - - Returns: - xr.Dataset: A concatenated dataset containing processed EFAS data, - reshaped and indexed by river and time. - - Raises: - InvalidEfasFile: If a file's domain does not match the specified - river indices, or if there are issues while reading a specific - EFAS file. - """ - logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") - - input_files = tuple(input_files) - if len(input_files) == 0: - raise ValueError( - "No input files were provided. This function needs at least one " - "EFAS file to be processed." - ) - river_configs = tuple( - c for c in config_content if c.data_source.type == "EFAS" - ) - n_rivers = len(river_configs) - - # Get the ids, latitude and longitude of the rivers - river_ids = np.empty((n_rivers,), dtype=int) - river_names = [] - lon_indices_array = np.empty_like(river_ids) - lat_indices_array = np.empty_like(river_ids) - for i, river in enumerate(river_configs): - river_ids[i] = river.id - river_names.append(river.name) - lon_indices_array[i] = river.data_source.longitude_index - lat_indices_array[i] = river.data_source.latitude_index - - # Transform the arrays with the indices of the latitude and the longitude - # into two DataArrays - lon_indices = xr.DataArray( - lon_indices_array, - dims=["id"], - coords={"id": river_ids}, - ) - lat_indices = xr.DataArray( - lat_indices_array, - dims=["id"], - coords={"id": river_ids}, - ) - - # Open the file with the coordinates of the overall domain to understand - # which part of the domain we downloaded - with xr.open_dataset(efas_domain_file) as f: - domain_latitudes = f.latitude.values - domain_longitudes = f.longitude.values - - datasets = [] - for file_path in input_files: - # This file could be a grib or a NetCDF; luckily, xarray supports both - logger.debug('Opening file "%s"', file_path) - # We use Dask here (chunks={}) because it is way more efficient than - # standard xarray when executing the isel method. - # By setting "decode_timedelta=True" we ensure that the values of the - # step variable are decoded as timedelta64 objects (and we also - # silence a warning) - try: - with xr.open_dataset( - file_path, chunks={}, decode_timedelta=True - ) as single_ds: - dataset_latitudes = single_ds.latitude.values - dataset_longitudes = single_ds.longitude.values - - i_lat1, i_lat2 = _find_slice( - domain_latitudes, dataset_latitudes - ) - i_lon1, i_lon2 = _find_slice( - domain_longitudes, dataset_longitudes - ) - - # Check that the position of the rivers is coherent with the - # domain of the file we have downloaded - outside_lat = np.logical_or( - lat_indices < i_lat1, lat_indices >= i_lat2 - ) - outside_lon = np.logical_or( - lon_indices < i_lon1, lon_indices >= i_lon2 - ) - if np.any(outside_lat) or np.any(outside_lon): - river_outside_index = np.nonzero( - (outside_lon | outside_lat).values - )[0][0] - river_name = river_names[river_outside_index] - river_latitude = lat_indices.values[river_outside_index] - river_longitude = lon_indices.values[river_outside_index] - - lat_sorted = np.sort(dataset_latitudes) - lon_sorted = np.sort(dataset_longitudes) - raise InvalidEfasFile( - f'The domain of the file "{file_path}" (latitudes ' - f"from {lat_sorted[0]:.3f} to {lat_sorted[-1]:.3f} " - f"and longitudes from {lon_sorted[0]:.3f} to " - f"{lon_sorted[-1]:.3f}) does not contain the river " - f'"{river_name}", whose mouth has latitude ' - f"{domain_latitudes[river_latitude]:.3f} and " - f"longitude {domain_longitudes[river_longitude]:.3f}." - ) - except ValueError as e: - raise ValueError( - f"ValueError while trying to read file {file_path}" - ) from e - - logger.debug("Retrieving rivers' data from the map") - file_content = single_ds.isel( - longitude=lon_indices - i_lon1, - latitude=lat_indices - i_lat1, - ).load() - - # Remove the surface variable - if "surface" in file_content: - logger.debug('Removing coordinate "surface"') - del file_content["surface"] - - try: - original_dataset = _read_single_efas_file(file_content) - except InvalidEfasFile as exception: - raise InvalidEfasFile( - f'Error while reading EFAS file "{file_path}"' - ) from exception - - logger.debug( - "Adding a new dataset to the collection: %s file have been read", - len(datasets) + 1, - ) - datasets.append(original_dataset) - - logger.debug("Concatenating the datasets...") - concat_dataset = xr.concat(datasets, dim="time") - - logger.debug("Getting the best forecast (if needed)") - _, slicing_indices = _get_best_unique_element( - concat_dataset.time.values, concat_dataset.computation_date.values - ) - - efas_data = concat_dataset.isel(time=slicing_indices) - - return efas_data - - def generate_efas_climatology( annual_files: dict[int, Path], config_content: Iterable[RiverConfigElement], From 9511e955a494e3778bc555eb9db161e377fb9396 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Thu, 5 Jun 2025 14:04:05 +0200 Subject: [PATCH 6/8] Added the possibility to read EFAS files using multiprocessing --- src/ogs_riverger/efas/efas_manager.py | 9 ++++- tests/test_efas.py | 54 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/ogs_riverger/efas/efas_manager.py b/src/ogs_riverger/efas/efas_manager.py index ad92ac4..bb8fa53 100644 --- a/src/ogs_riverger/efas/efas_manager.py +++ b/src/ogs_riverger/efas/efas_manager.py @@ -4,6 +4,7 @@ from contextlib import ExitStack from datetime import datetime from functools import partial +from multiprocessing.pool import Pool from os import PathLike from pathlib import Path from tempfile import TemporaryDirectory @@ -203,6 +204,7 @@ def read_efas_data_files( input_files: Iterable[Path], config_content: Iterable[RiverConfigElement], efas_domain_file: PathLike, + pool: None | Pool = None, ) -> xr.Dataset: """Reads and merges the content of multiple downloaded EFAS data files. @@ -218,6 +220,8 @@ def read_efas_data_files( river we must consider efas_domain_file: The path of a file containing the coordinates of the efas domain (generated by the `generate_efas_domain_file` function) + pool: A multiprocessing pool to use for parallel processing. If None, + the code will run sequentially. Returns: An Xarray dataset containing the EFAS rivers discharge. The @@ -242,7 +246,10 @@ def read_efas_data_files( efas_domain=efas_domain, temp_dir=Path(t), ) - datasets = map(elaborate_file, input_files) + if pool is None: + datasets = map(elaborate_file, input_files) + else: + datasets = pool.map(elaborate_file, input_files) logger.debug("Concatenating the datasets...") concat_dataset = xr.concat(datasets, dim="time") diff --git a/tests/test_efas.py b/tests/test_efas.py index bf53c0e..fd7cd63 100644 --- a/tests/test_efas.py +++ b/tests/test_efas.py @@ -1,7 +1,9 @@ from datetime import datetime from datetime import timedelta +from itertools import pairwise from itertools import product as cart_prod from logging import getLogger +from multiprocessing import Pool from pathlib import Path from types import MappingProxyType from unittest.mock import AsyncMock @@ -332,6 +334,58 @@ def test_read_efas_zipped_data(config_example, efas_domain_file, tmp_path): assert test_dataset.dis06.shape[1] == len(config_example) +def test_read_efas_zipped_data_in_parallel( + config_example, efas_domain_file, tmp_path +): + """ + GIVEN a set of EFAS files, + WHEN the function read_efas_data_files is called with a multiprocess pool, + THEN it returns a dataset with the correct dimensions and values. + """ + tmp_path = tmp_path / "test_read_efas_zipped_data" + tmp_path.mkdir(exist_ok=True) + + efas_domain = xr.load_dataset(efas_domain_file) + + start_date = datetime(2024, 1, 1) + dates = [start_date] + while dates[-1] < datetime(2024, 11, 30): + dates.append( + datetime( + year=dates[-1].year, + month=dates[-1].month + 1, + day=1, + ) + ) + + file_paths = [] + for d1, d2 in pairwise(dates): + file_generator = EfasLikeFileGenerator( + tmp_path / f"t_{d1.month:02}", + d1, + d2, + efas_domain=efas_domain, + ) + output_file = tmp_path / f"efas_{d1.month:02}.zip" + file_generator.create(output_file) + file_paths.append(output_file) + + with Pool(2) as p: + test_dataset = read_efas_data_files( + file_paths, + config_example, + efas_domain_file=efas_domain_file, + pool=p, + ) + output_dims = test_dataset.dis06.dims + + assert output_dims == ("time", "id") + + n_days = (dates[-1] - dates[0]).days + assert test_dataset.dis06.shape[0] == n_days * 4 + assert test_dataset.dis06.shape[1] == len(config_example) + + def test_generate_efas_climatology(config_example, efas_domain_file, tmp_path): """ GIVEN a set of yearly EFAS files, From f76b85d6fc9c8c29adf303226bd95891e2fbb376 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Thu, 5 Jun 2025 18:37:56 +0200 Subject: [PATCH 7/8] Loading small EFAS files into memory to speed up selecting rivers --- src/ogs_riverger/efas/efas_manager.py | 96 +++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 14 deletions(-) diff --git a/src/ogs_riverger/efas/efas_manager.py b/src/ogs_riverger/efas/efas_manager.py index bb8fa53..53f7ed9 100644 --- a/src/ogs_riverger/efas/efas_manager.py +++ b/src/ogs_riverger/efas/efas_manager.py @@ -8,10 +8,12 @@ from os import PathLike from pathlib import Path from tempfile import TemporaryDirectory +from typing import Any from typing import Literal from uuid import uuid1 from zipfile import ZipFile +import dask import numpy as np import xarray as xr @@ -412,15 +414,80 @@ def _read_raw_content_of_unzipped_efas_file( # This file could be a grib or a NetCDF; luckily, xarray supports both logger.debug('Opening file "%s"', file_path) - # We use Dask here (chunks={}) because it is way more efficient than - # standard xarray when executing the isel method. - # By setting "decode_timedelta=True" we ensure that the values of the - # step variable are decoded as timedelta64 objects (and we also - # silence a warning) + open_file_args: dict[str, Any] = {"decode_timedelta": True} + file_suffix = file_path.suffix.lower() + if file_suffix in (".grib", ".grb"): + logging.debug( + "The file %s is a grib file (suffix = %s); using cfgrib engine", + file_path, + file_suffix, + ) + open_file_args["engine"] = "cfgrib" + open_file_args["backend_kwargs"] = {"indexpath": ""} + else: + logging.debug( + "The file %s is a NetCDF file (suffix = %s)", + file_path, + file_suffix, + ) + try: - with xr.open_dataset( - file_path, chunks={}, decode_timedelta=True - ) as single_ds: + with ExitStack() as current_stack: + # File size of the file that we have to open in MegaBytes + file_size = file_path.stat().st_size // 1024 // 1024 + + # We decide how to open the file; EFAS files can be of several GB + # in size, but we only need a small portion of them (the + # rivers' data). + # Unfortunately, the data that we need is scattered around the + # file and, therefore, reading the data may be computationally + # expensive. We use the following strategy: if the file is smaller + # than 1024 MB, we simply load it into memory, and that's it. + # Otherwise, we open it, and we read its values from the file + # (which is more I/O expensive). If the file is a NetCDF file, it + # is important to specify that chunks = {} to split the result + # into chunks, and this is way more efficient than the standard + # Xarray implementation when executing the isel method. + # Finally, in any case, we decode the time variable as timedelta64 + # objects by using the flag decode_timedelta=True. + file_size_limit = 1024 # Mb + if file_size >= file_size_limit: + logger.debug( + "File size is %s MB; it will be opened using Xarray " + "open_dataset method", + file_size, + ) + + if open_file_args.get("engine", "netcdf") != "cfgrib": + open_file_args["chunks"] = {} + + logger.debug( + "Opening file %s using the following arguments: %s", + file_path, + open_file_args, + ) + single_ds = current_stack.enter_context( + xr.open_dataset( + file_path, + **open_file_args, + ) + ) + else: + logger.debug( + "File size is just %s MB; it will be opened using " + "Xarray load_dataset method", + file_size, + ) + logger.debug( + "Opening file %s using the following arguments: %s", + file_path, + open_file_args, + ) + single_ds = xr.load_dataset( + file_path, + **open_file_args, + ) + dataset_latitudes = single_ds.latitude.values dataset_longitudes = single_ds.longitude.values @@ -494,12 +561,13 @@ def _read_single_efas_file( """ logger = logging.getLogger(f"{__name__}.{inspect.stack()[0][3]}") - raw_content = _read_raw_content_of_efas_file( - file_path=file_path, - config_content=config_content, - efas_domain=efas_domain, - temp_dir=temp_dir, - ) + with dask.config.set(scheduler="synchronous"): + raw_content = _read_raw_content_of_efas_file( + file_path=file_path, + config_content=config_content, + efas_domain=efas_domain, + temp_dir=temp_dir, + ) # I don't know what the "surface" coordinate is, but it is unnecessary if "surface" in raw_content: From 8d51b9a81e440c801309d659e89e8c8c3b4f10c8 Mon Sep 17 00:00:00 2001 From: Stefano Piani Date: Thu, 5 Jun 2025 23:42:21 +0200 Subject: [PATCH 8/8] Added py.typed --- src/ogs_riverger/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/ogs_riverger/py.typed diff --git a/src/ogs_riverger/py.typed b/src/ogs_riverger/py.typed new file mode 100644 index 0000000..e69de29