Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import mark, raises
from torch.autograd import grad
from torch.testing import assert_close
from utils.tensors import rand_, randn_, tensor_
from utils.tensors import arange_, rand_, randn_, tensor_

from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
from torchjd.autojac import mtl_backward
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]):
"""Tests that mtl_backward works correctly with various kinds of feature lists."""

p0 = tensor_([1.0, 2.0], requires_grad=True)
p1 = torch.arange(len(shapes), dtype=torch.float32, requires_grad=True)
p1 = arange_(len(shapes), dtype=torch.float32, requires_grad=True)
p2 = tensor_(5.0, requires_grad=True)

features = [rand_(shape) @ p0 for shape in shapes]
Expand Down
1 change: 1 addition & 0 deletions tests/utils/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# for code written in the tests, while not affecting code written in src (what
# torch.set_default_device or what a too large `with torch.device(DEVICE)` context would have done).

arange_ = partial(torch.arange, device=DEVICE)
empty_ = partial(torch.empty, device=DEVICE)
eye_ = partial(torch.eye, device=DEVICE)
ones_ = partial(torch.ones, device=DEVICE)
Expand Down
Loading