diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index c50c8430..6c9f22e1 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -2,7 +2,7 @@ from pytest import mark, raises from torch.autograd import grad from torch.testing import assert_close -from utils.tensors import rand_, randn_, tensor_ +from utils.tensors import arange_, rand_, randn_, tensor_ from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import mtl_backward @@ -345,7 +345,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): """Tests that mtl_backward works correctly with various kinds of feature lists.""" p0 = tensor_([1.0, 2.0], requires_grad=True) - p1 = torch.arange(len(shapes), dtype=torch.float32, requires_grad=True) + p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True) p2 = tensor_(5.0, requires_grad=True) features = [rand_(shape) @ p0 for shape in shapes] diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 15251076..553fc22f 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -11,6 +11,7 @@ # for code written in the tests, while not affecting code written in src (what # torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done). +arange_ = partial(torch.arange, device=DEVICE) empty_ = partial(torch.empty, device=DEVICE) eye_ = partial(torch.eye, device=DEVICE) ones_ = partial(torch.ones, device=DEVICE)