diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py new file mode 100644 index 00000000..0f49f7f9 --- /dev/null +++ b/src/torchjd/_linalg/__init__.py @@ -0,0 +1,4 @@ +from .gramian import compute_gramian +from .matrix import Matrix, PSDMatrix + +__all__ = ["compute_gramian", "Matrix", "PSDMatrix"] diff --git a/src/torchjd/_linalg/gramian.py b/src/torchjd/_linalg/gramian.py new file mode 100644 index 00000000..c54273ed --- /dev/null +++ b/src/torchjd/_linalg/gramian.py @@ -0,0 +1,9 @@ +from .matrix import Matrix, PSDMatrix + + +def compute_gramian(matrix: Matrix) -> PSDMatrix: + """ + Computes the `Gramian matrix `_ of a given matrix. + """ + + return matrix @ matrix.T diff --git a/src/torchjd/_linalg/matrix.py b/src/torchjd/_linalg/matrix.py new file mode 100644 index 00000000..2f211a6d --- /dev/null +++ b/src/torchjd/_linalg/matrix.py @@ -0,0 +1,6 @@ +from typing import Annotated + +from torch import Tensor + +Matrix = Annotated[Tensor, "ndim=2"] +PSDMatrix = Annotated[Matrix, "Positive semi-definite"] diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index c15e4502..6f716444 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -2,8 +2,9 @@ from torch import Tensor, nn -from ._utils.gramian import compute_gramian -from ._weighting_bases import Matrix, PSDMatrix, Weighting +from torchjd._linalg import Matrix, PSDMatrix, compute_gramian + +from ._weighting_bases import Weighting class Aggregator(nn.Module, ABC): diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 3f0b7119..8e5dd0cc 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,10 +28,12 @@ import torch from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class AlignedMTL(GramianWeightedAggregator): @@ -73,7 +75,7 @@ def __init__(self, pref_vector: Tensor | None = None): self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: w = self.weighting(gramian) B = self._compute_balance_transformation(gramian) alpha = B @ w diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 77c76a9c..c690d6a9 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,7 +1,9 @@ from typing import cast +from torchjd._linalg import PSDMatrix + from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "clarabel"]) @@ -73,7 +75,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001): self.c = c self.norm_eps = norm_eps - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) reduced_matrix = U @ S.sqrt().diag() diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 35345abb..81404512 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,8 +1,10 @@ from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Constant(WeightedAggregator): diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 68a48776..85b55365 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,13 +2,15 @@ from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class DualProj(GramianWeightedAggregator): @@ -86,7 +88,7 @@ def __init__( self.reg_eps = reg_eps self.solver = solver - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) w = project_weights(u, G, self.solver) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 74da4a97..ea04040f 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -2,7 +2,8 @@ from torch import Tensor -from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting +from torchjd._linalg.matrix import PSDMatrix +from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting from torchjd.autogram._gramian_utils import reshape_gramian diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 52d86526..7c8369ee 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,9 +1,11 @@ import torch from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class IMTLG(GramianWeightedAggregator): @@ -27,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]): :class:`~torchjd.aggregation.IMTLG`. """ - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: d = torch.sqrt(torch.diagonal(gramian)) v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index cf14e577..7b523360 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,8 +2,10 @@ from torch import Tensor from torch.nn import functional as F +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class Krum(GramianWeightedAggregator): @@ -59,7 +61,7 @@ def __init__(self, n_byzantine: int, n_selected: int = 1): self.n_byzantine = n_byzantine self.n_selected = n_selected - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: self._check_matrix_shape(gramian) gradient_norms_squared = torch.diagonal(gramian) distances_squared = ( @@ -78,7 +80,7 @@ def forward(self, gramian: Tensor) -> Tensor: return weights - def _check_matrix_shape(self, gramian: Tensor) -> None: + def _check_matrix_shape(self, gramian: PSDMatrix) -> None: min_rows = self.n_byzantine + 3 if gramian.shape[0] < min_rows: raise ValueError( diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 36600ace..f739e966 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Mean(WeightedAggregator): diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index a5d7b2af..868f5263 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class MGDA(GramianWeightedAggregator): @@ -40,7 +42,7 @@ def __init__(self, epsilon: float = 0.001, max_iters: int = 100): self.epsilon = epsilon self.max_iters = max_iters - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: """ This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 9d1393a5..43939e3b 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -25,8 +25,10 @@ # mypy: ignore-errors +from torchjd._linalg import Matrix + from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index d85a8748..59923649 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,9 +1,11 @@ import torch from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class PCGrad(GramianWeightedAggregator): @@ -25,7 +27,7 @@ class PCGradWeighting(Weighting[PSDMatrix]): :class:`~torchjd.aggregation.PCGrad`. """ - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration device = gramian.device dtype = gramian.dtype diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index ba89a82c..2f2e330c 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,8 +2,10 @@ from torch import Tensor from torch.nn import functional as F +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Random(WeightedAggregator): diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 868257fd..da33512a 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,8 +1,10 @@ import torch from torch import Tensor +from torchjd._linalg import Matrix + from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Sum(WeightedAggregator): diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 750cc735..384af71c 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,13 +3,15 @@ import torch from torch import Tensor +from torchjd._linalg import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class UPGrad(GramianWeightedAggregator): @@ -87,7 +89,7 @@ def __init__( self.reg_eps = reg_eps self.solver = solver - def forward(self, gramian: Tensor) -> Tensor: + def forward(self, gramian: PSDMatrix) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) W = project_weights(U, G, self.solver) diff --git a/src/torchjd/aggregation/_utils/gramian.py b/src/torchjd/aggregation/_utils/gramian.py index dfe16987..81637e42 100644 --- a/src/torchjd/aggregation/_utils/gramian.py +++ b/src/torchjd/aggregation/_utils/gramian.py @@ -1,16 +1,9 @@ import torch -from torch import Tensor - -def compute_gramian(matrix: Tensor) -> Tensor: - """ - Computes the `Gramian matrix `_ of a given matrix. - """ - - return matrix @ matrix.T +from torchjd._linalg.matrix import PSDMatrix -def normalize(gramian: Tensor, eps: float) -> Tensor: +def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix: """ Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. @@ -25,7 +18,7 @@ def normalize(gramian: Tensor, eps: float) -> Tensor: return gramian / squared_frobenius_norm -def regularize(gramian: Tensor, eps: float) -> Tensor: +def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix: """ Adds a regularization term to the gramian to enforce positive definiteness. diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index 1f3efef0..efdfefb1 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,7 +1,8 @@ from torch import Tensor +from torchjd._linalg.matrix import Matrix from torchjd.aggregation._constant import ConstantWeighting -from torchjd.aggregation._weighting_bases import Matrix, Weighting +from torchjd.aggregation._weighting_bases import Weighting from .str import vector_to_str diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 154c7a30..a78b4c55 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -2,15 +2,13 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Annotated, Generic, TypeVar +from typing import Generic, TypeVar from torch import Tensor, nn _T = TypeVar("_T", contravariant=True) _FnInputT = TypeVar("_FnInputT") _FnOutputT = TypeVar("_FnOutputT") -Matrix = Annotated[Tensor, "ndim=2"] -PSDMatrix = Annotated[Matrix, "Positive semi-definite"] class Weighting(Generic[_T], nn.Module, ABC): diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 643845fc..ffe45bfe 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,6 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge +from torchjd._linalg.matrix import PSDMatrix + from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms @@ -232,6 +234,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f"`batch_dim=None` when creating the engine." ) + # Currently, the type PSDMatrix is hidden from users, so Tensor is correct. def compute_gramian(self, output: Tensor) -> Tensor: r""" Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of @@ -305,7 +308,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: return gramian - def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor: + def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> PSDMatrix: leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: @@ -330,6 +333,6 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # If the gramian were None, then leaf_targets would be empty, so autograd.grad would # have failed. So gramian is necessarily a valid Tensor here. - gramian = cast(Tensor, self._gramian_accumulator.gramian) + gramian = cast(PSDMatrix, self._gramian_accumulator.gramian) return gramian diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 2c9405bf..5d2e6c0f 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,6 +1,6 @@ from typing import Optional -from torch import Tensor +from torchjd._linalg.matrix import PSDMatrix class GramianAccumulator: @@ -13,19 +13,19 @@ class GramianAccumulator: """ def __init__(self) -> None: - self._gramian: Optional[Tensor] = None + self._gramian: Optional[PSDMatrix] = None def reset(self) -> None: self._gramian = None - def accumulate_gramian(self, gramian: Tensor) -> None: + def accumulate_gramian(self, gramian: PSDMatrix) -> None: if self._gramian is not None: self._gramian.add_(gramian) else: self._gramian = gramian @property - def gramian(self) -> Optional[Tensor]: + def gramian(self) -> Optional[PSDMatrix]: """ Get the Gramian matrix accumulated so far. diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f21..e5012024 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -4,6 +4,8 @@ from torch import Tensor from torch.utils._pytree import PyTree +from torchjd._linalg import compute_gramian +from torchjd._linalg.matrix import PSDMatrix from torchjd.autogram._jacobian_computer import JacobianComputer @@ -15,7 +17,7 @@ def __call__( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[Tensor]: + ) -> Optional[PSDMatrix]: """Compute what we can for a module and optionally return the gramian if it's ready.""" def track_forward_call(self) -> None: @@ -29,10 +31,6 @@ class JacobianBasedGramianComputer(GramianComputer, ABC): def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer - @staticmethod - def _to_gramian(jacobian: Tensor) -> Tensor: - return jacobian @ jacobian.T - class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): """ @@ -58,7 +56,7 @@ def __call__( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[Tensor]: + ) -> Optional[PSDMatrix]: """Compute what we can for a module and optionally return the gramian if it's ready.""" jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) @@ -71,7 +69,7 @@ def __call__( self.remaining_counter -= 1 if self.remaining_counter == 0: - gramian = self._to_gramian(self.summed_jacobian) + gramian = compute_gramian(self.summed_jacobian) del self.summed_jacobian return gramian else: diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 3bbfd062..d61bc8d9 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -1,7 +1,7 @@ -from torch import Tensor +from torchjd._linalg.matrix import PSDMatrix -def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: +def reshape_gramian(gramian: PSDMatrix, half_shape: list[int]) -> PSDMatrix: """ Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions must be done from the left, while the reshape of the second half must be done from the right. @@ -21,7 +21,7 @@ def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: return _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape)) -def _revert_last_dims(gramian: Tensor) -> Tensor: +def _revert_last_dims(gramian: PSDMatrix) -> PSDMatrix: """Inverts the order of the last half of the dimensions of the input generalized Gramian.""" half_ndim = gramian.ndim // 2 @@ -29,7 +29,9 @@ def _revert_last_dims(gramian: Tensor) -> Tensor: return gramian.movedim(last_dims, last_dims[::-1]) -def movedim_gramian(gramian: Tensor, half_source: list[int], half_destination: list[int]) -> Tensor: +def movedim_gramian( + gramian: PSDMatrix, half_source: list[int], half_destination: list[int] +) -> PSDMatrix: """ Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This must be done simultaneously on the first half of the dimensions and on the second half of the diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 41b07d93..2d1fe068 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -3,9 +3,9 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._linalg import compute_gramian from torchjd.aggregation import MGDA from torchjd.aggregation._mgda import MGDAWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import ( assert_expected_structure, diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index e96253ec..d55d87c0 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,10 +3,10 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._linalg import compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices diff --git a/tests/unit/linalg/__init__.py b/tests/unit/linalg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py new file mode 100644 index 00000000..4f0e5456 --- /dev/null +++ b/tests/unit/linalg/test_gramian.py @@ -0,0 +1,20 @@ +from pytest import mark +from utils.asserts import assert_psd_matrix +from utils.tensors import randn_ + +from torchjd._linalg.gramian import compute_gramian + + +@mark.parametrize( + "shape", + [ + [3, 1], + [4, 4], + [4, 3], + [6, 7], + ], +) +def test_gramian_is_psd(shape: list[int]): + matrix = randn_(shape) + gramian = compute_gramian(matrix) + assert_psd_matrix(gramian) diff --git a/tests/utils/asserts.py b/tests/utils/asserts.py index 09f2520f..3828aa5e 100644 --- a/tests/utils/asserts.py +++ b/tests/utils/asserts.py @@ -3,6 +3,7 @@ import torch from torch.testing import assert_close +from torchjd._linalg.matrix import PSDMatrix from torchjd.autojac._accumulation import TensorWithJac @@ -33,3 +34,15 @@ def assert_has_no_grad(t: torch.Tensor) -> None: def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor, **kwargs) -> None: assert t.grad is not None assert_close(t.grad, expected_grad, **kwargs) + + +def assert_psd_matrix(matrix: PSDMatrix, **kwargs) -> None: + + assert_close(matrix, matrix.mH, **kwargs, msg="Matrix is not symmetric/Hermitian") + + eig_vals = torch.linalg.eigvalsh(matrix) + expected_eig_vals = eig_vals.clamp(min=0.0) + + assert_close( + eig_vals, expected_eig_vals, **kwargs, msg="Matrix has significant negative eigenvalues" + )