Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from contextlib import nullcontext

import torch
from device import DEVICE
from pytest import RaisesExc, fixture, mark
from torch import Tensor
from utils.architectures import ModuleFactory

from tests.device import DEVICE
from tests.utils.architectures import ModuleFactory
from torchjd.aggregation import Aggregator, Weighting


Expand Down
2 changes: 1 addition & 1 deletion tests/plots/interactive_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from dash import Dash, Input, Output, callback, dcc, html
from plotly.graph_objs import Figure
from plots._utils import Plotter, angle_to_coord, coord_to_angle

from tests.plots._utils import Plotter, angle_to_coord, coord_to_angle
from torchjd.aggregation import (
IMTLG,
MGDA,
Expand Down
10 changes: 5 additions & 5 deletions tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import time

import torch
from device import DEVICE
from utils.architectures import (

from tests.device import DEVICE
from tests.utils.architectures import (
AlexNet,
Cifar10Model,
FreeParam,
Expand All @@ -15,15 +16,14 @@
SqueezeNet,
WithTransformerLarge,
)
from utils.forward_backwards import (
from tests.utils.forward_backwards import (
autograd_forward_backward,
autograd_gramian_forward_backward,
autogram_forward_backward,
autojac_forward_backward,
make_mse_loss_fn,
)
from utils.tensors import make_inputs_and_targets

from tests.utils.tensors import make_inputs_and_targets
from torchjd.aggregation import Mean
from torchjd.autogram import Engine

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pytest import raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import rand_, randperm_

from tests.utils.tensors import rand_, randperm_
from torchjd.aggregation import Aggregator
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from device import DEVICE
from utils.tensors import zeros_

from tests.device import DEVICE
from tests.utils.tensors import zeros_

from ._matrix_samplers import NonWeakSampler, NormalSampler, StrictlyWeakSampler, StrongSampler

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/aggregation/_matrix_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from torch import Tensor
from torch.nn.functional import normalize
from utils.tensors import randint_, randn_, randperm_, zeros_

from tests.utils.tensors import randint_, randn_, randperm_, zeros_


class MatrixSampler(ABC):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/_utils/test_dual_cone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
from pytest import mark, raises
from torch.testing import assert_close
from utils.tensors import rand_, randn_

from tests.utils.tensors import rand_, randn_
from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/_utils/test_pref_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation._mean import MeanWeighting
from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from contextlib import nullcontext as does_not_raise

from pytest import mark, raises
from utils.contexts import ExceptionContext
from utils.tensors import randn_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import randn_
from torchjd.aggregation import Aggregator


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation import CAGrad

from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import mark
from torch import Tensor
from utils.tensors import ones_

from tests.utils.tensors import ones_
from torchjd.aggregation import ConFIG

from ._asserts import (
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_, tensor_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_, tensor_
from torchjd.aggregation import Constant

from ._asserts import (
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import mark
from torch import Tensor
from utils.tensors import ones_

from tests.utils.tensors import ones_
from torchjd.aggregation import DualProj

from ._asserts import (
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation import GradDrop

from ._asserts import assert_expected_structure, assert_non_differentiable
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import ones_, zeros_

from tests.utils.tensors import ones_, zeros_
from torchjd.aggregation import IMTLG

from ._asserts import (
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation import Krum

from ._asserts import assert_expected_structure
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_mgda.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import ones_, randn_

from tests.utils.tensors import ones_, randn_
from torchjd.aggregation import MGDA
from torchjd.aggregation._mgda import MGDAWeighting
from torchjd.aggregation._utils.gramian import compute_gramian
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import ones_, randn_

from tests.utils.tensors import ones_, randn_
from torchjd.aggregation import NashMTL

from ._asserts import assert_expected_structure, assert_non_differentiable
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import ones_, randn_

from tests.utils.tensors import ones_, randn_
from torchjd.aggregation import PCGrad
from torchjd.aggregation._pcgrad import PCGradWeighting
from torchjd.aggregation._upgrad import UPGradWeighting
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/aggregation/test_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_

from tests.utils.contexts import ExceptionContext
from tests.utils.tensors import ones_
from torchjd.aggregation import TrimmedMean

from ._asserts import assert_expected_structure, assert_permutation_invariant
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import mark
from torch import Tensor
from utils.tensors import ones_

from tests.utils.tensors import ones_
from torchjd.aggregation import UPGrad

from ._asserts import (
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autogram/test_edge_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch.autograd.graph import get_gradient_edge
from utils.tensors import randn_

from tests.utils.tensors import randn_
from torchjd.autogram._edge_registry import EdgeRegistry


Expand Down
10 changes: 5 additions & 5 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torch.optim import SGD
from torch.testing import assert_close
from torch.utils._pytree import PyTree
from utils.architectures import (

from tests.utils.architectures import (
AlexNet,
Cifar10Model,
FreeParam,
Expand Down Expand Up @@ -63,8 +64,8 @@
WithTransformer,
WithTransformerLarge,
)
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.forward_backwards import (
from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.forward_backwards import (
CloneParams,
autograd_forward_backward,
autogram_forward_backward,
Expand All @@ -77,8 +78,7 @@
reduce_to_scalar,
reduce_to_vector,
)
from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_

from tests.utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_
from torchjd.aggregation import UPGradWeighting
from torchjd.autogram._engine import Engine
from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autogram/test_gramian_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pytest import mark
from torch.testing import assert_close
from utils.forward_backwards import compute_gramian
from utils.tensors import randn_

from tests.utils.forward_backwards import compute_gramian
from tests.utils.tensors import randn_
from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_accumulate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytest import mark, raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import ones_, tensor_, zeros_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import ones_, tensor_, zeros_
from torchjd.autojac._transform import Accumulate


Expand Down
6 changes: 3 additions & 3 deletions tests/unit/autojac/_transform/test_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math

import torch
from device import DEVICE
from pytest import mark, raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import rand_, tensor_, zeros_

from tests.device import DEVICE
from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import rand_, tensor_, zeros_
from torchjd.aggregation import Random
from torchjd.autojac._transform import OrderedSet, RequirementError
from torchjd.autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/_transform/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytest import raises
from torch import Tensor
from utils.tensors import empty_, randn_

from tests.utils.tensors import empty_, randn_
from torchjd.autojac._transform._base import Conjunction, RequirementError, TensorDict, Transform


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_diagonalize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import tensor_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_grad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import tensor_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import Grad, OrderedSet, RequirementError


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytest import raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import tensor_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import Init, RequirementError


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_interactions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from pytest import raises
from torch.testing import assert_close
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import tensor_, zeros_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_, zeros_
from torchjd.autojac._transform import (
Accumulate,
Conjunction,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_jac.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import mark, raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import eye_, ones_, tensor_, zeros_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import eye_, ones_, tensor_, zeros_
from torchjd.autojac._transform import Jac, OrderedSet, RequirementError


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/_transform/test_select.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from pytest import raises
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.tensors import tensor_

from tests.utils.dict_assertions import assert_tensor_dicts_are_close
from tests.utils.tensors import tensor_
from torchjd.autojac._transform import RequirementError, Select


Expand Down
Loading
Loading