From e757e316e32c4d37c8b95863217ba34e722c5e2b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 27 Aug 2025 16:51:21 +0200 Subject: [PATCH 01/44] Add non-batched support Add reshape of jacobian for the scalar output case. Fix reshape of the Gramian, we for the last half of the dimensions, we need to reshape in the same order as the first, then we move the dimensions. We could in principle create a `reshape_gramian` function that does this, as well as a `move_dim_gramian` Add a test of values for all four cases of having a batched/non-batched dimension. Tests or reshape/move-dim should work should go in another test. Remove some tests that do not test anything more than `test_gramian_is_correct`. Add `_gramian_utils.py` which contains helper to `reshape` and `movedim` on a Gramian. Add `generate_vmap_rule = True` for `JacobianAccumulator`. This allows vmaping the forward phase. This enables having several Engines defined on the same module. Add `test_reshape_equivariance` Add tests to verify that gramian utils yields the correct quadratic forms. Add tests to verify that gramian utils yields the correct quadratic forms. Add `test_movedim_equivariance` Fix warning. Fix warning. Remove handles from `ModuleHookManager` Change `batched_dims` to a single optional `batched_dim`. Fix movedim in `compute_gramian` and add `test_movedim_equivariance` Remove `grad_output`, can be added later, but should be `jac_output` instead. Make modules with incompatible batched operations are compatible with non-batched autogram. Fix doc tests Provide the autograd vjp for when no dimension is batched. This enables having a single forward in that case which should be faster. Make VJPs into Callable classes. --- src/torchjd/autogram/_engine.py | 72 ++++++- src/torchjd/autogram/_gramian_utils.py | 39 ++++ src/torchjd/autogram/_module_hook_manager.py | 41 +++- src/torchjd/autogram/_vjp.py | 98 ++++++--- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 4 +- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 206 ++++++++++++++++++- tests/unit/autogram/test_gramian_utils.py | 85 ++++++++ 9 files changed, 494 insertions(+), 55 deletions(-) create mode 100644 src/torchjd/autogram/_gramian_utils.py create mode 100644 tests/unit/autogram/test_gramian_utils.py diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 5aabcdc9..cc5e53df 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -2,11 +2,12 @@ from typing import cast import torch -from torch import Tensor, nn +from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._gramian_utils import movedim_gramian, reshape_gramian from ._module_hook_manager import ModuleHookManager _INCOMPATIBLE_MODULE_TYPES = ( @@ -57,6 +58,9 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. + :param is_batched: If a dimension is batched, then many intermediary jacobians are block + diagonal, which allows for a substancial memory optimization by backpropagating a squashed + Jacobian instead. If the only dimension of the losses vector is batched. Default to True. .. admonition:: Example @@ -79,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules()) + >>> engine = Engine(model.modules(), (0,)) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] @@ -127,10 +131,17 @@ class Engine: `_ layers. """ - def __init__(self, modules: Iterable[nn.Module]): + def __init__( + self, + modules: Iterable[nn.Module], + batched_dim: int | None = None, + ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() - self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) + self._batched_dim = batched_dim + self._module_hook_manager = ModuleHookManager( + self._target_edges, self._gramian_accumulator, batched_dim is not None + ) self._hook_modules(modules) @@ -143,9 +154,8 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None: self._check_module_is_compatible(module) self._module_hook_manager.hook_module(module) - @staticmethod - def _check_module_is_compatible(module: nn.Module) -> None: - if isinstance(module, _INCOMPATIBLE_MODULE_TYPES): + def _check_module_is_compatible(self, module: nn.Module) -> None: + if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their" @@ -166,12 +176,39 @@ def compute_gramian(self, output: Tensor) -> Tensor: ``modules``. :param output: The vector to differentiate. Must be a 1-D tensor. + :param grad_output: The tangents for the differentiation. Default to a vector of 1s of the + same shape as `output`. """ - reshaped_output = output.reshape([-1]) - return self._compute_square_gramian(reshaped_output) + if self._batched_dim is not None: + # move batched dim to the end + ordered_output = torch.movedim(output, self._batched_dim, -1) + ordered_shape = list(ordered_output.shape) + has_non_batched_dim = len(ordered_shape) > 1 + target_shape = [ordered_shape[-1]] + else: + ordered_output = output + ordered_shape = list(ordered_output.shape) + has_non_batched_dim = len(ordered_shape) > 0 + target_shape = [] - def _compute_square_gramian(self, output: Tensor) -> Tensor: + if has_non_batched_dim: + target_shape = [-1] + target_shape + + reshaped_output = ordered_output.reshape(target_shape) + + flat_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) + + unordered_gramian = reshape_gramian(flat_gramian, ordered_shape) + + if self._batched_dim is not None: + gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) + else: + gramian = unordered_gramian + + return gramian + + def _compute_square_gramian(self, output: Tensor, has_non_batched_dim: bool) -> Tensor: self._module_hook_manager.gramian_accumulation_phase = True leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) @@ -184,7 +221,20 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: retain_graph=True, ) - _ = differentiation(torch.ones_like(output)) + if has_non_batched_dim: + # There is one non-batched dimension, it is the first one + non_batched_dim_len = output.shape[0] + jac_output_shape = [output.shape[0]] + list(output.shape) + + # Need to batch `grad_output` over the first dimension + jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) + for i in range(non_batched_dim_len): + jac_output[i, i] = 1 + + _ = vmap(differentiation)(jac_output) + else: + grad_output = torch.ones_like(output) + _ = differentiation(grad_output) # If the gramian were None, then leaf_targets would be empty, so autograd.grad would # have failed. So gramian is necessarily a valid Tensor here. diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py new file mode 100644 index 00000000..19affdb2 --- /dev/null +++ b/src/torchjd/autogram/_gramian_utils.py @@ -0,0 +1,39 @@ +from torch import Tensor + + +def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: + """ + Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first + half of the target dimensions must be done from the left, while the reshape of the second half + must be done from the right. + :param gramian: Gramian to reshape + :param shape: First half of the target shape, the shape of the output is therefore + `shape + shape[::-1]`. + """ + + target_ndim = len(shape) + unordered_shape = shape + shape + unordered_gramian = gramian.reshape(unordered_shape) + last_dims = [target_ndim + i for i in range(target_ndim)] + return unordered_gramian.movedim(last_dims, last_dims[::-1]) + + +def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: + """ + Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a + Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the + dimensions and on the second half of the dimensions reversed. + :param gramian: Gramian to reshape. + :param source: Source dimensions, should be in the range [0, gramian.ndim/2]. Should be unique + :param destination: Destination dimensions, should be in the range [0, gramian.ndim/2]. Should + be unique and should have the same size as `source`. + """ + + length = gramian.ndim // 2 + source = [i if 0 <= i else i + length for i in source] + destination = [i if 0 <= i else i + length for i in destination] + + last_index = gramian.ndim - 1 + source_dims = source + [last_index - i for i in source] + destination_dims = destination + [last_index - i for i in destination] + return gramian.movedim(source_dims, destination_dims) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e5087bb0..0b51ec1e 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -3,12 +3,11 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_unflatten -from torch.utils.hooks import RemovableHandle as TorchRemovableHandle +from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import get_functional_vjp +from ._vjp import AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -32,11 +31,12 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, + has_batch_dim: bool, ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator + self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = False - self._handles: list[TorchRemovableHandle] = [] def hook_module(self, module: nn.Module) -> None: """ @@ -70,8 +70,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) - handle = module.register_forward_hook(module_hook) - self._handles.append(handle) + _ = module.register_forward_hook(module_hook) def _apply_jacobian_accumulator( self, @@ -80,21 +79,45 @@ def _apply_jacobian_accumulator( tree_spec: TreeSpec, flat_outputs: list[Tensor], ) -> PyTree: - vjp = torch.vmap(get_functional_vjp(module)) + + if self._has_batch_dim: + vjp = torch.vmap(FunctionalVJP(module)) + else: + vjp = AutogradVJP(module, flat_outputs) class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward(*flat_grad_outputs: Tensor) -> None: + # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) jacobians = vjp(grad_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { - module.get_parameter(param_name): jacobian + module.get_parameter(param_name): jacobian.reshape( + [-1] + list(module.get_parameter(param_name).shape) + ) for param_name, jacobian in jacobians.items() } ) + @staticmethod + def vmap(info, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + # There is a non-batched dimension + jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) + # We do not vmap over the args for the non-batched dimension + in_dims = (tree_unflatten(in_dims, tree_spec), tree_map(lambda _: None, args)) + jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + self._gramian_accumulator.accumulate_path_jacobians( + { + module.get_parameter(param_name): jacobian.reshape( + [-1] + list(module.get_parameter(param_name).shape) + ) + for param_name, jacobian in jacobians.items() + } + ) + return None, None + @staticmethod def setup_context(*_): pass @@ -108,6 +131,8 @@ class JacobianAccumulator(torch.autograd.Function): toggle mechanism to activate only during the Gramian accumulation phase. """ + generate_vmap_rule = True + @staticmethod def forward(*xs: Tensor) -> tuple[Tensor, ...]: return tuple([x.detach() for x in xs]) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index b4bea046..80be32a0 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -1,9 +1,10 @@ -from collections.abc import Callable +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_map_only +from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -14,19 +15,41 @@ # still support older versions of PyTorch where pytree is protected). -def get_functional_vjp(module: nn.Module) -> Callable[[PyTree, PyTree], dict[str, Tensor]]: +class VJP(ABC): """ - Create a VJP function for a module's forward pass with respect to its parameters. The returned - function takes both the input and the cotangents that can be vmaped jointly in both terms to - avoid providing to block diagonal jacobians. + Represents a VJP function for a module's forward pass with respect to its parameters using the + func api. :params module: The module to differentiate. - :returns: VJP function that takes cotangents and inputs and returns dictionary of names of + """ + + def __init__(self, module: nn.Module): + self.module = module + self.named_parameters = dict(module.named_parameters(recurse=False)) + + @abstractmethod + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + """ + VJP function that takes cotangents and inputs and returns dictionary of names of parameters (as given by `module.named_parameters.keys()`) to gradients of the parameters for the given cotangents at the given inputs. + """ + + +class FunctionalVJP(VJP): """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + func api. The __call__ function takes both the inputs and the cotangents that can be vmaped + jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using + this method is that it computes the forward phase. - def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + :params module: The module to differentiate. + """ + + def __init__(self, module: nn.Module): + super().__init__(module) + + def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with @@ -39,30 +62,51 @@ def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # primals (tuple), here the functional has a single primal which is # dict(module.named_parameters()). We therefore take the 0'th element to obtain # the dict of gradients w.r.t. the module's named_parameters. - return _vjp_from_module(module, inputs_j)(grad_outputs_j)[0] + return self._vjp_from_module(inputs_j)(grad_outputs_j)[0] - return get_vjp + def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: + """ + Create a VJP function for a module's forward pass with respect to its parameters. + Returns a function that computes vector-Jacobian products for the module's parameters given + fixed inputs. Only parameters with requires_grad=True are included in the differentiation. -def _vjp_from_module( - module: nn.Module, inputs: PyTree -) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: - """ - Create a VJP function for a module's forward pass with respect to its parameters. + :param inputs: Fixed inputs to the module for the VJP computation. + :returns: VJP function that takes cotangents and returns parameter gradients. + """ + requires_grad_named_params = { + k: v for k, v in self.named_parameters.items() if v.requires_grad + } + no_requires_grad_named_params = { + k: v for k, v in self.named_parameters.items() if not v.requires_grad + } - Returns a function that computes vector-Jacobian products for the module's parameters given - fixed inputs. Only parameters with requires_grad=True are included in the differentiation. + def functional_model_call(primals: dict[str, Parameter]) -> Tensor: + all_state = { + **primals, + **dict(self.module.named_buffers()), + **no_requires_grad_named_params, + } + return torch.func.functional_call(self.module, all_state, inputs) - :param module: The module to differentiate. - :param inputs: Fixed inputs to the module for the VJP computation. - :returns: VJP function that takes cotangents and returns parameter gradients. + return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + + +class AutogradVJP(VJP): + """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the + inputs. The main advantage of using this method is that it doesn't require computing the forward + phase. """ - named_params = dict(module.named_parameters(recurse=False)) - requires_grad_named_params = {k: v for k, v in named_params.items() if v.requires_grad} - no_requires_grad_named_params = {k: v for k, v in named_params.items() if not v.requires_grad} - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: - all_state = {**primals, **dict(module.named_buffers()), **no_requires_grad_named_params} - return torch.func.functional_call(module, all_state, inputs) + def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): + super().__init__(module) + self.outputs = outputs + self.parameters, self.tree_spec = tree_flatten(dict(self.named_parameters)) - return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: + grads = torch.autograd.grad( + self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True + ) + return tree_unflatten(grads, self.tree_spec) diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index be55a1ae..724d81e2 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), 0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 4f2b9351..86400146 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), 0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -325,7 +325,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules()) + engine = Engine(model[2:].modules(), 0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..ad26892a 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules()) + engine = Engine(model.modules(), (0,)) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index be7b593b..994a9ff5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,10 +1,13 @@ from itertools import combinations +from math import prod import pytest import torch from pytest import mark, param from torch import nn +from torch.nn import Linear from torch.optim import SGD +from torch.testing import assert_close from unit.conftest import DEVICE from utils.architectures import ( AlexNet, @@ -55,7 +58,7 @@ autojac_forward_backward, make_mse_loss_fn, ) -from utils.tensors import make_tensors +from utils.tensors import make_tensors, ones_, randn_, zeros_ from torchjd.aggregation import ( IMTLG, @@ -80,6 +83,7 @@ Weighting, ) from torchjd.autogram._engine import Engine +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet from torchjd.autojac._transform._aggregate import _Matrixify @@ -171,7 +175,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), 0) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) @@ -237,7 +241,7 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model_autogram = architecture().to(device=DEVICE) # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), 0) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} @@ -311,7 +315,7 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} model.zero_grad() - engine = Engine(gramian_modules) + engine = Engine(gramian_modules, 0) output = model(input) losses = loss_fn(output) @@ -330,4 +334,196 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules()) + _ = Engine(model.modules(), 0) + + +@mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) +@mark.parametrize("batch_size", [None, 3, 16, 32]) +@mark.parametrize("reduce_output", [True, False]) +def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): + """ + Tests that the Gramian computed by then `Engine` equals to a manual computation of the expected + Gramian. + """ + + is_batched = batch_size is not None + + if is_batched: + batched_dims = 0 + input_dim = [batch_size, shape[0]] + else: + batched_dims = None + input_dim = [shape[0]] + + model = Linear(shape[0], shape[1]) + engine = Engine([model], batched_dims) + + input = randn_(input_dim) + output = model(input) + if reduce_output: + output = torch.sum(output, dim=-1) + + assert output.ndim == int(not reduce_output) + int(is_batched) + + gramian = engine.compute_gramian(output) + + # compute the expected gramian + output_shape = list(output.shape) + initial_jacobian = torch.diag(ones_(output.numel())).reshape(output_shape + output_shape) + + if reduce_output: + initial_jacobian = initial_jacobian.unsqueeze(-1).repeat( + ([1] * initial_jacobian.ndim) + [shape[1]] + ) + if not is_batched: + initial_jacobian = initial_jacobian.unsqueeze(-2) + input = input.unsqueeze(0) + + assert initial_jacobian.shape[-2] == (1 if batch_size is None else batch_size) + assert initial_jacobian.shape[-1] == shape[1] + assert initial_jacobian.shape[:-2] == output.shape + + assert input.shape[0] == (1 if batch_size is None else batch_size) + assert input.shape[1] == shape[0] + + # If k is the batch_size (1 if None) and n the input size and m the output size, then + # - input has shape `[k, n]` + # - initial_jacobian has shape `output.shape + `[k, m]` + + # The partial (batched) jacobian of outputs w.r.t. weights is of shape `[k, m, m, n]`, whe + # multiplied (along 2 dims) by initial_jacobian this yields the jacobian of the weights of shape + # `output.shape + [m, n]`. The partial jacobian itself is block diagonal with diagonal defined + # by `partial_weight_jacobian[i, j, j] = input[i]` (other elements are 0). + + partial_weight_jacobian = zeros_([input.shape[0], shape[1], shape[1], shape[0]]) + for j in range(shape[1]): + partial_weight_jacobian[:, j, j, :] = input + weight_jacobian = torch.tensordot( + initial_jacobian, partial_weight_jacobian, dims=([-2, -1], [0, 1]) + ) + weight_gramian = torch.tensordot(weight_jacobian, weight_jacobian, dims=([-2, -1], [-2, -1])) + if weight_gramian.ndim == 4: + weight_gramian = weight_gramian.movedim((-2), (-1)) + + # The partial (batched) jacobian of outputs w.r.t. bias is of shape `[k, m, m]`, when multiplied + # (along 2 dims) by initial_jacobian this yields the jacobian of the bias of shape + # `output.shape + [m]`. The partial jacobian itself is block diagonal with diagonal defined by + # `partial_bias_jacobian[i, j, j] = 1` (other elements are 0). + partial_bias_jacobian = zeros_([input.shape[0], shape[1], shape[1]]) + for j in range(shape[1]): + partial_bias_jacobian[:, j, j] = 1.0 + bias_jacobian = torch.tensordot( + initial_jacobian, partial_bias_jacobian, dims=([-2, -1], [0, 1]) + ) + bias_gramian = torch.tensordot(bias_jacobian, bias_jacobian, dims=([-1], [-1])) + if bias_gramian.ndim == 4: + bias_gramian = bias_gramian.movedim(-2, -1) + + expected_gramian = weight_gramian + bias_gramian + + assert_close(gramian, expected_gramian) + + +@mark.parametrize( + "shape", + [ + [1, 2, 2, 3], + [7, 3, 2, 5], + [27, 6, 7], + ], +) +def test_reshape_equivariance(shape: list[int]): + """ + Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape + the `output` to some `shape`, then the result is the same as reshaping the Gramian to the + corresponding shape. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input) + + reshaped_output = output.reshape(shape[1:]) + + gramian = engine1.compute_gramian(output) + reshaped_gramian = engine2.compute_gramian(reshaped_output) + + expected_reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + assert_close(reshaped_gramian, expected_reshaped_gramian) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ], +) +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): + """ + Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim + the `output` on some dimensions, then the result is the same as movedim on the Gramian with the + corresponding dimensions. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input).reshape(shape[1:]) + + moved_output = output.movedim(source, destination) + + gramian = engine1.compute_gramian(output) + moved_gramian = engine2.compute_gramian(moved_output) + + expected_moved_gramian = movedim_gramian(gramian, source, destination) + + assert_close(moved_gramian, expected_moved_gramian) + + +@mark.parametrize( + ["shape", "batched_dim"], + [ + ([2, 5, 3, 2], 2), + ([3, 2, 5], 1), + ([6, 3], 0), + ([4, 3, 2], 1), + ], +) +def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): + """ + Tests that for a vector with some batched dimensions, the gramian is the same if we use the + appropriate `batched_dims` or if we don't use any. + """ + + non_batched_shape = [shape[i] for i in range(len(shape)) if i != batched_dim] + input_size = prod(non_batched_shape) + batch_size = shape[batched_dim] + output_size = input_size + + model = Linear(input_size, output_size) + engine1 = Engine([model], batched_dim) + engine2 = Engine([model]) + + input = randn_([batch_size, input_size]) + output = model(input) + output = output.reshape([batch_size] + non_batched_shape) + output = output.movedim(0, batched_dim) + + gramian1 = engine1.compute_gramian(output) + gramian2 = engine2.compute_gramian(output) + + assert_close(gramian1, gramian2) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py new file mode 100644 index 00000000..583f44b4 --- /dev/null +++ b/tests/unit/autogram/test_gramian_utils.py @@ -0,0 +1,85 @@ +from math import prod + +import torch +from pytest import mark +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import rand_ + +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian + + +def compute_quadratic_form(gramian: Tensor, vector: Tensor) -> Tensor: + """ + Compute the quadratic form x^T G x when the provided Gramian and vector may have multiple + dimensions. + """ + indices = list(range(vector.ndim)) + linear_form = torch.tensordot(vector, gramian, dims=(indices, indices)) + return torch.tensordot(linear_form, vector, dims=(indices[::-1], indices)) + + +@mark.parametrize( + "shape", + [ + [50, 2, 2, 3], + [60, 3, 2, 5], + [30, 6, 7], + ], +) +def test_quadratic_form_invariance_to_reshape(shape: list[int]): + """ + When reshaping a Gramian, we expect it to represent the same quadratic form that now applies to + reshaped inputs. So the mapping x -> x^T G x commutes with reshaping x, G and then computing the + corresponding quadratic form. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = matrix @ matrix.T + reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + for _ in range(iterations): + vector = rand_([flat_dim]) + reshaped_vector = vector.reshape(shape[1:]) + + quadratic_form = vector @ gramian @ vector + reshaped_quadratic_form = compute_quadratic_form(reshaped_gramian, reshaped_vector) + + assert_close(reshaped_quadratic_form, quadratic_form) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ], +) +def test_quadratic_form_invariance_to_movedim( + shape: list[int], source: list[int], destination: list[int] +): + """ + When moving dims on a Gramian, we expect it to represent the same quadratic form that now + applies to inputs with moved dims. So the mapping x -> x^T G x commutes with moving dims x, G + and then computing the quadratic form with those. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = reshape_gramian(matrix @ matrix.T, shape[1:]) + moved_gramian = movedim_gramian(gramian, source, destination) + + for _ in range(iterations): + vector = rand_(shape[1:]) + moved_vector = vector.movedim(source, destination) + + quadratic_form = compute_quadratic_form(gramian, vector) + moved_quadratic_form = compute_quadratic_form(moved_gramian, moved_vector) + + assert_close(moved_quadratic_form, quadratic_form) From 8f6e5ee4d8affa0264da7d78acc5a6467365b459 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 3 Sep 2025 13:47:18 +0200 Subject: [PATCH 02/44] Update src/torchjd/autogram/_vjp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- src/torchjd/autogram/_vjp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 80be32a0..68463e47 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -17,8 +17,7 @@ class VJP(ABC): """ - Represents a VJP function for a module's forward pass with respect to its parameters using the - func api. + Represents an abstract VJP function for a module's forward pass with respect to its parameters. :params module: The module to differentiate. """ From d507daa90430c63210e72ec62cd2738b3b20c349 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 3 Sep 2025 13:55:04 +0200 Subject: [PATCH 03/44] Update tests/unit/autogram/test_engine.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 994a9ff5..003f2aaa 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -342,7 +342,7 @@ def test_incompatible_modules(architecture: type[nn.Module]): @mark.parametrize("reduce_output", [True, False]) def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): """ - Tests that the Gramian computed by then `Engine` equals to a manual computation of the expected + Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected Gramian. """ From d60328ec2329c809cd21c9151e861570c2b6f56f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:52:22 +0200 Subject: [PATCH 04/44] Fix batched_dim param in some Engine usages --- src/torchjd/autogram/_engine.py | 2 +- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index cc5e53df..daf0ded3 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -83,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), (0,)) + >>> engine = Engine(model.modules(), 0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index ad26892a..a2e5bcf9 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), (0,)) + engine = Engine(model.modules(), 0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) From 4103e4d5a270b376c6524a6ba0c4e03a48a7024b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:54:47 +0200 Subject: [PATCH 05/44] Make batched_dim parameter name explicit when creating Engine --- src/torchjd/autogram/_engine.py | 2 +- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 4 ++-- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 12 ++++++------ 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index daf0ded3..a7b8f28e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -83,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), 0) + >>> engine = Engine(model.modules(), batched_dim=0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 724d81e2..e0e3117f 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 86400146..d7be637e 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -325,7 +325,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), 0) + engine = Engine(model[2:].modules(), batched_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index a2e5bcf9..7080373c 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 003f2aaa..111c83bb 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -175,7 +175,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), 0) + engine = Engine(model_autogram.modules(), batched_dim=0) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) @@ -241,7 +241,7 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model_autogram = architecture().to(device=DEVICE) # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules(), 0) + engine = Engine(model_autogram.modules(), batched_dim=0) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} @@ -315,7 +315,7 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} model.zero_grad() - engine = Engine(gramian_modules, 0) + engine = Engine(gramian_modules, batched_dim=0) output = model(input) losses = loss_fn(output) @@ -334,7 +334,7 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), 0) + _ = Engine(model.modules(), batched_dim=0) @mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) @@ -356,7 +356,7 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp input_dim = [shape[0]] model = Linear(shape[0], shape[1]) - engine = Engine([model], batched_dims) + engine = Engine([model], batched_dim=batched_dims) input = randn_(input_dim) output = model(input) @@ -515,7 +515,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): output_size = input_size model = Linear(input_size, output_size) - engine1 = Engine([model], batched_dim) + engine1 = Engine([model], batched_dim=batched_dim) engine2 = Engine([model]) input = randn_([batch_size, input_size]) From 4fadac6012e2479100a1333460f24cc6a538fe21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:55:13 +0200 Subject: [PATCH 06/44] Rename batched_dims to batched_dim in test_gramian_is_correct --- tests/unit/autogram/test_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 111c83bb..b83b87c9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -349,14 +349,14 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp is_batched = batch_size is not None if is_batched: - batched_dims = 0 + batched_dim = 0 input_dim = [batch_size, shape[0]] else: - batched_dims = None + batched_dim = None input_dim = [shape[0]] model = Linear(shape[0], shape[1]) - engine = Engine([model], batched_dim=batched_dims) + engine = Engine([model], batched_dim=batched_dim) input = randn_(input_dim) output = model(input) From 5c98ee8eed4cb722f5b65fa048e960ddcc1e0ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:58:58 +0200 Subject: [PATCH 07/44] Fix parameter description of batched_dim --- src/torchjd/autogram/_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index a7b8f28e..0951277e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -58,9 +58,10 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. - :param is_batched: If a dimension is batched, then many intermediary jacobians are block - diagonal, which allows for a substancial memory optimization by backpropagating a squashed - Jacobian instead. If the only dimension of the losses vector is batched. Default to True. + :param batched_dim: If the modules work with batches and process each batch element + independently, then many intermediary jacobians are sparse (block-diagonal), which allows + for a substancial memory optimization by backpropagating a squashed Jacobian instead. This + parameter indicates the batch dimension, if any. Defaults to None. .. admonition:: Example From e5e6e9d7eb9a8f9efd2ce627192231ab1d692771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:03:06 +0200 Subject: [PATCH 08/44] Improve error message in _check_module_is_compatible --- src/torchjd/autogram/_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0951277e..db1700a2 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -159,8 +159,11 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " - f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their" - " subclasses)." + f"engine when `batched_dim` is not `None`. The incompatible module types are " + f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to " + f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), " + f"but if you really can't and performance not a priority, you may also just set" + f"`batch_dim=None` when creating the engine." ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: From c6c6a3f06d39f107727fc44191d7c8820a574a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:04:10 +0200 Subject: [PATCH 09/44] Remove parameter description of removed parameter grad_output in compute_gramian --- src/torchjd/autogram/_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index db1700a2..3c23b2aa 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -180,8 +180,6 @@ def compute_gramian(self, output: Tensor) -> Tensor: ``modules``. :param output: The vector to differentiate. Must be a 1-D tensor. - :param grad_output: The tangents for the differentiation. Default to a vector of 1s of the - same shape as `output`. """ if self._batched_dim is not None: From af3373bbb0556ac47b0f27c8d957e1126ed6b586 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:05:24 +0200 Subject: [PATCH 10/44] Rename flat_gramian to square_gramian Not the final name I think, but at least it's consistent with the method name --- src/torchjd/autogram/_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 3c23b2aa..45bf4340 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -199,9 +199,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: reshaped_output = ordered_output.reshape(target_shape) - flat_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) + square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) - unordered_gramian = reshape_gramian(flat_gramian, ordered_shape) + unordered_gramian = reshape_gramian(square_gramian, ordered_shape) if self._batched_dim is not None: gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) From aeb8f3b1154a723ea06f03bd9cbd29e2a8918f35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:07:21 +0200 Subject: [PATCH 11/44] Remove redundant cast to dict in AutogradVJP --- src/torchjd/autogram/_vjp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 68463e47..0f4bcd67 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,7 +102,7 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.tree_spec = tree_flatten(dict(self.named_parameters)) + self.parameters, self.tree_spec = tree_flatten(self.named_parameters) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( From d5b868564d9390cc6e4727bd1044b6d34318a797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:11:47 +0200 Subject: [PATCH 12/44] Rename info to _ in AccumulateJacobian.vmap --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0b51ec1e..ad98ab44 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -102,7 +102,7 @@ def forward(*flat_grad_outputs: Tensor) -> None: ) @staticmethod - def vmap(info, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + def vmap(_, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) # We do not vmap over the args for the non-batched dimension From b789e358be69dee7075f4251a0f34ef42bba7749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:12:20 +0200 Subject: [PATCH 13/44] Type-hint in_dims as PyTree in AccumulateJacobian.vmap --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index ad98ab44..b12d087f 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -102,7 +102,7 @@ def forward(*flat_grad_outputs: Tensor) -> None: ) @staticmethod - def vmap(_, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) # We do not vmap over the args for the non-batched dimension From e949dffed0d91efc8fafd02354026a1bc67f5788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:15:12 +0200 Subject: [PATCH 14/44] Rename tree_spec to output_spec in ModuleHookManager --- src/torchjd/autogram/_module_hook_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index b12d087f..04919306 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -50,7 +50,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: if self.gramian_accumulation_phase: return output - flat_outputs, tree_spec = tree_flatten(output) + flat_outputs, output_spec = tree_flatten(output) if not any(isinstance(t, Tensor) for t in flat_outputs): # This can happen only if a module returns no Tensor, for instance some niche usage @@ -68,7 +68,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self._target_edges.register(get_gradient_edge(flat_outputs[index])) - return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) + return self._apply_jacobian_accumulator(module, args, output_spec, flat_outputs) _ = module.register_forward_hook(module_hook) @@ -76,7 +76,7 @@ def _apply_jacobian_accumulator( self, module: nn.Module, args: PyTree, - tree_spec: TreeSpec, + output_spec: TreeSpec, flat_outputs: list[Tensor], ) -> PyTree: @@ -90,7 +90,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension - grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) + grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) jacobians = vjp(grad_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { @@ -104,9 +104,9 @@ def forward(*flat_grad_outputs: Tensor) -> None: @staticmethod def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension - jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) + jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension - in_dims = (tree_unflatten(in_dims, tree_spec), tree_map(lambda _: None, args)) + in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { @@ -150,4 +150,4 @@ def backward(ctx, *flat_grad_outputs: Tensor): return flat_grad_outputs - return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec) + return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), output_spec) From 16f6aa9d3bbc5544bf1ad875c52f03c0d42fd92f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:15:26 +0200 Subject: [PATCH 15/44] Rename self.tree_spec to self.param_spec in AutogradVJP --- src/torchjd/autogram/_vjp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 0f4bcd67..25450a98 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,10 +102,10 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.tree_spec = tree_flatten(self.named_parameters) + self.parameters, self.param_spec = tree_flatten(self.named_parameters) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True ) - return tree_unflatten(grads, self.tree_spec) + return tree_unflatten(grads, self.param_spec) From ab134a82e8ce04c28a2ed48d4df4cbfe3c3cc565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:21:58 +0200 Subject: [PATCH 16/44] Add example in comment of reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 19affdb2..ef3338a2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -11,6 +11,10 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: `shape + shape[::-1]`. """ + # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: + # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] + # - The returned gramian will be of shape [4, 3, 2, 2, 3, 4] + target_ndim = len(shape) unordered_shape = shape + shape unordered_gramian = gramian.reshape(unordered_shape) From 8d66d57cd0cf096af773755a6e7aec1d50ce4375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:23:57 +0200 Subject: [PATCH 17/44] Improve variable names in compute_quadratic_form --- tests/unit/autogram/test_gramian_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 583f44b4..1bc1ed0e 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -9,14 +9,14 @@ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian -def compute_quadratic_form(gramian: Tensor, vector: Tensor) -> Tensor: +def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: """ - Compute the quadratic form x^T G x when the provided Gramian and vector may have multiple + Compute the quadratic form x^T G x when the provided generalized Gramian and x may have multiple dimensions. """ - indices = list(range(vector.ndim)) - linear_form = torch.tensordot(vector, gramian, dims=(indices, indices)) - return torch.tensordot(linear_form, vector, dims=(indices[::-1], indices)) + indices = list(range(x.ndim)) + linear_form = torch.tensordot(x, generalized_gramian, dims=(indices, indices)) + return torch.tensordot(linear_form, x, dims=(indices[::-1], indices)) @mark.parametrize( From 4bdc4008a55b57724096a3f33c50411d746997c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 16:22:41 +0200 Subject: [PATCH 18/44] Revamp documentation of compute_gramian --- src/torchjd/autogram/_engine.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 45bf4340..06d9d249 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -175,11 +175,30 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: ) def compute_gramian(self, output: Tensor) -> Tensor: - """ - Compute the Gramian of the Jacobian of ``output`` with respect the direct parameters of all - ``modules``. - - :param output: The vector to differentiate. Must be a 1-D tensor. + r""" + Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of + all ``modules``. + + .. note:: + This function doesn't require ``output`` to be a vector. For example, if ``output`` is + a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the + parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of + parameters in the model. This is what we call a `generalized Jacobian`. The + corresponding Gramian :math:`G = J J^\top` will be of shape :math:`[m_1, m_2, m_2, m_1]`. + This is what we call a `generalized Gramian`. The number of dimensions of the returned + generalized Gramian will always be twice that of the ``output``. + + A few examples: + - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the + norm of the gradient of ``output``). + - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian + descent). + - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and + multi-task learning, as each sample in the batch has one loss per task). + - etc. + + :param output: The tensor of arbitrary shape to differentiate. The shape of the returned + Gramian depends on the shape of this output, as explained in the note above. """ if self._batched_dim is not None: From 33216c12fb4291cf98ff47131a8393215aa20c9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 5 Sep 2025 16:43:30 +0200 Subject: [PATCH 19/44] Update src/torchjd/autogram/_engine.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 06d9d249..7578dfd8 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -190,7 +190,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: A few examples: - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the - norm of the gradient of ``output``). + squared norm of the gradient of ``output``). - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian descent). - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and From c090dfe2ff457d99a3390f93584fa5a2c6b163c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:40:53 +0200 Subject: [PATCH 20/44] Add ... indexing in jac_output for code clarity --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 7578dfd8..8b1dfc33 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -250,7 +250,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # Need to batch `grad_output` over the first dimension jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) for i in range(non_batched_dim_len): - jac_output[i, i] = 1 + jac_output[i, i, ...] = 1 _ = vmap(differentiation)(jac_output) else: From 3466a3c93ba27ace5da3e23ed94be97081569f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:47:19 +0200 Subject: [PATCH 21/44] Fix formatting of docstring --- src/torchjd/autogram/_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 8b1dfc33..a3c4e90f 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -183,10 +183,11 @@ def compute_gramian(self, output: Tensor) -> Tensor: This function doesn't require ``output`` to be a vector. For example, if ``output`` is a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of - parameters in the model. This is what we call a `generalized Jacobian`. The - corresponding Gramian :math:`G = J J^\top` will be of shape :math:`[m_1, m_2, m_2, m_1]`. - This is what we call a `generalized Gramian`. The number of dimensions of the returned - generalized Gramian will always be twice that of the ``output``. + parameters in the model. This is what we call a generalized Jacobian. The + corresponding Gramian :math:`G = J J^\top` will be of shape + :math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number + of dimensions of the returned generalized Gramian will always be twice that of the + ``output``. A few examples: - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the From d2bab7a1723758466171e89df6b4b35803b746f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:54:30 +0200 Subject: [PATCH 22/44] Add more parametrizations to test_reshape_equivariance --- tests/unit/autogram/test_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b83b87c9..4f2c1aa9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -430,6 +430,14 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp [1, 2, 2, 3], [7, 3, 2, 5], [27, 6, 7], + [3, 2, 1, 1], + [3, 2, 1], + [3, 2], + [3], + [1, 1, 1, 1], + [1, 1, 1], + [1, 1], + [1], ], ) def test_reshape_equivariance(shape: list[int]): From 79e0609e5ff426c6df1bed6832ee0a7d5f472b19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:59:57 +0200 Subject: [PATCH 23/44] Improve parametrization of test_movedim_equivariance --- tests/unit/autogram/test_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 4f2c1aa9..a3bc8748 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -473,6 +473,11 @@ def test_reshape_equivariance(shape: list[int]): ([50, 2, 2, 3], [0, 2], [1, 0]), ([60, 3, 2, 5], [1], [2]), ([30, 6, 7], [0, 1], [1, 0]), + ([3, 2], [0], [0]), + ([3], [], []), + ([3, 2, 1], [1, 0], [0, 1]), + ([4, 3, 2], [], []), + ([1, 1, 1], [1, 0], [0, 1]), ], ) def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): From c40f44c90720af2b67d7831e3f5ee2eeaaeec8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:03:43 +0200 Subject: [PATCH 24/44] Improve parametrization of test_batched_non_batched_equivalence --- tests/unit/autogram/test_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index a3bc8748..abf0f159 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -514,6 +514,12 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([3, 2, 5], 1), ([6, 3], 0), ([4, 3, 2], 1), + ([1, 1, 1], 0), + ([1, 1, 1], 1), + ([1, 1, 1], 2), + ([1, 1], 0), + ([1], 0), + ([4, 3, 1], 2), ], ) def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): From 2fe3b51ae70aabc8da6374f7b3c483c3a4910071 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:05:55 +0200 Subject: [PATCH 25/44] Add comment in compute_gramian --- src/torchjd/autogram/_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index a3c4e90f..f289c6b6 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -218,6 +218,11 @@ def compute_gramian(self, output: Tensor) -> Tensor: target_shape = [-1] + target_shape reshaped_output = ordered_output.reshape(target_shape) + # There are four different cases for the shape of reshaped_output: + # - Not batched and not non-batched: scalar of shape [] + # - Batched only: vector of shape [batch_size] + # - Non-batched only: vector of shape [dim] + # - Batched and non-batched: matrix of shape [dim, batch_size] square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) From d82a8f622d98c87ff3c868d92e985ffc6c716b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:12:15 +0200 Subject: [PATCH 26/44] Improve clarity of reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index ef3338a2..355199c2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -13,13 +13,14 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] - # - The returned gramian will be of shape [4, 3, 2, 2, 3, 4] + # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] + # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] target_ndim = len(shape) - unordered_shape = shape + shape - unordered_gramian = gramian.reshape(unordered_shape) + unordered_gramian = gramian.reshape(shape + shape) last_dims = [target_ndim + i for i in range(target_ndim)] - return unordered_gramian.movedim(last_dims, last_dims[::-1]) + reordered_gramian = unordered_gramian.movedim(last_dims, last_dims[::-1]) + return reordered_gramian def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: From c846ed73382e2617713595c6f19211b05804765f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:30:50 +0200 Subject: [PATCH 27/44] Improve clarity of movedim_gramian --- src/torchjd/autogram/_gramian_utils.py | 30 +++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 355199c2..f82b4ece 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -29,16 +29,26 @@ def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the dimensions and on the second half of the dimensions reversed. :param gramian: Gramian to reshape. - :param source: Source dimensions, should be in the range [0, gramian.ndim/2]. Should be unique - :param destination: Destination dimensions, should be in the range [0, gramian.ndim/2]. Should - be unique and should have the same size as `source`. + :param source: Source dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. Its elements should be unique. + :param destination: Destination dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `source`, and its + elements should be unique. """ - length = gramian.ndim // 2 - source = [i if 0 <= i else i + length for i in source] - destination = [i if 0 <= i else i + length for i in destination] + # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `source` of [-2, 2] and destination of [0, 1]: + # - `source_` will be [1, 2] and `destination_` will be [0, 1] + # - `mirrored_source` will be [1, 2, 4, 3] and `mirrored_destination` will be [0, 1, 5, 4] + # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3] - last_index = gramian.ndim - 1 - source_dims = source + [last_index - i for i in source] - destination_dims = destination + [last_index - i for i in destination] - return gramian.movedim(source_dims, destination_dims) + # Map everything to the range [0, gramian.ndim//2[ + length = gramian.ndim // 2 + source_ = [i if 0 <= i else i + length for i in source] + destination_ = [i if 0 <= i else i + length for i in destination] + + # Mirror the source and destination and use the result to move the dimensions of the gramian + last_dim = gramian.ndim - 1 + mirrored_source = source_ + [last_dim - i for i in source_] + mirrored_destination = destination_ + [last_dim - i for i in destination_] + moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) + return moved_gramian From c2fd0a05cabe6eb55d9e7f3317b8bac8aef4452d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:03:36 +0200 Subject: [PATCH 28/44] Revert removal of _handles in ModuleHookManager --- src/torchjd/autogram/_module_hook_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 04919306..9d3ab66a 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,6 +4,7 @@ from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten +from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator @@ -37,6 +38,7 @@ def __init__( self._gramian_accumulator = gramian_accumulator self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = False + self._handles: list[TorchRemovableHandle] = [] def hook_module(self, module: nn.Module) -> None: """ @@ -70,7 +72,8 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: return self._apply_jacobian_accumulator(module, args, output_spec, flat_outputs) - _ = module.register_forward_hook(module_hook) + handle = module.register_forward_hook(module_hook) + self._handles.append(handle) def _apply_jacobian_accumulator( self, From 1cbff18cf7507566ce915af8105e31309c2388c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:54:44 +0200 Subject: [PATCH 29/44] Add more edge cases to test_quadratic_form_invariance_to_reshape --- tests/unit/autogram/test_gramian_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 1bc1ed0e..eefd076a 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -25,6 +25,13 @@ def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: [50, 2, 2, 3], [60, 3, 2, 5], [30, 6, 7], + [4, 3, 1], + [4, 1, 1], + [1, 1, 1], + [4, 1], + [4], + [1, 1], + [1], ], ) def test_quadratic_form_invariance_to_reshape(shape: list[int]): From 693856019cec3df0007409663256692cd6d24424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:56:44 +0200 Subject: [PATCH 30/44] Add more edge cases to test_quadratic_form_invariance_to_movedim --- tests/unit/autogram/test_gramian_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index eefd076a..3ea8aa3d 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -64,6 +64,13 @@ def test_quadratic_form_invariance_to_reshape(shape: list[int]): ([50, 2, 2, 3], [0, 2], [1, 0]), ([60, 3, 2, 5], [1], [2]), ([30, 6, 7], [0, 1], [1, 0]), + ([4, 3, 1], [0, 1], [1, 0]), + ([4, 1, 1], [1], [0]), + ([1, 1, 1], [], []), + ([4, 1], [0], [0]), + ([4], [], []), + ([1, 1], [], []), + ([1], [], []), ], ) def test_quadratic_form_invariance_to_movedim( From 1b565580751f8b566eba7060f1da6a81dbd2c18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 19:02:45 +0200 Subject: [PATCH 31/44] Factorize code into _make_path_jacobians and use for-loop --- src/torchjd/autogram/_module_hook_manager.py | 29 +++++++++----------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 9d3ab66a..036c1bc2 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -95,14 +95,8 @@ def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) jacobians = vjp(grad_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian.reshape( - [-1] + list(module.get_parameter(param_name).shape) - ) - for param_name, jacobian in jacobians.items() - } - ) + path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: @@ -111,16 +105,19 @@ def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # We do not vmap over the args for the non-batched dimension in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian.reshape( - [-1] + list(module.get_parameter(param_name).shape) - ) - for param_name, jacobian in jacobians.items() - } - ) + path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None + @staticmethod + def _make_path_jacobians(jacobians: dict[str, Tensor]) -> dict[Tensor, Tensor]: + path_jacobians: dict[Tensor, Tensor] = {} + for param_name, jacobian in jacobians.items(): + key = module.get_parameter(param_name) + value = jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = value + return path_jacobians + @staticmethod def setup_context(*_): pass From 1d78e15f590e033efb117d95f06e565fe742e088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 19:19:24 +0200 Subject: [PATCH 32/44] Fix model not being moved to cuda in new tests --- tests/unit/autogram/test_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index abf0f159..e4bcf24f 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -355,7 +355,7 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp batched_dim = None input_dim = [shape[0]] - model = Linear(shape[0], shape[1]) + model = Linear(shape[0], shape[1]).to(device=DEVICE) engine = Engine([model], batched_dim=batched_dim) input = randn_(input_dim) @@ -450,7 +450,7 @@ def test_reshape_equivariance(shape: list[int]): input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model]) engine2 = Engine([model]) @@ -490,7 +490,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model]) engine2 = Engine([model]) @@ -533,7 +533,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): batch_size = shape[batched_dim] output_size = input_size - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model], batched_dim=batched_dim) engine2 = Engine([model]) From 91d739b31e8350cb7bc65ca76213a272cafe6484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:49:32 +0200 Subject: [PATCH 33/44] Make test_equivalence_autojac_autogram also work with non-batched engine At this point, 3 architectures fail: SomeFrozenParam, SomeUnusedParam and MultiOutputWithFrozenBranch --- tests/unit/autogram/test_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index e4bcf24f..aca8e06e 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -154,11 +154,13 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) @mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS) +@mark.parametrize("batched_engine", [False, True]) def test_equivalence_autojac_autogram( architecture: type[ShapedModule], batch_size: int, aggregator: Aggregator, weighting: Weighting, + batched_engine: bool, ): """ Tests that the autogram engine gives the same results as the autojac engine on IWRM for several @@ -175,7 +177,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batched_dim=0) + engine = Engine(model_autogram.modules(), batched_dim=0 if batched_engine else None) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) From 92b7b1aee2ce6c734d3008b7025126f68c2a4506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:25:08 +0200 Subject: [PATCH 34/44] Make separation between trainable and frozen params in VJP and rename variables This fixes non-batched engien on SomeFrozenParams architecture --- src/torchjd/autogram/_vjp.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 25450a98..86bbe299 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -24,7 +24,9 @@ class VJP(ABC): def __init__(self, module: nn.Module): self.module = module - self.named_parameters = dict(module.named_parameters(recurse=False)) + named_parameters = dict(module.named_parameters(recurse=False)) + self.trainable_params = {k: v for k, v in named_parameters.items() if v.requires_grad} + self.frozen_params = {k: v for k, v in named_parameters.items() if not v.requires_grad} @abstractmethod def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: @@ -73,22 +75,16 @@ def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, :param inputs: Fixed inputs to the module for the VJP computation. :returns: VJP function that takes cotangents and returns parameter gradients. """ - requires_grad_named_params = { - k: v for k, v in self.named_parameters.items() if v.requires_grad - } - no_requires_grad_named_params = { - k: v for k, v in self.named_parameters.items() if not v.requires_grad - } def functional_model_call(primals: dict[str, Parameter]) -> Tensor: all_state = { **primals, **dict(self.module.named_buffers()), - **no_requires_grad_named_params, + **self.frozen_params, } return torch.func.functional_call(self.module, all_state, inputs) - return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + return torch.func.vjp(functional_model_call, self.trainable_params)[1] class AutogradVJP(VJP): @@ -102,10 +98,10 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.param_spec = tree_flatten(self.named_parameters) + self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( - self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True + self.outputs, self.trainable_params, tree_flatten(grad_outputs)[0], retain_graph=True ) return tree_unflatten(grads, self.param_spec) From 9b92cf7444da7d117bf24dd91af0046f39850a7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:30:41 +0200 Subject: [PATCH 35/44] Add allow_unused=True and materialize_grads=True in call to autograd.grad in AutogradVJP This fixes non-batched engine on SomeUnusedParam --- src/torchjd/autogram/_vjp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 86bbe299..760682cf 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,6 +102,11 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( - self.outputs, self.trainable_params, tree_flatten(grad_outputs)[0], retain_graph=True + self.outputs, + self.trainable_params, + tree_flatten(grad_outputs)[0], + retain_graph=True, + allow_unused=True, + materialize_grads=True, ) return tree_unflatten(grads, self.param_spec) From 7950bb45c2baed239783dcb67e96207728b9a545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:49:14 +0200 Subject: [PATCH 36/44] Stop trying to differentiate outputs that dont require grad in AutogradVJP This fixes non-batched engine on MultiOutputWithFrozenBranch --- src/torchjd/autogram/_vjp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 760682cf..15027502 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -98,13 +98,15 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs + self.mask = [output.requires_grad for output in self.outputs] self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: + flat_grad_outputs = tree_flatten(grad_outputs)[0] grads = torch.autograd.grad( - self.outputs, + [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], self.trainable_params, - tree_flatten(grad_outputs)[0], + [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], retain_graph=True, allow_unused=True, materialize_grads=True, From 4dc3f7c7b0130f8f481337cab0b2022e6a4b434f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:55:24 +0200 Subject: [PATCH 37/44] Fix variable name --- src/torchjd/autogram/_vjp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 15027502..4a93f69a 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -99,13 +99,13 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs self.mask = [output.requires_grad for output in self.outputs] - self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) + self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: flat_grad_outputs = tree_flatten(grad_outputs)[0] grads = torch.autograd.grad( [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], - self.trainable_params, + self.flat_trainable_params, [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], retain_graph=True, allow_unused=True, From 23cb0a89f7ed7ee563c956228639caab4dc77c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 16:04:09 +0200 Subject: [PATCH 38/44] Rename jacobians to generalized_jacobians Maybe not a definitive name, but I think it's more clear --- src/torchjd/autogram/_module_hook_manager.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 036c1bc2..7660bc04 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -94,8 +94,8 @@ class AccumulateJacobian(torch.autograd.Function): def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) - jacobians = vjp(grad_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + generalized_jacobians = vjp(grad_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod @@ -104,18 +104,20 @@ def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) - jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None @staticmethod - def _make_path_jacobians(jacobians: dict[str, Tensor]) -> dict[Tensor, Tensor]: + def _make_path_jacobians( + generalized_jacobians: dict[str, Tensor], + ) -> dict[Tensor, Tensor]: path_jacobians: dict[Tensor, Tensor] = {} - for param_name, jacobian in jacobians.items(): + for param_name, generalized_jacobian in generalized_jacobians.items(): key = module.get_parameter(param_name) - value = jacobian.reshape([-1] + list(key.shape)) - path_jacobians[key] = value + jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = jacobian return path_jacobians @staticmethod From 1169b854c4e44dc13717974ed4944c9e9f097667 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 17:08:20 +0200 Subject: [PATCH 39/44] Replace torch.movedim by tensor.movedim --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index f289c6b6..dc67f38d 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -204,7 +204,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: if self._batched_dim is not None: # move batched dim to the end - ordered_output = torch.movedim(output, self._batched_dim, -1) + ordered_output = output.movedim(self._batched_dim, -1) ordered_shape = list(ordered_output.shape) has_non_batched_dim = len(ordered_shape) > 1 target_shape = [ordered_shape[-1]] From 59c957c4ba0455ef8d5e3dc7873f525b5cd11e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 17:17:32 +0200 Subject: [PATCH 40/44] Add batch_size variable in compute_gramian * Small improvement of clarity --- src/torchjd/autogram/_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index dc67f38d..9254aa86 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -206,8 +206,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: # move batched dim to the end ordered_output = output.movedim(self._batched_dim, -1) ordered_shape = list(ordered_output.shape) + batch_size = ordered_shape[-1] has_non_batched_dim = len(ordered_shape) > 1 - target_shape = [ordered_shape[-1]] + target_shape = [batch_size] else: ordered_output = output ordered_shape = list(ordered_output.shape) From 9a01a12bd1ff1ce66fa50649d6b2d89acdea73d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 9 Sep 2025 16:08:24 +0200 Subject: [PATCH 41/44] Add GeneralizedWeighting --- docs/source/docs/aggregation/index.rst | 5 +++++ src/torchjd/aggregation/__init__.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 25 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index e5cfef44..87c7d75c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -17,6 +17,11 @@ Abstract base classes :undoc-members: :exclude-members: forward +.. autoclass:: torchjd.aggregation.GeneralizedWeighting + :members: + :undoc-members: + :exclude-members: forward + .. toctree:: :hidden: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 8d1a9432..fdd4c2a1 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -55,7 +55,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import Weighting +from ._weighting_bases import GeneralizedWeighting, Weighting try: from ._cagrad import CAGrad, CAGradWeighting diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index daa1bdca..154c7a30 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -52,3 +52,28 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu def forward(self, stat: _T) -> Tensor: return self.weighting(self.fn(stat)) + + +class GeneralizedWeighting(nn.Module, ABC): + r""" + Abstract base class for all weightings that operate on generalized Gramians. It has the role of + extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a + generalized Gramian of dimension + :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, generalized_gramian: Tensor) -> Tensor: + """Computes the vector of weights from the input generalized Gramian.""" + + # Override to make type hints and documentation more specific + def __call__(self, generalized_gramian: Tensor) -> Tensor: + """ + Computes the tensor of weights from the input generalized Gramian and applies all registered + hooks. + """ + + return super().__call__(generalized_gramian) From 31812a11d103fd88b4e1fe71d77fc7f937dfa0ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 9 Sep 2025 16:23:32 +0200 Subject: [PATCH 42/44] Add FakeGeneralizedWeighting in tests --- tests/unit/autogram/test_engine.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index fe7d1742..83093a22 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -4,7 +4,7 @@ import pytest import torch from pytest import mark, param -from torch import nn +from torch import Tensor, nn from torch.nn import Linear from torch.optim import SGD from torch.testing import assert_close @@ -60,7 +60,7 @@ ) from utils.tensors import make_tensors, ones_, randn_, zeros_ -from torchjd.aggregation import UPGrad, UPGradWeighting +from torchjd.aggregation import GeneralizedWeighting, UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet @@ -507,3 +507,22 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +class FakeGeneralizedWeighting(GeneralizedWeighting): + """ + Fake GeneralizedWeighting flattening the Gramian and using UPGradWeighting on it. Could be + removed when we implement a proper FlatteningGeneralizedWeighting.""" + + def __init__(self): + super().__init__() + self.weighting = UPGradWeighting() + + def forward(self, generalized_gramian: Tensor) -> Tensor: + k = generalized_gramian.ndim // 2 + shape = generalized_gramian.shape[:k] + m = prod(shape) + square_gramian = reshape_gramian(generalized_gramian, [m]) + weights_vector = self.weighting(square_gramian) + weights = weights_vector.reshape(shape) + return weights From ee16f9b3db3494d213bc4d39323f3bbcf8aff64a Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 9 Sep 2025 17:47:03 +0200 Subject: [PATCH 43/44] reshape_gramian can now take generalized gramians as inputs, its shape can also contain (at most) one element set to -1, the size of that dimension is deduced from the total number of elements --- src/torchjd/autogram/_gramian_utils.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index f82b4ece..4379b7e0 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -1,3 +1,5 @@ +from math import prod + from torch import Tensor @@ -16,11 +18,18 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] - target_ndim = len(shape) - unordered_gramian = gramian.reshape(shape + shape) - last_dims = [target_ndim + i for i in range(target_ndim)] - reordered_gramian = unordered_gramian.movedim(last_dims, last_dims[::-1]) - return reordered_gramian + automatic_dimensions = [i for i in range(len(shape)) if shape[i] == -1] + if len(automatic_dimensions) == 1: + index = automatic_dimensions[0] + current_shape = gramian.shape[: len(gramian.shape) // 2] + numel = prod(current_shape) + specified_numel = -prod(shape) # shape[index] == -1, this is the product of all other dims + shape[index] = numel // specified_numel + + unordered_intput_gramian = _revert_last_dims(gramian) + unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape) + reordered_output_gramian = _revert_last_dims(unordered_output_gramian) + return reordered_output_gramian def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: @@ -52,3 +61,9 @@ def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) mirrored_destination = destination_ + [last_dim - i for i in destination_] moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) return moved_gramian + + +def _revert_last_dims(generalized_gramian: Tensor) -> Tensor: + input_ndim = len(generalized_gramian.shape) // 2 + last_dims = [input_ndim + i for i in range(input_ndim)] + return generalized_gramian.movedim(last_dims, last_dims[::-1]) From 0a2555581aa010928001a62fe378e06ee895ee21 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 10 Sep 2025 11:53:41 +0200 Subject: [PATCH 44/44] Implement `HierarchicalWeighting` Weighting, needs testing. --- .../aggregation/hierachical_weighting.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 src/torchjd/aggregation/hierachical_weighting.py diff --git a/src/torchjd/aggregation/hierachical_weighting.py b/src/torchjd/aggregation/hierachical_weighting.py new file mode 100644 index 00000000..088d5dde --- /dev/null +++ b/src/torchjd/aggregation/hierachical_weighting.py @@ -0,0 +1,105 @@ +import torch +from torch import Tensor + +from ..autogram._gramian_utils import reshape_gramian +from ._weighting_bases import GeneralizedWeighting, Weighting + + +class HierarchicalWeighting(GeneralizedWeighting): + """ + Hierarchically reduces a generalized Gramian using a sequence of weighting functions. + + Applies multiple weightings in sequence to a generalized Gramian ``G`` of shape + ``[n₁, ..., nₖ, nₖ, ..., n₁]``. It first applies the initial weighting to the innermost diagonal + Gramians, contracts those dimensions to form a smaller generalized Gramian, and repeats the + process with subsequent weightings. The final returned weights are chosen so that contracting + the original Gramian directly with these weights produces the same quadratic form as applying + the reductions step by step. + + :param weightings: A list of weighting callables, one for each hierarchical reduction step. + """ + + def __init__(self, weightings: list[Weighting]): + super().__init__() + self.weightings = weightings + self.n_dims = len(weightings) + + def forward(self, generalized_gramian: Tensor) -> Tensor: + + assert len(self.weightings) * 2 == len(generalized_gramian.shape) # temporary + + weighting = self.weightings[0] + dim_size = generalized_gramian.shape[0] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, weights) + + for i in range(self.n_dim): + weighting = self.weightings[i] + dim_size = generalized_gramian.shape[i] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + temp_weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, temp_weights) + weights = _scale_weights(weights, temp_weights) + + return weights + + +def _compute_weights(weighting: Weighting, generalized_gramian: Tensor) -> Tensor: + """ + Apply a weighting to each diagonal Gramian in a generalized Gramian. + + For a generalized Gramian ``G`` of shape ``[m, n, n, m]``, this extracts each diagonal Gramian + ``G[j, :, :, j]`` of shape ``[n, n]`` for ``j`` in ``[m]`` and applies the provided weighting. + The resulting weights are stacked into a tensor of shape ``[m, n]``. + + :param weighting: Callable that maps a Gramian of shape ``[n, n]`` to weights of shape ``[n]``. + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` containing the generalized Gramian. + :returns: Tensor of shape ``[m, n]`` containing the computed weights for each diagonal Gramian. + """ + + weights = torch.zeros( + generalized_gramian[:2], device=generalized_gramian.device, dtype=generalized_gramian.dtype + ) + for i in range(generalized_gramian.shape[0]): + weights[i] = weighting(generalized_gramian[i, :, :, i]) + return weights + + +def _contract_gramian(generalized_gramian: Tensor, weights: Tensor) -> Tensor: + r""" + Compute a partial quadratic form by contracting a generalized Gramian with weight vectors on + both sides. + + Given a generalized Gramian ``G`` of shape ``[m, n, n, m]`` and weights ``w`` of shape + ``[m, n]``, this function computes a Gramian ``G'`` of shape ``[m, m]`` where + + .. math:: + + G'[i, j] = \sum_{k, l=1}^n w[i, k] G[i, k, l, j] w[j, l]. + + This can be viewed as forming a quadratic form with respect to the two innermost dimensions of + ``G``. + + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` representing the generalized + Gramian. + :param weights: Tensor of shape ``[m, n]`` containing weight vectors to contract with the + Gramian. + :returns: Tensor of shape ``[m, m]`` containing the contracted Gramian, i.e. the partial + quadratic form. + """ + left_product = torch.einsum("ij,ijkl->ikl", weights, generalized_gramian) + return torch.einsum("ij,ijl->il", weights, left_product) + + +def _scale_weights(weights: Tensor, scalings: Tensor) -> Tensor: + """ + Scale a tensor along its leading dimensions by broadcasting scaling factors. + + :param weights: Tensor of shape [n₁, ..., nₖ, nₖ₊₁, ..., nₚ]. + :param scalings: Tensor of shape [n₁, ..., nₖ] providing scaling factors for the leading + dimensions of ``weights``. + :returns: Tensor of the same shape as ``weights``, where each slice + ``weights[i₁, ..., iₖ, :, ..., :]`` is multiplied by ``scalings[i₁, ..., iₖ]``. + """ + return weights * scalings[(...,) + (None,) * (weights.ndim - scalings.ndim)]