diff --git a/CHANGELOG.md b/CHANGELOG.md index d520e901..aa14a0a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose + between `"min"`, `"median"`, and `"rmse"` scaling. + ### Changed - **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 8e5dd0cc..40d3d9b1 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -25,6 +25,8 @@ # SOFTWARE. +from typing import Literal + import torch from torch import Tensor @@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses + the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` + uses the mean eigenvalue (as in the original implementation). .. note:: This implementation was adapted from the `official implementation `_. """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__( + self, + pref_vector: Tensor | None = None, + scale_mode: Literal["min", "median", "rmse"] = "min", + ): self._pref_vector = pref_vector - super().__init__(AlignedMTLWeighting(pref_vector)) + self._scale_mode = scale_mode + super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) def __repr__(self) -> str: - return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" + return ( + f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " + f"scale_mode={repr(self._scale_mode)})" + ) def __str__(self) -> str: return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" @@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses + the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` + uses the mean eigenvalue (as in the original implementation). """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__( + self, + pref_vector: Tensor | None = None, + scale_mode: Literal["min", "median", "rmse"] = "min", + ): super().__init__() self._pref_vector = pref_vector + self._scale_mode = scale_mode self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) def forward(self, gramian: PSDMatrix) -> Tensor: w = self.weighting(gramian) - B = self._compute_balance_transformation(gramian) + B = self._compute_balance_transformation(gramian, self._scale_mode) alpha = B @ w return alpha @staticmethod - def _compute_balance_transformation(M: Tensor) -> Tensor: + def _compute_balance_transformation( + M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min" + ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) @@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor: lambda_, V = lambda_[order][:rank], V[:, order][:, :rank] sigma_inv = torch.diag(1 / lambda_.sqrt()) - lambda_R = lambda_[-1] - B = lambda_R.sqrt() * V @ sigma_inv @ V.T + + if scale_mode == "min": + scale = lambda_[-1] + elif scale_mode == "median": + scale = torch.median(lambda_) + elif scale_mode == "rmse": + scale = lambda_.mean() + else: + raise ValueError( + f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'." + ) + + B = scale.sqrt() * V @ sigma_inv @ V.T return B diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 0be5f828..3580f80f 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -1,14 +1,20 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor +from utils.tensors import ones_ from torchjd.aggregation import AlignedMTL from ._asserts import assert_expected_structure, assert_permutation_invariant from ._inputs import scaled_matrices, typical_matrices -scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices] -typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices] +aggregators = [ + AlignedMTL(), + AlignedMTL(scale_mode="median"), + AlignedMTL(scale_mode="rmse"), +] +scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices] +typical_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in typical_matrices] @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) @@ -23,9 +29,16 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor): def test_representations(): A = AlignedMTL(pref_vector=None) - assert repr(A) == "AlignedMTL(pref_vector=None)" + assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" assert str(A) == "AlignedMTL" A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) - assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))" + assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')" assert str(A) == "AlignedMTL([1., 2., 3.])" + + +def test_invalid_scale_mode(): + aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type] + matrix = ones_(3, 4) + with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): + aggregator(matrix)