-
Notifications
You must be signed in to change notification settings - Fork 15
Add scale_mode options to AlignedMTL #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c02ebfb
88cb6c9
3559090
65674b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| # SOFTWARE. | ||
|
|
||
|
|
||
| from typing import Literal | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
@@ -44,18 +46,29 @@ class AlignedMTL(GramianWeightedAggregator): | |
|
|
||
| :param pref_vector: The preference vector to use. If not provided, defaults to | ||
| :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. | ||
| :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses | ||
| the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` | ||
| uses the mean eigenvalue (as in the original implementation). | ||
|
|
||
| .. note:: | ||
| This implementation was adapted from the `official implementation | ||
| <https://github.com/SamsungLabs/MTL/tree/master/code/optim/aligned>`_. | ||
| """ | ||
|
|
||
| def __init__(self, pref_vector: Tensor | None = None): | ||
| def __init__( | ||
| self, | ||
| pref_vector: Tensor | None = None, | ||
| scale_mode: Literal["min", "median", "rmse"] = "min", | ||
| ): | ||
| self._pref_vector = pref_vector | ||
| super().__init__(AlignedMTLWeighting(pref_vector)) | ||
| self._scale_mode = scale_mode | ||
| super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" | ||
| return ( | ||
| f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " | ||
| f"scale_mode={repr(self._scale_mode)})" | ||
| ) | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"AlignedMTL{pref_vector_to_str_suffix(self._pref_vector)}" | ||
|
|
@@ -68,22 +81,32 @@ class AlignedMTLWeighting(Weighting[PSDMatrix]): | |
|
|
||
| :param pref_vector: The preference vector to use. If not provided, defaults to | ||
| :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. | ||
| :param scale_mode: The scaling mode used to build the balance transformation. ``"min"`` uses | ||
| the smallest eigenvalue (default), ``"median"`` uses the median eigenvalue, and ``"rmse"`` | ||
| uses the mean eigenvalue (as in the original implementation). | ||
| """ | ||
|
|
||
| def __init__(self, pref_vector: Tensor | None = None): | ||
| def __init__( | ||
| self, | ||
| pref_vector: Tensor | None = None, | ||
| scale_mode: Literal["min", "median", "rmse"] = "min", | ||
| ): | ||
| super().__init__() | ||
| self._pref_vector = pref_vector | ||
| self._scale_mode = scale_mode | ||
| self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) | ||
|
|
||
| def forward(self, gramian: PSDMatrix) -> Tensor: | ||
| w = self.weighting(gramian) | ||
| B = self._compute_balance_transformation(gramian) | ||
| B = self._compute_balance_transformation(gramian, self._scale_mode) | ||
| alpha = B @ w | ||
|
|
||
| return alpha | ||
|
|
||
| @staticmethod | ||
| def _compute_balance_transformation(M: Tensor) -> Tensor: | ||
| def _compute_balance_transformation( | ||
| M: Tensor, scale_mode: Literal["min", "median", "rmse"] = "min" | ||
| ) -> Tensor: | ||
| lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig | ||
| tol = torch.max(lambda_) * len(M) * torch.finfo().eps | ||
| rank = sum(lambda_ > tol) | ||
|
|
@@ -96,6 +119,17 @@ def _compute_balance_transformation(M: Tensor) -> Tensor: | |
| lambda_, V = lambda_[order][:rank], V[:, order][:, :rank] | ||
|
|
||
| sigma_inv = torch.diag(1 / lambda_.sqrt()) | ||
| lambda_R = lambda_[-1] | ||
| B = lambda_R.sqrt() * V @ sigma_inv @ V.T | ||
|
|
||
| if scale_mode == "min": | ||
| scale = lambda_[-1] | ||
| elif scale_mode == "median": | ||
| scale = torch.median(lambda_) | ||
| elif scale_mode == "rmse": | ||
| scale = lambda_.mean() | ||
|
Comment on lines
+125
to
+128
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add code coverage for this? I think it should be enough to just parametrize |
||
| else: | ||
| raise ValueError( | ||
| f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'." | ||
| ) | ||
|
Comment on lines
+129
to
+132
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also need coverage for this. A simple test should be enough: from pytest import raises
from utils.tensors import ones_
def test_invalid_scale_mode():
aggregator = AlignedMTL(scale_mode="test")
matrix = ones_(3, 4)
with raises(ValueError):
aggregator(matrix) |
||
|
|
||
| B = scale.sqrt() * V @ sigma_inv @ V.T | ||
| return B | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a changelog entry (under "Unreleased") saying that it's now possible to provide a
scale_modeparameter forAlignedMTLandAlignedMTLWeighting?