diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..f3e974f0 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -1,3 +1,4 @@ +import gc import time import torch @@ -15,6 +16,7 @@ ) from utils.forward_backwards import ( autograd_forward_backward, + autograd_gramian_forward_backward, autogram_forward_backward, autojac_forward_backward, make_mse_loss_fn, @@ -31,8 +33,8 @@ (AlexNet, 8), (InstanceNormResNet18, 16), (GroupNormMobileNetV3Small, 16), - (SqueezeNet, 16), - (InstanceNormMobileNetV2, 8), + (SqueezeNet, 4), + (InstanceNormMobileNetV2, 2), ] @@ -58,13 +60,23 @@ def fn_autograd(): def init_fn_autograd(): torch.cuda.empty_cache() + gc.collect() fn_autograd() + def fn_autograd_gramian(): + autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W) + + def init_fn_autograd_gramian(): + torch.cuda.empty_cache() + gc.collect() + fn_autograd_gramian() + def fn_autojac(): autojac_forward_backward(model, inputs, loss_fn, A) def init_fn_autojac(): torch.cuda.empty_cache() + gc.collect() fn_autojac() def fn_autogram(): @@ -72,6 +84,7 @@ def fn_autogram(): def init_fn_autogram(): torch.cuda.empty_cache() + gc.collect() fn_autogram() def optionally_cuda_sync(): @@ -91,6 +104,16 @@ def post_fn(): print(autograd_times) print() + autograd_gramian_times = torch.tensor( + time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs) + ) + print( + f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = " + f"{autograd_gramian_times.std():.5f}" + ) + print(autograd_gramian_times) + print() + autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs)) print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}") print(autojac_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index ba4ec06a..7ccfc3e8 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -5,6 +5,7 @@ from pytest import mark, param from torch import nn from torch.optim import SGD +from torch.testing import assert_close from unit.conftest import DEVICE from utils.architectures import ( AlexNet, @@ -52,15 +53,13 @@ from utils.forward_backwards import ( autograd_forward_backward, autogram_forward_backward, - autojac_forward_backward, + compute_gramian_with_autograd, make_mse_loss_fn, ) from utils.tensors import make_tensors -from torchjd.aggregation import UPGrad, UPGradWeighting +from torchjd.aggregation import UPGradWeighting from torchjd.autogram._engine import Engine -from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet -from torchjd.autojac._transform._aggregate import _Matrixify PARAMETRIZATIONS = [ (OverlyNested, 32), @@ -107,110 +106,34 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_equivalence_autojac_autogram( - architecture: type[ShapedModule], - batch_size: int, -): - """ - Tests that the autogram engine gives the same results as the autojac engine on IWRM for several - JD steps. - """ - - n_iter = 3 +def test_compute_gramian(architecture: type[ShapedModule], batch_size: int): + """Tests that the autograd and the autogram engines compute the same gramian.""" input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES - weighting = UPGradWeighting() - aggregator = UPGrad() - torch.manual_seed(0) - model_autojac = architecture().to(device=DEVICE) + model_autograd = architecture().to(device=DEVICE) torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) engine = Engine(model_autogram.modules()) - optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) - optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) - - for i in range(n_iter): - inputs = make_tensors(batch_size, input_shapes) - targets = make_tensors(batch_size, output_shapes) - loss_fn = make_mse_loss_fn(targets) - - torch.random.manual_seed(0) # Fix randomness for random aggregators and random models - autojac_forward_backward(model_autojac, inputs, loss_fn, aggregator) - expected_grads = { - name: p.grad for name, p in model_autojac.named_parameters() if p.grad is not None - } - - torch.random.manual_seed(0) # Fix randomness for random weightings and random models - autogram_forward_backward(model_autogram, engine, weighting, inputs, loss_fn) - grads = { - name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None - } - - assert_tensor_dicts_are_close(grads, expected_grads) - - optimizer_autojac.step() - model_autojac.zero_grad() - - optimizer_autogram.step() - model_autogram.zero_grad() - - -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int): - """ - Tests that the hooks added when constructing the engine do not interfere with a simple autograd - call. - """ - - input_shapes = architecture.INPUT_SHAPES - output_shapes = architecture.OUTPUT_SHAPES - W = UPGradWeighting() - A = UPGrad() - input = make_tensors(batch_size, input_shapes) + inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) - torch.manual_seed(0) - model = architecture().to(device=DEVICE) - - torch.manual_seed(0) # Fix randomness for random models - autojac_forward_backward(model, input, loss_fn, A) - autojac_grads = { - name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None - } - model.zero_grad() - - torch.manual_seed(0) # Fix randomness for random models - autograd_forward_backward(model, input, loss_fn) - autograd_grads = { - name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None - } - - torch.manual_seed(0) - model_autogram = architecture().to(device=DEVICE) + torch.random.manual_seed(0) # Fix randomness for random models + output = model_autograd(inputs) + losses = loss_fn(output) + autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters())) - # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules()) - torch.manual_seed(0) # Fix randomness for random models - autogram_forward_backward(model_autogram, engine, W, input, loss_fn) - grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, autojac_grads) - model_autogram.zero_grad() + torch.random.manual_seed(0) # Fix randomness for random models + output = model_autogram(inputs) + losses = loss_fn(output) + autogram_gramian = engine.compute_gramian(losses) - # Verify that even with the hooked modules, autograd works normally when not using the engine. - # Results should be the same as a normal call to autograd, and no time should be spent computing - # the gramian at all. - torch.manual_seed(0) # Fix randomness for random models - autograd_forward_backward(model_autogram, input, loss_fn) - assert engine._gramian_accumulator.gramian is None - grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, autograd_grads) - model_autogram.zero_grad() + assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) def _non_empty_subsets(elements: set) -> list[set]: @@ -221,20 +144,15 @@ def _non_empty_subsets(elements: set) -> list[set]: @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) -def test_partial_autogram(gramian_module_names: set[str]): +def test_compute_partial_gramian(gramian_module_names: set[str]): """ - Tests that partial JD via the autogram engine works similarly as if the gramian was computed via - the autojac engine. - - Note that this test is a bit redundant now that we have the Engine interface, because it now - just compares two ways of computing the Gramian, which is independant of the idea of partial JD. + Tests that the autograd and the autogram engines compute the same gramian when only a subset of + the model parameters is specified. """ architecture = SimpleBranched batch_size = 64 - weighting = UPGradWeighting() - input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES @@ -247,39 +165,91 @@ def test_partial_autogram(gramian_module_names: set[str]): output = model(input) losses = loss_fn(output) - losses_ = OrderedSet(losses) - - init = Init(losses_) - diag = Diagonalize(losses_) gramian_modules = [model.get_submodule(name) for name in gramian_module_names] - gramian_params = OrderedSet({}) + gramian_params = [] for m in gramian_modules: - gramian_params += OrderedSet(m.parameters()) + gramian_params += list(m.parameters()) - jac = Jac(losses_, OrderedSet(gramian_params), None, True) - mat = _Matrixify() - transform = mat << jac << diag << init - - jacobian_matrices = transform({}) - jacobian_matrix = torch.cat(list(jacobian_matrices.values()), dim=1) - gramian = jacobian_matrix @ jacobian_matrix.T + autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - losses.backward(weighting(gramian)) - - expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} - model.zero_grad() engine = Engine(gramian_modules) output = model(input) losses = loss_fn(output) gramian = engine.compute_gramian(losses) + + assert_close(gramian, autograd_gramian) + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: int): + """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" + + n_iter = 3 + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + weighting = UPGradWeighting() + + model = architecture().to(device=DEVICE) + + engine = Engine(model.modules()) + optimizer = SGD(model.parameters(), lr=1e-7) + + for i in range(n_iter): + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + autogram_forward_backward(model, engine, weighting, inputs, loss_fn) + + optimizer.step() + model.zero_grad() + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize("compute_gramian", [False, True]) +def test_autograd_while_modules_are_hooked( + architecture: type[ShapedModule], batch_size: int, compute_gramian: bool +): + """ + Tests that the hooks added when constructing the engine do not interfere with a simple autograd + call. + """ + + input = make_tensors(batch_size, architecture.INPUT_SHAPES) + targets = make_tensors(batch_size, architecture.OUTPUT_SHAPES) + loss_fn = make_mse_loss_fn(targets) + + torch.manual_seed(0) + model = architecture().to(device=DEVICE) torch.manual_seed(0) - losses.backward(weighting(gramian)) + model_autogram = architecture().to(device=DEVICE) + + torch.manual_seed(0) # Fix randomness for random models + autograd_forward_backward(model, input, loss_fn) + autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} + + # Hook modules and optionally compute the Gramian + engine = Engine(model_autogram.modules()) + if compute_gramian: + torch.manual_seed(0) # Fix randomness for random models + output = model_autogram(input) + losses = loss_fn(output) + _ = engine.compute_gramian(losses) - grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} - assert_tensor_dicts_are_close(grads, expected_grads) + # Verify that even with the hooked modules, autograd works normally when not using the engine. + # Results should be the same as a normal call to autograd, and no time should be spent computing + # the gramian at all. + torch.manual_seed(0) # Fix randomness for random models + autograd_forward_backward(model_autogram, input, loss_fn) + grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} + + assert_tensor_dicts_are_close(grads, autograd_grads) + assert engine._gramian_accumulator.gramian is None @mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats]) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 53ad3753..28611409 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -1,7 +1,7 @@ from typing import Callable import torch -from torch import Tensor, nn +from torch import Tensor, nn, vmap from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map @@ -29,6 +29,18 @@ def autojac_forward_backward( backward(losses, aggregator=aggregator) +def autograd_gramian_forward_backward( + model: nn.Module, + inputs: PyTree, + params: list[nn.Parameter], + loss_fn: Callable[[PyTree], Tensor], + weighting: Weighting, +) -> None: + losses = _forward_pass(model, inputs, loss_fn) + gramian = compute_gramian_with_autograd(losses, params, retain_graph=True) + losses.backward(weighting(gramian)) + + def autogram_forward_backward( model: nn.Module, engine: Engine, @@ -80,3 +92,30 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor: return raw_losses.unsqueeze(1) else: return raw_losses.flatten(start_dim=1) + + +def compute_gramian_with_autograd( + output: Tensor, inputs: list[nn.Parameter], retain_graph: bool = False +) -> Tensor: + """ + Computes the Gramian of the Jacobian of the outputs with respect to the inputs using vmapped + calls to the autograd engine. + """ + + filtered_inputs = [input for input in inputs if input.requires_grad] + + def get_vjp(grad_outputs: Tensor) -> list[Tensor]: + grads = torch.autograd.grad( + output, + filtered_inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + allow_unused=True, + ) + return [grad for grad in grads if grad is not None] + + jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) + jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] + gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) + + return gramian