Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/torchjd/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .compute_gramian import Matrix, PSDMatrix, compute_gramian, compute_gramian_sum

__all__ = ["compute_gramian", "compute_gramian_sum", "Matrix", "PSDMatrix"]
25 changes: 25 additions & 0 deletions src/torchjd/_utils/compute_gramian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Annotated, cast

import torch
from torch import Tensor

Matrix = Annotated[Tensor, "ndim=2"]
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]


def compute_gramian(generalized_matrix: Tensor) -> PSDMatrix:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given
generalized matrix. Specifically, this is equivalent to

matrix = generalized_matrix.reshape([generalized_matrix.shape[0], -1])
return matrix @ matrix.T
"""
dims = list(range(1, generalized_matrix.ndim))
gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims))
return cast(PSDMatrix, gramian)


def compute_gramian_sum(generalized_matrices: list[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in generalized_matrices])
return cast(PSDMatrix, gramian)
6 changes: 4 additions & 2 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from torch import Tensor, nn

from ._utils.gramian import compute_gramian
from ._weighting_bases import Matrix, PSDMatrix, Weighting
from torchjd._utils import Matrix, PSDMatrix, compute_gramian

from ._weighting_bases import Weighting


class Aggregator(nn.Module, ABC):
Expand Down Expand Up @@ -79,3 +80,4 @@ class GramianWeightedAggregator(WeightedAggregator):

def __init__(self, weighting: Weighting[PSDMatrix]):
super().__init__(weighting << compute_gramian)
self.psd_weighting = weighting
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class AlignedMTL(GramianWeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import cast

from torchjd._utils.compute_gramian import PSDMatrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "clarabel"])

Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from torch import Tensor

from torchjd._utils.compute_gramian import Matrix

from ._aggregator_bases import WeightedAggregator
from ._utils.str import vector_to_str
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Constant(WeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class DualProj(GramianWeightedAggregator):
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from torch import Tensor

from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting
from torchjd._utils.compute_gramian import PSDMatrix
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
from torchjd.autogram._gramian_utils import reshape_gramian


Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class IMTLG(GramianWeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class Krum(GramianWeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Mean(WeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class MGDA(GramianWeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@

# mypy: ignore-errors

from torchjd._utils.compute_gramian import Matrix

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "ecos"])

Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class PCGrad(GramianWeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from torch import Tensor
from torch.nn import functional as F

from torchjd._utils.compute_gramian import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Random(WeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import Matrix

from ._aggregator_bases import WeightedAggregator
from ._weighting_bases import Matrix, Weighting
from ._weighting_bases import Weighting


class Sum(WeightedAggregator):
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import torch
from torch import Tensor

from torchjd._utils.compute_gramian import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import PSDMatrix, Weighting
from ._weighting_bases import Weighting


class UPGrad(GramianWeightedAggregator):
Expand Down
8 changes: 0 additions & 8 deletions src/torchjd/aggregation/_utils/gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@
from torch import Tensor


def compute_gramian(matrix: Tensor) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T


def normalize(gramian: Tensor, eps: float) -> Tensor:
"""
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor

from torchjd._utils.compute_gramian import Matrix
from torchjd.aggregation._constant import ConstantWeighting
from torchjd.aggregation._weighting_bases import Matrix, Weighting
from torchjd.aggregation._weighting_bases import Weighting

from .str import vector_to_str

Expand Down
4 changes: 1 addition & 3 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Annotated, Generic, TypeVar
from typing import Generic, TypeVar

from torch import Tensor, nn

_T = TypeVar("_T", contravariant=True)
_FnInputT = TypeVar("_FnInputT")
_FnOutputT = TypeVar("_FnOutputT")
Matrix = Annotated[Tensor, "ndim=2"]
PSDMatrix = Annotated[Matrix, "Positive semi-definite"]


class Weighting(Generic[_T], nn.Module, ABC):
Expand Down
7 changes: 2 additions & 5 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor
from torch.utils._pytree import PyTree

from torchjd._utils import compute_gramian
from torchjd.autogram._jacobian_computer import JacobianComputer


Expand All @@ -29,10 +30,6 @@ class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer):
self.jacobian_computer = jacobian_computer

@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T


class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
"""
Expand Down Expand Up @@ -71,7 +68,7 @@ def __call__(
self.remaining_counter -= 1

if self.remaining_counter == 0:
gramian = self._to_gramian(self.summed_jacobian)
gramian = compute_gramian(self.summed_jacobian)
del self.summed_jacobian
return gramian
else:
Expand Down
29 changes: 26 additions & 3 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from typing import cast

import torch
from torch import Tensor
from torch import Tensor, tensordot

from torchjd._utils import PSDMatrix, compute_gramian_sum
from torchjd.aggregation import Aggregator
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator
from torchjd.aggregation._weighting_bases import Weighting

from ._accumulation import TensorWithJac, accumulate_grads

Expand Down Expand Up @@ -73,10 +76,30 @@ def jac_to_grad(
if not retain_jac:
_free_jacs(tensors_)

if isinstance(aggregator, GramianWeightedAggregator):
gradients = _gramian_based(aggregator.psd_weighting, jacobians, tensors_)
else:
gradients = _jacobian_based(aggregator, jacobians, tensors_)

accumulate_grads(tensors_, gradients)


def _gramian_based(
weighting: Weighting[PSDMatrix], jacobians: list[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
gramian = compute_gramian_sum(jacobians)
weights = weighting(gramian)
gradients = [tensordot(weights, jacobian, dims=1) for jacobian in jacobians]
return gradients


def _jacobian_based(
aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
accumulate_grads(tensors_, gradients)
gradients = _disunite_gradient(gradient_vector, jacobians, tensors)
return gradients


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torch.testing import assert_close
from utils.tensors import ones_, randn_

from torchjd._utils.compute_gramian import compute_gramian
from torchjd.aggregation import MGDA
from torchjd.aggregation._mgda import MGDAWeighting
from torchjd.aggregation._utils.gramian import compute_gramian

from ._asserts import (
assert_expected_structure,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from torch.testing import assert_close
from utils.tensors import ones_, randn_

from torchjd._utils.compute_gramian import compute_gramian
from torchjd.aggregation import PCGrad
from torchjd.aggregation._pcgrad import PCGradWeighting
from torchjd.aggregation._upgrad import UPGradWeighting
from torchjd.aggregation._utils.gramian import compute_gramian

from ._asserts import assert_expected_structure, assert_non_differentiable
from ._inputs import scaled_matrices, typical_matrices
Expand Down
Loading