From fe227418fc6e58be3c8e148e248359fe2e8bd1e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Nov 2025 18:11:02 +0100 Subject: [PATCH] test: Turn into a package --- tests/__init__.py | 0 tests/conftest.py | 4 ++-- tests/plots/interactive_plotter.py | 2 +- tests/speed/autogram/grad_vs_jac_vs_gram.py | 10 +++++----- tests/unit/aggregation/_asserts.py | 2 +- tests/unit/aggregation/_inputs.py | 5 +++-- tests/unit/aggregation/_matrix_samplers.py | 3 ++- tests/unit/aggregation/_utils/test_dual_cone.py | 2 +- tests/unit/aggregation/_utils/test_pref_vector.py | 4 ++-- tests/unit/aggregation/test_aggregator_bases.py | 4 ++-- tests/unit/aggregation/test_cagrad.py | 4 ++-- tests/unit/aggregation/test_config.py | 2 +- tests/unit/aggregation/test_constant.py | 4 ++-- tests/unit/aggregation/test_dualproj.py | 2 +- tests/unit/aggregation/test_graddrop.py | 4 ++-- tests/unit/aggregation/test_imtl_g.py | 2 +- tests/unit/aggregation/test_krum.py | 4 ++-- tests/unit/aggregation/test_mgda.py | 2 +- tests/unit/aggregation/test_nash_mtl.py | 2 +- tests/unit/aggregation/test_pcgrad.py | 2 +- tests/unit/aggregation/test_trimmed_mean.py | 4 ++-- tests/unit/aggregation/test_upgrad.py | 2 +- tests/unit/autogram/test_edge_registry.py | 2 +- tests/unit/autogram/test_engine.py | 10 +++++----- tests/unit/autogram/test_gramian_utils.py | 4 ++-- tests/unit/autojac/_transform/test_accumulate.py | 4 ++-- tests/unit/autojac/_transform/test_aggregate.py | 6 +++--- tests/unit/autojac/_transform/test_base.py | 2 +- tests/unit/autojac/_transform/test_diagonalize.py | 4 ++-- tests/unit/autojac/_transform/test_grad.py | 4 ++-- tests/unit/autojac/_transform/test_init.py | 4 ++-- tests/unit/autojac/_transform/test_interactions.py | 4 ++-- tests/unit/autojac/_transform/test_jac.py | 4 ++-- tests/unit/autojac/_transform/test_select.py | 4 ++-- tests/unit/autojac/_transform/test_stack.py | 4 ++-- tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- tests/unit/autojac/test_utils.py | 4 ++-- tests/utils/architectures.py | 5 +++-- tests/utils/contexts.py | 3 ++- tests/utils/forward_backwards.py | 4 ++-- tests/utils/tensors.py | 7 ++++--- 42 files changed, 79 insertions(+), 74 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 7afad8c4..b4a3996d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,11 +2,11 @@ from contextlib import nullcontext import torch -from device import DEVICE from pytest import RaisesExc, fixture, mark from torch import Tensor -from utils.architectures import ModuleFactory +from tests.device import DEVICE +from tests.utils.architectures import ModuleFactory from torchjd.aggregation import Aggregator, Weighting diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 5a40e0b6..8fb8b466 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -7,8 +7,8 @@ import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure -from plots._utils import Plotter, angle_to_coord, coord_to_angle +from tests.plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( IMTLG, MGDA, diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 7a6556f2..568df7fb 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -2,8 +2,9 @@ import time import torch -from device import DEVICE -from utils.architectures import ( + +from tests.device import DEVICE +from tests.utils.architectures import ( AlexNet, Cifar10Model, FreeParam, @@ -15,15 +16,14 @@ SqueezeNet, WithTransformerLarge, ) -from utils.forward_backwards import ( +from tests.utils.forward_backwards import ( autograd_forward_backward, autograd_gramian_forward_backward, autogram_forward_backward, autojac_forward_backward, make_mse_loss_fn, ) -from utils.tensors import make_inputs_and_targets - +from tests.utils.tensors import make_inputs_and_targets from torchjd.aggregation import Mean from torchjd.autogram import Engine diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 332bd5b8..df6516fd 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -2,8 +2,8 @@ from pytest import raises from torch import Tensor from torch.testing import assert_close -from utils.tensors import rand_, randperm_ +from tests.utils.tensors import rand_, randperm_ from torchjd.aggregation import Aggregator from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError diff --git a/tests/unit/aggregation/_inputs.py b/tests/unit/aggregation/_inputs.py index 8985665a..c792e586 100644 --- a/tests/unit/aggregation/_inputs.py +++ b/tests/unit/aggregation/_inputs.py @@ -1,6 +1,7 @@ import torch -from device import DEVICE -from utils.tensors import zeros_ + +from tests.device import DEVICE +from tests.utils.tensors import zeros_ from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 137ca2e6..e9104ad2 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -3,7 +3,8 @@ import torch from torch import Tensor from torch.nn.functional import normalize -from utils.tensors import randint_, randn_, randperm_, zeros_ + +from tests.utils.tensors import randint_, randn_, randperm_, zeros_ class MatrixSampler(ABC): diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/aggregation/_utils/test_dual_cone.py index 3923d3f0..de9acd7f 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/aggregation/_utils/test_dual_cone.py @@ -2,8 +2,8 @@ import torch from pytest import mark, raises from torch.testing import assert_close -from utils.tensors import rand_, randn_ +from tests.utils.tensors import rand_, randn_ from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights diff --git a/tests/unit/aggregation/_utils/test_pref_vector.py b/tests/unit/aggregation/_utils/test_pref_vector.py index 159582dd..b1efc43f 100644 --- a/tests/unit/aggregation/_utils/test_pref_vector.py +++ b/tests/unit/aggregation/_utils/test_pref_vector.py @@ -2,9 +2,9 @@ from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_ from torchjd.aggregation._mean import MeanWeighting from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting diff --git a/tests/unit/aggregation/test_aggregator_bases.py b/tests/unit/aggregation/test_aggregator_bases.py index b08c37a8..291c47b5 100644 --- a/tests/unit/aggregation/test_aggregator_bases.py +++ b/tests/unit/aggregation/test_aggregator_bases.py @@ -2,9 +2,9 @@ from contextlib import nullcontext as does_not_raise from pytest import mark, raises -from utils.contexts import ExceptionContext -from utils.tensors import randn_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import randn_ from torchjd.aggregation import Aggregator diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index e604c087..73b3e41d 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -2,9 +2,9 @@ from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_ from torchjd.aggregation import CAGrad from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 69cc4af1..58b3e692 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -1,8 +1,8 @@ import torch from pytest import mark from torch import Tensor -from utils.tensors import ones_ +from tests.utils.tensors import ones_ from torchjd.aggregation import ConFIG from ._asserts import ( diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index dd860e23..f246582d 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -3,9 +3,9 @@ import torch from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_, tensor_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_, tensor_ from torchjd.aggregation import Constant from ._asserts import ( diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index ba34dd9e..4a6b53f0 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,8 +1,8 @@ import torch from pytest import mark from torch import Tensor -from utils.tensors import ones_ +from tests.utils.tensors import ones_ from torchjd.aggregation import DualProj from ._asserts import ( diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 59e6e1ae..44f4279b 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -4,9 +4,9 @@ import torch from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_ from torchjd.aggregation import GradDrop from ._asserts import assert_expected_structure, assert_non_differentiable diff --git a/tests/unit/aggregation/test_imtl_g.py b/tests/unit/aggregation/test_imtl_g.py index e9ba838c..383bd3d3 100644 --- a/tests/unit/aggregation/test_imtl_g.py +++ b/tests/unit/aggregation/test_imtl_g.py @@ -1,8 +1,8 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, zeros_ +from tests.utils.tensors import ones_, zeros_ from torchjd.aggregation import IMTLG from ._asserts import ( diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index ff75e5cf..9d615db6 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -2,9 +2,9 @@ from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_ from torchjd.aggregation import Krum from ._asserts import assert_expected_structure diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 41b07d93..6a1a096a 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -1,8 +1,8 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from tests.utils.tensors import ones_, randn_ from torchjd.aggregation import MGDA from torchjd.aggregation._mgda import MGDAWeighting from torchjd.aggregation._utils.gramian import compute_gramian diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index ce343e5d..68d027a9 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,8 +1,8 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from tests.utils.tensors import ones_, randn_ from torchjd.aggregation import NashMTL from ._asserts import assert_expected_structure, assert_non_differentiable diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index e96253ec..550782a6 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -1,8 +1,8 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from tests.utils.tensors import ones_, randn_ from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index cdeb9398..c8bc41b1 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -2,9 +2,9 @@ from pytest import mark, raises from torch import Tensor -from utils.contexts import ExceptionContext -from utils.tensors import ones_ +from tests.utils.contexts import ExceptionContext +from tests.utils.tensors import ones_ from torchjd.aggregation import TrimmedMean from ._asserts import assert_expected_structure, assert_permutation_invariant diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 4660da4b..261e9dd5 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,8 +1,8 @@ import torch from pytest import mark from torch import Tensor -from utils.tensors import ones_ +from tests.utils.tensors import ones_ from torchjd.aggregation import UPGrad from ._asserts import ( diff --git a/tests/unit/autogram/test_edge_registry.py b/tests/unit/autogram/test_edge_registry.py index 88d6da8c..92c9a220 100644 --- a/tests/unit/autogram/test_edge_registry.py +++ b/tests/unit/autogram/test_edge_registry.py @@ -1,6 +1,6 @@ from torch.autograd.graph import get_gradient_edge -from utils.tensors import randn_ +from tests.utils.tensors import randn_ from torchjd.autogram._edge_registry import EdgeRegistry diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 50135796..74806836 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -10,7 +10,8 @@ from torch.optim import SGD from torch.testing import assert_close from torch.utils._pytree import PyTree -from utils.architectures import ( + +from tests.utils.architectures import ( AlexNet, Cifar10Model, FreeParam, @@ -63,8 +64,8 @@ WithTransformer, WithTransformerLarge, ) -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.forward_backwards import ( +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.forward_backwards import ( CloneParams, autograd_forward_backward, autogram_forward_backward, @@ -77,8 +78,7 @@ reduce_to_scalar, reduce_to_vector, ) -from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_ - +from tests.utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_ from torchjd.aggregation import UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index fde35252..0360d565 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -1,8 +1,8 @@ from pytest import mark from torch.testing import assert_close -from utils.forward_backwards import compute_gramian -from utils.tensors import randn_ +from tests.utils.forward_backwards import compute_gramian +from tests.utils.tensors import randn_ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 6c8bbbfe..b2963bb7 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -1,7 +1,7 @@ from pytest import mark, raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import ones_, tensor_, zeros_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import ones_, tensor_, zeros_ from torchjd.autojac._transform import Accumulate diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py index 3b90297e..ffecbbfe 100644 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ b/tests/unit/autojac/_transform/test_aggregate.py @@ -1,11 +1,11 @@ import math import torch -from device import DEVICE from pytest import mark, raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import rand_, tensor_, zeros_ +from tests.device import DEVICE +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import rand_, tensor_, zeros_ from torchjd.aggregation import Random from torchjd.autojac._transform import OrderedSet, RequirementError from torchjd.autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 435b97a7..fa7de37a 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -1,7 +1,7 @@ from pytest import raises from torch import Tensor -from utils.tensors import empty_, randn_ +from tests.utils.tensors import empty_, randn_ from torchjd.autojac._transform._base import Conjunction, RequirementError, TensorDict, Transform diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index c1b30d31..03cf7348 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -1,8 +1,8 @@ import torch from pytest import raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import tensor_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import tensor_ from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index f834f73a..45939968 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -1,8 +1,8 @@ import torch from pytest import raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import tensor_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import tensor_ from torchjd.autojac._transform import Grad, OrderedSet, RequirementError diff --git a/tests/unit/autojac/_transform/test_init.py b/tests/unit/autojac/_transform/test_init.py index 38e4a29e..f84e9b2c 100644 --- a/tests/unit/autojac/_transform/test_init.py +++ b/tests/unit/autojac/_transform/test_init.py @@ -1,7 +1,7 @@ from pytest import raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import tensor_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import tensor_ from torchjd.autojac._transform import Init, RequirementError diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 8a943e83..51ae9f44 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -1,9 +1,9 @@ import torch from pytest import raises from torch.testing import assert_close -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import tensor_, zeros_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import tensor_, zeros_ from torchjd.autojac._transform import ( Accumulate, Conjunction, diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index c00e43d2..4be590ab 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -1,8 +1,8 @@ import torch from pytest import mark, raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import eye_, ones_, tensor_, zeros_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import eye_, ones_, tensor_, zeros_ from torchjd.autojac._transform import Jac, OrderedSet, RequirementError diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 041eefc9..8a969372 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -1,8 +1,8 @@ import torch from pytest import raises -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import tensor_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import tensor_ from torchjd.autojac._transform import RequirementError, Select diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index 0151fade..73264ba1 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -2,9 +2,9 @@ import torch from torch import Tensor -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import ones_, tensor_, zeros_ +from tests.utils.dict_assertions import assert_tensor_dicts_are_close +from tests.utils.tensors import ones_, tensor_, zeros_ from torchjd.autojac._transform import Stack, Transform from torchjd.autojac._transform._base import TensorDict diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 885a9c15..fd92cc03 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -2,8 +2,8 @@ from pytest import mark, raises from torch.autograd import grad from torch.testing import assert_close -from utils.tensors import randn_, tensor_ +from tests.utils.tensors import randn_, tensor_ from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import backward from torchjd.autojac._backward import _create_transform diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index c50c8430..5c1bcf85 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -2,8 +2,8 @@ from pytest import mark, raises from torch.autograd import grad from torch.testing import assert_close -from utils.tensors import rand_, randn_, tensor_ +from tests.utils.tensors import rand_, randn_, tensor_ from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import mtl_backward from torchjd.autojac._mtl_backward import _create_transform diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index ef1b539c..da02f5bb 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -1,8 +1,8 @@ -from device import DEVICE from pytest import mark, raises from torch.nn import Linear, MSELoss, ReLU, Sequential -from utils.tensors import randn_, tensor_ +from tests.device import DEVICE +from tests.utils.tensors import randn_, tensor_ from torchjd.autojac._utils import get_leaf_tensors diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 109b8e17..c92a1c66 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -2,11 +2,12 @@ import torch import torchvision -from device import DEVICE from torch import Tensor, nn from torch.nn import Flatten, ReLU from torch.utils._pytree import PyTree -from utils.contexts import fork_rng + +from tests.device import DEVICE +from tests.utils.contexts import fork_rng class ModuleFactory: diff --git a/tests/utils/contexts.py b/tests/utils/contexts.py index ef4c0ecf..f207d5fb 100644 --- a/tests/utils/contexts.py +++ b/tests/utils/contexts.py @@ -3,7 +3,8 @@ from typing import Any, TypeAlias import torch -from device import DEVICE + +from tests.device import DEVICE ExceptionContext: TypeAlias = AbstractContextManager[Exception | None] diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 7baabdb7..a4417105 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -5,9 +5,9 @@ from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map from torch.utils.hooks import RemovableHandle -from utils.architectures import get_in_out_shapes -from utils.contexts import fork_rng +from tests.utils.architectures import get_in_out_shapes +from tests.utils.contexts import fork_rng from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine from torchjd.autojac import backward diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 15251076..5e49df86 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -1,11 +1,12 @@ from functools import partial import torch -from device import DEVICE from torch import nn from torch.utils._pytree import PyTree, tree_map -from utils.architectures import get_in_out_shapes -from utils.contexts import fork_rng + +from tests.device import DEVICE +from tests.utils.architectures import get_in_out_shapes +from tests.utils.contexts import fork_rng # Curried calls to torch functions that require a device so that we automatically fix the device # for code written in the tests, while not affecting code written in src (what