From 07867c0fa80477408c157a89c9404063c680dadb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 15 Jan 2026 10:39:09 +0100 Subject: [PATCH 1/4] factorise computing gramians --- src/torchjd/_utils/__init__.py | 3 +++ src/torchjd/_utils/compute_gramian.py | 15 +++++++++++++++ src/torchjd/aggregation/_aggregator_bases.py | 3 ++- src/torchjd/aggregation/_utils/gramian.py | 8 -------- src/torchjd/autogram/_gramian_computer.py | 7 ++----- tests/unit/aggregation/test_mgda.py | 2 +- tests/unit/aggregation/test_pcgrad.py | 2 +- 7 files changed, 24 insertions(+), 16 deletions(-) create mode 100644 src/torchjd/_utils/__init__.py create mode 100644 src/torchjd/_utils/compute_gramian.py diff --git a/src/torchjd/_utils/__init__.py b/src/torchjd/_utils/__init__.py new file mode 100644 index 00000000..5aba58ab --- /dev/null +++ b/src/torchjd/_utils/__init__.py @@ -0,0 +1,3 @@ +from .compute_gramian import compute_gramian + +__all__ = ["compute_gramian"] diff --git a/src/torchjd/_utils/compute_gramian.py b/src/torchjd/_utils/compute_gramian.py new file mode 100644 index 00000000..e2745595 --- /dev/null +++ b/src/torchjd/_utils/compute_gramian.py @@ -0,0 +1,15 @@ +import torch +from torch import Tensor + + +def compute_gramian(generalized_matrix: Tensor) -> Tensor: + """ + Computes the `Gramian matrix `_ of a given + generalized matrix. Specifically, this is equivalent to + + matrix = generalized_matrix.reshape([generalized_matrix.shape[0], -1]) + return matrix @ matrix.T + """ + dims = list(range(1, generalized_matrix.ndim)) + gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims)) + return gramian diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index c15e4502..853f720a 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -2,7 +2,8 @@ from torch import Tensor, nn -from ._utils.gramian import compute_gramian +from torchjd._utils import compute_gramian + from ._weighting_bases import Matrix, PSDMatrix, Weighting diff --git a/src/torchjd/aggregation/_utils/gramian.py b/src/torchjd/aggregation/_utils/gramian.py index dfe16987..58fc11fa 100644 --- a/src/torchjd/aggregation/_utils/gramian.py +++ b/src/torchjd/aggregation/_utils/gramian.py @@ -2,14 +2,6 @@ from torch import Tensor -def compute_gramian(matrix: Tensor) -> Tensor: - """ - Computes the `Gramian matrix `_ of a given matrix. - """ - - return matrix @ matrix.T - - def normalize(gramian: Tensor, eps: float) -> Tensor: """ Normalizes the gramian `G=AA^T` with respect to the Frobenius norm of `A`. diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 2bc62f21..92ba933d 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.utils._pytree import PyTree +from torchjd._utils import compute_gramian from torchjd.autogram._jacobian_computer import JacobianComputer @@ -29,10 +30,6 @@ class JacobianBasedGramianComputer(GramianComputer, ABC): def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer - @staticmethod - def _to_gramian(jacobian: Tensor) -> Tensor: - return jacobian @ jacobian.T - class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): """ @@ -71,7 +68,7 @@ def __call__( self.remaining_counter -= 1 if self.remaining_counter == 0: - gramian = self._to_gramian(self.summed_jacobian) + gramian = compute_gramian(self.summed_jacobian) del self.summed_jacobian return gramian else: diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 41b07d93..a63e1724 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -3,9 +3,9 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._utils.compute_gramian import compute_gramian from torchjd.aggregation import MGDA from torchjd.aggregation._mgda import MGDAWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import ( assert_expected_structure, diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index e96253ec..e4a14ed3 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,10 +3,10 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ +from torchjd._utils.compute_gramian import compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting -from torchjd.aggregation._utils.gramian import compute_gramian from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices From bede311be7865f6274a44bf6a8bbe6f93269ed8b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 15 Jan 2026 10:47:44 +0100 Subject: [PATCH 2/4] Extract the Aggregator based logic of `jac_to_grad` --- src/torchjd/autojac/_jac_to_grad.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8c85025f..2bc74c88 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -73,10 +73,18 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) + gradients = _jacobian_based(aggregator, jacobians, tensors_) + + accumulate_grads(tensors_, gradients) + + +def _jacobian_based( + aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) - accumulate_grads(tensors_, gradients) + gradients = _disunite_gradient(gradient_vector, jacobians, tensors) + return gradients def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: From 9ee68f88fd9c5e845312abe4475d099f4aad4b03 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 15 Jan 2026 11:15:00 +0100 Subject: [PATCH 3/4] several things: - moves Matrix and PSDMatrix to compute_gramian (not best position probably, but should be in _utils) - Change return type of compute_gramian to PSDMatrix - Add compute_gramian_sum (note that the responsability of casting to PSDMatrix is given to _utils now). - add _gramian_based version of jac_to_grad. Note that we could put the tensordot(weights, jacobian, dims=1) in _utils as a weight_generalize_matrix method. --- src/torchjd/_utils/__init__.py | 4 ++-- src/torchjd/_utils/compute_gramian.py | 14 ++++++++++++-- src/torchjd/aggregation/_aggregator_bases.py | 4 +++- src/torchjd/aggregation/_aligned_mtl.py | 3 ++- src/torchjd/aggregation/_cagrad.py | 3 ++- src/torchjd/aggregation/_constant.py | 3 ++- src/torchjd/aggregation/_dualproj.py | 3 ++- src/torchjd/aggregation/_flattening.py | 3 ++- src/torchjd/aggregation/_imtl_g.py | 3 ++- src/torchjd/aggregation/_krum.py | 3 ++- src/torchjd/aggregation/_mean.py | 3 ++- src/torchjd/aggregation/_mgda.py | 3 ++- src/torchjd/aggregation/_nash_mtl.py | 3 ++- src/torchjd/aggregation/_pcgrad.py | 3 ++- src/torchjd/aggregation/_random.py | 3 ++- src/torchjd/aggregation/_sum.py | 3 ++- src/torchjd/aggregation/_upgrad.py | 3 ++- src/torchjd/aggregation/_utils/pref_vector.py | 3 ++- src/torchjd/aggregation/_weighting_bases.py | 4 +--- src/torchjd/autojac/_jac_to_grad.py | 19 +++++++++++++++++-- 20 files changed, 65 insertions(+), 25 deletions(-) diff --git a/src/torchjd/_utils/__init__.py b/src/torchjd/_utils/__init__.py index 5aba58ab..6eaf806c 100644 --- a/src/torchjd/_utils/__init__.py +++ b/src/torchjd/_utils/__init__.py @@ -1,3 +1,3 @@ -from .compute_gramian import compute_gramian +from .compute_gramian import Matrix, PSDMatrix, compute_gramian, compute_gramian_sum -__all__ = ["compute_gramian"] +__all__ = ["compute_gramian", "compute_gramian_sum", "Matrix", "PSDMatrix"] diff --git a/src/torchjd/_utils/compute_gramian.py b/src/torchjd/_utils/compute_gramian.py index e2745595..1d3fdd59 100644 --- a/src/torchjd/_utils/compute_gramian.py +++ b/src/torchjd/_utils/compute_gramian.py @@ -1,8 +1,13 @@ +from typing import Annotated, cast + import torch from torch import Tensor +Matrix = Annotated[Tensor, "ndim=2"] +PSDMatrix = Annotated[Matrix, "Positive semi-definite"] + -def compute_gramian(generalized_matrix: Tensor) -> Tensor: +def compute_gramian(generalized_matrix: Tensor) -> PSDMatrix: """ Computes the `Gramian matrix `_ of a given generalized matrix. Specifically, this is equivalent to @@ -12,4 +17,9 @@ def compute_gramian(generalized_matrix: Tensor) -> Tensor: """ dims = list(range(1, generalized_matrix.ndim)) gramian = torch.tensordot(generalized_matrix, generalized_matrix, dims=(dims, dims)) - return gramian + return cast(PSDMatrix, gramian) + + +def compute_gramian_sum(generalized_matrices: list[Tensor]) -> PSDMatrix: + gramian = sum([compute_gramian(matrix) for matrix in generalized_matrices]) + return cast(PSDMatrix, gramian) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 853f720a..44356383 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -4,7 +4,8 @@ from torchjd._utils import compute_gramian -from ._weighting_bases import Matrix, PSDMatrix, Weighting +from .._utils.compute_gramian import Matrix, PSDMatrix +from ._weighting_bases import Weighting class Aggregator(nn.Module, ABC): @@ -80,3 +81,4 @@ class GramianWeightedAggregator(WeightedAggregator): def __init__(self, weighting: Weighting[PSDMatrix]): super().__init__(weighting << compute_gramian) + self.psd_weighting = weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 3f0b7119..928513c8 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,10 +28,11 @@ import torch from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class AlignedMTL(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 77c76a9c..772825d5 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,7 +1,8 @@ from typing import cast +from .._utils.compute_gramian import PSDMatrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "clarabel"]) diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 35345abb..746a0c60 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,8 +1,9 @@ from torch import Tensor +from .._utils.compute_gramian import Matrix from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Constant(WeightedAggregator): diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 68a48776..ce03e20e 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,13 +2,14 @@ from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class DualProj(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 74da4a97..08559ca3 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -2,7 +2,8 @@ from torch import Tensor -from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting +from torchjd._utils.compute_gramian import PSDMatrix +from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting from torchjd.autogram._gramian_utils import reshape_gramian diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 52d86526..7c213524 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,9 +1,10 @@ import torch from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class IMTLG(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index cf14e577..c8bd5af9 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,8 +2,9 @@ from torch import Tensor from torch.nn import functional as F +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class Krum(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 36600ace..988a4049 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,8 +1,9 @@ import torch from torch import Tensor +from .._utils.compute_gramian import Matrix from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Mean(WeightedAggregator): diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index a5d7b2af..e8b70ac8 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,8 +1,9 @@ import torch from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class MGDA(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 9d1393a5..83f93ae6 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -25,8 +25,9 @@ # mypy: ignore-errors +from .._utils.compute_gramian import Matrix from ._utils.check_dependencies import check_dependencies_are_installed -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting check_dependencies_are_installed(["cvxpy", "ecos"]) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index d85a8748..627f988e 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,9 +1,10 @@ import torch from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class PCGrad(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index ba89a82c..4df46c59 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,8 +2,9 @@ from torch import Tensor from torch.nn import functional as F +from .._utils.compute_gramian import Matrix from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Random(WeightedAggregator): diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 868257fd..a20af56e 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,8 +1,9 @@ import torch from torch import Tensor +from .._utils.compute_gramian import Matrix from ._aggregator_bases import WeightedAggregator -from ._weighting_bases import Matrix, Weighting +from ._weighting_bases import Weighting class Sum(WeightedAggregator): diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 750cc735..3d3beea1 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,13 +3,14 @@ import torch from torch import Tensor +from .._utils.compute_gramian import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from ._weighting_bases import PSDMatrix, Weighting +from ._weighting_bases import Weighting class UPGrad(GramianWeightedAggregator): diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index 1f3efef0..a121e0f5 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,7 +1,8 @@ from torch import Tensor +from torchjd._utils.compute_gramian import Matrix from torchjd.aggregation._constant import ConstantWeighting -from torchjd.aggregation._weighting_bases import Matrix, Weighting +from torchjd.aggregation._weighting_bases import Weighting from .str import vector_to_str diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 154c7a30..a78b4c55 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -2,15 +2,13 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Annotated, Generic, TypeVar +from typing import Generic, TypeVar from torch import Tensor, nn _T = TypeVar("_T", contravariant=True) _FnInputT = TypeVar("_FnInputT") _FnOutputT = TypeVar("_FnOutputT") -Matrix = Annotated[Tensor, "ndim=2"] -PSDMatrix = Annotated[Matrix, "Positive semi-definite"] class Weighting(Generic[_T], nn.Module, ABC): diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2bc74c88..105f6e0f 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -2,9 +2,12 @@ from typing import cast import torch -from torch import Tensor +from torch import Tensor, tensordot +from torchjd._utils import PSDMatrix, compute_gramian_sum from torchjd.aggregation import Aggregator +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator +from torchjd.aggregation._weighting_bases import Weighting from ._accumulation import TensorWithJac, accumulate_grads @@ -73,11 +76,23 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) - gradients = _jacobian_based(aggregator, jacobians, tensors_) + if isinstance(aggregator, GramianWeightedAggregator): + gradients = _gramian_based(aggregator.psd_weighting, jacobians, tensors_) + else: + gradients = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) +def _gramian_based( + weighting: Weighting[PSDMatrix], jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: + gramian = compute_gramian_sum(jacobians) + weights = weighting(gramian) + gradients = [tensordot(weights, jacobian, dims=1) for jacobian in jacobians] + return gradients + + def _jacobian_based( aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: From 9531e3c22de5106027c90e51bd0c36f887f24673 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 15 Jan 2026 11:34:48 +0100 Subject: [PATCH 4/4] Improve imports --- src/torchjd/aggregation/_aggregator_bases.py | 3 +-- src/torchjd/aggregation/_aligned_mtl.py | 3 ++- src/torchjd/aggregation/_cagrad.py | 3 ++- src/torchjd/aggregation/_constant.py | 3 ++- src/torchjd/aggregation/_dualproj.py | 3 ++- src/torchjd/aggregation/_imtl_g.py | 3 ++- src/torchjd/aggregation/_krum.py | 3 ++- src/torchjd/aggregation/_mean.py | 3 ++- src/torchjd/aggregation/_mgda.py | 3 ++- src/torchjd/aggregation/_nash_mtl.py | 3 ++- src/torchjd/aggregation/_pcgrad.py | 3 ++- src/torchjd/aggregation/_random.py | 3 ++- src/torchjd/aggregation/_sum.py | 3 ++- src/torchjd/aggregation/_upgrad.py | 3 ++- 14 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 44356383..95b37b70 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -2,9 +2,8 @@ from torch import Tensor, nn -from torchjd._utils import compute_gramian +from torchjd._utils import Matrix, PSDMatrix, compute_gramian -from .._utils.compute_gramian import Matrix, PSDMatrix from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 928513c8..eb74295e 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,7 +28,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 772825d5..ed114e36 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -1,6 +1,7 @@ from typing import cast -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 746a0c60..00d3d97e 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,6 +1,7 @@ from torch import Tensor -from .._utils.compute_gramian import Matrix +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index ce03e20e..742a32cc 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,7 +2,8 @@ from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 7c213524..9d2c05fa 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index c8bd5af9..dc119073 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,7 +2,8 @@ from torch import Tensor from torch.nn import functional as F -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 988a4049..50919909 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import Matrix +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index e8b70ac8..9b8c7e82 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 83f93ae6..c02b9c54 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -25,7 +25,8 @@ # mypy: ignore-errors -from .._utils.compute_gramian import Matrix +from torchjd._utils.compute_gramian import Matrix + from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 627f988e..d9338ad0 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 4df46c59..de3fc235 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,7 +2,8 @@ from torch import Tensor from torch.nn import functional as F -from .._utils.compute_gramian import Matrix +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index a20af56e..0efe354f 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,7 +1,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import Matrix +from torchjd._utils.compute_gramian import Matrix + from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 3d3beea1..aaa3621d 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,7 +3,8 @@ import torch from torch import Tensor -from .._utils.compute_gramian import PSDMatrix +from torchjd._utils.compute_gramian import PSDMatrix + from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._utils.dual_cone import project_weights