Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 7 additions & 48 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,7 @@
)
from utils.tensors import make_tensors

from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
AlignedMTLWeighting,
DualProj,
DualProjWeighting,
IMTLGWeighting,
Mean,
MeanWeighting,
MGDAWeighting,
PCGrad,
PCGradWeighting,
Random,
RandomWeighting,
Sum,
SumWeighting,
UPGrad,
UPGradWeighting,
Weighting,
)
from torchjd.aggregation import UPGrad, UPGradWeighting
from torchjd.autogram._engine import Engine
from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet
from torchjd.autojac._transform._aggregate import _Matrixify
Expand Down Expand Up @@ -126,35 +105,11 @@
param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]),
]

AGGREGATORS_AND_WEIGHTINGS: list[tuple[Aggregator, Weighting]] = [
(UPGrad(), UPGradWeighting()),
(AlignedMTL(), AlignedMTLWeighting()),
(DualProj(), DualProjWeighting()),
(IMTLG(), IMTLGWeighting()),
(Mean(), MeanWeighting()),
(MGDA(), MGDAWeighting()),
(PCGrad(), PCGradWeighting()),
(Random(), RandomWeighting()),
(Sum(), SumWeighting()),
]

try:
from torchjd.aggregation import CAGrad, CAGradWeighting

AGGREGATORS_AND_WEIGHTINGS.append((CAGrad(c=0.5), CAGradWeighting(c=0.5)))
except ImportError:
pass

WEIGHTINGS = [weighting for _, weighting in AGGREGATORS_AND_WEIGHTINGS]


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
@mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS)
def test_equivalence_autojac_autogram(
architecture: type[ShapedModule],
batch_size: int,
aggregator: Aggregator,
weighting: Weighting,
):
"""
Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
Expand All @@ -166,6 +121,9 @@ def test_equivalence_autojac_autogram(
input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

weighting = UPGradWeighting()
aggregator = UPGrad()

torch.manual_seed(0)
model_autojac = architecture().to(device=DEVICE)
torch.manual_seed(0)
Expand Down Expand Up @@ -262,9 +220,8 @@ def _non_empty_subsets(elements: set) -> list[set]:
return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)]


@mark.parametrize("weighting", WEIGHTINGS)
@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"}))
def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
def test_partial_autogram(gramian_module_names: set[str]):
"""
Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
the autojac engine.
Expand All @@ -276,6 +233,8 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]):
architecture = SimpleBranched
batch_size = 64

weighting = UPGradWeighting()

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

Expand Down
Loading