diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 5188ba86..f4121558 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -10,10 +10,11 @@ GroupNormMobileNetV3Small, InstanceNormMobileNetV2, InstanceNormResNet18, + ModuleFactory, NoFreeParam, - ShapedModule, SqueezeNet, WithTransformerLarge, + get_in_out_shapes, ) from utils.forward_backwards import ( autograd_forward_backward, @@ -28,33 +29,30 @@ from torchjd.autogram import Engine PARAMETRIZATIONS = [ - (WithTransformerLarge, 8), - (FreeParam, 64), - (NoFreeParam, 64), - (Cifar10Model, 64), - (AlexNet, 8), - (InstanceNormResNet18, 16), - (GroupNormMobileNetV3Small, 16), - (SqueezeNet, 4), - (InstanceNormMobileNetV2, 2), + (ModuleFactory(WithTransformerLarge), 8), + (ModuleFactory(FreeParam), 64), + (ModuleFactory(NoFreeParam), 64), + (ModuleFactory(Cifar10Model), 64), + (ModuleFactory(AlexNet), 8), + (ModuleFactory(InstanceNormResNet18), 16), + (ModuleFactory(GroupNormMobileNetV3Small), 16), + (ModuleFactory(SqueezeNet), 4), + (ModuleFactory(InstanceNormMobileNetV2), 2), ] -def compare_autograd_autojac_and_autogram_speed(architecture: type[ShapedModule], batch_size: int): - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES +def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): + model = factory() + input_shapes, output_shapes = get_in_out_shapes(model) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - model = architecture().to(device=DEVICE) - A = Mean() W = A.weighting print( - f"\nTimes for forward + backward on {architecture.__name__} with BS={batch_size}, A={A}" - f" on {DEVICE}." + f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}." ) def fn_autograd(): @@ -148,8 +146,8 @@ def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> def main(): - for architecture, batch_size in PARAMETRIZATIONS: - compare_autograd_autojac_and_autogram_speed(architecture, batch_size) + for factory, batch_size in PARAMETRIZATIONS: + compare_autograd_autojac_and_autogram_speed(factory, batch_size) print("\n") diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7007a9c1..00a3855d 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -5,11 +5,10 @@ import pytest import torch from pytest import mark, param -from torch import Tensor, nn +from torch import Tensor from torch.nn import Linear from torch.optim import SGD from torch.testing import assert_close -from unit.conftest import DEVICE from utils.architectures import ( AlexNet, Cifar10Model, @@ -22,6 +21,7 @@ MISOBranched, ModelAlsoUsingSubmoduleParamsDirectly, ModelUsingSubmoduleParamsDirectly, + ModuleFactory, ModuleReuse, MultiInputMultiOutput, MultiInputSingleOutput, @@ -39,7 +39,6 @@ PyTreeInputSingleOutput, Randomness, RequiresGradOfSchrodinger, - ShapedModule, SimpleBranched, SimpleParamReuse, SingleInputPyTreeOutput, @@ -64,6 +63,7 @@ WithSomeFrozenModule, WithTransformer, WithTransformerLarge, + get_in_out_shapes, ) from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( @@ -84,54 +84,58 @@ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian PARAMETRIZATIONS = [ - (OverlyNested, 32), - (MultiInputSingleOutput, 32), - (MultiInputMultiOutput, 32), - (SingleInputPyTreeOutput, 32), - (PyTreeInputSingleOutput, 32), - (PyTreeInputPyTreeOutput, 32), - (SimpleBranched, 32), - (SimpleBranched, SimpleBranched.INPUT_SHAPES[0]), # Edge case: batch_size = input dim - (MIMOBranched, 32), - (MISOBranched, 32), - (SIPOBranched, 32), - (PISOBranched, 32), - (PIPOBranched, 1), - (PIPOBranched, 2), - (PIPOBranched, 32), - (WithNoTensorOutput, 32), - (WithBuffered, 32), - (SimpleParamReuse, 32), - (ModuleReuse, 32), - (SomeUnusedParam, 32), - (SomeFrozenParam, 32), - (MultiOutputWithFrozenBranch, 32), - (WithSomeFrozenModule, 32), - (RequiresGradOfSchrodinger, 32), - (SomeUnusedOutput, 32), - (Ndim0Output, 32), - (Ndim1Output, 32), - (Ndim2Output, 32), - (Ndim3Output, 32), - (Ndim4Output, 32), - (WithDropout, 32), - (WithModuleWithStringArg, 32), - (WithModuleWithHybridPyTreeArg, 32), - (WithModuleWithStringOutput, 32), - (WithModuleWithStringKwarg, 32), - (WithModuleWithHybridPyTreeKwarg, 32), - (WithMultiHeadAttention, 32), - param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")), - (FreeParam, 32), - (NoFreeParam, 32), - param(Cifar10Model, 16, marks=mark.slow), - param(AlexNet, 2, marks=mark.slow), - param(InstanceNormResNet18, 4, marks=mark.slow), - param(GroupNormMobileNetV3Small, 3, marks=mark.slow), - param(SqueezeNet, 8, marks=mark.slow), - param(InstanceNormMobileNetV2, 2, marks=mark.slow), + (ModuleFactory(OverlyNested), 32), + (ModuleFactory(MultiInputSingleOutput), 32), + (ModuleFactory(MultiInputMultiOutput), 32), + (ModuleFactory(SingleInputPyTreeOutput), 32), + (ModuleFactory(PyTreeInputSingleOutput), 32), + (ModuleFactory(PyTreeInputPyTreeOutput), 32), + (ModuleFactory(SimpleBranched), 32), + (ModuleFactory(SimpleBranched), SimpleBranched.INPUT_SHAPES[0]), # Edge case: bs = input dim + (ModuleFactory(MIMOBranched), 32), + (ModuleFactory(MISOBranched), 32), + (ModuleFactory(SIPOBranched), 32), + (ModuleFactory(PISOBranched), 32), + (ModuleFactory(PIPOBranched), 1), + (ModuleFactory(PIPOBranched), 2), + (ModuleFactory(PIPOBranched), 32), + (ModuleFactory(WithNoTensorOutput), 32), + (ModuleFactory(WithBuffered), 32), + (ModuleFactory(SimpleParamReuse), 32), + (ModuleFactory(ModuleReuse), 32), + (ModuleFactory(SomeUnusedParam), 32), + (ModuleFactory(SomeFrozenParam), 32), + (ModuleFactory(MultiOutputWithFrozenBranch), 32), + (ModuleFactory(WithSomeFrozenModule), 32), + (ModuleFactory(RequiresGradOfSchrodinger), 32), + (ModuleFactory(SomeUnusedOutput), 32), + (ModuleFactory(Ndim0Output), 32), + (ModuleFactory(Ndim1Output), 32), + (ModuleFactory(Ndim2Output), 32), + (ModuleFactory(Ndim3Output), 32), + (ModuleFactory(Ndim4Output), 32), + (ModuleFactory(WithDropout), 32), + (ModuleFactory(WithModuleWithStringArg), 32), + (ModuleFactory(WithModuleWithHybridPyTreeArg), 32), + (ModuleFactory(WithModuleWithStringOutput), 32), + (ModuleFactory(WithModuleWithStringKwarg), 32), + (ModuleFactory(WithModuleWithHybridPyTreeKwarg), 32), + (ModuleFactory(WithMultiHeadAttention), 32), param( - WithTransformerLarge, + ModuleFactory(WithTransformer), + 32, + marks=mark.filterwarnings("ignore:There is a performance drop"), + ), + (ModuleFactory(FreeParam), 32), + (ModuleFactory(NoFreeParam), 32), + param(ModuleFactory(Cifar10Model), 16, marks=mark.slow), + param(ModuleFactory(AlexNet), 2, marks=mark.slow), + param(ModuleFactory(InstanceNormResNet18), 4, marks=mark.slow), + param(ModuleFactory(GroupNormMobileNetV3Small), 3, marks=mark.slow), + param(ModuleFactory(SqueezeNet), 8, marks=mark.slow), + param(ModuleFactory(InstanceNormMobileNetV2), 2, marks=mark.slow), + param( + ModuleFactory(WithTransformerLarge), 8, marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")], ), @@ -139,15 +143,10 @@ def _assert_gramian_is_equivalent_to_autograd( - architecture: type[ShapedModule], batch_size: int, batch_dim: int | None + factory: ModuleFactory, batch_size: int, batch_dim: int | None ): - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - - torch.manual_seed(0) - model_autograd = architecture().to(device=DEVICE) - torch.manual_seed(0) - model_autogram = architecture().to(device=DEVICE) + model_autograd, model_autogram = factory(), factory() + input_shapes, output_shapes = get_in_out_shapes(model_autograd) engine = Engine(model_autogram, batch_dim=batch_dim) @@ -168,57 +167,57 @@ def _assert_gramian_is_equivalent_to_autograd( assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5) -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batch_dim: int | None): +def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None): """Tests that the autograd and the autogram engines compute the same gramian.""" - _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) @mark.parametrize( - "architecture", + "factory", [ - WithBatchNorm, - WithSideEffect, - Randomness, - WithModuleTrackingRunningStats, - param(WithRNN, marks=mark.xfail_if_cuda), + ModuleFactory(WithBatchNorm), + ModuleFactory(WithSideEffect), + ModuleFactory(Randomness), + ModuleFactory(WithModuleTrackingRunningStats), + param(ModuleFactory(WithRNN), marks=mark.xfail_if_cuda), ], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) def test_compute_gramian_with_weird_modules( - architecture: type[ShapedModule], batch_size: int, batch_dim: int | None + factory: ModuleFactory, batch_size: int, batch_dim: int | None ): """ Tests that compute_gramian works even with some problematic modules when batch_dim is None. It is expected to fail on those when the engine uses the batched optimization (when batch_dim=0). """ - _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) @mark.xfail @mark.parametrize( - "architecture", + "factory", [ - ModelUsingSubmoduleParamsDirectly, - ModelAlsoUsingSubmoduleParamsDirectly, - InterModuleParamReuse, + ModuleFactory(ModelUsingSubmoduleParamsDirectly), + ModuleFactory(ModelAlsoUsingSubmoduleParamsDirectly), + ModuleFactory(InterModuleParamReuse), ], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) def test_compute_gramian_unsupported_architectures( - architecture: type[ShapedModule], batch_size: int, batch_dim: int | None + factory: ModuleFactory, batch_size: int, batch_dim: int | None ): """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to fail. """ - _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) @mark.parametrize("batch_size", [1, 3, 16]) @@ -256,14 +255,9 @@ def test_compute_gramian_various_output_shapes( have various different shapes, and can be batched in any of its dimensions. """ - architecture = Ndim2Output - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - - torch.manual_seed(0) - model_autograd = architecture().to(device=DEVICE) - torch.manual_seed(0) - model_autogram = architecture().to(device=DEVICE) + factory = ModuleFactory(Ndim2Output) + model_autograd, model_autogram = factory(), factory() + input_shapes, output_shapes = get_in_out_shapes(model_autograd) engine = Engine(model_autogram, batch_dim=batch_dim) @@ -304,19 +298,15 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int the model parameters is specified. """ - architecture = SimpleBranched + factory = ModuleFactory(SimpleBranched) + model = factory() + input_shapes, output_shapes = get_in_out_shapes(model) batch_size = 64 - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - input = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - torch.manual_seed(0) - model = architecture().to(device=DEVICE) - output = model(input) losses = reduce_to_vector(loss_fn(output)) @@ -337,22 +327,18 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int assert_close(gramian, autograd_gramian) -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_iwrm_steps_with_autogram( - architecture: type[ShapedModule], batch_size: int, batch_dim: int | None -): +def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch_dim: int | None): """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES + model = factory() + input_shapes, output_shapes = get_in_out_shapes(model) weighting = UPGradWeighting() - model = architecture().to(device=DEVICE) - engine = Engine(model, batch_dim=batch_dim) optimizer = SGD(model.parameters(), lr=1e-7) @@ -367,25 +353,23 @@ def test_iwrm_steps_with_autogram( model.zero_grad() -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("use_engine", [False, True]) @mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - architecture: type[ShapedModule], batch_size: int, use_engine: bool, batch_dim: int | None + factory: ModuleFactory, batch_size: int, use_engine: bool, batch_dim: int | None ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd call. """ - input = make_tensors(batch_size, architecture.INPUT_SHAPES) - targets = make_tensors(batch_size, architecture.OUTPUT_SHAPES) - loss_fn = make_mse_loss_fn(targets) + model, model_autogram = factory(), factory() + input_shapes, output_shapes = get_in_out_shapes(model) - torch.manual_seed(0) - model = architecture().to(device=DEVICE) - torch.manual_seed(0) - model_autogram = architecture().to(device=DEVICE) + input = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) torch.manual_seed(0) # Fix randomness for random models autograd_forward_backward(model, input, loss_fn) @@ -411,18 +395,17 @@ def test_autograd_while_modules_are_hooked( @mark.parametrize( - ["architecture", "batch_dim"], + ["factory", "batch_dim"], [ - (WithModuleTrackingRunningStats, 0), - (WithRNN, 0), - (WithBatchNorm, 0), + (ModuleFactory(WithModuleTrackingRunningStats), 0), + (ModuleFactory(WithRNN), 0), + (ModuleFactory(WithBatchNorm), 0), ], ) -def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | None): +def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): """Tests that the engine cannot be constructed with incompatible modules.""" - model = architecture().to(device=DEVICE) - + model = factory() with pytest.raises(ValueError): _ = Engine(model, batch_dim=batch_dim) @@ -436,8 +419,8 @@ def test_compute_gramian_manual(): in_dims = 18 out_dims = 25 - torch.manual_seed(0) - model = Linear(in_dims, out_dims).to(device=DEVICE) + factory = ModuleFactory(Linear, in_dims, out_dims) + model = factory() engine = Engine(model, batch_dim=None) input = randn_(in_dims) @@ -482,10 +465,8 @@ def test_reshape_equivariance(shape: list[int]): input_size = shape[0] output_size = prod(shape[1:]) - torch.manual_seed(0) - model1 = Linear(input_size, output_size).to(device=DEVICE) - torch.manual_seed(0) - model2 = Linear(input_size, output_size).to(device=DEVICE) + factory = ModuleFactory(Linear, input_size, output_size) + model1, model2 = factory(), factory() engine1 = Engine(model1, batch_dim=None) engine2 = Engine(model2, batch_dim=None) @@ -524,10 +505,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: input_size = shape[0] output_size = prod(shape[1:]) - torch.manual_seed(0) - model1 = Linear(input_size, output_size).to(device=DEVICE) - torch.manual_seed(0) - model2 = Linear(input_size, output_size).to(device=DEVICE) + factory = ModuleFactory(Linear, input_size, output_size) + model1, model2 = factory(), factory() engine1 = Engine(model1, batch_dim=None) engine2 = Engine(model2, batch_dim=None) @@ -569,10 +548,8 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): batch_size = shape[batch_dim] output_size = input_size - torch.manual_seed(0) - model1 = Linear(input_size, output_size).to(device=DEVICE) - torch.manual_seed(0) - model2 = Linear(input_size, output_size).to(device=DEVICE) + factory = ModuleFactory(Linear, input_size, output_size) + model1, model2 = factory(), factory() engine1 = Engine(model1, batch_dim=batch_dim) engine2 = Engine(model2, batch_dim=None) @@ -587,8 +564,8 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): assert_close(gramian1, gramian2) -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): +@mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) +def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int): """ Same as test_batched_non_batched_equivalence but on real architectures, and thus only between batch_size=0 and batch_size=None. @@ -597,13 +574,8 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz that the get_used_params does not work for some module of the architecture. """ - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - - torch.manual_seed(0) - model_0 = architecture().to(device=DEVICE) - torch.manual_seed(0) - model_none = architecture().to(device=DEVICE) + model_0, model_none = factory(), factory() + input_shapes, output_shapes = get_in_out_shapes(model_0) engine_0 = Engine(model_0, batch_dim=0) engine_none = Engine(model_none, batch_dim=None) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index d5760d03..c31704cd 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -5,6 +5,26 @@ from torch import Tensor, nn from torch.nn import Flatten, ReLU from torch.utils._pytree import PyTree +from unit.conftest import DEVICE + + +class ModuleFactory: + def __init__(self, architecture: type[nn.Module], *args, **kwargs): + self.architecture = architecture + self.args = args + self.kwargs = kwargs + + def __call__(self) -> nn.Module: + devices = [DEVICE] if DEVICE.type == "cuda" else [] + with torch.random.fork_rng(devices=devices, device_type=DEVICE.type): + torch.random.manual_seed(0) + return self.architecture(*self.args, **self.kwargs).to(device=DEVICE) + + def __str__(self) -> str: + args_string = ", ".join([str(arg) for arg in self.args]) + kwargs_string = ", ".join([f"{key}={value}" for key, value in self.kwargs.items()]) + optional_comma = "" if args_string == "" or kwargs_string == "" else ", " + return f"{self.architecture.__name__}({args_string}{optional_comma}{kwargs_string})" class ShapedModule(nn.Module): @@ -21,6 +41,13 @@ def __init_subclass__(cls): raise TypeError(f"{cls.__name__} must define OUTPUT_SHAPES") +def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: + if isinstance(module, ShapedModule): + return module.INPUT_SHAPES, module.OUTPUT_SHAPES + else: + raise ValueError("Unknown input / output shapes of module", module) + + class OverlyNested(ShapedModule): """Model that contains many unnecessary levels of nested modules.""" diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 1eb07092..5a9e13a9 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -4,6 +4,7 @@ from torch import Tensor, nn, vmap from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map +from utils.architectures import get_in_out_shapes from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine @@ -58,7 +59,8 @@ def _forward_pass( ) -> PyTree: output = model(inputs) - assert tree_map(lambda t: t.shape[1:], output) == model.OUTPUT_SHAPES + _, expected_output_shapes = get_in_out_shapes(model) + assert tree_map(lambda t: t.shape[1:], output) == expected_output_shapes loss_tensors = loss_fn(output) losses = reduce_to_vector(loss_tensors)