Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
between `"min"`, `"median"`, and `"rmse"` scaling.

### Changed

- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
Expand Down
50 changes: 42 additions & 8 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
# SOFTWARE.


from typing import Literal

import torch
from torch import Tensor

Expand All @@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator):

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).

.. note::
This implementation was adapted from the `official implementation
<https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
"""

def __init__(self, pref_vector: Tensor | None = None):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
):
self._pref_vector = pref_vector
super().__init__(AlignedMTLWeighting(pref_vector))
self._scale_mode = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
)

def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
Expand All @@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).
"""

def __init__(self, pref_vector: Tensor | None = None):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
):
super().__init__()
self._pref_vector = pref_vector
self._scale_mode = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

def forward(self, gramian: PSDMatrix) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
alpha = B @ w

return alpha

@staticmethod
def _compute_balance_transformation(M: Tensor) -> Tensor:
def _compute_balance_transformation(
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
Expand All @@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]

sigma_inv = torch.diag(1 / lambda_.sqrt())
lambda_R = lambda_[-1]
B = lambda_R.sqrt() * V @ sigma_inv @ V.T

if scale_mode == "min":
scale = lambda_[-1]
elif scale_mode == "median":
scale = torch.median(lambda_)
elif scale_mode == "rmse":
scale = lambda_.mean()
else:
raise ValueError(
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
)

B = scale.sqrt() * V @ sigma_inv @ V.T
return B
22 changes: 18 additions & 4 deletions tests/unit/aggregation/test_aligned_mtl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import torch
from pytest import mark
from pytest import mark, raises
from torch import Tensor
from utils.tensors import ones_

from torchjd.aggregation import AlignedMTL

from ._asserts import assert_expected_structure, assert_permutation_invariant
from ._inputs import scaled_matrices, typical_matrices

scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices]
aggregators = [
AlignedMTL(),
AlignedMTL(scale_mode="median"),
AlignedMTL(scale_mode="rmse"),
]
scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices]
# test_permutation_invariant seems to fail on gpu for scale_mode="median" or scale_mode="rmse".
typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices]


Expand All @@ -23,9 +30,16 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor):

def test_representations():
A = AlignedMTL(pref_vector=None)
assert repr(A) == "AlignedMTL(pref_vector=None)"
assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')"
assert str(A) == "AlignedMTL"

A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')"
assert str(A) == "AlignedMTL([1., 2., 3.])"


def test_invalid_scale_mode():
aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type]
matrix = ones_(3, 4)
with raises(ValueError, match=r"Invalid scale_mode=.*Expected"):
aggregator(matrix)
Loading