Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9fec107
refactor(linalg): Add `PSDQuadraticForm` and `GeneralizedMatrix`.
PierreQuinton Jan 19, 2026
acf8e58
Merge branch 'main' into add-generalized-matrix-psd-matrix
PierreQuinton Jan 19, 2026
a744fa2
Sort items of `__all__` of `_linalg.__init__`
PierreQuinton Jan 19, 2026
23de54d
one line
PierreQuinton Jan 19, 2026
2bd603e
fix `is_psd_quadratic_form`
PierreQuinton Jan 19, 2026
d6f8375
remove outdated comment
PierreQuinton Jan 19, 2026
24d24bb
Add `assert_psd_quadratic_form` and TODOs for where to test it. I als…
PierreQuinton Jan 19, 2026
242cb55
fix is_psd_quadratic_form
PierreQuinton Jan 20, 2026
72a9a5f
Rename `PSDQuadraticForm` to `PSDGeneralizedMatrix`
PierreQuinton Jan 20, 2026
09df593
fix type of weighting in Flattening
PierreQuinton Jan 20, 2026
0a1d45c
Add parametrization of zero matrix for test_gramian_is_psd
PierreQuinton Jan 20, 2026
6f63182
Add test of the PSD property for functions in aggregation/_utils/gramian
PierreQuinton Jan 20, 2026
5a42ecd
rename test of equivariance accordingly
PierreQuinton Jan 20, 2026
0497f3a
Rename functions in `autogram/_gramian_utils` so that they don't incl…
PierreQuinton Jan 20, 2026
48df0a8
Test the PSD property on outputs of functions in `autogram/_gramian_u…
PierreQuinton Jan 20, 2026
40977f3
Remove internal checks of shapes of matrices
PierreQuinton Jan 20, 2026
92b975b
Remove uninformative shadowing of assertion error in assert_psd_*
PierreQuinton Jan 20, 2026
bda0a5f
Factorize `compute_gramian` from `forward_backward` by making the one…
PierreQuinton Jan 20, 2026
97bcf42
Revert "Factorize `compute_gramian` from `forward_backward` by making…
PierreQuinton Jan 20, 2026
03aebae
Generalizes `compute_gramian` to take a `GeneralizedMatrix` instead.
PierreQuinton Jan 20, 2026
ee54c09
Move `aggregation/_utils/gramian.py` to `_linalg/gramian.py`
PierreQuinton Jan 20, 2026
2b94d78
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 20, 2026
e347075
Apply suggestions from code review
PierreQuinton Jan 21, 2026
3d9742c
Remove outdated comments
PierreQuinton Jan 21, 2026
f2d0d1b
Improve style
PierreQuinton Jan 21, 2026
5eafa74
Improve typing of `forward_backward.compute_gramian`
PierreQuinton Jan 21, 2026
d60e9fa
improve asserts
PierreQuinton Jan 21, 2026
f4d611b
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 21, 2026
57af9f1
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey Jan 21, 2026
a793693
Can parametrize number of dimensions to contract in `compute_gramian`
PierreQuinton Jan 22, 2026
7da352c
Remove GeneralizedMatrix
ValerianRey Jan 23, 2026
994932b
Rename PSDGeneralizedMatrix to PSDTensor
ValerianRey Jan 23, 2026
ab809c6
Add comment about using classes
ValerianRey Jan 23, 2026
47bf743
Remove useless overload of compute_gramian
ValerianRey Jan 23, 2026
55bc6f8
Rename matrix to t in compute_gramian
ValerianRey Jan 23, 2026
a80f3f6
Add overload for compute_gramian when t is matrix and contracted_dims…
ValerianRey Jan 23, 2026
09393cc
Stop expecting coverage for overload functions
ValerianRey Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,9 @@ full = [

[tool.pytest.ini_options]
xfail_strict = true

[tool.coverage.report]
exclude_lines = [
"pragma: not covered",
"@overload",
]
16 changes: 13 additions & 3 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .gramian import compute_gramian
from .matrix import Matrix, PSDMatrix
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor

__all__ = ["compute_gramian", "Matrix", "PSDMatrix"]
__all__ = [
"compute_gramian",
"normalize",
"regularize",
"Matrix",
"PSDMatrix",
"PSDTensor",
"is_matrix",
"is_psd_matrix",
"is_psd_tensor",
]
70 changes: 70 additions & 0 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Literal, cast, overload

import torch
from torch import Tensor

from ._matrix import Matrix, PSDMatrix, PSDTensor


@overload
def compute_gramian(t: Tensor) -> PSDMatrix:
pass


@overload
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
pass


@overload
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
pass


def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.

`contracted_dims` specifies the number of trailing dimensions to contract. If negative,
it indicates the number of leading dimensions to preserve (e.g., ``-1`` preserves the
first dimension).
"""

contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
return cast(PSDTensor, gramian)


def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""
Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`.

If `G=A A^T`, then the Frobenius norm of `A` is the square root of the trace of `G`, i.e., the
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
therefore `G` divided by the sum of its diagonal elements.
"""
squared_frobenius_norm = gramian.diagonal().sum()
if squared_frobenius_norm < eps:
output = torch.zeros_like(gramian)
else:
output = gramian / squared_frobenius_norm
return cast(PSDMatrix, output)


def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""
Adds a regularization term to the gramian to enforce positive definiteness.

Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a
regularization term which is a small proportion of the identity matrix ensures that the gramian
is positive definite.
"""

regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
)
output = gramian + regularization_matrix
return cast(PSDMatrix, output)
40 changes: 40 additions & 0 deletions src/torchjd/_linalg/_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import TypeGuard

from torch import Tensor

# Note: we're using classes and inherittance instead of NewType because it's possible to have
# multiple inherittance but there is no type intersection. However, these classes should never be
# instantiated: they're only used for static type checking.


