Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# SOFTWARE.


from typing import Literal
from typing import Literal, TypeAlias

import torch
from torch import Tensor
Expand All @@ -37,6 +37,8 @@
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting

SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"]


class AlignedMTL(GramianWeightedAggregator):
r"""
Expand All @@ -58,10 +60,10 @@ class AlignedMTL(GramianWeightedAggregator):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
scale_mode: SUPPORTED_SCALE_MODE = "min",
):
self._pref_vector = pref_vector
self._scale_mode = scale_mode
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

def __repr__(self) -> str:
Expand Down Expand Up @@ -89,11 +91,11 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
scale_mode: SUPPORTED_SCALE_MODE = "min",
):
super().__init__()
self._pref_vector = pref_vector
self._scale_mode = scale_mode
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

def forward(self, gramian: PSDMatrix) -> Tensor:
Expand All @@ -105,7 +107,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:

@staticmethod
def _compute_balance_transformation(
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
Expand Down
12 changes: 5 additions & 7 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Literal

from torch import Tensor

from torchjd._linalg import PSDMatrix, normalize, regularize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand Down Expand Up @@ -34,12 +32,12 @@ def __init__(
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
solver: SUPPORTED_SOLVER = "quadprog",
):
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver = solver
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
Expand Down Expand Up @@ -78,14 +76,14 @@ def __init__(
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
solver: SUPPORTED_SOLVER = "quadprog",
):
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver = solver
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix) -> Tensor:
u = self.weighting(gramian)
Expand Down
12 changes: 5 additions & 7 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Literal

import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix, normalize, regularize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand Down Expand Up @@ -35,12 +33,12 @@ def __init__(
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
solver: SUPPORTED_SOLVER = "quadprog",
):
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver = solver
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
Expand Down Expand Up @@ -79,14 +77,14 @@ def __init__(
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: Literal["quadprog"] = "quadprog",
solver: SUPPORTED_SOLVER = "quadprog",
):
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver = solver
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix) -> Tensor:
U = torch.diag(self.weighting(gramian))
Expand Down
8 changes: 5 additions & 3 deletions src/torchjd/aggregation/_utils/dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Literal
from typing import Literal, TypeAlias

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]

def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:

def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
"""
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
rows of a matrix whose Gramian is provided.
Expand All @@ -25,7 +27,7 @@ def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor
return torch.as_tensor(W, device=G.device, dtype=G.dtype)


def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray:
r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
Expand Down