From deb38e87291f15923e1c0c8f67b61fcdd304a9c0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 08:17:00 +0200 Subject: [PATCH 01/20] Implement `_compute_gramian_with_autograd` and add `test_gramian_equivalence_autograd_autogram`. The tests of autogram are not independent from autojac because we still have the full forward_backward phase integration test. --- tests/unit/autogram/test_engine.py | 79 ++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index ba4ec06a..558a7a8b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -3,8 +3,9 @@ import pytest import torch from pytest import mark, param -from torch import nn +from torch import Tensor, nn, vmap from torch.optim import SGD +from torch.testing import assert_close from unit.conftest import DEVICE from utils.architectures import ( AlexNet, @@ -59,8 +60,6 @@ from torchjd.aggregation import UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine -from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet -from torchjd.autojac._transform._aggregate import _Matrixify PARAMETRIZATIONS = [ (OverlyNested, 32), @@ -106,6 +105,64 @@ ] +def _compute_gramian_with_autograd( + output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False +) -> Tensor: + filtered_inputs = [input for input in inputs if input.requires_grad] + + def get_vjp(grad_outputs: Tensor) -> list[Tensor]: + grads = torch.autograd.grad( + output, + filtered_inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + allow_unused=True, + ) + return [grad for grad in grads if grad is not None] + + jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) + jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] + gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) + + return gramian + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_gramian_equivalence_autograd_autogram( + architecture: type[ShapedModule], + batch_size: int, +): + """ + Tests that the autograd and the autogram engines compute the same gramian. + """ + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model_autograd = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_autogram = architecture().to(device=DEVICE) + + engine = Engine(model_autogram.modules()) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random aggregators and random models + output = model_autograd(inputs) + losses = loss_fn(output) + autograd_gramian = _compute_gramian_with_autograd(losses, list(model_autograd.parameters())) + + torch.random.manual_seed(0) # Fix randomness for random weightings and random models + output = model_autogram(inputs) + losses = loss_fn(output) + autogram_gramian = engine.compute_gramian(losses) + + assert_close(autogram_gramian, autograd_gramian) + + @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) def test_equivalence_autojac_autogram( architecture: type[ShapedModule], @@ -247,23 +304,13 @@ def test_partial_autogram(gramian_module_names: set[str]): output = model(input) losses = loss_fn(output) - losses_ = OrderedSet(losses) - - init = Init(losses_) - diag = Diagonalize(losses_) gramian_modules = [model.get_submodule(name) for name in gramian_module_names] - gramian_params = OrderedSet({}) + gramian_params = [] for m in gramian_modules: - gramian_params += OrderedSet(m.parameters()) - - jac = Jac(losses_, OrderedSet(gramian_params), None, True) - mat = _Matrixify() - transform = mat << jac << diag << init + gramian_params += list(m.parameters()) - jacobian_matrices = transform({}) - jacobian_matrix = torch.cat(list(jacobian_matrices.values()), dim=1) - gramian = jacobian_matrix @ jacobian_matrix.T + gramian = _compute_gramian_with_autograd(losses, gramian_params) torch.manual_seed(0) losses.backward(weighting(gramian)) From d1b28f17cddbbe570aa6eee241e3e8e8dc1f0acf Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 08:46:34 +0200 Subject: [PATCH 02/20] Add `requires_grad = True` in `test_partial_autogram` --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 558a7a8b..d4aeffe5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -310,7 +310,7 @@ def test_partial_autogram(gramian_module_names: set[str]): for m in gramian_modules: gramian_params += list(m.parameters()) - gramian = _compute_gramian_with_autograd(losses, gramian_params) + gramian = _compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) losses.backward(weighting(gramian)) From db159215a6e089f27c121c3b98d0ee6e83174615 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 10:29:55 +0200 Subject: [PATCH 03/20] Move `compute_gramian_with_autograd` to utils --- tests/unit/autogram/test_engine.py | 29 ++++--------------------- tests/utils/autograd_compute_gramian.py | 24 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 25 deletions(-) create mode 100644 tests/utils/autograd_compute_gramian.py diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index d4aeffe5..db98d3ad 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -3,7 +3,7 @@ import pytest import torch from pytest import mark, param -from torch import Tensor, nn, vmap +from torch import nn from torch.optim import SGD from torch.testing import assert_close from unit.conftest import DEVICE @@ -49,6 +49,7 @@ WithSideEffect, WithSomeFrozenModule, ) +from utils.autograd_compute_gramian import compute_gramian_with_autograd from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( autograd_forward_backward, @@ -105,28 +106,6 @@ ] -def _compute_gramian_with_autograd( - output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False -) -> Tensor: - filtered_inputs = [input for input in inputs if input.requires_grad] - - def get_vjp(grad_outputs: Tensor) -> list[Tensor]: - grads = torch.autograd.grad( - output, - filtered_inputs, - grad_outputs=grad_outputs, - retain_graph=retain_graph, - allow_unused=True, - ) - return [grad for grad in grads if grad is not None] - - jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) - jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] - gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) - - return gramian - - @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) def test_gramian_equivalence_autograd_autogram( architecture: type[ShapedModule], @@ -153,7 +132,7 @@ def test_gramian_equivalence_autograd_autogram( torch.random.manual_seed(0) # Fix randomness for random aggregators and random models output = model_autograd(inputs) losses = loss_fn(output) - autograd_gramian = _compute_gramian_with_autograd(losses, list(model_autograd.parameters())) + autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters())) torch.random.manual_seed(0) # Fix randomness for random weightings and random models output = model_autogram(inputs) @@ -310,7 +289,7 @@ def test_partial_autogram(gramian_module_names: set[str]): for m in gramian_modules: gramian_params += list(m.parameters()) - gramian = _compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) + gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) losses.backward(weighting(gramian)) diff --git a/tests/utils/autograd_compute_gramian.py b/tests/utils/autograd_compute_gramian.py new file mode 100644 index 00000000..3ae04e6b --- /dev/null +++ b/tests/utils/autograd_compute_gramian.py @@ -0,0 +1,24 @@ +import torch +from torch import Tensor, nn, vmap + + +def compute_gramian_with_autograd( + output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False +) -> Tensor: + filtered_inputs = [input for input in inputs if input.requires_grad] + + def get_vjp(grad_outputs: Tensor) -> list[Tensor]: + grads = torch.autograd.grad( + output, + filtered_inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + allow_unused=True, + ) + return [grad for grad in grads if grad is not None] + + jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) + jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] + gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) + + return gramian From 390e747b3da87a6f4439483f5ccdc8b14dcee644 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 10:41:20 +0200 Subject: [PATCH 04/20] Remove usage of `autojac` in tests of `autogram` --- tests/speed/autogram/grad_vs_jac_vs_gram.py | 4 +-- tests/unit/autogram/test_engine.py | 34 ++++++++++----------- tests/utils/forward_backwards.py | 12 +++++--- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..a16a8eab 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -15,8 +15,8 @@ ) from utils.forward_backwards import ( autograd_forward_backward, + autograd_gramian_forward_backward, autogram_forward_backward, - autojac_forward_backward, make_mse_loss_fn, ) from utils.tensors import make_tensors @@ -61,7 +61,7 @@ def init_fn_autograd(): fn_autograd() def fn_autojac(): - autojac_forward_backward(model, inputs, loss_fn, A) + autograd_gramian_forward_backward(model, inputs, loss_fn, A) def init_fn_autojac(): torch.cuda.empty_cache() diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index db98d3ad..49f2695e 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -53,13 +53,13 @@ from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( autograd_forward_backward, + autograd_gramian_forward_backward, autogram_forward_backward, - autojac_forward_backward, make_mse_loss_fn, ) from utils.tensors import make_tensors -from torchjd.aggregation import UPGrad, UPGradWeighting +from torchjd.aggregation import UPGradWeighting from torchjd.autogram._engine import Engine PARAMETRIZATIONS = [ @@ -143,12 +143,12 @@ def test_gramian_equivalence_autograd_autogram( @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_equivalence_autojac_autogram( +def test_equivalence_autograd_autogram( architecture: type[ShapedModule], batch_size: int, ): """ - Tests that the autogram engine gives the same results as the autojac engine on IWRM for several + Tests that the autogram engine gives the same results as the autograd engine on IWRM for several JD steps. """ @@ -158,15 +158,14 @@ def test_equivalence_autojac_autogram( output_shapes = architecture.OUTPUT_SHAPES weighting = UPGradWeighting() - aggregator = UPGrad() torch.manual_seed(0) - model_autojac = architecture().to(device=DEVICE) + model_autograd = architecture().to(device=DEVICE) torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) engine = Engine(model_autogram.modules()) - optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) + optimizer_autograd = SGD(model_autograd.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) for i in range(n_iter): @@ -175,9 +174,11 @@ def test_equivalence_autojac_autogram( loss_fn = make_mse_loss_fn(targets) torch.random.manual_seed(0) # Fix randomness for random aggregators and random models - autojac_forward_backward(model_autojac, inputs, loss_fn, aggregator) + autograd_gramian_forward_backward( + model_autograd, inputs, list(model_autograd.parameters()), loss_fn, weighting + ) expected_grads = { - name: p.grad for name, p in model_autojac.named_parameters() if p.grad is not None + name: p.grad for name, p in model_autograd.named_parameters() if p.grad is not None } torch.random.manual_seed(0) # Fix randomness for random weightings and random models @@ -188,8 +189,8 @@ def test_equivalence_autojac_autogram( assert_tensor_dicts_are_close(grads, expected_grads) - optimizer_autojac.step() - model_autojac.zero_grad() + optimizer_autograd.step() + model_autograd.zero_grad() optimizer_autogram.step() model_autogram.zero_grad() @@ -206,7 +207,6 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat output_shapes = architecture.OUTPUT_SHAPES W = UPGradWeighting() - A = UPGrad() input = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) @@ -215,8 +215,8 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model = architecture().to(device=DEVICE) torch.manual_seed(0) # Fix randomness for random models - autojac_forward_backward(model, input, loss_fn, A) - autojac_grads = { + autograd_gramian_forward_backward(model, input, list(model.parameters()), loss_fn, W) + autograd_gramian_grads = { name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None } model.zero_grad() @@ -230,12 +230,12 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - # Hook modules and verify that we're equivalent to autojac when using the engine + # Hook modules and verify that we're equivalent to autograd when using the engine engine = Engine(model_autogram.modules()) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, autojac_grads) + assert_tensor_dicts_are_close(grads, autograd_gramian_grads) model_autogram.zero_grad() # Verify that even with the hooked modules, autograd works normally when not using the engine. @@ -260,7 +260,7 @@ def _non_empty_subsets(elements: set) -> list[set]: def test_partial_autogram(gramian_module_names: set[str]): """ Tests that partial JD via the autogram engine works similarly as if the gramian was computed via - the autojac engine. + the autograd engine. Note that this test is a bit redundant now that we have the Engine interface, because it now just compares two ways of computing the Gramian, which is independant of the idea of partial JD. diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 53ad3753..7540c841 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -4,10 +4,10 @@ from torch import Tensor, nn from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map +from utils.autograd_compute_gramian import compute_gramian_with_autograd -from torchjd.aggregation import Aggregator, Weighting +from torchjd.aggregation import Weighting from torchjd.autogram import Engine -from torchjd.autojac import backward def autograd_forward_backward( @@ -19,14 +19,16 @@ def autograd_forward_backward( losses.sum().backward() -def autojac_forward_backward( +def autograd_gramian_forward_backward( model: nn.Module, inputs: PyTree, + params: list[nn.Parameter], loss_fn: Callable[[PyTree], Tensor], - aggregator: Aggregator, + weighting: Weighting, ) -> None: losses = _forward_pass(model, inputs, loss_fn) - backward(losses, aggregator=aggregator) + gramian = compute_gramian_with_autograd(losses, params, retain_graph=True) + losses.backward(weighting(gramian)) def autogram_forward_backward( From 3fff2a5c4822511218342b56ca1dbe607c80df55 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 10:56:06 +0200 Subject: [PATCH 05/20] Change `test_equivalence_autograd_autogram` to `test_IWRM_steps_with_autogram` --- tests/unit/autogram/test_engine.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 49f2695e..2cd26127 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -143,13 +143,12 @@ def test_gramian_equivalence_autograd_autogram( @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_equivalence_autograd_autogram( +def test_IWRM_steps_with_autogram( architecture: type[ShapedModule], batch_size: int, ): """ - Tests that the autogram engine gives the same results as the autograd engine on IWRM for several - JD steps. + Tests that the autogram engine doesn't yield any error during several IWRM iterations. """ n_iter = 3 @@ -159,13 +158,9 @@ def test_equivalence_autograd_autogram( weighting = UPGradWeighting() - torch.manual_seed(0) - model_autograd = architecture().to(device=DEVICE) - torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) engine = Engine(model_autogram.modules()) - optimizer_autograd = SGD(model_autograd.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) for i in range(n_iter): @@ -173,24 +168,7 @@ def test_equivalence_autograd_autogram( targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - torch.random.manual_seed(0) # Fix randomness for random aggregators and random models - autograd_gramian_forward_backward( - model_autograd, inputs, list(model_autograd.parameters()), loss_fn, weighting - ) - expected_grads = { - name: p.grad for name, p in model_autograd.named_parameters() if p.grad is not None - } - - torch.random.manual_seed(0) # Fix randomness for random weightings and random models autogram_forward_backward(model_autogram, engine, weighting, inputs, loss_fn) - grads = { - name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None - } - - assert_tensor_dicts_are_close(grads, expected_grads) - - optimizer_autograd.step() - model_autograd.zero_grad() optimizer_autogram.step() model_autogram.zero_grad() From 1c5f32ec3ded4f623834ec1b434bda9cb83146ef Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 16:48:22 +0200 Subject: [PATCH 06/20] Remove computation of gradients in `test_partial_autogram` --- tests/unit/autogram/test_engine.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 2cd26127..08a64a65 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -247,8 +247,6 @@ def test_partial_autogram(gramian_module_names: set[str]): architecture = SimpleBranched batch_size = 64 - weighting = UPGradWeighting() - input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES @@ -267,23 +265,16 @@ def test_partial_autogram(gramian_module_names: set[str]): for m in gramian_modules: gramian_params += list(m.parameters()) - gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) + autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - losses.backward(weighting(gramian)) - - expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} - model.zero_grad() engine = Engine(gramian_modules) output = model(input) losses = loss_fn(output) gramian = engine.compute_gramian(losses) - torch.manual_seed(0) - losses.backward(weighting(gramian)) - grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, expected_grads) + assert_close(gramian, autograd_gramian) @mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats]) From 3165ea08f56085c7591beaa52c5217fda6588aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:14:06 +0200 Subject: [PATCH 07/20] Fix function name --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 08a64a65..076e8738 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -143,7 +143,7 @@ def test_gramian_equivalence_autograd_autogram( @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_IWRM_steps_with_autogram( +def test_iwrm_steps_with_autogram( architecture: type[ShapedModule], batch_size: int, ): From 6dd2ec03426c4e148b1777ce01fc1b05e465ff00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:14:51 +0200 Subject: [PATCH 08/20] Fix variable names --- tests/unit/autogram/test_engine.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 076e8738..8aa09977 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -158,20 +158,20 @@ def test_iwrm_steps_with_autogram( weighting = UPGradWeighting() - model_autogram = architecture().to(device=DEVICE) + model = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules()) - optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) + engine = Engine(model.modules()) + optimizer = SGD(model.parameters(), lr=1e-7) for i in range(n_iter): inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - autogram_forward_backward(model_autogram, engine, weighting, inputs, loss_fn) + autogram_forward_backward(model, engine, weighting, inputs, loss_fn) - optimizer_autogram.step() - model_autogram.zero_grad() + optimizer.step() + model.zero_grad() @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) From 5c132be9744b56aaa25d0abad22b8aae547c8a0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:15:47 +0200 Subject: [PATCH 09/20] Add docstring to compute_gramian_with_autograd --- tests/utils/autograd_compute_gramian.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/utils/autograd_compute_gramian.py b/tests/utils/autograd_compute_gramian.py index 3ae04e6b..c71d330f 100644 --- a/tests/utils/autograd_compute_gramian.py +++ b/tests/utils/autograd_compute_gramian.py @@ -5,6 +5,11 @@ def compute_gramian_with_autograd( output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False ) -> Tensor: + """ + Computes the Gramian of the Jacobian of the outputs with respect to the inputs using vmapped + calls to the autograd engine. + """ + filtered_inputs = [input for input in inputs if input.requires_grad] def get_vjp(grad_outputs: Tensor) -> list[Tensor]: From 7d22521b0fb9557444d131e559a3605cfcf13e1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:16:30 +0200 Subject: [PATCH 10/20] Avoid using sum() --- tests/utils/autograd_compute_gramian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/autograd_compute_gramian.py b/tests/utils/autograd_compute_gramian.py index c71d330f..3c624016 100644 --- a/tests/utils/autograd_compute_gramian.py +++ b/tests/utils/autograd_compute_gramian.py @@ -24,6 +24,6 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]: jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] - gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) + gramian = torch.sum(torch.stack([jacobian @ jacobian.T for jacobian in jacobian_matrices])) return gramian From 7c0d4b34d0b5ebe4964e64eeb954762c9291aaf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:19:38 +0200 Subject: [PATCH 11/20] Revert removal of autojac_forward_backward and change in compare_autograd_autojac_and_autogram_speed --- tests/speed/autogram/grad_vs_jac_vs_gram.py | 4 ++-- tests/utils/forward_backwards.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index a16a8eab..881d0321 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -15,8 +15,8 @@ ) from utils.forward_backwards import ( autograd_forward_backward, - autograd_gramian_forward_backward, autogram_forward_backward, + autojac_forward_backward, make_mse_loss_fn, ) from utils.tensors import make_tensors @@ -61,7 +61,7 @@ def init_fn_autograd(): fn_autograd() def fn_autojac(): - autograd_gramian_forward_backward(model, inputs, loss_fn, A) + autojac_forward_backward(model, inputs, loss_fn, A) def init_fn_autojac(): torch.cuda.empty_cache() diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 7540c841..f5655267 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -6,8 +6,9 @@ from torch.utils._pytree import PyTree, tree_flatten, tree_map from utils.autograd_compute_gramian import compute_gramian_with_autograd -from torchjd.aggregation import Weighting +from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine +from torchjd.autojac import backward def autograd_forward_backward( @@ -19,6 +20,16 @@ def autograd_forward_backward( losses.sum().backward() +def autojac_forward_backward( + model: nn.Module, + inputs: PyTree, + loss_fn: Callable[[PyTree], Tensor], + aggregator: Aggregator, +) -> None: + losses = _forward_pass(model, inputs, loss_fn) + backward(losses, aggregator=aggregator) + + def autograd_gramian_forward_backward( model: nn.Module, inputs: PyTree, From e81732a2a3609535551b35b63c8dcbd9824054a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:22:00 +0200 Subject: [PATCH 12/20] Move compute_gramian_with_autograd to forward_backwards It's a form of backward pass so I think it's fine to have it here + it's one fewer file. --- tests/unit/autogram/test_engine.py | 2 +- tests/utils/autograd_compute_gramian.py | 29 ------------------------ tests/utils/forward_backwards.py | 30 +++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 32 deletions(-) delete mode 100644 tests/utils/autograd_compute_gramian.py diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 8aa09977..6e62e202 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -49,12 +49,12 @@ WithSideEffect, WithSomeFrozenModule, ) -from utils.autograd_compute_gramian import compute_gramian_with_autograd from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( autograd_forward_backward, autograd_gramian_forward_backward, autogram_forward_backward, + compute_gramian_with_autograd, make_mse_loss_fn, ) from utils.tensors import make_tensors diff --git a/tests/utils/autograd_compute_gramian.py b/tests/utils/autograd_compute_gramian.py deleted file mode 100644 index 3c624016..00000000 --- a/tests/utils/autograd_compute_gramian.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from torch import Tensor, nn, vmap - - -def compute_gramian_with_autograd( - output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False -) -> Tensor: - """ - Computes the Gramian of the Jacobian of the outputs with respect to the inputs using vmapped - calls to the autograd engine. - """ - - filtered_inputs = [input for input in inputs if input.requires_grad] - - def get_vjp(grad_outputs: Tensor) -> list[Tensor]: - grads = torch.autograd.grad( - output, - filtered_inputs, - grad_outputs=grad_outputs, - retain_graph=retain_graph, - allow_unused=True, - ) - return [grad for grad in grads if grad is not None] - - jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) - jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] - gramian = torch.sum(torch.stack([jacobian @ jacobian.T for jacobian in jacobian_matrices])) - - return gramian diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index f5655267..29f6c66e 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -1,10 +1,9 @@ from typing import Callable import torch -from torch import Tensor, nn +from torch import Tensor, nn, vmap from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map -from utils.autograd_compute_gramian import compute_gramian_with_autograd from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine @@ -93,3 +92,30 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor: return raw_losses.unsqueeze(1) else: return raw_losses.flatten(start_dim=1) + + +def compute_gramian_with_autograd( + output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False +) -> Tensor: + """ + Computes the Gramian of the Jacobian of the outputs with respect to the inputs using vmapped + calls to the autograd engine. + """ + + filtered_inputs = [input for input in inputs if input.requires_grad] + + def get_vjp(grad_outputs: Tensor) -> list[Tensor]: + grads = torch.autograd.grad( + output, + filtered_inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + allow_unused=True, + ) + return [grad for grad in grads if grad is not None] + + jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) + jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] + gramian = torch.sum(torch.stack([jacobian @ jacobian.T for jacobian in jacobian_matrices])) + + return gramian From eefaff38724d9ca24569692c1cef0594c40b1e5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:28:09 +0200 Subject: [PATCH 13/20] Revert breaking everything with sum change --- tests/utils/forward_backwards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 29f6c66e..28611409 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -116,6 +116,6 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]: jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] - gramian = torch.sum(torch.stack([jacobian @ jacobian.T for jacobian in jacobian_matrices])) + gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) return gramian From b22dcb0e1dc82b344a40d84377ffeebacbf19600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 13 Sep 2025 19:35:22 +0200 Subject: [PATCH 14/20] Add autograd_gramian speed test in grad_vs_jac_vs_gram Also reduce batch sizes, otherwise too high cuda memory usage. GPU has lost his will to live after so much time spent rendering graphics for world of warcraft I guess. --- tests/speed/autogram/grad_vs_jac_vs_gram.py | 27 +++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..f3e974f0 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -1,3 +1,4 @@ +import gc import time import torch @@ -15,6 +16,7 @@ ) from utils.forward_backwards import ( autograd_forward_backward, + autograd_gramian_forward_backward, autogram_forward_backward, autojac_forward_backward, make_mse_loss_fn, @@ -31,8 +33,8 @@ (AlexNet, 8), (InstanceNormResNet18, 16), (GroupNormMobileNetV3Small, 16), - (SqueezeNet, 16), - (InstanceNormMobileNetV2, 8), + (SqueezeNet, 4), + (InstanceNormMobileNetV2, 2), ] @@ -58,13 +60,23 @@ def fn_autograd(): def init_fn_autograd(): torch.cuda.empty_cache() + gc.collect() fn_autograd() + def fn_autograd_gramian(): + autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W) + + def init_fn_autograd_gramian(): + torch.cuda.empty_cache() + gc.collect() + fn_autograd_gramian() + def fn_autojac(): autojac_forward_backward(model, inputs, loss_fn, A) def init_fn_autojac(): torch.cuda.empty_cache() + gc.collect() fn_autojac() def fn_autogram(): @@ -72,6 +84,7 @@ def fn_autogram(): def init_fn_autogram(): torch.cuda.empty_cache() + gc.collect() fn_autogram() def optionally_cuda_sync(): @@ -91,6 +104,16 @@ def post_fn(): print(autograd_times) print() + autograd_gramian_times = torch.tensor( + time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs) + ) + print( + f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = " + f"{autograd_gramian_times.std():.5f}" + ) + print(autograd_gramian_times) + print() + autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)) print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}") print(autojac_times) From eace7ad3b8c9f353d8537918ea6d775e6d48fc0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 14:42:31 +0200 Subject: [PATCH 15/20] Fix outdated comment --- tests/unit/autogram/test_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6e62e202..1373a4a3 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -129,12 +129,12 @@ def test_gramian_equivalence_autograd_autogram( targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - torch.random.manual_seed(0) # Fix randomness for random aggregators and random models + torch.random.manual_seed(0) # Fix randomness for random models output = model_autograd(inputs) losses = loss_fn(output) autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters())) - torch.random.manual_seed(0) # Fix randomness for random weightings and random models + torch.random.manual_seed(0) # Fix randomness for random models output = model_autogram(inputs) losses = loss_fn(output) autogram_gramian = engine.compute_gramian(losses) From 451bd3eee4b3959c811fbbb6c15522e83213066e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 14:49:56 +0200 Subject: [PATCH 16/20] Fix tolerance of test_gramian_equivalence_autograd_autogram This should make the test pass even on MacOS --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 1373a4a3..8ae01ace 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -139,7 +139,7 @@ def test_gramian_equivalence_autograd_autogram( losses = loss_fn(output) autogram_gramian = engine.compute_gramian(losses) - assert_close(autogram_gramian, autograd_gramian) + assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) From fc9a1c9035bff3459e80e1294f4992144afb29fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 14:52:00 +0200 Subject: [PATCH 17/20] Remove outdated docstring --- tests/unit/autogram/test_engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 8ae01ace..271bde5b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -239,9 +239,6 @@ def test_partial_autogram(gramian_module_names: set[str]): """ Tests that partial JD via the autogram engine works similarly as if the gramian was computed via the autograd engine. - - Note that this test is a bit redundant now that we have the Engine interface, because it now - just compares two ways of computing the Gramian, which is independant of the idea of partial JD. """ architecture = SimpleBranched From 33cf529931d91f689a058091748ee43e0145f29a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 14:54:07 +0200 Subject: [PATCH 18/20] Fix docstring --- tests/unit/autogram/test_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 271bde5b..01b6764d 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -147,9 +147,7 @@ def test_iwrm_steps_with_autogram( architecture: type[ShapedModule], batch_size: int, ): - """ - Tests that the autogram engine doesn't yield any error during several IWRM iterations. - """ + """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 From 2057e055037ff72a8099403e56b507fa2bdc92ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 15:08:02 +0200 Subject: [PATCH 19/20] Simplify test_autograd_while_modules_are_hooked --- tests/unit/autogram/test_engine.py | 46 +++++++++++------------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 01b6764d..2374b776 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -52,7 +52,6 @@ from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( autograd_forward_backward, - autograd_gramian_forward_backward, autogram_forward_backward, compute_gramian_with_autograd, make_mse_loss_fn, @@ -173,56 +172,45 @@ def test_iwrm_steps_with_autogram( @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int): +@mark.parametrize("compute_gramian", [False, True]) +def test_autograd_while_modules_are_hooked( + architecture: type[ShapedModule], batch_size: int, compute_gramian: bool +): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd call. """ - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - - W = UPGradWeighting() - input = make_tensors(batch_size, input_shapes) - targets = make_tensors(batch_size, output_shapes) + input = make_tensors(batch_size, architecture.INPUT_SHAPES) + targets = make_tensors(batch_size, architecture.OUTPUT_SHAPES) loss_fn = make_mse_loss_fn(targets) torch.manual_seed(0) model = architecture().to(device=DEVICE) - - torch.manual_seed(0) # Fix randomness for random models - autograd_gramian_forward_backward(model, input, list(model.parameters()), loss_fn, W) - autograd_gramian_grads = { - name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None - } - model.zero_grad() + torch.manual_seed(0) + model_autogram = architecture().to(device=DEVICE) torch.manual_seed(0) # Fix randomness for random models autograd_forward_backward(model, input, loss_fn) - autograd_grads = { - name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None - } + autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} - torch.manual_seed(0) - model_autogram = architecture().to(device=DEVICE) - - # Hook modules and verify that we're equivalent to autograd when using the engine + # Hook modules and optionally compute the Gramian engine = Engine(model_autogram.modules()) - torch.manual_seed(0) # Fix randomness for random models - autogram_forward_backward(model_autogram, engine, W, input, loss_fn) - grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, autograd_gramian_grads) - model_autogram.zero_grad() + if compute_gramian: + torch.manual_seed(0) # Fix randomness for random models + output = model_autogram(input) + losses = loss_fn(output) + _ = engine.compute_gramian(losses) # Verify that even with the hooked modules, autograd works normally when not using the engine. # Results should be the same as a normal call to autograd, and no time should be spent computing # the gramian at all. torch.manual_seed(0) # Fix randomness for random models autograd_forward_backward(model_autogram, input, loss_fn) - assert engine._gramian_accumulator.gramian is None grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} + assert_tensor_dicts_are_close(grads, autograd_grads) - model_autogram.zero_grad() + assert engine._gramian_accumulator.gramian is None def _non_empty_subsets(elements: set) -> list[set]: From e5698f9474be0a7bc8d89bcd8d5cd37c3d909e69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 15:14:48 +0200 Subject: [PATCH 20/20] Restructure file * Move / rename functions * Improve docstrings * Reformat --- tests/unit/autogram/test_engine.py | 108 +++++++++++++---------------- 1 file changed, 50 insertions(+), 58 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 2374b776..7ccfc3e8 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -106,13 +106,8 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_gramian_equivalence_autograd_autogram( - architecture: type[ShapedModule], - batch_size: int, -): - """ - Tests that the autograd and the autogram engines compute the same gramian. - """ +def test_compute_gramian(architecture: type[ShapedModule], batch_size: int): + """Tests that the autograd and the autogram engines compute the same gramian.""" input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES @@ -141,11 +136,55 @@ def test_gramian_equivalence_autograd_autogram( assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) +def _non_empty_subsets(elements: set) -> list[set]: + """ + Generates the list of subsets of the given set, excluding the empty set. + """ + return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)] + + +@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) +def test_compute_partial_gramian(gramian_module_names: set[str]): + """ + Tests that the autograd and the autogram engines compute the same gramian when only a subset of + the model parameters is specified. + """ + + architecture = SimpleBranched + batch_size = 64 + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + input = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.manual_seed(0) + model = architecture().to(device=DEVICE) + + output = model(input) + losses = loss_fn(output) + + gramian_modules = [model.get_submodule(name) for name in gramian_module_names] + gramian_params = [] + for m in gramian_modules: + gramian_params += list(m.parameters()) + + autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) + torch.manual_seed(0) + + engine = Engine(gramian_modules) + + output = model(input) + losses = loss_fn(output) + gramian = engine.compute_gramian(losses) + + assert_close(gramian, autograd_gramian) + + @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_iwrm_steps_with_autogram( - architecture: type[ShapedModule], - batch_size: int, -): +def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: int): """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 @@ -213,53 +252,6 @@ def test_autograd_while_modules_are_hooked( assert engine._gramian_accumulator.gramian is None -def _non_empty_subsets(elements: set) -> list[set]: - """ - Generates the list of subsets of the given set, excluding the empty set. - """ - return [set(c) for r in range(1, len(elements) + 1) for c in combinations(elements, r)] - - -@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) -def test_partial_autogram(gramian_module_names: set[str]): - """ - Tests that partial JD via the autogram engine works similarly as if the gramian was computed via - the autograd engine. - """ - - architecture = SimpleBranched - batch_size = 64 - - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - - input = make_tensors(batch_size, input_shapes) - targets = make_tensors(batch_size, output_shapes) - loss_fn = make_mse_loss_fn(targets) - - torch.manual_seed(0) - model = architecture().to(device=DEVICE) - - output = model(input) - losses = loss_fn(output) - - gramian_modules = [model.get_submodule(name) for name in gramian_module_names] - gramian_params = [] - for m in gramian_modules: - gramian_params += list(m.parameters()) - - autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) - torch.manual_seed(0) - - engine = Engine(gramian_modules) - - output = model(input) - losses = loss_fn(output) - gramian = engine.compute_gramian(losses) - - assert_close(gramian, autograd_gramian) - - @mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats]) def test_incompatible_modules(architecture: type[nn.Module]): """Tests that the engine cannot be constructed with incompatible modules."""