diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index f0e62860..eadef9ab 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -25,7 +25,7 @@ # SOFTWARE. -from typing import Literal +from typing import Literal, TypeAlias import torch from torch import Tensor @@ -37,6 +37,8 @@ from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting +SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] + class AlignedMTL(GramianWeightedAggregator): r""" @@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - scale_mode: Literal["min", "median", "rmse"] = "min", + scale_mode: SUPPORTED_SCALE_MODE = "min", ): self._pref_vector = pref_vector - self._scale_mode = scale_mode + self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) def __repr__(self) -> str: @@ -89,11 +91,11 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): def __init__( self, pref_vector: Tensor | None = None, - scale_mode: Literal["min", "median", "rmse"] = "min", + scale_mode: SUPPORTED_SCALE_MODE = "min", ): super().__init__() self._pref_vector = pref_vector - self._scale_mode = scale_mode + self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) def forward(self, gramian: PSDMatrix) -> Tensor: @@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor: @staticmethod def _compute_balance_transformation( - M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min" + M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min" ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 202f204a..d91e32aa 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,10 @@ -from typing import Literal - from torch import Tensor from torchjd._linalg import PSDMatrix, normalize, regularize from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import project_weights +from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -34,12 +32,12 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps - self._solver = solver + self._solver: SUPPORTED_SOLVER = solver super().__init__( DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) @@ -78,14 +76,14 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver = solver + self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix) -> Tensor: u = self.weighting(gramian) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c7efb367..132b72e6 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,5 +1,3 @@ -from typing import Literal - import torch from torch import Tensor @@ -7,7 +5,7 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import project_weights +from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -35,12 +33,12 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps - self._solver = solver + self._solver: SUPPORTED_SOLVER = solver super().__init__( UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) @@ -79,14 +77,14 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver = solver + self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix) -> Tensor: U = torch.diag(self.weighting(gramian)) diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py index 539685be..b076366b 100644 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ b/src/torchjd/aggregation/_utils/dual_cone.py @@ -1,12 +1,14 @@ -from typing import Literal +from typing import Literal, TypeAlias import numpy as np import torch from qpsolvers import solve_qp from torch import Tensor +SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] -def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: + +def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: """ Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the rows of a matrix whose Gramian is provided. @@ -25,7 +27,7 @@ def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor return torch.as_tensor(W, device=G.device, dtype=G.dtype) -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: +def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies