From 56019f9b814490724b6c17a2e4f865c62c689220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 1 Feb 2026 03:55:32 +0100 Subject: [PATCH 1/2] typing(aggregation): Add SUPPORTED_SOLVER type alias --- src/torchjd/aggregation/_dualproj.py | 8 +++----- src/torchjd/aggregation/_upgrad.py | 8 +++----- src/torchjd/aggregation/_utils/dual_cone.py | 8 +++++--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 202f204a..a5ba7bc0 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,10 @@ -from typing import Literal - from torch import Tensor from torchjd._linalg import PSDMatrix, normalize, regularize from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import project_weights +from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -34,7 +32,7 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): self._pref_vector = pref_vector self._norm_eps = norm_eps @@ -78,7 +76,7 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): super().__init__() self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c7efb367..f95513ad 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,5 +1,3 @@ -from typing import Literal - import torch from torch import Tensor @@ -7,7 +5,7 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import project_weights +from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -35,7 +33,7 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): self._pref_vector = pref_vector self._norm_eps = norm_eps @@ -79,7 +77,7 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: Literal["quadprog"] = "quadprog", + solver: SUPPORTED_SOLVER = "quadprog", ): super().__init__() self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py index 539685be..b076366b 100644 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ b/src/torchjd/aggregation/_utils/dual_cone.py @@ -1,12 +1,14 @@ -from typing import Literal +from typing import Literal, TypeAlias import numpy as np import torch from qpsolvers import solve_qp from torch import Tensor +SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] -def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: + +def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: """ Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the rows of a matrix whose Gramian is provided. @@ -25,7 +27,7 @@ def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor return torch.as_tensor(W, device=G.device, dtype=G.dtype) -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: +def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies From 82e2c2d8117aa480819523cb2925993286492264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 1 Feb 2026 04:06:29 +0100 Subject: [PATCH 2/2] Fix Literal automatic type widening When an attribute is a Literal["something", "something else"], it's important to type both the parameter value of the constructor and the attribute itself in the constructor. Before, we only typed the parameter value. When doing that, the attribute is considered as a string that is initialized at "something" or "something else", but that may change in the future to "even something else". So it's just considered as a string in other methods, not as a Literal["something", "something else"]. This only happens with Literals of strings, so we don't need to do that for every attribute. This fixes 3 errors reported by Pylance. Also create TypeAlias SUPPORTED_SCALE_MODE (forgot to separate commits) --- src/torchjd/aggregation/_aligned_mtl.py | 14 ++++++++------ src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_upgrad.py | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index f0e62860..eadef9ab 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -25,7 +25,7 @@ # SOFTWARE. -from typing import Literal +from typing import Literal, TypeAlias import torch from torch import Tensor @@ -37,6 +37,8 @@ from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting +SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] + class AlignedMTL(GramianWeightedAggregator): r""" @@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - scale_mode: Literal["min", "median", "rmse"] = "min", + scale_mode: SUPPORTED_SCALE_MODE = "min", ): self._pref_vector = pref_vector - self._scale_mode = scale_mode + self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) def __repr__(self) -> str: @@ -89,11 +91,11 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): def __init__( self, pref_vector: Tensor | None = None, - scale_mode: Literal["min", "median", "rmse"] = "min", + scale_mode: SUPPORTED_SCALE_MODE = "min", ): super().__init__() self._pref_vector = pref_vector - self._scale_mode = scale_mode + self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) def forward(self, gramian: PSDMatrix) -> Tensor: @@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor: @staticmethod def _compute_balance_transformation( - M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min" + M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min" ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index a5ba7bc0..d91e32aa 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -37,7 +37,7 @@ def __init__( self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps - self._solver = solver + self._solver: SUPPORTED_SOLVER = solver super().__init__( DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) @@ -83,7 +83,7 @@ def __init__( self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver = solver + self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix) -> Tensor: u = self.weighting(gramian) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index f95513ad..132b72e6 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -38,7 +38,7 @@ def __init__( self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps - self._solver = solver + self._solver: SUPPORTED_SOLVER = solver super().__init__( UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) @@ -84,7 +84,7 @@ def __init__( self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver = solver + self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix) -> Tensor: U = torch.diag(self.weighting(gramian))