class Matrix(Tensor):
"""Tensor with exactly 2 dimensions."""


class PSDTensor(Tensor):
"""
Tensor representing a quadratic form. The first half of its dimensions matches the reversed
second half of its dimensions (e.g. shape=[4, 3, 3, 4]), and its reshaping into a matrix should
be positive semi-definite.
"""


class PSDMatrix(PSDTensor, Matrix):
"""Positive semi-definite matrix."""


def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
return t.ndim == 2


def is_psd_tensor(t: Tensor) -> TypeGuard[PSDTensor]:
half_dim = t.ndim // 2
return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
# every function that uses this TypeGuard by using `assert_is_psd_tensor`.


def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
return t.ndim == 2 and t.shape[0] == t.shape[1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
# every function that uses this TypeGuard, by using `assert_is_psd_matrix`.
9 changes: 0 additions & 9 deletions src/torchjd/_linalg/gramian.py

This file was deleted.

6 changes: 0 additions & 6 deletions src/torchjd/_linalg/matrix.py

This file was deleted.

14 changes: 6 additions & 8 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor, nn

from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix

from ._weighting_bases import Weighting

Expand All @@ -18,20 +18,19 @@ def __init__(self):

@staticmethod
def _check_is_matrix(matrix: Tensor) -> None:
if len(matrix.shape) != 2:
if not is_matrix(matrix):
raise ValueError(
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
f"{matrix.shape}`."
)

@abstractmethod
def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Matrix) -> Tensor:
"""Computes the aggregation from the input matrix."""

# Override to make type hints and documentation more specific
def __call__(self, matrix: Tensor) -> Tensor:
"""Computes the aggregation from the input matrix and applies all registered hooks."""

Aggregator._check_is_matrix(matrix)
return super().__call__(matrix)

def __repr__(self) -> str:
Expand All @@ -54,7 +53,7 @@ def __init__(self, weighting: Weighting[Matrix]):
self.weighting = weighting

@staticmethod
def combine(matrix: Tensor, weights: Tensor) -> Tensor:
def combine(matrix: Matrix, weights: Tensor) -> Tensor:
"""
Aggregates a matrix by making a linear combination of its rows, using the provided vector of
weights.
Expand All @@ -63,8 +62,7 @@ def combine(matrix: Tensor, weights: Tensor) -> Tensor:
vector = weights @ matrix
return vector

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
def forward(self, matrix: Matrix) -> Tensor:
weights = self.weighting(matrix)
vector = self.combine(matrix, weights)
return vector
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import torch
from torch import Tensor

from torchjd._linalg import normalize

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.gramian import normalize
from ._utils.non_differentiable import raise_non_differentiable_error


Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._sum import SumWeighting
from ._utils.non_differentiable import raise_non_differentiable_error
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Matrix) -> Tensor:
weights = self.weighting(matrix)
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
best_direction = torch.linalg.pinv(units) @ weights
Expand Down
3 changes: 1 addition & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd._linalg import PSDMatrix, normalize, regularize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand Down
13 changes: 5 additions & 8 deletions src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from math import prod

from torch import Tensor

from torchjd._linalg.matrix import PSDMatrix
from torchjd._linalg import PSDTensor
from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting
from torchjd.autogram._gramian_utils import reshape_gramian
from torchjd.autogram._gramian_utils import flatten


class Flattening(GeneralizedWeighting):
Expand All @@ -22,15 +20,14 @@ class Flattening(GeneralizedWeighting):
:param weighting: The weighting to apply to the Gramian matrix.
"""

def __init__(self, weighting: Weighting[PSDMatrix]):
def __init__(self, weighting: Weighting):
super().__init__()
self.weighting = weighting

def forward(self, generalized_gramian: Tensor) -> Tensor:
def forward(self, generalized_gramian: PSDTensor) -> Tensor:
k = generalized_gramian.ndim // 2
shape = generalized_gramian.shape[:k]
m = prod(shape)
square_gramian = reshape_gramian(generalized_gramian, [m])
square_gramian = flatten(generalized_gramian)
weights_vector = self.weighting(square_gramian)
weights = weights_vector.reshape(shape)
return weights
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor

from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._utils.non_differentiable import raise_non_differentiable_error

Expand Down Expand Up @@ -38,8 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
def forward(self, matrix: Matrix) -> Tensor:
self._check_matrix_has_enough_rows(matrix)

if matrix.shape[0] == 0 or matrix.shape[1] == 0:
Expand Down
4 changes: 3 additions & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import torch
from torch import Tensor

Expand Down Expand Up @@ -32,7 +34,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
device = gramian.device
dtype = gramian.dtype
cpu = torch.device("cpu")
gramian = gramian.to(device=cpu)
gramian = cast(PSDMatrix, gramian.to(device=cpu))

dimension = gramian.shape[0]
weights = torch.zeros(dimension, device=cpu, dtype=dtype)
Expand Down
1 change: 0 additions & 1 deletion src/torchjd/aggregation/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, trim_number: int):
self.trim_number = trim_number

def forward(self, matrix: Tensor) -> Tensor:
self._check_is_matrix(matrix)
self._check_matrix_has_enough_rows(matrix)

n_rows = matrix.shape[0]
Expand Down
3 changes: 1 addition & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd._linalg import PSDMatrix, normalize, regularize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import project_weights
from ._utils.gramian import normalize, regularize
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand Down
33 changes: 0 additions & 33 deletions src/torchjd/aggregation/_utils/gramian.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import Tensor

from torchjd._linalg.matrix import Matrix
from torchjd._linalg import Matrix
from torchjd.aggregation._constant import ConstantWeighting
from torchjd.aggregation._weighting_bases import Weighting

Expand Down
Loading
Loading