Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from contextlib import nullcontext

import torch
from device import DEVICE
from pytest import RaisesExc, fixture, mark
from settings import DEVICE
from torch import Tensor
from utils.architectures import ModuleFactory

Expand Down
16 changes: 16 additions & 0 deletions tests/device.py → tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@
raise ValueError('Requested device "cuda:0" but cuda is not available.')

DEVICE = torch.device(_device_str)


_POSSIBLE_TEST_DTYPES = {"float32", "float64"}

try:
_dtype_str = os.environ["PYTEST_TORCH_DTYPE"]
except KeyError:
_dtype_str = "float32" # Default to float32 if environment variable not set

if _dtype_str not in _POSSIBLE_TEST_DTYPES:
raise ValueError(
f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n"
f"Possible values: {_POSSIBLE_TEST_DTYPES}."
)

DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32
2 changes: 1 addition & 1 deletion tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc

import torch
from device import DEVICE
from settings import DEVICE
from utils.architectures import (
AlexNet,
Cifar10Model,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from device import DEVICE
from settings import DEVICE
from utils.tensors import zeros_

from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/aggregation/_matrix_samplers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

import torch
from settings import DTYPE
from torch import Tensor
from torch.nn.functional import normalize
from utils.tensors import randint_, randn_, randperm_, zeros_
Expand All @@ -9,7 +10,7 @@
class MatrixSampler(ABC):
"""Abstract base class for sampling matrices of a given shape, rank and dtype."""

def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = torch.float32):
def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE):
self._check_params(m, n, rank, dtype)
self.m = m
self.n = n
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_expected_structure(aggregator: DualProj, matrix: Tensor):

@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_non_conflicting(aggregator: DualProj, matrix: Tensor):
assert_non_conflicting(aggregator, matrix, atol=5e-05, rtol=5e-05)
assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_expected_structure(aggregator: UPGrad, matrix: Tensor):

@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_non_conflicting(aggregator: UPGrad, matrix: Tensor):
assert_non_conflicting(aggregator, matrix, atol=3e-04, rtol=3e-04)
assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
Expand All @@ -38,7 +38,7 @@ def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor):

@mark.parametrize(["aggregator", "matrix"], typical_pairs)
def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor):
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=3e-02, rtol=3e-02)
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=4e-02, rtol=4e-02)


@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/_transform/test_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math

import torch
from device import DEVICE
from pytest import mark, raises
from settings import DEVICE
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import rand_, tensor_, zeros_

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from pytest import mark, raises
from settings import DTYPE
from torch.autograd import grad
from torch.testing import assert_close
from utils.tensors import arange_, rand_, randn_, tensor_
Expand Down Expand Up @@ -345,7 +346,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]):
"""Tests that mtl_backward works correctly with various kinds of feature lists."""

p0 = tensor_([1.0, 2.0], requires_grad=True)
p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True)
p1 = arange_(len(shapes), dtype=DTYPE, requires_grad=True)
p2 = tensor_(5.0, requires_grad=True)

features = [rand_(shape) @ p0 for shape in shapes]
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/autojac/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from device import DEVICE
from pytest import mark, raises
from settings import DEVICE, DTYPE
from torch.nn import Linear, MSELoss, ReLU, Sequential
from utils.tensors import randn_, tensor_

Expand Down Expand Up @@ -85,7 +85,7 @@ def test_get_leaf_tensors_model():
x = randn_(16, 10)
y = randn_(16, 1)

model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE)
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)).to(device=DEVICE, dtype=DTYPE)
loss_fn = MSELoss(reduction="none")

y_hat = model(x)
Expand All @@ -104,8 +104,8 @@ def test_get_leaf_tensors_model_excluded_2():
x = randn_(16, 10)
z = randn_(16, 1)

model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE)
model2 = Linear(5, 1).to(device=DEVICE)
model1 = Sequential(Linear(10, 5), ReLU()).to(device=DEVICE, dtype=DTYPE)
model2 = Linear(5, 1).to(device=DEVICE, dtype=DTYPE)
loss_fn = MSELoss(reduction="none")

y = model1(x)
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torchvision
from device import DEVICE
from settings import DEVICE, DTYPE
from torch import Tensor, nn
from torch.nn import Flatten, ReLU
from torch.utils._pytree import PyTree
Expand All @@ -17,7 +17,7 @@ def __init__(self, architecture: type[nn.Module], *args, **kwargs):

def __call__(self) -> nn.Module:
with fork_rng(seed=0):
return self.architecture(*self.args, **self.kwargs).to(device=DEVICE)
return self.architecture(*self.args, **self.kwargs).to(device=DEVICE, dtype=DTYPE)

def __str__(self) -> str:
args_string = ", ".join([str(arg) for arg in self.args])
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, TypeAlias

import torch
from device import DEVICE
from settings import DEVICE

ExceptionContext: TypeAlias = AbstractContextManager[Exception | None]

Expand Down
19 changes: 11 additions & 8 deletions tests/utils/tensors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

import torch
from device import DEVICE
from settings import DEVICE, DTYPE
from torch import nn
from torch.utils._pytree import PyTree, tree_map
from utils.architectures import get_in_out_shapes
Expand All @@ -11,16 +11,19 @@
# for code written in the tests, while not affecting code written in src (what
# torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done).

# Default device is most likely int.
arange_ = partial(torch.arange, device=DEVICE)
empty_ = partial(torch.empty, device=DEVICE)
eye_ = partial(torch.eye, device=DEVICE)
ones_ = partial(torch.ones, device=DEVICE)
rand_ = partial(torch.rand, device=DEVICE)
randint_ = partial(torch.randint, device=DEVICE)
randn_ = partial(torch.randn, device=DEVICE)
randperm_ = partial(torch.randperm, device=DEVICE)
tensor_ = partial(torch.tensor, device=DEVICE)
zeros_ = partial(torch.zeros, device=DEVICE)

# Default device is most likely float. Set it to the right kind of float.
empty_ = partial(torch.empty, device=DEVICE, dtype=DTYPE)
eye_ = partial(torch.eye, device=DEVICE, dtype=DTYPE)
ones_ = partial(torch.ones, device=DEVICE, dtype=DTYPE)
rand_ = partial(torch.rand, device=DEVICE, dtype=DTYPE)
randn_ = partial(torch.randn, device=DEVICE, dtype=DTYPE)
tensor_ = partial(torch.tensor, device=DEVICE, dtype=DTYPE)
zeros_ = partial(torch.zeros, device=DEVICE, dtype=DTYPE)


def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, PyTree]:
Expand Down
Loading