diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index be7b593b..ba4ec06a 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -57,28 +57,7 @@ ) from utils.tensors import make_tensors -from torchjd.aggregation import ( - IMTLG, - MGDA, - Aggregator, - AlignedMTL, - AlignedMTLWeighting, - DualProj, - DualProjWeighting, - IMTLGWeighting, - Mean, - MeanWeighting, - MGDAWeighting, - PCGrad, - PCGradWeighting, - Random, - RandomWeighting, - Sum, - SumWeighting, - UPGrad, - UPGradWeighting, - Weighting, -) +from torchjd.aggregation import UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet from torchjd.autojac._transform._aggregate import _Matrixify @@ -126,35 +105,11 @@ param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]), ] -AGGREGATORS_AND_WEIGHTINGS: list[tuple[Aggregator, Weighting]] = [ - (UPGrad(), UPGradWeighting()), - (AlignedMTL(), AlignedMTLWeighting()), - (DualProj(), DualProjWeighting()), - (IMTLG(), IMTLGWeighting()), - (Mean(), MeanWeighting()), - (MGDA(), MGDAWeighting()), - (PCGrad(), PCGradWeighting()), - (Random(), RandomWeighting()), - (Sum(), SumWeighting()), -] - -try: - from torchjd.aggregation import CAGrad, CAGradWeighting - - AGGREGATORS_AND_WEIGHTINGS.append((CAGrad(c=0.5), CAGradWeighting(c=0.5))) -except ImportError: - pass - -WEIGHTINGS = [weighting for _, weighting in AGGREGATORS_AND_WEIGHTINGS] - @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -@mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS) def test_equivalence_autojac_autogram( architecture: type[ShapedModule], batch_size: int, - aggregator: Aggregator, - weighting: Weighting, ): """ Tests that the autogram engine gives the same results as the autojac engine on IWRM for several @@ -166,6 +121,9 @@ def test_equivalence_autojac_autogram( input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES + weighting = UPGradWeighting() + aggregator = UPGrad() + torch.manual_seed(0) model_autojac = architecture().to(device=DEVICE) torch.manual_seed(0) @@ -262,9 +220,8 @@ def _non_empty_subsets(elements: set) -> list[set]: return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)] -@mark.parametrize("weighting", WEIGHTINGS) @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) -def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): +def test_partial_autogram(gramian_module_names: set[str]): """ Tests that partial JD via the autogram engine works similarly as if the gramian was computed via the autojac engine. @@ -276,6 +233,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): architecture = SimpleBranched batch_size = 64 + weighting = UPGradWeighting() + input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES