From 79c0f45b2fe8e41ffa4a20f43af36a478a530bb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 21 Dec 2025 16:16:42 +0100 Subject: [PATCH 1/2] test: Add dtype setting --- tests/conftest.py | 2 +- tests/{device.py => settings.py} | 16 ++++++++++++++++ tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/aggregation/_inputs.py | 2 +- tests/unit/aggregation/_matrix_samplers.py | 3 ++- .../unit/autojac/_transform/test_aggregate.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 3 ++- tests/unit/autojac/test_utils.py | 8 ++++---- tests/utils/architectures.py | 4 ++-- tests/utils/contexts.py | 2 +- tests/utils/tensors.py | 19 +++++++++++-------- 11 files changed, 42 insertions(+), 21 deletions(-) rename tests/{device.py => settings.py} (51%) diff --git a/tests/conftest.py b/tests/conftest.py index 7afad8c4..a2921520 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,8 @@ from contextlib import nullcontext import torch -from device import DEVICE from pytest import RaisesExc, fixture, mark +from settings import DEVICE from torch import Tensor from utils.architectures import ModuleFactory diff --git a/tests/device.py b/tests/settings.py similarity index 51% rename from tests/device.py rename to tests/settings.py index 7be2c75c..c032e194 100644 --- a/tests/device.py +++ b/tests/settings.py @@ -14,3 +14,19 @@ raise ValueError('Requested device "cuda:0" but cuda is not available.') DEVICE = torch.device(_device_str) + + +_POSSIBLE_TEST_DTYPES = {"float32", "float64"} + +try: + _dtype_str = os.environ["PYTEST_TORCH_DTYPE"] +except KeyError: + _dtype_str = "float32" # Default to float32 if environment variable not set + +if _dtype_str not in _POSSIBLE_TEST_DTYPES: + raise ValueError( + f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n" + f"Possible values: {_POSSIBLE_TEST_DTYPES}." + ) + +DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32 diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 64d33911..e6d9233a 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -1,7 +1,7 @@ import gc import torch -from device import DEVICE +from settings import DEVICE from utils.architectures import ( AlexNet, Cifar10Model, diff --git a/tests/unit/aggregation/_inputs.py b/tests/unit/aggregation/_inputs.py index 8985665a..4a18246a 100644 --- a/tests/unit/aggregation/_inputs.py +++ b/tests/unit/aggregation/_inputs.py @@ -1,5 +1,5 @@ import torch -from device import DEVICE +from settings import DEVICE from utils.tensors import zeros_ from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 137ca2e6..f24e2028 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import torch +from settings import DTYPE from torch import Tensor from torch.nn.functional import normalize from utils.tensors import randint_, randn_, randperm_, zeros_ @@ -9,7 +10,7 @@ class MatrixSampler(ABC): """Abstract base class for sampling matrices of a given shape, rank and dtype.""" - def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = torch.float32): + def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE): self._check_params(m, n, rank, dtype) self.m = m self.n = n diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py index 3b90297e..5beaed20 100644 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ b/tests/unit/autojac/_transform/test_aggregate.py @@ -1,8 +1,8 @@ import math import torch -from device import DEVICE from pytest import mark, raises +from settings import DEVICE from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import rand_, tensor_, zeros_ diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 6c9f22e1..86595f92 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -1,5 +1,6 @@ import torch from pytest import mark, raises +from settings import DTYPE from torch.autograd import grad from torch.testing import assert_close from utils.tensors import arange_, rand_, randn_, tensor_ @@ -345,7 +346,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): """Tests that mtl_backward works correctly with various kinds of feature lists.""" p0 = tensor_([1.0, 2.0], requires_grad=True) - p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True) + p1 = arange_(len(shapes), dtype=DTYPE, requires_grad=True) p2 = tensor_(5.0, requires_grad=True) features = [rand_(shape) @ p0 for shape in shapes] diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index ef1b539c..a7036d7c 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -1,5 +1,5 @@ -from device import DEVICE from pytest import mark, raises +from settings import DEVICE, DTYPE from torch.nn import Linear, MSELoss, ReLU, Sequential from utils.tensors import randn_, tensor_ @@ -85,7 +85,7 @@ def test_get_leaf_tensors_model(): x = randn_(16, 10) y = randn_(16, 1) - model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE) + model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE, dtype=DTYPE) loss_fn = MSELoss(reduction="none") y_hat = model(x) @@ -104,8 +104,8 @@ def test_get_leaf_tensors_model_excluded_2(): x = randn_(16, 10) z = randn_(16, 1) - model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE) - model2 = Linear(5, 1).to(device=DEVICE) + model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE, dtype=DTYPE) + model2 = Linear(5, 1).to(device=DEVICE, dtype=DTYPE) loss_fn = MSELoss(reduction="none") y = model1(x) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 109b8e17..a74c7de8 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -2,7 +2,7 @@ import torch import torchvision -from device import DEVICE +from settings import DEVICE, DTYPE from torch import Tensor, nn from torch.nn import Flatten, ReLU from torch.utils._pytree import PyTree @@ -17,7 +17,7 @@ def __init__(self, architecture: type[nn.Module], *args, **kwargs): def __call__(self) -> nn.Module: with fork_rng(seed=0): - return self.architecture(*self.args, **self.kwargs).to(device=DEVICE) + return self.architecture(*self.args, **self.kwargs).to(device=DEVICE, dtype=DTYPE) def __str__(self) -> str: args_string = ", ".join([str(arg) for arg in self.args]) diff --git a/tests/utils/contexts.py b/tests/utils/contexts.py index ef4c0ecf..dc508130 100644 --- a/tests/utils/contexts.py +++ b/tests/utils/contexts.py @@ -3,7 +3,7 @@ from typing import Any, TypeAlias import torch -from device import DEVICE +from settings import DEVICE ExceptionContext: TypeAlias = AbstractContextManager[Exception | None] diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 553fc22f..6d8066dc 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -1,7 +1,7 @@ from functools import partial import torch -from device import DEVICE +from settings import DEVICE, DTYPE from torch import nn from torch.utils._pytree import PyTree, tree_map from utils.architectures import get_in_out_shapes @@ -11,16 +11,19 @@ # for code written in the tests, while not affecting code written in src (what # torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done). +# Default device is most likely int. arange_ = partial(torch.arange, device=DEVICE) -empty_ = partial(torch.empty, device=DEVICE) -eye_ = partial(torch.eye, device=DEVICE) -ones_ = partial(torch.ones, device=DEVICE) -rand_ = partial(torch.rand, device=DEVICE) randint_ = partial(torch.randint, device=DEVICE) -randn_ = partial(torch.randn, device=DEVICE) randperm_ = partial(torch.randperm, device=DEVICE) -tensor_ = partial(torch.tensor, device=DEVICE) -zeros_ = partial(torch.zeros, device=DEVICE) + +# Default device is most likely float. Set it to the right kind of float. +empty_ = partial(torch.empty, device=DEVICE, dtype=DTYPE) +eye_ = partial(torch.eye, device=DEVICE, dtype=DTYPE) +ones_ = partial(torch.ones, device=DEVICE, dtype=DTYPE) +rand_ = partial(torch.rand, device=DEVICE, dtype=DTYPE) +randn_ = partial(torch.randn, device=DEVICE, dtype=DTYPE) +tensor_ = partial(torch.tensor, device=DEVICE, dtype=DTYPE) +zeros_ = partial(torch.zeros, device=DEVICE, dtype=DTYPE) def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, PyTree]: From d8019b9b0edfdebaadbb6fe1558c6dd656397613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 21 Dec 2025 17:29:51 +0100 Subject: [PATCH 2/2] Fix test tolerances --- tests/unit/aggregation/test_dualproj.py | 2 +- tests/unit/aggregation/test_upgrad.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index ba34dd9e..0f4407d2 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -27,7 +27,7 @@ def test_expected_structure(aggregator: DualProj, matrix: Tensor): @mark.parametrize(["aggregator", "matrix"], typical_pairs) def test_non_conflicting(aggregator: DualProj, matrix: Tensor): - assert_non_conflicting(aggregator, matrix, atol=5e-05, rtol=5e-05) + assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 4660da4b..b13aa971 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -28,7 +28,7 @@ def test_expected_structure(aggregator: UPGrad, matrix: Tensor): @mark.parametrize(["aggregator", "matrix"], typical_pairs) def test_non_conflicting(aggregator: UPGrad, matrix: Tensor): - assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04) + assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) @@ -38,7 +38,7 @@ def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor): @mark.parametrize(["aggregator", "matrix"], typical_pairs) def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor): - assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02) + assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=4e-02, rtol=4e-02) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs)