From c02ebfb9fb78555eaba1e308d8083041b474e848 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Fri, 23 Jan 2026 10:12:19 -0500 Subject: [PATCH 1/4] feat: Add scale_mode options to AlignedMTL Expose median and rmse scaling modes for the balance transformation to match original behavior while keeping min as default. --- src/torchjd/aggregation/_aligned_mtl.py | 50 ++++++++++++++++++---- tests/unit/aggregation/test_aligned_mtl.py | 4 +- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 8e5dd0cc..40d3d9b1 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -25,6 +25,8 @@ # SOFTWARE. +from typing import Literal + import torch from torch import Tensor @@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses + the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` + uses the mean eigenvalue (as in the original implementation). .. note:: This implementation was adapted from the `official implementation `_. """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__( + self, + pref_vector: Tensor | None = None, + scale_mode: Literal["min", "median", "rmse"] = "min", + ): self._pref_vector = pref_vector - super().__init__(AlignedMTLWeighting(pref_vector)) + self._scale_mode = scale_mode + super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) def __repr__(self) -> str: - return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" + return ( + f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " + f"scale_mode={repr(self._scale_mode)})" + ) def __str__(self) -> str: return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" @@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. + :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses + the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` + uses the mean eigenvalue (as in the original implementation). """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__( + self, + pref_vector: Tensor | None = None, + scale_mode: Literal["min", "median", "rmse"] = "min", + ): super().__init__() self._pref_vector = pref_vector + self._scale_mode = scale_mode self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) def forward(self, gramian: PSDMatrix) -> Tensor: w = self.weighting(gramian) - B = self._compute_balance_transformation(gramian) + B = self._compute_balance_transformation(gramian, self._scale_mode) alpha = B @ w return alpha @staticmethod - def _compute_balance_transformation(M: Tensor) -> Tensor: + def _compute_balance_transformation( + M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min" + ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) @@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor: lambda_, V = lambda_[order][:rank], V[:, order][:, :rank] sigma_inv = torch.diag(1 / lambda_.sqrt()) - lambda_R = lambda_[-1] - B = lambda_R.sqrt() * V @ sigma_inv @ V.T + + if scale_mode == "min": + scale = lambda_[-1] + elif scale_mode == "median": + scale = torch.median(lambda_) + elif scale_mode == "rmse": + scale = lambda_.mean() + else: + raise ValueError( + f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'." + ) + + B = scale.sqrt() * V @ sigma_inv @ V.T return B diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 0be5f828..3f87fd8c 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -23,9 +23,9 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor): def test_representations(): A = AlignedMTL(pref_vector=None) - assert repr(A) == "AlignedMTL(pref_vector=None)" + assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" assert str(A) == "AlignedMTL" A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) - assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))" + assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')" assert str(A) == "AlignedMTL([1., 2., 3.])" From 88cb6c9bda1774b76f225fa8d3757d4c1702f167 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Fri, 23 Jan 2026 11:32:28 -0500 Subject: [PATCH 2/4] test: Cover AlignedMTL scale_mode variants Add coverage for median/rmse modes and invalid scale_mode handling. --- tests/unit/aggregation/test_aligned_mtl.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 3f87fd8c..f22bc110 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -1,14 +1,21 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor from torchjd.aggregation import AlignedMTL +from utils.tensors import ones_ + from ._asserts import assert_expected_structure, assert_permutation_invariant from ._inputs import scaled_matrices, typical_matrices -scaled_pairs = [(AlignedMTL(), matrix) for matrix in scaled_matrices] -typical_pairs = [(AlignedMTL(), matrix) for matrix in typical_matrices] +aggregators = [ + AlignedMTL(), + AlignedMTL(scale_mode="median"), + AlignedMTL(scale_mode="rmse"), +] +scaled_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in scaled_matrices] +typical_pairs = [(aggregator, matrix) for aggregator in aggregators for matrix in typical_matrices] @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) @@ -29,3 +36,10 @@ def test_representations(): A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu")) assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]), scale_mode='min')" assert str(A) == "AlignedMTL([1., 2., 3.])" + + +def test_invalid_scale_mode(): + aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type] + matrix = ones_(3, 4) + with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): + aggregator(matrix) From 355909031fd54cbf7bd8660a5423b147706e2348 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Fri, 23 Jan 2026 11:33:27 -0500 Subject: [PATCH 3/4] docs: Add changelog entry for AlignedMTL scale_mode Document the new scale_mode parameter for AlignedMTL and AlignedMTLWeighting. --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d520e901..aa14a0a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose + between `"min"`, `"median"`, and `"rmse"` scaling. + ### Changed - **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the From 65674b53fd93ae69d684216d55053c55976efcfc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:33:48 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unit/aggregation/test_aligned_mtl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index f22bc110..3580f80f 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -1,11 +1,10 @@ import torch from pytest import mark, raises from torch import Tensor +from utils.tensors import ones_ from torchjd.aggregation import AlignedMTL -from utils.tensors import ones_ - from ._asserts import assert_expected_structure, assert_permutation_invariant from ._inputs import scaled_matrices, typical_matrices