diff --git a/CHANGELOG.md b/CHANGELOG.md index a50687cb23..ae9b568374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ Code freeze date: YYYY-MM-DD ### Added +- Better type hints and overloads signatures for ImpactFuncSet [#1250](https://github.com/CLIMADA-project/climada_python/pull/1250) + ### Changed - Updated Impact Calculation Tutorial (`doc.climada_engine_Impact.ipynb`) [#1095](https://github.com/CLIMADA-project/climada_python/pull/1095). diff --git a/climada/entity/impact_funcs/base.py b/climada/entity/impact_funcs/base.py index c51540d573..3642eed7a8 100644 --- a/climada/entity/impact_funcs/base.py +++ b/climada/entity/impact_funcs/base.py @@ -189,7 +189,7 @@ def from_step_impf( haz_type: str, mdd: tuple[float, float] = (0, 1), paa: tuple[float, float] = (1, 1), - impf_id: int = 1, + impf_id: int | str = 1, **kwargs, ): """Step function type impact function. @@ -207,7 +207,7 @@ def from_step_impf( (min, max) mdd values. The default is (0, 1) paa: tuple(float, float) (min, max) paa values. The default is (1, 1) - impf_id : int, optional, default=1 + impf_id : int|str, optional, default=1 impact function id kwargs : keyword arguments passed to ImpactFunc() @@ -250,7 +250,7 @@ def from_sigmoid_impf( k: float, x0: float, haz_type: str, - impf_id: int = 1, + impf_id: int | str = 1, **kwargs, ): r"""Sigmoid type impact function hinging on three parameter. @@ -320,7 +320,7 @@ def from_poly_s_shape( scale: float, exponent: float, haz_type: str, - impf_id: int = 1, + impf_id: int | str = 1, **kwargs, ): r"""S-shape polynomial impact function hinging on four parameter. diff --git a/climada/entity/impact_funcs/impact_func_set.py b/climada/entity/impact_funcs/impact_func_set.py index 030f73f2be..b6f4cf73d7 100755 --- a/climada/entity/impact_funcs/impact_func_set.py +++ b/climada/entity/impact_funcs/impact_func_set.py @@ -24,7 +24,7 @@ import copy import logging from itertools import repeat -from typing import Iterable, Optional +from typing import Iterable, Optional, Union, overload import matplotlib.pyplot as plt import numpy as np @@ -119,7 +119,7 @@ def clear(self): """Reinitialize attributes.""" self._data = dict() # {hazard_type : {id:ImpactFunc}} - def append(self, func): + def append(self, func: ImpactFunc): """Append a ImpactFunc. Overwrite existing if same id and haz_type. Parameters @@ -141,7 +141,9 @@ def append(self, func): self._data[func.haz_type] = dict() self._data[func.haz_type][func.id] = func - def remove_func(self, haz_type=None, fun_id=None): + def remove_func( + self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None + ): """Remove impact function(s) with provided hazard type and/or id. If no input provided, all impact functions are removed. @@ -173,7 +175,29 @@ def remove_func(self, haz_type=None, fun_id=None): else: self._data = dict() - def get_func(self, haz_type=None, fun_id=None): + @overload + def get_func( + self, haz_type: None = None, fun_id: None = None + ) -> dict[str, dict[Union[int, str], ImpactFunc]]: ... + + @overload + def get_func( + self, haz_type: None = ..., fun_id: int | str = ... + ) -> list[ImpactFunc]: ... + + @overload + def get_func( + self, haz_type: str = ..., fun_id: None = None + ) -> list[ImpactFunc]: ... + + @overload + def get_func(self, haz_type: str = ..., fun_id: int | str = ...) -> ImpactFunc: ... + + def get_func( + self, haz_type: Optional[str] = None, fun_id: Optional[int | str] = None + ) -> Union[ + ImpactFunc, list[ImpactFunc], dict[str, dict[Union[int, str], ImpactFunc]] + ]: """Get ImpactFunc(s) of input hazard type and/or id. If no input provided, all impact functions are returned. @@ -209,7 +233,7 @@ def get_func(self, haz_type=None, fun_id=None): else: return self._data - def get_hazard_types(self, fun_id=None): + def get_hazard_types(self, fun_id: Optional[str | int] = None) -> list[str]: """Get impact functions hazard types contained for the id provided. Return all hazard types if no input id. @@ -231,7 +255,15 @@ def get_hazard_types(self, fun_id=None): haz_types.append(vul_haz) return haz_types - def get_ids(self, haz_type=None): + @overload + def get_ids(self, haz_type: None = None) -> dict[str, list[str | int]]: ... + + @overload + def get_ids(self, haz_type: str) -> list[int | str]: ... + + def get_ids( + self, haz_type: Optional[str] = None + ) -> dict[str, list[str | int]] | list[int | str]: """Get impact functions ids contained for the hazard type provided. Return all ids for each hazard type if no input hazard type. @@ -256,7 +288,9 @@ def get_ids(self, haz_type=None): except KeyError: return list() - def size(self, haz_type=None, fun_id=None): + def size( + self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None + ) -> int: """Get number of impact functions contained with input hazard type and /or id. If no input provided, get total number of impact functions. @@ -279,6 +313,7 @@ def size(self, haz_type=None, fun_id=None): return 1 if (haz_type is not None) or (fun_id is not None): return len(self.get_func(haz_type, fun_id)) + return sum(len(vul_list) for vul_list in self.get_ids().values()) def check(self): @@ -300,7 +335,7 @@ def check(self): ) vul.check() - def extend(self, impact_funcs): + def extend(self, impact_funcs: "ImpactFuncSet"): """Append impact functions of input ImpactFuncSet to current ImpactFuncSet. Overwrite ImpactFunc if same id and haz_type. @@ -323,7 +358,13 @@ def extend(self, impact_funcs): for _, vul in vul_dict.items(): self.append(vul) - def plot(self, haz_type=None, fun_id=None, axis=None, **kwargs): + def plot( + self, + haz_type: Optional[str] = None, + fun_id: Optional[str | int] = None, + axis=None, + **kwargs, + ): """Plot impact functions of selected hazard (all if not provided) and selected function id (all if not provided).