diff --git a/src/torchjd/_utils/__init__.py b/src/torchjd/_utils/__init__.py new file mode 100644 index 00000000..6eaf806c --- /dev/null +++ b/src/torchjd/_utils/__init__.py @@ -0,0 +1,3 @@ +from .compute_gramian import Matrix, PSDMatrix, compute_gramian, compute_gramian_sum + +__all__ = ["compute_gramian", "compute_gramian_sum", "Matrix", "PSDMatrix"] diff --git a/src/torchjd/_utils/compute_gramian.py b/src/torchjd/_utils/compute_gramian.py new file mode 100644 index 00000000..1d3fdd59 --- /dev/null +++ b/src/torchjd/_utils/compute_gramian.py @@ -0,0 +1,25 @@ +from typing import Annotated, cast + +import torch +from torch import Tensor + +Matrix = Annotated[Tensor, "ndim=2"] +PSDMatrix = Annotated[Matrix, "Positive semi-definite"] + + +def compute_gramian(generalized_matrix: Tensor) -> PSDMatrix: + """ + Computes the `Gramian matrix `_ of a given + generalized matrix. Specifically, this is equivalent to + + matrix = generalized_matrix.reshape([generalized_matrix.shape[0], -1]) + return matrix @ matrix.T + """ + dims = list(range(1, generalized_matrix.ndim)) + gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims)) + return cast(PSDMatrix, gramian) + + +def compute_gramian_sum(generalized_matrices: list[Tensor]) -> PSDMatrix: + gramian = sum([compute_gramian(matrix) for matrix in generalized_matrices]) + return cast(PSDMatrix, gramian) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index c15e4502..95b37b70 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -2,8 +2,9 @@ from torch import Tensor, nn -from ._utils.gramian import compute_gramian -from ._weighting_bases import Matrix, PSDMatrix, Weighting +from torchjd._utils import Matrix, PSDMatrix, compute_gramian + +from ._weighting_bases import Weighting class Aggregator(nn.Module, ABC): @@ -79,3 +80,4 @@ class GramianWeightedAggregator(WeightedAggregator): def __init__(self, weighting: Weighting[PSDMatrix]): super().__init__(weighting << compute_gramian) + self.psd_weighting = weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 3f0b7119..eb74295e 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,10 +28,12 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class AlignedMTL(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 77c76a9c..ed114e36 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,7 +1,9 @@ from typing import cast +from torchjd._utils.compute_gramian import PSDMatrix + from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "clarabel"]) diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 35345abb..00d3d97e 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,8 +1,10 @@ from torch import Tensor +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Constant(WeightedAggregator): diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 68a48776..742a32cc 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,13 +2,15 @@ from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class DualProj(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 74da4a97..08559ca3 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -2,7 +2,8 @@ from torch import Tensor -from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting +from torchjd._utils.compute_gramian import PSDMatrix +from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting from torchjd.autogram._gramian_utils import reshape_gramian diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 52d86526..9d2c05fa 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,9 +1,11 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class IMTLG(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index cf14e577..dc119073 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,8 +2,10 @@ from torch import Tensor from torch.nn import functional as F +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class Krum(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 36600ace..50919909 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Mean(WeightedAggregator): diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index a5d7b2af..9b8c7e82 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class MGDA(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 9d1393a5..c02b9c54 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -25,8 +25,10 @@ # mypy: ignore-errors +from torchjd._utils.compute_gramian import Matrix + from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index d85a8748..d9338ad0 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,9 +1,11 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class PCGrad(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index ba89a82c..de3fc235 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,8 +2,10 @@ from torch import Tensor from torch.nn import functional as F +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Random(WeightedAggregator): diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 868257fd..0efe354f 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Sum(WeightedAggregator): diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 750cc735..aaa3621d 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,13 +3,15 @@ import torch from torch import Tensor +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class UPGrad(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_utils/gramian.py b/src/torchjd/aggregation/_utils/gramian.py index dfe16987..58fc11fa 100644 --- a/src/torchjd/aggregation/_utils/gramian.py +++ b/src/torchjd/aggregation/_utils/gramian.py @@ -2,14 +2,6 @@ from torch import Tensor -def compute_gramian(matrix: Tensor) -> Tensor: - """ - Computes the `Gramian matrix `_ of a given matrix. - """ - - return matrix @ matrix.T - - def normalize(gramian: Tensor, eps: float) -> Tensor: """ Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index 1f3efef0..a121e0f5 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,7 +1,8 @@ from torch import Tensor +from torchjd._utils.compute_gramian import Matrix from torchjd.aggregation._constant import ConstantWeighting -from torchjd.aggregation._weighting_bases import Matrix, Weighting +from torchjd.aggregation._weighting_bases import Weighting from .str import vector_to_str diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 154c7a30..a78b4c55 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -2,15 +2,13 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Annotated, Generic, TypeVar +from typing import Generic, TypeVar from torch import Tensor, nn _T = TypeVar("_T", contravariant=True) _FnInputT = TypeVar("_FnInputT") _FnOutputT = TypeVar("_FnOutputT") -Matrix = Annotated[Tensor, "ndim=2"] -PSDMatrix = Annotated[Matrix, "Positive semi-definite"] class Weighting(Generic[_T], nn.Module, ABC): diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f21..92ba933d 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.utils._pytree import PyTree +from torchjd._utils import compute_gramian from torchjd.autogram._jacobian_computer import JacobianComputer @@ -29,10 +30,6 @@ class JacobianBasedGramianComputer(GramianComputer, ABC): def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer - @staticmethod - def _to_gramian(jacobian: Tensor) -> Tensor: - return jacobian @ jacobian.T - class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): """ @@ -71,7 +68,7 @@ def __call__( self.remaining_counter -= 1 if self.remaining_counter == 0: - gramian = self._to_gramian(self.summed_jacobian) + gramian = compute_gramian(self.summed_jacobian) del self.summed_jacobian return gramian else: diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8c85025f..105f6e0f 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -2,9 +2,12 @@ from typing import cast import torch -from torch import Tensor +from torch import Tensor, tensordot +from torchjd._utils import PSDMatrix, compute_gramian_sum from torchjd.aggregation import Aggregator +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._weighting_bases import Weighting from ._accumulation import TensorWithJac, accumulate_grads @@ -73,10 +76,30 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) + if isinstance(aggregator, GramianWeightedAggregator): + gradients = _gramian_based(aggregator.psd_weighting, jacobians, tensors_) + else: + gradients = _jacobian_based(aggregator, jacobians, tensors_) + + accumulate_grads(tensors_, gradients) + + +def _gramian_based( + weighting: Weighting[PSDMatrix], jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: + gramian = compute_gramian_sum(jacobians) + weights = weighting(gramian) + gradients = [tensordot(weights, jacobian, dims=1) for jacobian in jacobians] + return gradients + + +def _jacobian_based( + aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) - accumulate_grads(tensors_, gradients) + gradients = _disunite_gradient(gradient_vector, jacobians, tensors) + return gradients def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 41b07d93..a63e1724 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -3,9 +3,9 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._utils.compute_gramian import compute_gramian from torchjd.aggregation import MGDA from torchjd.aggregation._mgda import MGDAWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import ( assert_expected_structure, diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index e96253ec..e4a14ed3 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,10 +3,10 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._utils.compute_gramian import compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices