Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .gramian import compute_gramian
from .matrix import Matrix, PSDMatrix

__all__ = ["compute_gramian", "Matrix", "PSDMatrix"]
9 changes: 9 additions & 0 deletions src/torchjd/_linalg/gramian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .matrix import Matrix, PSDMatrix


def compute_gramian(matrix: Matrix) -> PSDMatrix:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T
6 changes: 6 additions & 0 deletions src/torchjd/_linalg/matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Annotated

from torch import Tensor

Matrix = Annotated[Tensor, "ndim=2"]
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from torch import Tensor, nn

from ._utils.gramian import compute_gramian
from ._weighting_bases import Matrix, PSDMatrix, Weighting
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian

from ._weighting_bases import Weighting


class Aggregator(nn.Module, ABC):
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class AlignedMTL(GramianWeightedAggregator):
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self, pref_vector: Tensor | None = None):
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian)
alpha = B @ w
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import cast

from torchjd._linalg import PSDMatrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "clarabel"])

Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001):
self.c = c
self.norm_eps = norm_eps

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))

reduced_matrix = U @ S.sqrt().diag()
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Constant(WeightedAggregator):
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class DualProj(GramianWeightedAggregator):
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(
self.reg_eps = reg_eps
self.solver = solver

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from torch import Tensor

from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
from torchjd._linalg.matrix import PSDMatrix
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
from torchjd.autogram._gramian_utils import reshape_gramian


Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class IMTLG(GramianWeightedAggregator):
Expand All @@ -27,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]):
:class:`~torchjd.aggregation.IMTLG`.
"""

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
d = torch.sqrt(torch.diagonal(gramian))
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()
Expand Down
8 changes: 5 additions & 3 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class Krum(GramianWeightedAggregator):
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
self.n_byzantine = n_byzantine
self.n_selected = n_selected

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
self._check_matrix_shape(gramian)
gradient_norms_squared = torch.diagonal(gramian)
distances_squared = (
Expand All @@ -78,7 +80,7 @@ def forward(self, gramian: Tensor) -> Tensor:

return weights

def _check_matrix_shape(self, gramian: Tensor) -> None:
def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
min_rows = self.n_byzantine + 3
if gramian.shape[0] < min_rows:
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Mean(WeightedAggregator):
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class MGDA(GramianWeightedAggregator):
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
self.epsilon = epsilon
self.max_iters = max_iters

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
"""
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
Optimization
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@

# mypy: ignore-errors

from torchjd._linalg import Matrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "ecos"])

Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class PCGrad(GramianWeightedAggregator):
Expand All @@ -25,7 +27,7 @@ class PCGradWeighting(Weighting[PSDMatrix]):
:class:`~torchjd.aggregation.PCGrad`.
"""

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
dtype = gramian.dtype
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Random(WeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Sum(WeightedAggregator):
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class UPGrad(GramianWeightedAggregator):
Expand Down Expand Up @@ -87,7 +89,7 @@ def __init__(
self.reg_eps = reg_eps
self.solver = solver

def forward(self, gramian: Tensor) -> Tensor:
def forward(self, gramian: PSDMatrix) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
Expand Down
13 changes: 3 additions & 10 deletions src/torchjd/aggregation/_utils/gramian.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import torch
from torch import Tensor


def compute_gramian(matrix: Tensor) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T
from torchjd._linalg.matrix import PSDMatrix


def normalize(gramian: Tensor, eps: float) -> Tensor:
def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.

Expand All @@ -25,7 +18,7 @@ def normalize(gramian: Tensor, eps: float) -> Tensor:
return gramian / squared_frobenius_norm


def regularize(gramian: Tensor, eps: float) -> Tensor:
def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""
Adds a regularization term to the gramian to enforce positive definiteness.

Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor

from torchjd._linalg.matrix import Matrix
from torchjd.aggregation._constant import ConstantWeighting
from torchjd.aggregation._weighting_bases import Matrix, Weighting
from torchjd.aggregation._weighting_bases import Weighting

from .str import vector_to_str

Expand Down
4 changes: 1 addition & 3 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Annotated, Generic, TypeVar
from typing import Generic, TypeVar

from torch import Tensor, nn

_T = TypeVar("_T", contravariant=True)
_FnInputT = TypeVar("_FnInputT")
_FnOutputT = TypeVar("_FnOutputT")
Matrix = Annotated[Tensor, "ndim=2"]
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]


class Weighting(Generic[_T], nn.Module, ABC):
Expand Down
7 changes: 5 additions & 2 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch import Tensor, nn, vmap
from torch.autograd.graph import get_gradient_edge

from torchjd._linalg.matrix import PSDMatrix

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
Expand Down Expand Up @@ -232,6 +234,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None:
f"`batch_dim=None` when creating the engine."
)

# Currently, the type PSDMatrix is hidden from users, so Tensor is correct.
def compute_gramian(self, output: Tensor) -> Tensor:
r"""
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
Expand Down Expand Up @@ -305,7 +308,7 @@ def compute_gramian(self, output: Tensor) -> Tensor:

return gramian

def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor:
def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> PSDMatrix:
leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))

def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
Expand All @@ -330,6 +333,6 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:

# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
# have failed. So gramian is necessarily a valid Tensor here.
gramian = cast(Tensor, self._gramian_accumulator.gramian)
gramian = cast(PSDMatrix, self._gramian_accumulator.gramian)

return gramian
Loading
Loading