From b9b15ec159a5fda7c0bbf95c3f156c39a841759e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:11:13 +0100 Subject: [PATCH 01/12] test: Add possibility to test on MPS --- tests/device.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/device.py b/tests/device.py index 7be2c75c..da4f20d9 100644 --- a/tests/device.py +++ b/tests/device.py @@ -2,15 +2,34 @@ import torch +_POSSIBLE_TEST_DEVICES = {"cpu", "cuda:0", "mps"} + try: _device_str = os.environ["PYTEST_TORCH_DEVICE"] except KeyError: _device_str = "cpu" # Default to cpu if environment variable not set -if _device_str != "cuda:0" and _device_str != "cpu": - raise ValueError(f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}") +if _device_str not in _POSSIBLE_TEST_DEVICES: + raise ValueError( + f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n" + f"Possible devices: {_POSSIBLE_TEST_DEVICES}" + ) if _device_str == "cuda:0" and not torch.cuda.is_available(): raise ValueError('Requested device "cuda:0" but cuda is not available.') +if _device_str == "mps": + # Check that MPS is available (following https://docs.pytorch.org/docs/stable/notes/mps.html) + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + raise ValueError( + "MPS not available because the current PyTorch install was not built with MPS " + "enabled." + ) + else: + raise ValueError( + "MPS not available because the current MacOS version is not 12.3+ and/or you do not" + " have an MPS-enabled device on this machine." + ) + DEVICE = torch.device(_device_str) From f4b0e351e02c4a3ec8c1634036e99b8bfb86ce33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:11:28 +0100 Subject: [PATCH 02/12] Update CONTRIBUTING.md to explain how to test on MPS --- CONTRIBUTING.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 267b69b3..f3623650 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -87,6 +87,12 @@ uv run pre-commit install CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit ``` + - If you work on a MacOS device with Metal programming framework (MPS), you can check that the + unit tests pass on it: + ```bash + PYTEST_TORCH_DEVICE=mps uv run pytest tests/unit + ``` + - To check that the usage examples from docstrings and `.rst` files are correct, we test their behavior in `tests/doc`. To run these tests, do: ```bash From c04af90e528ba71906ab9aca08dd196d2f55c458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:19:22 +0100 Subject: [PATCH 03/12] Change tests.yml to also test on MPS when device is MacOS --- .github/workflows/tests.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 570f6dcc..e7faf0e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,6 +19,11 @@ jobs: matrix: python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] os: [ubuntu-latest, macOS-latest, windows-latest] + device: [cpu] + include: + # Add the MPS device when the OS is macOS-latest. + - device: mps + os: macOS-latest steps: - uses: actions/checkout@v4 From e29471470396b39cc4340ab6b8fbfee740e5eeac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:27:44 +0100 Subject: [PATCH 04/12] fix ci --- .github/workflows/tests.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e7faf0e9..ae0a5964 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,11 +19,13 @@ jobs: matrix: python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] os: [ubuntu-latest, macOS-latest, windows-latest] - device: [cpu] - include: - # Add the MPS device when the OS is macOS-latest. - - device: mps - os: macOS-latest + device: [cpu, mps] + exclude: + # Only test on MPS when the OS is macOS-latest. + - os: ubuntu-latest + device: mps + - os: windows-latest + device: mps steps: - uses: actions/checkout@v4 From 2ec4df53c61c75613cc1d8bb8c1b5daae22d3945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:29:03 +0100 Subject: [PATCH 05/12] fix ci --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ae0a5964..2087afdf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,9 @@ jobs: - name: Install default (with full options) and test dependencies run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report - run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml + run: | + PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc + --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: From e7822724452c44cf7a8d721be7f36a0ec2a4a860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:31:51 +0100 Subject: [PATCH 06/12] fix ci --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2087afdf..18dac455 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -37,7 +37,7 @@ jobs: run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report run: | - PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc + PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc \ --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 From 3d485ff71660830b369e581d93106a0220b57a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:38:55 +0100 Subject: [PATCH 07/12] fix ci --- .github/workflows/tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 18dac455..53e58c54 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -37,8 +37,7 @@ jobs: run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report run: | - PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc \ - --cov=src --cov-report=xml + PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: From 346d6668ba0d0de986d05fbaefa14b0c67bb0f5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:49:30 +0100 Subject: [PATCH 08/12] fix ci --- .github/workflows/tests.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 53e58c54..a0802db1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,8 +36,9 @@ jobs: - name: Install default (with full options) and test dependencies run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report - run: | - PYTEST_TORCH_DEVICE=${{ matrix.device }} uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml + env: + PYTEST_TORCH_DEVICE: ${{ matrix.torch_device }} + run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: From 59cd35e6e5c64ddc4e03cbce8f90abcf652dd650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 17:50:45 +0100 Subject: [PATCH 09/12] fix ci --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a0802db1..e7057d6b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -37,7 +37,7 @@ jobs: run: uv pip install --python-version=${{ matrix.python-version }} -e '.[full]' --group test - name: Run unit and doc tests with coverage report env: - PYTEST_TORCH_DEVICE: ${{ matrix.torch_device }} + PYTEST_TORCH_DEVICE: ${{ matrix.device }} run: uv run pytest -W error tests/unit tests/doc --cov=src --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 From 87c0eb1468f9eed6da93c207ccda7e29f5bbf8c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 18:00:00 +0100 Subject: [PATCH 10/12] Add try/except around linalg.qr for MPS --- tests/unit/aggregation/_matrix_samplers.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 137ca2e6..40db71a8 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -167,5 +167,14 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = # project A onto the orthogonal complement of Q A_proj = A - Q @ (Q.T @ A) - Q_prime, _ = torch.linalg.qr(A_proj) + try: + Q_prime, _ = torch.linalg.qr(A_proj) + except NotImplementedError: + # This will happen on MPS until they add support for aten::linalg_qr.out + # See status in https://github.com/pytorch/pytorch/issues/141287 + # In this case, perform the qr on CPU and move back to the original device + original_device = A_proj.device + Q_prime, _ = torch.linalg.qr(A_proj.to(device="cpu")) + Q_prime = Q_prime.to(device=original_device) + return Q_prime From 9abcb8534e08bc258303d71ec7ab507dde10699d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 18:32:33 +0100 Subject: [PATCH 11/12] Add try/except around linalg.eigh for MPS --- src/torchjd/aggregation/_aligned_mtl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 3f0b7119..7fe291dc 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -82,7 +82,17 @@ def forward(self, gramian: Tensor) -> Tensor: @staticmethod def _compute_balance_transformation(M: Tensor) -> Tensor: - lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig + try: + lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig + except NotImplementedError: + # This will happen on MPS until they add support for aten::_linalg_eigh.eigenvalues + # See status in https://github.com/pytorch/pytorch/issues/141287 + # In this case, perform the qr on CPU and move back to the original device + original_device = M.device + lambda_, V = torch.linalg.eigh(M.cpu(), UPLO="U") + lambda_ = lambda_.to(device=original_device) + V = V.to(device=original_device) + tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) From 32aab4faca758a482e97a37a9dcc36d9082a7d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 16 Dec 2025 18:40:37 +0100 Subject: [PATCH 12/12] Fix fork_rng for non-cuda and non-cpu device --- tests/utils/contexts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/contexts.py b/tests/utils/contexts.py index ef4c0ecf..8971ff58 100644 --- a/tests/utils/contexts.py +++ b/tests/utils/contexts.py @@ -10,7 +10,7 @@ @contextmanager def fork_rng(seed: int = 0) -> Generator[Any, None, None]: - devices = [DEVICE] if DEVICE.type == "cuda" else [] + devices = [] if DEVICE.type == "cpu" else [DEVICE] with torch.random.fork_rng(devices=devices, device_type=DEVICE.type) as ctx: torch.manual_seed(seed) yield ctx