Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
between `"min"`, `"median"`, and `"rmse"` scaling.

### Changed

- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the
Expand Down
50 changes: 42 additions & 8 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
# SOFTWARE.


from typing import Literal

import torch
from torch import Tensor

Expand All @@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator):

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).

.. note::
This implementation was adapted from the `official implementation
<https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_.
"""

def __init__(self, pref_vector: Tensor | None = None):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a changelog entry (under "Unreleased") saying that it's now possible to provide a scale_mode parameter for AlignedMTL and AlignedMTLWeighting?

):
self._pref_vector = pref_vector
super().__init__(AlignedMTLWeighting(pref_vector))
self._scale_mode = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))

def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
)

def __str__(self) -> str:
return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}"
Expand All @@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]):

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses
the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"``
uses the mean eigenvalue (as in the original implementation).
"""

def __init__(self, pref_vector: Tensor | None = None):
def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: Literal["min", "median", "rmse"] = "min",
):
super().__init__()
self._pref_vector = pref_vector
self._scale_mode = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

def forward(self, gramian: PSDMatrix) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
alpha = B @ w

return alpha

@staticmethod
def _compute_balance_transformation(M: Tensor) -> Tensor:
def _compute_balance_transformation(
M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min"
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
Expand All @@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]

sigma_inv = torch.diag(1 / lambda_.sqrt())
lambda_R = lambda_[-1]
B = lambda_R.sqrt() * V @ sigma_inv @ V.T

if scale_mode == "min":
scale = lambda_[-1]
elif scale_mode == "median":
scale = torch.median(lambda_)
elif scale_mode == "rmse":
scale = lambda_.mean()
Comment on lines +125 to +128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add code coverage for this? I think it should be enough to just parametrize test_expected_structure and test_permutation_invariant (from tests/unit/aggregation/test_aligned_mtl.py) with AlignedMTL(scale_mode="median") and AlignedMTL(scale_mode="rmse") (on top of the already existing AlignedMTL()).

else:
raise ValueError(
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
)
Comment on lines +129 to +132
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need coverage for this. A simple test should be enough:

from pytest import raises
from utils.tensors import ones_

def test_invalid_scale_mode():
    aggregator = AlignedMTL(scale_mode="test")
    matrix = ones_(3, 4)
    with raises(ValueError):
        aggregator(matrix)


B = scale.sqrt() * V @ sigma_inv @ V.T
return B
23 changes: 18 additions & 5 deletions tests/unit/aggregation/test_aligned_mtl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import torch
from pytest import mark
from pytest import mark, raises
from torch import Tensor
from utils.tensors import ones_

from torchjd.aggregation import AlignedMTL

from ._asserts import assert_expected_structure, assert_permutation_invariant
from ._inputs import scaled_matrices, typical_matrices

scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices]
typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices]
aggregators = [
AlignedMTL(),
AlignedMTL(scale_mode="median"),
AlignedMTL(scale_mode="rmse"),
]
scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices]
typical_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in typical_matrices]


@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)
Expand All @@ -23,9 +29,16 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor):

def test_representations():
A = AlignedMTL(pref_vector=None)
assert repr(A) == "AlignedMTL(pref_vector=None)"
assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')"
assert str(A) == "AlignedMTL"

A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')"
assert str(A) == "AlignedMTL([1., 2., 3.])"


def test_invalid_scale_mode():
aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type]
matrix = ones_(3, 4)
with raises(ValueError, match=r"Invalid scale_mode=.*Expected"):
aggregator(matrix)