From e757e316e32c4d37c8b95863217ba34e722c5e2b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 27 Aug 2025 16:51:21 +0200 Subject: [PATCH 001/114] Add non-batched support Add reshape of jacobian for the scalar output case. Fix reshape of the Gramian, we for the last half of the dimensions, we need to reshape in the same order as the first, then we move the dimensions. We could in principle create a `reshape_gramian` function that does this, as well as a `move_dim_gramian` Add a test of values for all four cases of having a batched/non-batched dimension. Tests or reshape/move-dim should work should go in another test. Remove some tests that do not test anything more than `test_gramian_is_correct`. Add `_gramian_utils.py` which contains helper to `reshape` and `movedim` on a Gramian. Add `generate_vmap_rule = True` for `JacobianAccumulator`. This allows vmaping the forward phase. This enables having several Engines defined on the same module. Add `test_reshape_equivariance` Add tests to verify that gramian utils yields the correct quadratic forms. Add tests to verify that gramian utils yields the correct quadratic forms. Add `test_movedim_equivariance` Fix warning. Fix warning. Remove handles from `ModuleHookManager` Change `batched_dims` to a single optional `batched_dim`. Fix movedim in `compute_gramian` and add `test_movedim_equivariance` Remove `grad_output`, can be added later, but should be `jac_output` instead. Make modules with incompatible batched operations are compatible with non-batched autogram. Fix doc tests Provide the autograd vjp for when no dimension is batched. This enables having a single forward in that case which should be faster. Make VJPs into Callable classes. --- src/torchjd/autogram/_engine.py | 72 ++++++- src/torchjd/autogram/_gramian_utils.py | 39 ++++ src/torchjd/autogram/_module_hook_manager.py | 41 +++- src/torchjd/autogram/_vjp.py | 98 ++++++--- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 4 +- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 206 ++++++++++++++++++- tests/unit/autogram/test_gramian_utils.py | 85 ++++++++ 9 files changed, 494 insertions(+), 55 deletions(-) create mode 100644 src/torchjd/autogram/_gramian_utils.py create mode 100644 tests/unit/autogram/test_gramian_utils.py diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 5aabcdc9..cc5e53df 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -2,11 +2,12 @@ from typing import cast import torch -from torch import Tensor, nn +from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._gramian_utils import movedim_gramian, reshape_gramian from ._module_hook_manager import ModuleHookManager _INCOMPATIBLE_MODULE_TYPES = ( @@ -57,6 +58,9 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. + :param is_batched: If a dimension is batched, then many intermediary jacobians are block + diagonal, which allows for a substancial memory optimization by backpropagating a squashed + Jacobian instead. If the only dimension of the losses vector is batched. Default to True. .. admonition:: Example @@ -79,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules()) + >>> engine = Engine(model.modules(), (0,)) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] @@ -127,10 +131,17 @@ class Engine: `_ layers. """ - def __init__(self, modules: Iterable[nn.Module]): + def __init__( + self, + modules: Iterable[nn.Module], + batched_dim: int | None = None, + ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() - self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) + self._batched_dim = batched_dim + self._module_hook_manager = ModuleHookManager( + self._target_edges, self._gramian_accumulator, batched_dim is not None + ) self._hook_modules(modules) @@ -143,9 +154,8 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None: self._check_module_is_compatible(module) self._module_hook_manager.hook_module(module) - @staticmethod - def _check_module_is_compatible(module: nn.Module) -> None: - if isinstance(module, _INCOMPATIBLE_MODULE_TYPES): + def _check_module_is_compatible(self, module: nn.Module) -> None: + if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their" @@ -166,12 +176,39 @@ def compute_gramian(self, output: Tensor) -> Tensor: ``modules``. :param output: The vector to differentiate. Must be a 1-D tensor. + :param grad_output: The tangents for the differentiation. Default to a vector of 1s of the + same shape as `output`. """ - reshaped_output = output.reshape([-1]) - return self._compute_square_gramian(reshaped_output) + if self._batched_dim is not None: + # move batched dim to the end + ordered_output = torch.movedim(output, self._batched_dim, -1) + ordered_shape = list(ordered_output.shape) + has_non_batched_dim = len(ordered_shape) > 1 + target_shape = [ordered_shape[-1]] + else: + ordered_output = output + ordered_shape = list(ordered_output.shape) + has_non_batched_dim = len(ordered_shape) > 0 + target_shape = [] - def _compute_square_gramian(self, output: Tensor) -> Tensor: + if has_non_batched_dim: + target_shape = [-1] + target_shape + + reshaped_output = ordered_output.reshape(target_shape) + + flat_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) + + unordered_gramian = reshape_gramian(flat_gramian, ordered_shape) + + if self._batched_dim is not None: + gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) + else: + gramian = unordered_gramian + + return gramian + + def _compute_square_gramian(self, output: Tensor, has_non_batched_dim: bool) -> Tensor: self._module_hook_manager.gramian_accumulation_phase = True leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) @@ -184,7 +221,20 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: retain_graph=True, ) - _ = differentiation(torch.ones_like(output)) + if has_non_batched_dim: + # There is one non-batched dimension, it is the first one + non_batched_dim_len = output.shape[0] + jac_output_shape = [output.shape[0]] + list(output.shape) + + # Need to batch `grad_output` over the first dimension + jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) + for i in range(non_batched_dim_len): + jac_output[i, i] = 1 + + _ = vmap(differentiation)(jac_output) + else: + grad_output = torch.ones_like(output) + _ = differentiation(grad_output) # If the gramian were None, then leaf_targets would be empty, so autograd.grad would # have failed. So gramian is necessarily a valid Tensor here. diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py new file mode 100644 index 00000000..19affdb2 --- /dev/null +++ b/src/torchjd/autogram/_gramian_utils.py @@ -0,0 +1,39 @@ +from torch import Tensor + + +def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: + """ + Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first + half of the target dimensions must be done from the left, while the reshape of the second half + must be done from the right. + :param gramian: Gramian to reshape + :param shape: First half of the target shape, the shape of the output is therefore + `shape + shape[::-1]`. + """ + + target_ndim = len(shape) + unordered_shape = shape + shape + unordered_gramian = gramian.reshape(unordered_shape) + last_dims = [target_ndim + i for i in range(target_ndim)] + return unordered_gramian.movedim(last_dims, last_dims[::-1]) + + +def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: + """ + Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a + Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the + dimensions and on the second half of the dimensions reversed. + :param gramian: Gramian to reshape. + :param source: Source dimensions, should be in the range [0, gramian.ndim/2]. Should be unique + :param destination: Destination dimensions, should be in the range [0, gramian.ndim/2]. Should + be unique and should have the same size as `source`. + """ + + length = gramian.ndim // 2 + source = [i if 0 <= i else i + length for i in source] + destination = [i if 0 <= i else i + length for i in destination] + + last_index = gramian.ndim - 1 + source_dims = source + [last_index - i for i in source] + destination_dims = destination + [last_index - i for i in destination] + return gramian.movedim(source_dims, destination_dims) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e5087bb0..0b51ec1e 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -3,12 +3,11 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_unflatten -from torch.utils.hooks import RemovableHandle as TorchRemovableHandle +from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import get_functional_vjp +from ._vjp import AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -32,11 +31,12 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, + has_batch_dim: bool, ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator + self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = False - self._handles: list[TorchRemovableHandle] = [] def hook_module(self, module: nn.Module) -> None: """ @@ -70,8 +70,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) - handle = module.register_forward_hook(module_hook) - self._handles.append(handle) + _ = module.register_forward_hook(module_hook) def _apply_jacobian_accumulator( self, @@ -80,21 +79,45 @@ def _apply_jacobian_accumulator( tree_spec: TreeSpec, flat_outputs: list[Tensor], ) -> PyTree: - vjp = torch.vmap(get_functional_vjp(module)) + + if self._has_batch_dim: + vjp = torch.vmap(FunctionalVJP(module)) + else: + vjp = AutogradVJP(module, flat_outputs) class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward(*flat_grad_outputs: Tensor) -> None: + # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) jacobians = vjp(grad_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { - module.get_parameter(param_name): jacobian + module.get_parameter(param_name): jacobian.reshape( + [-1] + list(module.get_parameter(param_name).shape) + ) for param_name, jacobian in jacobians.items() } ) + @staticmethod + def vmap(info, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + # There is a non-batched dimension + jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) + # We do not vmap over the args for the non-batched dimension + in_dims = (tree_unflatten(in_dims, tree_spec), tree_map(lambda _: None, args)) + jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + self._gramian_accumulator.accumulate_path_jacobians( + { + module.get_parameter(param_name): jacobian.reshape( + [-1] + list(module.get_parameter(param_name).shape) + ) + for param_name, jacobian in jacobians.items() + } + ) + return None, None + @staticmethod def setup_context(*_): pass @@ -108,6 +131,8 @@ class JacobianAccumulator(torch.autograd.Function): toggle mechanism to activate only during the Gramian accumulation phase. """ + generate_vmap_rule = True + @staticmethod def forward(*xs: Tensor) -> tuple[Tensor, ...]: return tuple([x.detach() for x in xs]) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index b4bea046..80be32a0 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -1,9 +1,10 @@ -from collections.abc import Callable +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_map_only +from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -14,19 +15,41 @@ # still support older versions of PyTorch where pytree is protected). -def get_functional_vjp(module: nn.Module) -> Callable[[PyTree, PyTree], dict[str, Tensor]]: +class VJP(ABC): """ - Create a VJP function for a module's forward pass with respect to its parameters. The returned - function takes both the input and the cotangents that can be vmaped jointly in both terms to - avoid providing to block diagonal jacobians. + Represents a VJP function for a module's forward pass with respect to its parameters using the + func api. :params module: The module to differentiate. - :returns: VJP function that takes cotangents and inputs and returns dictionary of names of + """ + + def __init__(self, module: nn.Module): + self.module = module + self.named_parameters = dict(module.named_parameters(recurse=False)) + + @abstractmethod + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + """ + VJP function that takes cotangents and inputs and returns dictionary of names of parameters (as given by `module.named_parameters.keys()`) to gradients of the parameters for the given cotangents at the given inputs. + """ + + +class FunctionalVJP(VJP): """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + func api. The __call__ function takes both the inputs and the cotangents that can be vmaped + jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using + this method is that it computes the forward phase. - def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + :params module: The module to differentiate. + """ + + def __init__(self, module: nn.Module): + super().__init__(module) + + def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with @@ -39,30 +62,51 @@ def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # primals (tuple), here the functional has a single primal which is # dict(module.named_parameters()). We therefore take the 0'th element to obtain # the dict of gradients w.r.t. the module's named_parameters. - return _vjp_from_module(module, inputs_j)(grad_outputs_j)[0] + return self._vjp_from_module(inputs_j)(grad_outputs_j)[0] - return get_vjp + def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: + """ + Create a VJP function for a module's forward pass with respect to its parameters. + Returns a function that computes vector-Jacobian products for the module's parameters given + fixed inputs. Only parameters with requires_grad=True are included in the differentiation. -def _vjp_from_module( - module: nn.Module, inputs: PyTree -) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: - """ - Create a VJP function for a module's forward pass with respect to its parameters. + :param inputs: Fixed inputs to the module for the VJP computation. + :returns: VJP function that takes cotangents and returns parameter gradients. + """ + requires_grad_named_params = { + k: v for k, v in self.named_parameters.items() if v.requires_grad + } + no_requires_grad_named_params = { + k: v for k, v in self.named_parameters.items() if not v.requires_grad + } - Returns a function that computes vector-Jacobian products for the module's parameters given - fixed inputs. Only parameters with requires_grad=True are included in the differentiation. + def functional_model_call(primals: dict[str, Parameter]) -> Tensor: + all_state = { + **primals, + **dict(self.module.named_buffers()), + **no_requires_grad_named_params, + } + return torch.func.functional_call(self.module, all_state, inputs) - :param module: The module to differentiate. - :param inputs: Fixed inputs to the module for the VJP computation. - :returns: VJP function that takes cotangents and returns parameter gradients. + return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + + +class AutogradVJP(VJP): + """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the + inputs. The main advantage of using this method is that it doesn't require computing the forward + phase. """ - named_params = dict(module.named_parameters(recurse=False)) - requires_grad_named_params = {k: v for k, v in named_params.items() if v.requires_grad} - no_requires_grad_named_params = {k: v for k, v in named_params.items() if not v.requires_grad} - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: - all_state = {**primals, **dict(module.named_buffers()), **no_requires_grad_named_params} - return torch.func.functional_call(module, all_state, inputs) + def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): + super().__init__(module) + self.outputs = outputs + self.parameters, self.tree_spec = tree_flatten(dict(self.named_parameters)) - return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: + grads = torch.autograd.grad( + self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True + ) + return tree_unflatten(grads, self.tree_spec) diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index be55a1ae..724d81e2 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), 0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 4f2b9351..86400146 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), 0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -325,7 +325,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules()) + engine = Engine(model[2:].modules(), 0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..ad26892a 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules()) + engine = Engine(model.modules(), (0,)) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index be7b593b..994a9ff5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,10 +1,13 @@ from itertools import combinations +from math import prod import pytest import torch from pytest import mark, param from torch import nn +from torch.nn import Linear from torch.optim import SGD +from torch.testing import assert_close from unit.conftest import DEVICE from utils.architectures import ( AlexNet, @@ -55,7 +58,7 @@ autojac_forward_backward, make_mse_loss_fn, ) -from utils.tensors import make_tensors +from utils.tensors import make_tensors, ones_, randn_, zeros_ from torchjd.aggregation import ( IMTLG, @@ -80,6 +83,7 @@ Weighting, ) from torchjd.autogram._engine import Engine +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet from torchjd.autojac._transform._aggregate import _Matrixify @@ -171,7 +175,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), 0) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) @@ -237,7 +241,7 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model_autogram = architecture().to(device=DEVICE) # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), 0) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} @@ -311,7 +315,7 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} model.zero_grad() - engine = Engine(gramian_modules) + engine = Engine(gramian_modules, 0) output = model(input) losses = loss_fn(output) @@ -330,4 +334,196 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules()) + _ = Engine(model.modules(), 0) + + +@mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) +@mark.parametrize("batch_size", [None, 3, 16, 32]) +@mark.parametrize("reduce_output", [True, False]) +def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): + """ + Tests that the Gramian computed by then `Engine` equals to a manual computation of the expected + Gramian. + """ + + is_batched = batch_size is not None + + if is_batched: + batched_dims = 0 + input_dim = [batch_size, shape[0]] + else: + batched_dims = None + input_dim = [shape[0]] + + model = Linear(shape[0], shape[1]) + engine = Engine([model], batched_dims) + + input = randn_(input_dim) + output = model(input) + if reduce_output: + output = torch.sum(output, dim=-1) + + assert output.ndim == int(not reduce_output) + int(is_batched) + + gramian = engine.compute_gramian(output) + + # compute the expected gramian + output_shape = list(output.shape) + initial_jacobian = torch.diag(ones_(output.numel())).reshape(output_shape + output_shape) + + if reduce_output: + initial_jacobian = initial_jacobian.unsqueeze(-1).repeat( + ([1] * initial_jacobian.ndim) + [shape[1]] + ) + if not is_batched: + initial_jacobian = initial_jacobian.unsqueeze(-2) + input = input.unsqueeze(0) + + assert initial_jacobian.shape[-2] == (1 if batch_size is None else batch_size) + assert initial_jacobian.shape[-1] == shape[1] + assert initial_jacobian.shape[:-2] == output.shape + + assert input.shape[0] == (1 if batch_size is None else batch_size) + assert input.shape[1] == shape[0] + + # If k is the batch_size (1 if None) and n the input size and m the output size, then + # - input has shape `[k, n]` + # - initial_jacobian has shape `output.shape + `[k, m]` + + # The partial (batched) jacobian of outputs w.r.t. weights is of shape `[k, m, m, n]`, whe + # multiplied (along 2 dims) by initial_jacobian this yields the jacobian of the weights of shape + # `output.shape + [m, n]`. The partial jacobian itself is block diagonal with diagonal defined + # by `partial_weight_jacobian[i, j, j] = input[i]` (other elements are 0). + + partial_weight_jacobian = zeros_([input.shape[0], shape[1], shape[1], shape[0]]) + for j in range(shape[1]): + partial_weight_jacobian[:, j, j, :] = input + weight_jacobian = torch.tensordot( + initial_jacobian, partial_weight_jacobian, dims=([-2, -1], [0, 1]) + ) + weight_gramian = torch.tensordot(weight_jacobian, weight_jacobian, dims=([-2, -1], [-2, -1])) + if weight_gramian.ndim == 4: + weight_gramian = weight_gramian.movedim((-2), (-1)) + + # The partial (batched) jacobian of outputs w.r.t. bias is of shape `[k, m, m]`, when multiplied + # (along 2 dims) by initial_jacobian this yields the jacobian of the bias of shape + # `output.shape + [m]`. The partial jacobian itself is block diagonal with diagonal defined by + # `partial_bias_jacobian[i, j, j] = 1` (other elements are 0). + partial_bias_jacobian = zeros_([input.shape[0], shape[1], shape[1]]) + for j in range(shape[1]): + partial_bias_jacobian[:, j, j] = 1.0 + bias_jacobian = torch.tensordot( + initial_jacobian, partial_bias_jacobian, dims=([-2, -1], [0, 1]) + ) + bias_gramian = torch.tensordot(bias_jacobian, bias_jacobian, dims=([-1], [-1])) + if bias_gramian.ndim == 4: + bias_gramian = bias_gramian.movedim(-2, -1) + + expected_gramian = weight_gramian + bias_gramian + + assert_close(gramian, expected_gramian) + + +@mark.parametrize( + "shape", + [ + [1, 2, 2, 3], + [7, 3, 2, 5], + [27, 6, 7], + ], +) +def test_reshape_equivariance(shape: list[int]): + """ + Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape + the `output` to some `shape`, then the result is the same as reshaping the Gramian to the + corresponding shape. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input) + + reshaped_output = output.reshape(shape[1:]) + + gramian = engine1.compute_gramian(output) + reshaped_gramian = engine2.compute_gramian(reshaped_output) + + expected_reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + assert_close(reshaped_gramian, expected_reshaped_gramian) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ], +) +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): + """ + Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim + the `output` on some dimensions, then the result is the same as movedim on the Gramian with the + corresponding dimensions. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input).reshape(shape[1:]) + + moved_output = output.movedim(source, destination) + + gramian = engine1.compute_gramian(output) + moved_gramian = engine2.compute_gramian(moved_output) + + expected_moved_gramian = movedim_gramian(gramian, source, destination) + + assert_close(moved_gramian, expected_moved_gramian) + + +@mark.parametrize( + ["shape", "batched_dim"], + [ + ([2, 5, 3, 2], 2), + ([3, 2, 5], 1), + ([6, 3], 0), + ([4, 3, 2], 1), + ], +) +def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): + """ + Tests that for a vector with some batched dimensions, the gramian is the same if we use the + appropriate `batched_dims` or if we don't use any. + """ + + non_batched_shape = [shape[i] for i in range(len(shape)) if i != batched_dim] + input_size = prod(non_batched_shape) + batch_size = shape[batched_dim] + output_size = input_size + + model = Linear(input_size, output_size) + engine1 = Engine([model], batched_dim) + engine2 = Engine([model]) + + input = randn_([batch_size, input_size]) + output = model(input) + output = output.reshape([batch_size] + non_batched_shape) + output = output.movedim(0, batched_dim) + + gramian1 = engine1.compute_gramian(output) + gramian2 = engine2.compute_gramian(output) + + assert_close(gramian1, gramian2) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py new file mode 100644 index 00000000..583f44b4 --- /dev/null +++ b/tests/unit/autogram/test_gramian_utils.py @@ -0,0 +1,85 @@ +from math import prod + +import torch +from pytest import mark +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import rand_ + +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian + + +def compute_quadratic_form(gramian: Tensor, vector: Tensor) -> Tensor: + """ + Compute the quadratic form x^T G x when the provided Gramian and vector may have multiple + dimensions. + """ + indices = list(range(vector.ndim)) + linear_form = torch.tensordot(vector, gramian, dims=(indices, indices)) + return torch.tensordot(linear_form, vector, dims=(indices[::-1], indices)) + + +@mark.parametrize( + "shape", + [ + [50, 2, 2, 3], + [60, 3, 2, 5], + [30, 6, 7], + ], +) +def test_quadratic_form_invariance_to_reshape(shape: list[int]): + """ + When reshaping a Gramian, we expect it to represent the same quadratic form that now applies to + reshaped inputs. So the mapping x -> x^T G x commutes with reshaping x, G and then computing the + corresponding quadratic form. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = matrix @ matrix.T + reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + for _ in range(iterations): + vector = rand_([flat_dim]) + reshaped_vector = vector.reshape(shape[1:]) + + quadratic_form = vector @ gramian @ vector + reshaped_quadratic_form = compute_quadratic_form(reshaped_gramian, reshaped_vector) + + assert_close(reshaped_quadratic_form, quadratic_form) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ], +) +def test_quadratic_form_invariance_to_movedim( + shape: list[int], source: list[int], destination: list[int] +): + """ + When moving dims on a Gramian, we expect it to represent the same quadratic form that now + applies to inputs with moved dims. So the mapping x -> x^T G x commutes with moving dims x, G + and then computing the quadratic form with those. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = reshape_gramian(matrix @ matrix.T, shape[1:]) + moved_gramian = movedim_gramian(gramian, source, destination) + + for _ in range(iterations): + vector = rand_(shape[1:]) + moved_vector = vector.movedim(source, destination) + + quadratic_form = compute_quadratic_form(gramian, vector) + moved_quadratic_form = compute_quadratic_form(moved_gramian, moved_vector) + + assert_close(moved_quadratic_form, quadratic_form) From 8f6e5ee4d8affa0264da7d78acc5a6467365b459 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 3 Sep 2025 13:47:18 +0200 Subject: [PATCH 002/114] Update src/torchjd/autogram/_vjp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- src/torchjd/autogram/_vjp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 80be32a0..68463e47 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -17,8 +17,7 @@ class VJP(ABC): """ - Represents a VJP function for a module's forward pass with respect to its parameters using the - func api. + Represents an abstract VJP function for a module's forward pass with respect to its parameters. :params module: The module to differentiate. """ From d507daa90430c63210e72ec62cd2738b3b20c349 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 3 Sep 2025 13:55:04 +0200 Subject: [PATCH 003/114] Update tests/unit/autogram/test_engine.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 994a9ff5..003f2aaa 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -342,7 +342,7 @@ def test_incompatible_modules(architecture: type[nn.Module]): @mark.parametrize("reduce_output", [True, False]) def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): """ - Tests that the Gramian computed by then `Engine` equals to a manual computation of the expected + Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected Gramian. """ From d60328ec2329c809cd21c9151e861570c2b6f56f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:52:22 +0200 Subject: [PATCH 004/114] Fix batched_dim param in some Engine usages --- src/torchjd/autogram/_engine.py | 2 +- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index cc5e53df..daf0ded3 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -83,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), (0,)) + >>> engine = Engine(model.modules(), 0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index ad26892a..a2e5bcf9 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), (0,)) + engine = Engine(model.modules(), 0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) From 4103e4d5a270b376c6524a6ba0c4e03a48a7024b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:54:47 +0200 Subject: [PATCH 005/114] Make batched_dim parameter name explicit when creating Engine --- src/torchjd/autogram/_engine.py | 2 +- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 4 ++-- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 12 ++++++------ 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index daf0ded3..a7b8f28e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -83,7 +83,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), 0) + >>> engine = Engine(model.modules(), batched_dim=0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 724d81e2..e0e3117f 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 86400146..d7be637e 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -325,7 +325,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), 0) + engine = Engine(model[2:].modules(), batched_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index a2e5bcf9..7080373c 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), 0) + engine = Engine(model.modules(), batched_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 003f2aaa..111c83bb 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -175,7 +175,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), 0) + engine = Engine(model_autogram.modules(), batched_dim=0) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) @@ -241,7 +241,7 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model_autogram = architecture().to(device=DEVICE) # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules(), 0) + engine = Engine(model_autogram.modules(), batched_dim=0) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} @@ -315,7 +315,7 @@ def test_partial_autogram(weighting: Weighting, gramian_module_names: set[str]): expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} model.zero_grad() - engine = Engine(gramian_modules, 0) + engine = Engine(gramian_modules, batched_dim=0) output = model(input) losses = loss_fn(output) @@ -334,7 +334,7 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), 0) + _ = Engine(model.modules(), batched_dim=0) @mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) @@ -356,7 +356,7 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp input_dim = [shape[0]] model = Linear(shape[0], shape[1]) - engine = Engine([model], batched_dims) + engine = Engine([model], batched_dim=batched_dims) input = randn_(input_dim) output = model(input) @@ -515,7 +515,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): output_size = input_size model = Linear(input_size, output_size) - engine1 = Engine([model], batched_dim) + engine1 = Engine([model], batched_dim=batched_dim) engine2 = Engine([model]) input = randn_([batch_size, input_size]) From 4fadac6012e2479100a1333460f24cc6a538fe21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:55:13 +0200 Subject: [PATCH 006/114] Rename batched_dims to batched_dim in test_gramian_is_correct --- tests/unit/autogram/test_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 111c83bb..b83b87c9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -349,14 +349,14 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp is_batched = batch_size is not None if is_batched: - batched_dims = 0 + batched_dim = 0 input_dim = [batch_size, shape[0]] else: - batched_dims = None + batched_dim = None input_dim = [shape[0]] model = Linear(shape[0], shape[1]) - engine = Engine([model], batched_dim=batched_dims) + engine = Engine([model], batched_dim=batched_dim) input = randn_(input_dim) output = model(input) From 5c98ee8eed4cb722f5b65fa048e960ddcc1e0ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 14:58:58 +0200 Subject: [PATCH 007/114] Fix parameter description of batched_dim --- src/torchjd/autogram/_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index a7b8f28e..0951277e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -58,9 +58,10 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. - :param is_batched: If a dimension is batched, then many intermediary jacobians are block - diagonal, which allows for a substancial memory optimization by backpropagating a squashed - Jacobian instead. If the only dimension of the losses vector is batched. Default to True. + :param batched_dim: If the modules work with batches and process each batch element + independently, then many intermediary jacobians are sparse (block-diagonal), which allows + for a substancial memory optimization by backpropagating a squashed Jacobian instead. This + parameter indicates the batch dimension, if any. Defaults to None. .. admonition:: Example From e5e6e9d7eb9a8f9efd2ce627192231ab1d692771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:03:06 +0200 Subject: [PATCH 008/114] Improve error message in _check_module_is_compatible --- src/torchjd/autogram/_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0951277e..db1700a2 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -159,8 +159,11 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " - f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their" - " subclasses)." + f"engine when `batched_dim` is not `None`. The incompatible module types are " + f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to " + f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), " + f"but if you really can't and performance not a priority, you may also just set" + f"`batch_dim=None` when creating the engine." ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: From c6c6a3f06d39f107727fc44191d7c8820a574a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:04:10 +0200 Subject: [PATCH 009/114] Remove parameter description of removed parameter grad_output in compute_gramian --- src/torchjd/autogram/_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index db1700a2..3c23b2aa 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -180,8 +180,6 @@ def compute_gramian(self, output: Tensor) -> Tensor: ``modules``. :param output: The vector to differentiate. Must be a 1-D tensor. - :param grad_output: The tangents for the differentiation. Default to a vector of 1s of the - same shape as `output`. """ if self._batched_dim is not None: From af3373bbb0556ac47b0f27c8d957e1126ed6b586 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:05:24 +0200 Subject: [PATCH 010/114] Rename flat_gramian to square_gramian Not the final name I think, but at least it's consistent with the method name --- src/torchjd/autogram/_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 3c23b2aa..45bf4340 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -199,9 +199,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: reshaped_output = ordered_output.reshape(target_shape) - flat_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) + square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) - unordered_gramian = reshape_gramian(flat_gramian, ordered_shape) + unordered_gramian = reshape_gramian(square_gramian, ordered_shape) if self._batched_dim is not None: gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) From aeb8f3b1154a723ea06f03bd9cbd29e2a8918f35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:07:21 +0200 Subject: [PATCH 011/114] Remove redundant cast to dict in AutogradVJP --- src/torchjd/autogram/_vjp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 68463e47..0f4bcd67 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,7 +102,7 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.tree_spec = tree_flatten(dict(self.named_parameters)) + self.parameters, self.tree_spec = tree_flatten(self.named_parameters) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( From d5b868564d9390cc6e4727bd1044b6d34318a797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:11:47 +0200 Subject: [PATCH 012/114] Rename info to _ in AccumulateJacobian.vmap --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0b51ec1e..ad98ab44 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -102,7 +102,7 @@ def forward(*flat_grad_outputs: Tensor) -> None: ) @staticmethod - def vmap(info, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + def vmap(_, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) # We do not vmap over the args for the non-batched dimension From b789e358be69dee7075f4251a0f34ef42bba7749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:12:20 +0200 Subject: [PATCH 013/114] Type-hint in_dims as PyTree in AccumulateJacobian.vmap --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index ad98ab44..b12d087f 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -102,7 +102,7 @@ def forward(*flat_grad_outputs: Tensor) -> None: ) @staticmethod - def vmap(_, in_dims, *flat_jac_outputs: Tensor) -> tuple[None, None]: + def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) # We do not vmap over the args for the non-batched dimension From e949dffed0d91efc8fafd02354026a1bc67f5788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:15:12 +0200 Subject: [PATCH 014/114] Rename tree_spec to output_spec in ModuleHookManager --- src/torchjd/autogram/_module_hook_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index b12d087f..04919306 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -50,7 +50,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: if self.gramian_accumulation_phase: return output - flat_outputs, tree_spec = tree_flatten(output) + flat_outputs, output_spec = tree_flatten(output) if not any(isinstance(t, Tensor) for t in flat_outputs): # This can happen only if a module returns no Tensor, for instance some niche usage @@ -68,7 +68,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self._target_edges.register(get_gradient_edge(flat_outputs[index])) - return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) + return self._apply_jacobian_accumulator(module, args, output_spec, flat_outputs) _ = module.register_forward_hook(module_hook) @@ -76,7 +76,7 @@ def _apply_jacobian_accumulator( self, module: nn.Module, args: PyTree, - tree_spec: TreeSpec, + output_spec: TreeSpec, flat_outputs: list[Tensor], ) -> PyTree: @@ -90,7 +90,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension - grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) + grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) jacobians = vjp(grad_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { @@ -104,9 +104,9 @@ def forward(*flat_grad_outputs: Tensor) -> None: @staticmethod def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # There is a non-batched dimension - jac_outputs = tree_unflatten(flat_jac_outputs, tree_spec) + jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension - in_dims = (tree_unflatten(in_dims, tree_spec), tree_map(lambda _: None, args)) + in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) self._gramian_accumulator.accumulate_path_jacobians( { @@ -150,4 +150,4 @@ def backward(ctx, *flat_grad_outputs: Tensor): return flat_grad_outputs - return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec) + return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), output_spec) From 16f6aa9d3bbc5544bf1ad875c52f03c0d42fd92f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:15:26 +0200 Subject: [PATCH 015/114] Rename self.tree_spec to self.param_spec in AutogradVJP --- src/torchjd/autogram/_vjp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 0f4bcd67..25450a98 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,10 +102,10 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.tree_spec = tree_flatten(self.named_parameters) + self.parameters, self.param_spec = tree_flatten(self.named_parameters) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True ) - return tree_unflatten(grads, self.tree_spec) + return tree_unflatten(grads, self.param_spec) From ab134a82e8ce04c28a2ed48d4df4cbfe3c3cc565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:21:58 +0200 Subject: [PATCH 016/114] Add example in comment of reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 19affdb2..ef3338a2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -11,6 +11,10 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: `shape + shape[::-1]`. """ + # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: + # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] + # - The returned gramian will be of shape [4, 3, 2, 2, 3, 4] + target_ndim = len(shape) unordered_shape = shape + shape unordered_gramian = gramian.reshape(unordered_shape) From 8d66d57cd0cf096af773755a6e7aec1d50ce4375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 15:23:57 +0200 Subject: [PATCH 017/114] Improve variable names in compute_quadratic_form --- tests/unit/autogram/test_gramian_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 583f44b4..1bc1ed0e 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -9,14 +9,14 @@ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian -def compute_quadratic_form(gramian: Tensor, vector: Tensor) -> Tensor: +def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: """ - Compute the quadratic form x^T G x when the provided Gramian and vector may have multiple + Compute the quadratic form x^T G x when the provided generalized Gramian and x may have multiple dimensions. """ - indices = list(range(vector.ndim)) - linear_form = torch.tensordot(vector, gramian, dims=(indices, indices)) - return torch.tensordot(linear_form, vector, dims=(indices[::-1], indices)) + indices = list(range(x.ndim)) + linear_form = torch.tensordot(x, generalized_gramian, dims=(indices, indices)) + return torch.tensordot(linear_form, x, dims=(indices[::-1], indices)) @mark.parametrize( From 4bdc4008a55b57724096a3f33c50411d746997c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 4 Sep 2025 16:22:41 +0200 Subject: [PATCH 018/114] Revamp documentation of compute_gramian --- src/torchjd/autogram/_engine.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 45bf4340..06d9d249 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -175,11 +175,30 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: ) def compute_gramian(self, output: Tensor) -> Tensor: - """ - Compute the Gramian of the Jacobian of ``output`` with respect the direct parameters of all - ``modules``. - - :param output: The vector to differentiate. Must be a 1-D tensor. + r""" + Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of + all ``modules``. + + .. note:: + This function doesn't require ``output`` to be a vector. For example, if ``output`` is + a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the + parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of + parameters in the model. This is what we call a `generalized Jacobian`. The + corresponding Gramian :math:`G = J J^\top` will be of shape :math:`[m_1, m_2, m_2, m_1]`. + This is what we call a `generalized Gramian`. The number of dimensions of the returned + generalized Gramian will always be twice that of the ``output``. + + A few examples: + - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the + norm of the gradient of ``output``). + - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian + descent). + - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and + multi-task learning, as each sample in the batch has one loss per task). + - etc. + + :param output: The tensor of arbitrary shape to differentiate. The shape of the returned + Gramian depends on the shape of this output, as explained in the note above. """ if self._batched_dim is not None: From 33216c12fb4291cf98ff47131a8393215aa20c9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 5 Sep 2025 16:43:30 +0200 Subject: [PATCH 019/114] Update src/torchjd/autogram/_engine.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 06d9d249..7578dfd8 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -190,7 +190,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: A few examples: - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the - norm of the gradient of ``output``). + squared norm of the gradient of ``output``). - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian descent). - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and From c090dfe2ff457d99a3390f93584fa5a2c6b163c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:40:53 +0200 Subject: [PATCH 020/114] Add ... indexing in jac_output for code clarity --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 7578dfd8..8b1dfc33 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -250,7 +250,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # Need to batch `grad_output` over the first dimension jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) for i in range(non_batched_dim_len): - jac_output[i, i] = 1 + jac_output[i, i, ...] = 1 _ = vmap(differentiation)(jac_output) else: From 3466a3c93ba27ace5da3e23ed94be97081569f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:47:19 +0200 Subject: [PATCH 021/114] Fix formatting of docstring --- src/torchjd/autogram/_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 8b1dfc33..a3c4e90f 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -183,10 +183,11 @@ def compute_gramian(self, output: Tensor) -> Tensor: This function doesn't require ``output`` to be a vector. For example, if ``output`` is a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of - parameters in the model. This is what we call a `generalized Jacobian`. The - corresponding Gramian :math:`G = J J^\top` will be of shape :math:`[m_1, m_2, m_2, m_1]`. - This is what we call a `generalized Gramian`. The number of dimensions of the returned - generalized Gramian will always be twice that of the ``output``. + parameters in the model. This is what we call a generalized Jacobian. The + corresponding Gramian :math:`G = J J^\top` will be of shape + :math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number + of dimensions of the returned generalized Gramian will always be twice that of the + ``output``. A few examples: - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the From d2bab7a1723758466171e89df6b4b35803b746f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:54:30 +0200 Subject: [PATCH 022/114] Add more parametrizations to test_reshape_equivariance --- tests/unit/autogram/test_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b83b87c9..4f2c1aa9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -430,6 +430,14 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp [1, 2, 2, 3], [7, 3, 2, 5], [27, 6, 7], + [3, 2, 1, 1], + [3, 2, 1], + [3, 2], + [3], + [1, 1, 1, 1], + [1, 1, 1], + [1, 1], + [1], ], ) def test_reshape_equivariance(shape: list[int]): From 79e0609e5ff426c6df1bed6832ee0a7d5f472b19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 16:59:57 +0200 Subject: [PATCH 023/114] Improve parametrization of test_movedim_equivariance --- tests/unit/autogram/test_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 4f2c1aa9..a3bc8748 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -473,6 +473,11 @@ def test_reshape_equivariance(shape: list[int]): ([50, 2, 2, 3], [0, 2], [1, 0]), ([60, 3, 2, 5], [1], [2]), ([30, 6, 7], [0, 1], [1, 0]), + ([3, 2], [0], [0]), + ([3], [], []), + ([3, 2, 1], [1, 0], [0, 1]), + ([4, 3, 2], [], []), + ([1, 1, 1], [1, 0], [0, 1]), ], ) def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): From c40f44c90720af2b67d7831e3f5ee2eeaaeec8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:03:43 +0200 Subject: [PATCH 024/114] Improve parametrization of test_batched_non_batched_equivalence --- tests/unit/autogram/test_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index a3bc8748..abf0f159 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -514,6 +514,12 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([3, 2, 5], 1), ([6, 3], 0), ([4, 3, 2], 1), + ([1, 1, 1], 0), + ([1, 1, 1], 1), + ([1, 1, 1], 2), + ([1, 1], 0), + ([1], 0), + ([4, 3, 1], 2), ], ) def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): From 2fe3b51ae70aabc8da6374f7b3c483c3a4910071 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:05:55 +0200 Subject: [PATCH 025/114] Add comment in compute_gramian --- src/torchjd/autogram/_engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index a3c4e90f..f289c6b6 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -218,6 +218,11 @@ def compute_gramian(self, output: Tensor) -> Tensor: target_shape = [-1] + target_shape reshaped_output = ordered_output.reshape(target_shape) + # There are four different cases for the shape of reshaped_output: + # - Not batched and not non-batched: scalar of shape [] + # - Batched only: vector of shape [batch_size] + # - Non-batched only: vector of shape [dim] + # - Batched and non-batched: matrix of shape [dim, batch_size] square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) From d82a8f622d98c87ff3c868d92e985ffc6c716b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:12:15 +0200 Subject: [PATCH 026/114] Improve clarity of reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index ef3338a2..355199c2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -13,13 +13,14 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] - # - The returned gramian will be of shape [4, 3, 2, 2, 3, 4] + # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] + # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] target_ndim = len(shape) - unordered_shape = shape + shape - unordered_gramian = gramian.reshape(unordered_shape) + unordered_gramian = gramian.reshape(shape + shape) last_dims = [target_ndim + i for i in range(target_ndim)] - return unordered_gramian.movedim(last_dims, last_dims[::-1]) + reordered_gramian = unordered_gramian.movedim(last_dims, last_dims[::-1]) + return reordered_gramian def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: From c846ed73382e2617713595c6f19211b05804765f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 17:30:50 +0200 Subject: [PATCH 027/114] Improve clarity of movedim_gramian --- src/torchjd/autogram/_gramian_utils.py | 30 +++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 355199c2..f82b4ece 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -29,16 +29,26 @@ def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the dimensions and on the second half of the dimensions reversed. :param gramian: Gramian to reshape. - :param source: Source dimensions, should be in the range [0, gramian.ndim/2]. Should be unique - :param destination: Destination dimensions, should be in the range [0, gramian.ndim/2]. Should - be unique and should have the same size as `source`. + :param source: Source dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. Its elements should be unique. + :param destination: Destination dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `source`, and its + elements should be unique. """ - length = gramian.ndim // 2 - source = [i if 0 <= i else i + length for i in source] - destination = [i if 0 <= i else i + length for i in destination] + # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `source` of [-2, 2] and destination of [0, 1]: + # - `source_` will be [1, 2] and `destination_` will be [0, 1] + # - `mirrored_source` will be [1, 2, 4, 3] and `mirrored_destination` will be [0, 1, 5, 4] + # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3] - last_index = gramian.ndim - 1 - source_dims = source + [last_index - i for i in source] - destination_dims = destination + [last_index - i for i in destination] - return gramian.movedim(source_dims, destination_dims) + # Map everything to the range [0, gramian.ndim//2[ + length = gramian.ndim // 2 + source_ = [i if 0 <= i else i + length for i in source] + destination_ = [i if 0 <= i else i + length for i in destination] + + # Mirror the source and destination and use the result to move the dimensions of the gramian + last_dim = gramian.ndim - 1 + mirrored_source = source_ + [last_dim - i for i in source_] + mirrored_destination = destination_ + [last_dim - i for i in destination_] + moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) + return moved_gramian From c2fd0a05cabe6eb55d9e7f3317b8bac8aef4452d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:03:36 +0200 Subject: [PATCH 028/114] Revert removal of _handles in ModuleHookManager --- src/torchjd/autogram/_module_hook_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 04919306..9d3ab66a 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,6 +4,7 @@ from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten +from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator @@ -37,6 +38,7 @@ def __init__( self._gramian_accumulator = gramian_accumulator self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = False + self._handles: list[TorchRemovableHandle] = [] def hook_module(self, module: nn.Module) -> None: """ @@ -70,7 +72,8 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: return self._apply_jacobian_accumulator(module, args, output_spec, flat_outputs) - _ = module.register_forward_hook(module_hook) + handle = module.register_forward_hook(module_hook) + self._handles.append(handle) def _apply_jacobian_accumulator( self, From 1cbff18cf7507566ce915af8105e31309c2388c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:54:44 +0200 Subject: [PATCH 029/114] Add more edge cases to test_quadratic_form_invariance_to_reshape --- tests/unit/autogram/test_gramian_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 1bc1ed0e..eefd076a 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -25,6 +25,13 @@ def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: [50, 2, 2, 3], [60, 3, 2, 5], [30, 6, 7], + [4, 3, 1], + [4, 1, 1], + [1, 1, 1], + [4, 1], + [4], + [1, 1], + [1], ], ) def test_quadratic_form_invariance_to_reshape(shape: list[int]): From 693856019cec3df0007409663256692cd6d24424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 18:56:44 +0200 Subject: [PATCH 030/114] Add more edge cases to test_quadratic_form_invariance_to_movedim --- tests/unit/autogram/test_gramian_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index eefd076a..3ea8aa3d 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -64,6 +64,13 @@ def test_quadratic_form_invariance_to_reshape(shape: list[int]): ([50, 2, 2, 3], [0, 2], [1, 0]), ([60, 3, 2, 5], [1], [2]), ([30, 6, 7], [0, 1], [1, 0]), + ([4, 3, 1], [0, 1], [1, 0]), + ([4, 1, 1], [1], [0]), + ([1, 1, 1], [], []), + ([4, 1], [0], [0]), + ([4], [], []), + ([1, 1], [], []), + ([1], [], []), ], ) def test_quadratic_form_invariance_to_movedim( From 1b565580751f8b566eba7060f1da6a81dbd2c18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 19:02:45 +0200 Subject: [PATCH 031/114] Factorize code into _make_path_jacobians and use for-loop --- src/torchjd/autogram/_module_hook_manager.py | 29 +++++++++----------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 9d3ab66a..036c1bc2 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -95,14 +95,8 @@ def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) jacobians = vjp(grad_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian.reshape( - [-1] + list(module.get_parameter(param_name).shape) - ) - for param_name, jacobian in jacobians.items() - } - ) + path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: @@ -111,16 +105,19 @@ def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: # We do not vmap over the args for the non-batched dimension in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian.reshape( - [-1] + list(module.get_parameter(param_name).shape) - ) - for param_name, jacobian in jacobians.items() - } - ) + path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None + @staticmethod + def _make_path_jacobians(jacobians: dict[str, Tensor]) -> dict[Tensor, Tensor]: + path_jacobians: dict[Tensor, Tensor] = {} + for param_name, jacobian in jacobians.items(): + key = module.get_parameter(param_name) + value = jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = value + return path_jacobians + @staticmethod def setup_context(*_): pass From 1d78e15f590e033efb117d95f06e565fe742e088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 5 Sep 2025 19:19:24 +0200 Subject: [PATCH 032/114] Fix model not being moved to cuda in new tests --- tests/unit/autogram/test_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index abf0f159..e4bcf24f 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -355,7 +355,7 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp batched_dim = None input_dim = [shape[0]] - model = Linear(shape[0], shape[1]) + model = Linear(shape[0], shape[1]).to(device=DEVICE) engine = Engine([model], batched_dim=batched_dim) input = randn_(input_dim) @@ -450,7 +450,7 @@ def test_reshape_equivariance(shape: list[int]): input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model]) engine2 = Engine([model]) @@ -490,7 +490,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model]) engine2 = Engine([model]) @@ -533,7 +533,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): batch_size = shape[batched_dim] output_size = input_size - model = Linear(input_size, output_size) + model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model], batched_dim=batched_dim) engine2 = Engine([model]) From 91d739b31e8350cb7bc65ca76213a272cafe6484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:49:32 +0200 Subject: [PATCH 033/114] Make test_equivalence_autojac_autogram also work with non-batched engine At this point, 3 architectures fail: SomeFrozenParam, SomeUnusedParam and MultiOutputWithFrozenBranch --- tests/unit/autogram/test_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index e4bcf24f..aca8e06e 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -154,11 +154,13 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) @mark.parametrize(["aggregator", "weighting"], AGGREGATORS_AND_WEIGHTINGS) +@mark.parametrize("batched_engine", [False, True]) def test_equivalence_autojac_autogram( architecture: type[ShapedModule], batch_size: int, aggregator: Aggregator, weighting: Weighting, + batched_engine: bool, ): """ Tests that the autogram engine gives the same results as the autojac engine on IWRM for several @@ -175,7 +177,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batched_dim=0) + engine = Engine(model_autogram.modules(), batched_dim=0 if batched_engine else None) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) From 92b7b1aee2ce6c734d3008b7025126f68c2a4506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:25:08 +0200 Subject: [PATCH 034/114] Make separation between trainable and frozen params in VJP and rename variables This fixes non-batched engien on SomeFrozenParams architecture --- src/torchjd/autogram/_vjp.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 25450a98..86bbe299 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -24,7 +24,9 @@ class VJP(ABC): def __init__(self, module: nn.Module): self.module = module - self.named_parameters = dict(module.named_parameters(recurse=False)) + named_parameters = dict(module.named_parameters(recurse=False)) + self.trainable_params = {k: v for k, v in named_parameters.items() if v.requires_grad} + self.frozen_params = {k: v for k, v in named_parameters.items() if not v.requires_grad} @abstractmethod def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: @@ -73,22 +75,16 @@ def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, :param inputs: Fixed inputs to the module for the VJP computation. :returns: VJP function that takes cotangents and returns parameter gradients. """ - requires_grad_named_params = { - k: v for k, v in self.named_parameters.items() if v.requires_grad - } - no_requires_grad_named_params = { - k: v for k, v in self.named_parameters.items() if not v.requires_grad - } def functional_model_call(primals: dict[str, Parameter]) -> Tensor: all_state = { **primals, **dict(self.module.named_buffers()), - **no_requires_grad_named_params, + **self.frozen_params, } return torch.func.functional_call(self.module, all_state, inputs) - return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + return torch.func.vjp(functional_model_call, self.trainable_params)[1] class AutogradVJP(VJP): @@ -102,10 +98,10 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs - self.parameters, self.param_spec = tree_flatten(self.named_parameters) + self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( - self.outputs, self.parameters, tree_flatten(grad_outputs)[0], retain_graph=True + self.outputs, self.trainable_params, tree_flatten(grad_outputs)[0], retain_graph=True ) return tree_unflatten(grads, self.param_spec) From 9b92cf7444da7d117bf24dd91af0046f39850a7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:30:41 +0200 Subject: [PATCH 035/114] Add allow_unused=True and materialize_grads=True in call to autograd.grad in AutogradVJP This fixes non-batched engine on SomeUnusedParam --- src/torchjd/autogram/_vjp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 86bbe299..760682cf 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -102,6 +102,11 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: grads = torch.autograd.grad( - self.outputs, self.trainable_params, tree_flatten(grad_outputs)[0], retain_graph=True + self.outputs, + self.trainable_params, + tree_flatten(grad_outputs)[0], + retain_graph=True, + allow_unused=True, + materialize_grads=True, ) return tree_unflatten(grads, self.param_spec) From 7950bb45c2baed239783dcb67e96207728b9a545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:49:14 +0200 Subject: [PATCH 036/114] Stop trying to differentiate outputs that dont require grad in AutogradVJP This fixes non-batched engine on MultiOutputWithFrozenBranch --- src/torchjd/autogram/_vjp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 760682cf..15027502 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -98,13 +98,15 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs + self.mask = [output.requires_grad for output in self.outputs] self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: + flat_grad_outputs = tree_flatten(grad_outputs)[0] grads = torch.autograd.grad( - self.outputs, + [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], self.trainable_params, - tree_flatten(grad_outputs)[0], + [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], retain_graph=True, allow_unused=True, materialize_grads=True, From 4dc3f7c7b0130f8f481337cab0b2022e6a4b434f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 15:55:24 +0200 Subject: [PATCH 037/114] Fix variable name --- src/torchjd/autogram/_vjp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 15027502..4a93f69a 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -99,13 +99,13 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) self.outputs = outputs self.mask = [output.requires_grad for output in self.outputs] - self.trainable_params, self.param_spec = tree_flatten(self.trainable_params) + self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: flat_grad_outputs = tree_flatten(grad_outputs)[0] grads = torch.autograd.grad( [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], - self.trainable_params, + self.flat_trainable_params, [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], retain_graph=True, allow_unused=True, From 23cb0a89f7ed7ee563c956228639caab4dc77c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 16:04:09 +0200 Subject: [PATCH 038/114] Rename jacobians to generalized_jacobians Maybe not a definitive name, but I think it's more clear --- src/torchjd/autogram/_module_hook_manager.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 036c1bc2..7660bc04 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -94,8 +94,8 @@ class AccumulateJacobian(torch.autograd.Function): def forward(*flat_grad_outputs: Tensor) -> None: # There is no non-batched dimension grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) - jacobians = vjp(grad_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + generalized_jacobians = vjp(grad_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod @@ -104,18 +104,20 @@ def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) - jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(jacobians) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None @staticmethod - def _make_path_jacobians(jacobians: dict[str, Tensor]) -> dict[Tensor, Tensor]: + def _make_path_jacobians( + generalized_jacobians: dict[str, Tensor], + ) -> dict[Tensor, Tensor]: path_jacobians: dict[Tensor, Tensor] = {} - for param_name, jacobian in jacobians.items(): + for param_name, generalized_jacobian in generalized_jacobians.items(): key = module.get_parameter(param_name) - value = jacobian.reshape([-1] + list(key.shape)) - path_jacobians[key] = value + jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = jacobian return path_jacobians @staticmethod From 1169b854c4e44dc13717974ed4944c9e9f097667 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 17:08:20 +0200 Subject: [PATCH 039/114] Replace torch.movedim by tensor.movedim --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index f289c6b6..dc67f38d 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -204,7 +204,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: if self._batched_dim is not None: # move batched dim to the end - ordered_output = torch.movedim(output, self._batched_dim, -1) + ordered_output = output.movedim(self._batched_dim, -1) ordered_shape = list(ordered_output.shape) has_non_batched_dim = len(ordered_shape) > 1 target_shape = [ordered_shape[-1]] From 59c957c4ba0455ef8d5e3dc7873f525b5cd11e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 6 Sep 2025 17:17:32 +0200 Subject: [PATCH 040/114] Add batch_size variable in compute_gramian * Small improvement of clarity --- src/torchjd/autogram/_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index dc67f38d..9254aa86 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -206,8 +206,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: # move batched dim to the end ordered_output = output.movedim(self._batched_dim, -1) ordered_shape = list(ordered_output.shape) + batch_size = ordered_shape[-1] has_non_batched_dim = len(ordered_shape) > 1 - target_shape = [ordered_shape[-1]] + target_shape = [batch_size] else: ordered_output = output ordered_shape = list(ordered_output.shape) From 9a01a12bd1ff1ce66fa50649d6b2d89acdea73d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 9 Sep 2025 16:08:24 +0200 Subject: [PATCH 041/114] Add GeneralizedWeighting --- docs/source/docs/aggregation/index.rst | 5 +++++ src/torchjd/aggregation/__init__.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 25 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index e5cfef44..87c7d75c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -17,6 +17,11 @@ Abstract base classes :undoc-members: :exclude-members: forward +.. autoclass:: torchjd.aggregation.GeneralizedWeighting + :members: + :undoc-members: + :exclude-members: forward + .. toctree:: :hidden: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 8d1a9432..fdd4c2a1 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -55,7 +55,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import Weighting +from ._weighting_bases import GeneralizedWeighting, Weighting try: from ._cagrad import CAGrad, CAGradWeighting diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index daa1bdca..154c7a30 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -52,3 +52,28 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu def forward(self, stat: _T) -> Tensor: return self.weighting(self.fn(stat)) + + +class GeneralizedWeighting(nn.Module, ABC): + r""" + Abstract base class for all weightings that operate on generalized Gramians. It has the role of + extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a + generalized Gramian of dimension + :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, generalized_gramian: Tensor) -> Tensor: + """Computes the vector of weights from the input generalized Gramian.""" + + # Override to make type hints and documentation more specific + def __call__(self, generalized_gramian: Tensor) -> Tensor: + """ + Computes the tensor of weights from the input generalized Gramian and applies all registered + hooks. + """ + + return super().__call__(generalized_gramian) From 31812a11d103fd88b4e1fe71d77fc7f937dfa0ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 9 Sep 2025 16:23:32 +0200 Subject: [PATCH 042/114] Add FakeGeneralizedWeighting in tests --- tests/unit/autogram/test_engine.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index fe7d1742..83093a22 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -4,7 +4,7 @@ import pytest import torch from pytest import mark, param -from torch import nn +from torch import Tensor, nn from torch.nn import Linear from torch.optim import SGD from torch.testing import assert_close @@ -60,7 +60,7 @@ ) from utils.tensors import make_tensors, ones_, randn_, zeros_ -from torchjd.aggregation import UPGrad, UPGradWeighting +from torchjd.aggregation import GeneralizedWeighting, UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet @@ -507,3 +507,22 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +class FakeGeneralizedWeighting(GeneralizedWeighting): + """ + Fake GeneralizedWeighting flattening the Gramian and using UPGradWeighting on it. Could be + removed when we implement a proper FlatteningGeneralizedWeighting.""" + + def __init__(self): + super().__init__() + self.weighting = UPGradWeighting() + + def forward(self, generalized_gramian: Tensor) -> Tensor: + k = generalized_gramian.ndim // 2 + shape = generalized_gramian.shape[:k] + m = prod(shape) + square_gramian = reshape_gramian(generalized_gramian, [m]) + weights_vector = self.weighting(square_gramian) + weights = weights_vector.reshape(shape) + return weights From ee16f9b3db3494d213bc4d39323f3bbcf8aff64a Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 9 Sep 2025 17:47:03 +0200 Subject: [PATCH 043/114] reshape_gramian can now take generalized gramians as inputs, its shape can also contain (at most) one element set to -1, the size of that dimension is deduced from the total number of elements --- src/torchjd/autogram/_gramian_utils.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index f82b4ece..4379b7e0 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -1,3 +1,5 @@ +from math import prod + from torch import Tensor @@ -16,11 +18,18 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] - target_ndim = len(shape) - unordered_gramian = gramian.reshape(shape + shape) - last_dims = [target_ndim + i for i in range(target_ndim)] - reordered_gramian = unordered_gramian.movedim(last_dims, last_dims[::-1]) - return reordered_gramian + automatic_dimensions = [i for i in range(len(shape)) if shape[i] == -1] + if len(automatic_dimensions) == 1: + index = automatic_dimensions[0] + current_shape = gramian.shape[: len(gramian.shape) // 2] + numel = prod(current_shape) + specified_numel = -prod(shape) # shape[index] == -1, this is the product of all other dims + shape[index] = numel // specified_numel + + unordered_intput_gramian = _revert_last_dims(gramian) + unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape) + reordered_output_gramian = _revert_last_dims(unordered_output_gramian) + return reordered_output_gramian def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: @@ -52,3 +61,9 @@ def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) mirrored_destination = destination_ + [last_dim - i for i in destination_] moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) return moved_gramian + + +def _revert_last_dims(generalized_gramian: Tensor) -> Tensor: + input_ndim = len(generalized_gramian.shape) // 2 + last_dims = [input_ndim + i for i in range(input_ndim)] + return generalized_gramian.movedim(last_dims, last_dims[::-1]) From 0a2555581aa010928001a62fe378e06ee895ee21 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 10 Sep 2025 11:53:41 +0200 Subject: [PATCH 044/114] Implement `HierarchicalWeighting` Weighting, needs testing. --- .../aggregation/hierachical_weighting.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 src/torchjd/aggregation/hierachical_weighting.py diff --git a/src/torchjd/aggregation/hierachical_weighting.py b/src/torchjd/aggregation/hierachical_weighting.py new file mode 100644 index 00000000..088d5dde --- /dev/null +++ b/src/torchjd/aggregation/hierachical_weighting.py @@ -0,0 +1,105 @@ +import torch +from torch import Tensor + +from ..autogram._gramian_utils import reshape_gramian +from ._weighting_bases import GeneralizedWeighting, Weighting + + +class HierarchicalWeighting(GeneralizedWeighting): + """ + Hierarchically reduces a generalized Gramian using a sequence of weighting functions. + + Applies multiple weightings in sequence to a generalized Gramian ``G`` of shape + ``[n₁, ..., nₖ, nₖ, ..., n₁]``. It first applies the initial weighting to the innermost diagonal + Gramians, contracts those dimensions to form a smaller generalized Gramian, and repeats the + process with subsequent weightings. The final returned weights are chosen so that contracting + the original Gramian directly with these weights produces the same quadratic form as applying + the reductions step by step. + + :param weightings: A list of weighting callables, one for each hierarchical reduction step. + """ + + def __init__(self, weightings: list[Weighting]): + super().__init__() + self.weightings = weightings + self.n_dims = len(weightings) + + def forward(self, generalized_gramian: Tensor) -> Tensor: + + assert len(self.weightings) * 2 == len(generalized_gramian.shape) # temporary + + weighting = self.weightings[0] + dim_size = generalized_gramian.shape[0] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, weights) + + for i in range(self.n_dim): + weighting = self.weightings[i] + dim_size = generalized_gramian.shape[i] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + temp_weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, temp_weights) + weights = _scale_weights(weights, temp_weights) + + return weights + + +def _compute_weights(weighting: Weighting, generalized_gramian: Tensor) -> Tensor: + """ + Apply a weighting to each diagonal Gramian in a generalized Gramian. + + For a generalized Gramian ``G`` of shape ``[m, n, n, m]``, this extracts each diagonal Gramian + ``G[j, :, :, j]`` of shape ``[n, n]`` for ``j`` in ``[m]`` and applies the provided weighting. + The resulting weights are stacked into a tensor of shape ``[m, n]``. + + :param weighting: Callable that maps a Gramian of shape ``[n, n]`` to weights of shape ``[n]``. + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` containing the generalized Gramian. + :returns: Tensor of shape ``[m, n]`` containing the computed weights for each diagonal Gramian. + """ + + weights = torch.zeros( + generalized_gramian[:2], device=generalized_gramian.device, dtype=generalized_gramian.dtype + ) + for i in range(generalized_gramian.shape[0]): + weights[i] = weighting(generalized_gramian[i, :, :, i]) + return weights + + +def _contract_gramian(generalized_gramian: Tensor, weights: Tensor) -> Tensor: + r""" + Compute a partial quadratic form by contracting a generalized Gramian with weight vectors on + both sides. + + Given a generalized Gramian ``G`` of shape ``[m, n, n, m]`` and weights ``w`` of shape + ``[m, n]``, this function computes a Gramian ``G'`` of shape ``[m, m]`` where + + .. math:: + + G'[i, j] = \sum_{k, l=1}^n w[i, k] G[i, k, l, j] w[j, l]. + + This can be viewed as forming a quadratic form with respect to the two innermost dimensions of + ``G``. + + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` representing the generalized + Gramian. + :param weights: Tensor of shape ``[m, n]`` containing weight vectors to contract with the + Gramian. + :returns: Tensor of shape ``[m, m]`` containing the contracted Gramian, i.e. the partial + quadratic form. + """ + left_product = torch.einsum("ij,ijkl->ikl", weights, generalized_gramian) + return torch.einsum("ij,ijl->il", weights, left_product) + + +def _scale_weights(weights: Tensor, scalings: Tensor) -> Tensor: + """ + Scale a tensor along its leading dimensions by broadcasting scaling factors. + + :param weights: Tensor of shape [n₁, ..., nₖ, nₖ₊₁, ..., nₚ]. + :param scalings: Tensor of shape [n₁, ..., nₖ] providing scaling factors for the leading + dimensions of ``weights``. + :returns: Tensor of the same shape as ``weights``, where each slice + ``weights[i₁, ..., iₖ, :, ..., :]`` is multiplied by ``scalings[i₁, ..., iₖ]``. + """ + return weights * scalings[(...,) + (None,) * (weights.ndim - scalings.ndim)] From 16ee3e31dfa9af4bde68c2e8367c7699fdd8b455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 16:52:29 +0200 Subject: [PATCH 045/114] Revert "Implement `HierarchicalWeighting` Weighting, needs testing." This reverts commit 0a2555581aa010928001a62fe378e06ee895ee21. --- .../aggregation/hierachical_weighting.py | 105 ------------------ 1 file changed, 105 deletions(-) delete mode 100644 src/torchjd/aggregation/hierachical_weighting.py diff --git a/src/torchjd/aggregation/hierachical_weighting.py b/src/torchjd/aggregation/hierachical_weighting.py deleted file mode 100644 index 088d5dde..00000000 --- a/src/torchjd/aggregation/hierachical_weighting.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -from torch import Tensor - -from ..autogram._gramian_utils import reshape_gramian -from ._weighting_bases import GeneralizedWeighting, Weighting - - -class HierarchicalWeighting(GeneralizedWeighting): - """ - Hierarchically reduces a generalized Gramian using a sequence of weighting functions. - - Applies multiple weightings in sequence to a generalized Gramian ``G`` of shape - ``[n₁, ..., nₖ, nₖ, ..., n₁]``. It first applies the initial weighting to the innermost diagonal - Gramians, contracts those dimensions to form a smaller generalized Gramian, and repeats the - process with subsequent weightings. The final returned weights are chosen so that contracting - the original Gramian directly with these weights produces the same quadratic form as applying - the reductions step by step. - - :param weightings: A list of weighting callables, one for each hierarchical reduction step. - """ - - def __init__(self, weightings: list[Weighting]): - super().__init__() - self.weightings = weightings - self.n_dims = len(weightings) - - def forward(self, generalized_gramian: Tensor) -> Tensor: - - assert len(self.weightings) * 2 == len(generalized_gramian.shape) # temporary - - weighting = self.weightings[0] - dim_size = generalized_gramian.shape[0] - reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) - weights = _compute_weights(weighting, reshaped_gramian) - generalized_gramian = _contract_gramian(reshaped_gramian, weights) - - for i in range(self.n_dim): - weighting = self.weightings[i] - dim_size = generalized_gramian.shape[i] - reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) - temp_weights = _compute_weights(weighting, reshaped_gramian) - generalized_gramian = _contract_gramian(reshaped_gramian, temp_weights) - weights = _scale_weights(weights, temp_weights) - - return weights - - -def _compute_weights(weighting: Weighting, generalized_gramian: Tensor) -> Tensor: - """ - Apply a weighting to each diagonal Gramian in a generalized Gramian. - - For a generalized Gramian ``G`` of shape ``[m, n, n, m]``, this extracts each diagonal Gramian - ``G[j, :, :, j]`` of shape ``[n, n]`` for ``j`` in ``[m]`` and applies the provided weighting. - The resulting weights are stacked into a tensor of shape ``[m, n]``. - - :param weighting: Callable that maps a Gramian of shape ``[n, n]`` to weights of shape ``[n]``. - :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` containing the generalized Gramian. - :returns: Tensor of shape ``[m, n]`` containing the computed weights for each diagonal Gramian. - """ - - weights = torch.zeros( - generalized_gramian[:2], device=generalized_gramian.device, dtype=generalized_gramian.dtype - ) - for i in range(generalized_gramian.shape[0]): - weights[i] = weighting(generalized_gramian[i, :, :, i]) - return weights - - -def _contract_gramian(generalized_gramian: Tensor, weights: Tensor) -> Tensor: - r""" - Compute a partial quadratic form by contracting a generalized Gramian with weight vectors on - both sides. - - Given a generalized Gramian ``G`` of shape ``[m, n, n, m]`` and weights ``w`` of shape - ``[m, n]``, this function computes a Gramian ``G'`` of shape ``[m, m]`` where - - .. math:: - - G'[i, j] = \sum_{k, l=1}^n w[i, k] G[i, k, l, j] w[j, l]. - - This can be viewed as forming a quadratic form with respect to the two innermost dimensions of - ``G``. - - :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` representing the generalized - Gramian. - :param weights: Tensor of shape ``[m, n]`` containing weight vectors to contract with the - Gramian. - :returns: Tensor of shape ``[m, m]`` containing the contracted Gramian, i.e. the partial - quadratic form. - """ - left_product = torch.einsum("ij,ijkl->ikl", weights, generalized_gramian) - return torch.einsum("ij,ijl->il", weights, left_product) - - -def _scale_weights(weights: Tensor, scalings: Tensor) -> Tensor: - """ - Scale a tensor along its leading dimensions by broadcasting scaling factors. - - :param weights: Tensor of shape [n₁, ..., nₖ, nₖ₊₁, ..., nₚ]. - :param scalings: Tensor of shape [n₁, ..., nₖ] providing scaling factors for the leading - dimensions of ``weights``. - :returns: Tensor of the same shape as ``weights``, where each slice - ``weights[i₁, ..., iₖ, :, ..., :]`` is multiplied by ``scalings[i₁, ..., iₖ]``. - """ - return weights * scalings[(...,) + (None,) * (weights.ndim - scalings.ndim)] From f7955a64c4872dd4da73952b384b4eab8f567c6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 17:05:20 +0200 Subject: [PATCH 046/114] Remove FakeGeneralizedWeighting --- tests/unit/autogram/test_engine.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 83093a22..fe7d1742 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -4,7 +4,7 @@ import pytest import torch from pytest import mark, param -from torch import Tensor, nn +from torch import nn from torch.nn import Linear from torch.optim import SGD from torch.testing import assert_close @@ -60,7 +60,7 @@ ) from utils.tensors import make_tensors, ones_, randn_, zeros_ -from torchjd.aggregation import GeneralizedWeighting, UPGrad, UPGradWeighting +from torchjd.aggregation import UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet @@ -507,22 +507,3 @@ def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) - - -class FakeGeneralizedWeighting(GeneralizedWeighting): - """ - Fake GeneralizedWeighting flattening the Gramian and using UPGradWeighting on it. Could be - removed when we implement a proper FlatteningGeneralizedWeighting.""" - - def __init__(self): - super().__init__() - self.weighting = UPGradWeighting() - - def forward(self, generalized_gramian: Tensor) -> Tensor: - k = generalized_gramian.ndim // 2 - shape = generalized_gramian.shape[:k] - m = prod(shape) - square_gramian = reshape_gramian(generalized_gramian, [m]) - weights_vector = self.weighting(square_gramian) - weights = weights_vector.reshape(shape) - return weights From 98853aed171c7d95e21553856325df70d6f03747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 17:05:33 +0200 Subject: [PATCH 047/114] Add Flattening --- docs/source/docs/aggregation/flattening.rst | 9 +++++++ docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 1 + src/torchjd/aggregation/_flattening.py | 30 +++++++++++++++++++++ 4 files changed, 41 insertions(+) create mode 100644 docs/source/docs/aggregation/flattening.rst create mode 100644 src/torchjd/aggregation/_flattening.py diff --git a/docs/source/docs/aggregation/flattening.rst b/docs/source/docs/aggregation/flattening.rst new file mode 100644 index 00000000..b4ac237d --- /dev/null +++ b/docs/source/docs/aggregation/flattening.rst @@ -0,0 +1,9 @@ +:hide-toc: + +Flattening +========== + +.. autoclass:: torchjd.aggregation.Flattening + :members: + :undoc-members: + :exclude-members: forward diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 87c7d75c..c15d5980 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -33,6 +33,7 @@ Abstract base classes config.rst constant.rst dualproj.rst + flattening.rst graddrop.rst imtl_g.rst krum.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index fdd4c2a1..09355341 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -42,6 +42,7 @@ from ._config import ConFIG from ._constant import Constant, ConstantWeighting from ._dualproj import DualProj, DualProjWeighting +from ._flattening import Flattening from ._graddrop import GradDrop from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py new file mode 100644 index 00000000..558d428c --- /dev/null +++ b/src/torchjd/aggregation/_flattening.py @@ -0,0 +1,30 @@ +from math import prod + +from torch import Tensor + +from torchjd.aggregation._weighting_bases import GeneralizedWeighting, PSDMatrix, Weighting +from torchjd.autogram._gramian_utils import reshape_gramian + + +class Flattening(GeneralizedWeighting): + """ + :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the Gramian, + extracting a vector of weights from it using a + :class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of + weights. + + :param weighting: The weighting to apply to the Gramian matrix. + """ + + def __init__(self, weighting: Weighting[PSDMatrix]): + super().__init__() + self.weighting = weighting + + def forward(self, generalized_gramian: Tensor) -> Tensor: + k = generalized_gramian.ndim // 2 + shape = generalized_gramian.shape[:k] + m = prod(shape) + square_gramian = reshape_gramian(generalized_gramian, [m]) + weights_vector = self.weighting(square_gramian) + weights = weights_vector.reshape(shape) + return weights From fc37e7d7a6db9b951d1d40aa32346fdf3d485544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 19:06:03 +0200 Subject: [PATCH 048/114] Add indications about shapes in documentation of aggregation/__init__.py --- src/torchjd/aggregation/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 09355341..9de24564 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -20,7 +20,8 @@ :class:`Aggregators ` and :class:`Weightings ` are callables that take a Jacobian matrix or a Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either -aggregate a Jacobian or obtain the weights from the Gramian of the Jacobian. +aggregate a Jacobian (of shape ``[m, n]``, where ``m`` is the number of objectives and ``n`` is the +number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape ``[m, m]``). >>> from torch import tensor >>> from torchjd.aggregation import UPGrad, UPGradWeighting From 58fe74126b1d15fc48cd3ab5a968011c2e88811d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 19:07:15 +0200 Subject: [PATCH 049/114] Add GeneralizedWeighting usage example and doctest --- src/torchjd/aggregation/__init__.py | 21 +++++++++++++++++++++ tests/doc/test_aggregation.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9de24564..72b7d98c 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -36,6 +36,27 @@ >>> weights = weighting(gramian) >>> weights tensor([1.1109, 0.7894]) + +When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not +necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian +will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a +:class:`GeneralizedWeighting` to extract +a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest +:class:`GeneralizedWeighting` is +:class:`Flattening`: it simply "flattens" the +generalized Gramian into a square matrix, applies a normal weighting to it to obtain a vector of +weights, and returns the reshaped tensor of weights. + +>>> from torch import ones +>>> from torchjd.aggregation import Flattening, UPGradWeighting +>>> +>>> weighting = Flattening(UPGradWeighting()) +>>> # Generate a generalized Gramian filled with ones, for the sake of the example +>>> generalized_gramian = ones((2, 3, 3, 2)) +>>> weights = weighting(generalized_gramian) +>>> weights +tensor([[0.1667, 0.1667, 0.1667], + [0.1667, 0.1667, 0.1667]]) """ from ._aggregator_bases import Aggregator diff --git a/tests/doc/test_aggregation.py b/tests/doc/test_aggregation.py index 373918ff..75ef0ddc 100644 --- a/tests/doc/test_aggregation.py +++ b/tests/doc/test_aggregation.py @@ -1,5 +1,6 @@ """This file contains the test corresponding to the usage example of Aggregator and Weighting.""" +import torch from torch.testing import assert_close @@ -19,3 +20,16 @@ def test_aggregation_and_weighting(): weights = weighting(gramian) assert_close(weights, tensor([1.1109, 0.7894]), rtol=0, atol=1e-4) + + +def test_generalized_weighting(): + from torch import ones + + from torchjd.aggregation import Flattening, UPGradWeighting + + weighting = Flattening(UPGradWeighting()) + # Generate a generalized Gramian filled with ones, for the sake of the example + generalized_gramian = ones((2, 3, 3, 2)) + weights = weighting(generalized_gramian) + + assert_close(weights, torch.full((2, 3), 0.1667), rtol=0, atol=1e-4) From cc7f78e74423aaea5e72aaa294b6d1e8dfad99c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 10 Sep 2025 19:55:02 +0200 Subject: [PATCH 050/114] Add basic usage example of IWMTL --- docs/source/examples/index.rst | 5 ++++ docs/source/examples/iwmtl.rst | 43 ++++++++++++++++++++++++++++++++++ tests/doc/test_rst.py | 39 ++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 docs/source/examples/iwmtl.rst diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index db34dad5..49c5c1f4 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -18,6 +18,10 @@ This section contains some usage examples for TorchJD. - :doc:`Multi-Task Learning (MTL) ` provides an example of multi-task learning where Jacobian descent is used to optimize the vector of per-task losses of a multi-task model, using the dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`. +- :doc:`Instance-Wise Multi-Task Learning (IWMTL) ` shows how to combine multi-task learning + with instance-wise risk minimization: one loss per task and per element of the batch, using the + :doc:`autogram.Engine <../docs/autogram/engine>` and a :doc:`GeneralizedWeighting + <../docs/aggregation/index>`. - :doc:`Recurrent Neural Network (RNN) ` shows how to apply Jacobian descent to RNN training, with one loss per output sequence element. - :doc:`Monitoring Aggregations ` shows how to monitor the aggregation performed by the @@ -34,6 +38,7 @@ This section contains some usage examples for TorchJD. iwrm.rst partial_jd.rst mtl.rst + iwmtl.rst rnn.rst monitoring.rst lightning_integration.rst diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst new file mode 100644 index 00000000..dc70ae34 --- /dev/null +++ b/docs/source/examples/iwmtl.rst @@ -0,0 +1,43 @@ +Instance-Wise Multi-Task Learning (IWMTL) +========================================= + +TODO + +.. code-block:: python + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import Flattening, UPGradWeighting + from torchjd.autogram import Engine + + shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_module = Linear(3, 1) + task2_module = Linear(3, 1) + params = [ + *shared_module.parameters(), + *task1_module.parameters(), + *task2_module.parameters(), + ] + + loss_fn = MSELoss(reduction="none") + optimizer = SGD(params, lr=0.1) + weighting = Flattening(UPGradWeighting()) + engine = Engine(shared_module.modules(), batched_dim=1) + + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task + task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task + + for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + features = shared_module(input) + output1 = task1_module(features).squeeze(1) + output2 = task2_module(features).squeeze(1) + losses = torch.stack([loss_fn(output1, target1), loss_fn(output2, target2)]) + gramian = engine.compute_gramian(losses) + weights = weighting(gramian) + + optimizer.zero_grad() + losses.backward(weights) + optimizer.step() diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index d7be637e..fb2c9fb7 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -307,6 +307,45 @@ def test_mtl(): optimizer.step() +def test_iwmtl(): + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import Flattening, UPGradWeighting + from torchjd.autogram import Engine + + shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_module = Linear(3, 1) + task2_module = Linear(3, 1) + params = [ + *shared_module.parameters(), + *task1_module.parameters(), + *task2_module.parameters(), + ] + + loss_fn = MSELoss(reduction="none") + optimizer = SGD(params, lr=0.1) + weighting = Flattening(UPGradWeighting()) + engine = Engine(shared_module.modules(), batched_dim=1) + + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task + task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task + + for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + features = shared_module(input) + output1 = task1_module(features).squeeze(1) + output2 = task2_module(features).squeeze(1) + losses = torch.stack([loss_fn(output1, target1), loss_fn(output2, target2)]) + gramian = engine.compute_gramian(losses) + weights = weighting(gramian) + + optimizer.zero_grad() + losses.backward(weights) + optimizer.step() + + def test_partial_jd(): import torch from torch.nn import Linear, MSELoss, ReLU, Sequential From e8085925171886a1e24625b55c5fdeb9112f1f29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 15:21:25 +0200 Subject: [PATCH 051/114] Improve clarity of iwmtl usage example --- docs/source/examples/iwmtl.rst | 24 ++++++++++++++++-------- tests/doc/test_rst.py | 23 +++++++++++++++-------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index dc70ae34..c78aa8f3 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -4,6 +4,7 @@ Instance-Wise Multi-Task Learning (IWMTL) TODO .. code-block:: python + :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -21,23 +22,30 @@ TODO *task2_module.parameters(), ] - loss_fn = MSELoss(reduction="none") optimizer = SGD(params, lr=0.1) + mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batched_dim=1) + engine = Engine(shared_module.modules(), batched_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task for input, target1, target2 in zip(inputs, task1_targets, task2_targets): - features = shared_module(input) - output1 = task1_module(features).squeeze(1) - output2 = task2_module(features).squeeze(1) - losses = torch.stack([loss_fn(output1, target1), loss_fn(output2, target2)]) - gramian = engine.compute_gramian(losses) - weights = weighting(gramian) + features = shared_module(input) # shape: [16, 3] + out1 = task1_module(features).squeeze(1) # shape: [16] + out2 = task2_module(features).squeeze(1) # shape: [16] + + # Compute the matrix of losses: one loss per element of the batch and per task + losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] + + # Compute the gramian (inner products between pairs of gradients of the losses) + gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] + + # Obtain the weights that lead to no conflict between reweighted gradients + weights = weighting(gramian) # shape [16, 2] optimizer.zero_grad() + # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index fb2c9fb7..dfc82cea 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -324,24 +324,31 @@ def test_iwmtl(): *task2_module.parameters(), ] - loss_fn = MSELoss(reduction="none") optimizer = SGD(params, lr=0.1) + mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batched_dim=1) + engine = Engine(shared_module.modules(), batched_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task for input, target1, target2 in zip(inputs, task1_targets, task2_targets): - features = shared_module(input) - output1 = task1_module(features).squeeze(1) - output2 = task2_module(features).squeeze(1) - losses = torch.stack([loss_fn(output1, target1), loss_fn(output2, target2)]) - gramian = engine.compute_gramian(losses) - weights = weighting(gramian) + features = shared_module(input) # shape: [16, 3] + out1 = task1_module(features).squeeze(1) # shape: [16] + out2 = task2_module(features).squeeze(1) # shape: [16] + + # Compute the matrix of losses: one loss per element of the batch and per task + losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] + + # Compute the gramian (inner products between pairs of gradients of the losses) + gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] + + # Obtain the weights that lead to no conflict between reweighted gradients + weights = weighting(gramian) # shape [16, 2] optimizer.zero_grad() + # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() From 064271ab052e496366b74f515e20bc4977a36432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 15:35:37 +0200 Subject: [PATCH 052/114] Add explanation to the IWMTL example --- docs/source/examples/iwmtl.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index c78aa8f3..2f577579 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -1,7 +1,13 @@ Instance-Wise Multi-Task Learning (IWMTL) ========================================= -TODO +When training a model with multiple tasks, the gradients of the individual tasks are likely to +sometimes conflict. This is particularly true when looking at the individual (per-sample) gradients. +The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian +of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from +this Gramian to reweight the gradients and resolve conflict entirely. + +The following example shows how to do that. .. code-block:: python :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 From 4fb4917967927cbc3c6fcdd7d2aee15fd6f3aaaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 15:36:01 +0200 Subject: [PATCH 053/114] Add link to IWMTL in Engine.compute_gramian docstring --- src/torchjd/autogram/_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 9254aa86..486215b9 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -194,8 +194,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: squared norm of the gradient of ``output``). - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian descent). - - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and - multi-task learning, as each sample in the batch has one loss per task). + - 2D (matrix) ``output``: 4D Gramian (this can be used for :doc:`Instance-Wise + Multi-Task Learning (IWMTL) <../../examples/iwmtl>`, as each sample in the batch + has one loss per task). - etc. :param output: The tensor of arbitrary shape to differentiate. The shape of the returned From 483bda1517b2edd53186f7118d1250c28b3e254c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 19:34:58 +0200 Subject: [PATCH 054/114] Remove support for automatic dimension in reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 4379b7e0..40e60121 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -1,5 +1,3 @@ -from math import prod - from torch import Tensor @@ -18,14 +16,6 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] - automatic_dimensions = [i for i in range(len(shape)) if shape[i] == -1] - if len(automatic_dimensions) == 1: - index = automatic_dimensions[0] - current_shape = gramian.shape[: len(gramian.shape) // 2] - numel = prod(current_shape) - specified_numel = -prod(shape) # shape[index] == -1, this is the product of all other dims - shape[index] = numel // specified_numel - unordered_intput_gramian = _revert_last_dims(gramian) unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape) reordered_output_gramian = _revert_last_dims(unordered_output_gramian) From 64be0f5f0ab67535b783e8fc64ca5bc6e86a604a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 19:39:26 +0200 Subject: [PATCH 055/114] Improve examples in reshape_gramian --- src/torchjd/autogram/_gramian_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 40e60121..2fa921f0 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -11,15 +11,14 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: `shape + shape[::-1]`. """ - # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: - # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] - # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] - # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] - - unordered_intput_gramian = _revert_last_dims(gramian) - unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape) - reordered_output_gramian = _revert_last_dims(unordered_output_gramian) - return reordered_output_gramian + # Example 1: `gramian` of shape [4, 3, 2, 2, 3, 4] and `shape` of [8, 3]: + # [4, 3, 2, 2, 3, 4] -(movedim)-> [4, 3, 2, 4, 3, 2] -(reshape)-> [8, 3, 8, 3] -(movedim)-> + # [8, 3, 3, 8] + # + # Example 2: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: + # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4] + + return _revert_last_dims(_revert_last_dims(gramian).reshape(shape + shape)) def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: From e22a095ec7652d878411e24fd023962bf5b116d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 19:48:21 +0200 Subject: [PATCH 056/114] Clean up gramian utils --- src/torchjd/autogram/_gramian_utils.py | 58 ++++++++++++++------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 2fa921f0..10929741 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -1,13 +1,14 @@ from torch import Tensor -def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: +def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: """ Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first half of the target dimensions must be done from the left, while the reshape of the second half must be done from the right. - :param gramian: Gramian to reshape - :param shape: First half of the target shape, the shape of the output is therefore + + :param gramian: Gramian to reshape. Can be a generalized Gramian. + :param half_shape: First half of the target shape, the shape of the output is therefore `shape + shape[::-1]`. """ @@ -18,41 +19,46 @@ def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: # Example 2: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: # [24, 24] -(movedim)-> [24, 24] -(reshape)-> [4, 3, 2, 4, 3, 2] -(movedim)-> [4, 3, 2, 2, 3, 4] - return _revert_last_dims(_revert_last_dims(gramian).reshape(shape + shape)) + return _revert_last_dims(_revert_last_dims(gramian).reshape(half_shape + half_shape)) + + +def _revert_last_dims(gramian: Tensor) -> Tensor: + """Inverts the order of the last half of the dimensions of the input generalized Gramian.""" + + half_ndim = len(gramian.shape) // 2 + last_dims = [half_ndim + i for i in range(half_ndim)] + return gramian.movedim(last_dims, last_dims[::-1]) -def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: +def movedim_gramian(gramian: Tensor, half_source: list[int], half_destination: list[int]) -> Tensor: """ Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the dimensions and on the second half of the dimensions reversed. - :param gramian: Gramian to reshape. - :param source: Source dimensions, that should be in the range - [-gramian.ndim//2, gramian.ndim//2[. Its elements should be unique. - :param destination: Destination dimensions, that should be in the range - [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `source`, and its + + :param gramian: Gramian to reshape. Can be a generalized Gramian. + :param half_source: Source dimensions, that should be in the range [-gramian.ndim//2, + gramian.ndim//2[. Its elements should be unique. + :param half_destination: Destination dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `half_source`, and its elements should be unique. """ - # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `source` of [-2, 2] and destination of [0, 1]: - # - `source_` will be [1, 2] and `destination_` will be [0, 1] - # - `mirrored_source` will be [1, 2, 4, 3] and `mirrored_destination` will be [0, 1, 5, 4] + # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `half_source` of [-2, 2] and + # `half_destination` of [0, 1]: + # - `half_source_` will be [1, 2] and `half_destination_` will be [0, 1] + # - `source` will be [1, 2, 4, 3] and `destination` will be [0, 1, 5, 4] # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3] # Map everything to the range [0, gramian.ndim//2[ - length = gramian.ndim // 2 - source_ = [i if 0 <= i else i + length for i in source] - destination_ = [i if 0 <= i else i + length for i in destination] + half_ndim = gramian.ndim // 2 + half_source_ = [i if 0 <= i else i + half_ndim for i in half_source] + half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination] - # Mirror the source and destination and use the result to move the dimensions of the gramian + # Mirror the half source and the half destination and use the result to move the dimensions of + # the gramian last_dim = gramian.ndim - 1 - mirrored_source = source_ + [last_dim - i for i in source_] - mirrored_destination = destination_ + [last_dim - i for i in destination_] - moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) + source = half_source_ + [last_dim - i for i in half_source_] + destination = half_destination_ + [last_dim - i for i in half_destination_] + moved_gramian = gramian.movedim(source, destination) return moved_gramian - - -def _revert_last_dims(generalized_gramian: Tensor) -> Tensor: - input_ndim = len(generalized_gramian.shape) // 2 - last_dims = [input_ndim + i for i in range(input_ndim)] - return generalized_gramian.movedim(last_dims, last_dims[::-1]) From e3758cf6efa3510b28d4248ab396bdb2c470935d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 11 Sep 2025 22:36:12 +0200 Subject: [PATCH 057/114] Fix typo --- docs/source/examples/iwmtl.rst | 2 +- tests/doc/test_rst.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 2f577579..ed13c62e 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -49,7 +49,7 @@ The following example shows how to do that. gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] # Obtain the weights that lead to no conflict between reweighted gradients - weights = weighting(gramian) # shape [16, 2] + weights = weighting(gramian) # shape: [16, 2] optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index dfc82cea..c97feaf6 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -345,7 +345,7 @@ def test_iwmtl(): gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] # Obtain the weights that lead to no conflict between reweighted gradients - weights = weighting(gramian) # shape [16, 2] + weights = weighting(gramian) # shape: [16, 2] optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights From 119e0058f113ecf8800677d6f37705ec353f293b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 12 Sep 2025 07:42:48 +0200 Subject: [PATCH 058/114] Improve `gramian_utils` tests. --- tests/unit/autogram/test_gramian_utils.py | 117 +++++++++------------- 1 file changed, 48 insertions(+), 69 deletions(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 3ea8aa3d..8d8648cf 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -1,99 +1,78 @@ -from math import prod - import torch from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import rand_ +from utils.tensors import randn_ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian -def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: +def _compute_gramian(matrix: Tensor) -> Tensor: """ - Compute the quadratic form x^T G x when the provided generalized Gramian and x may have multiple - dimensions. + Contracts the last dimension of matrix to make it into a Gramian. """ - indices = list(range(x.ndim)) - linear_form = torch.tensordot(x, generalized_gramian, dims=(indices, indices)) - return torch.tensordot(linear_form, x, dims=(indices[::-1], indices)) + + indices = list(range(len(matrix.shape))) + transposed_matrix = matrix.movedim(indices, indices[::-1]) + return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) @mark.parametrize( - "shape", + ["original_shape", "target_shape"], [ - [50, 2, 2, 3], - [60, 3, 2, 5], - [30, 6, 7], - [4, 3, 1], - [4, 1, 1], - [1, 1, 1], - [4, 1], - [4], - [1, 1], - [1], + ([], []), + ([], [1, 1]), + ([1], []), + ([12], [2, 3, 2]), + ([12], [4, 3]), + ([12], [12]), + ([4, 3], [12]), + ([4, 3], [2, 3, 2]), + ([4, 3], [3, 4]), + ([4, 3], [4, 3]), + ([6, 7, 9], [378]), + ([6, 7, 9], [9, 42]), + ([6, 7, 9], [2, 7, 27]), + ([6, 7, 9], [6, 7, 9]), ], ) -def test_quadratic_form_invariance_to_reshape(shape: list[int]): - """ - When reshaping a Gramian, we expect it to represent the same quadratic form that now applies to - reshaped inputs. So the mapping x -> x^T G x commutes with reshaping x, G and then computing the - corresponding quadratic form. - """ +def test_gramian_equivariance_reshape(original_shape: list[int], target_shape: list[int]): + original_matrix = randn_(original_shape + [2]) + target_matrix = original_matrix.reshape(target_shape + [2]) - flat_dim = prod(shape[1:]) - iterations = 20 + original_gramian = _compute_gramian(original_matrix) + target_gramian = _compute_gramian(target_matrix) - matrix = rand_([flat_dim, shape[0]]) - gramian = matrix @ matrix.T - reshaped_gramian = reshape_gramian(gramian, shape[1:]) + reshaped_gramian = reshape_gramian(original_gramian, target_shape) - for _ in range(iterations): - vector = rand_([flat_dim]) - reshaped_vector = vector.reshape(shape[1:]) - - quadratic_form = vector @ gramian @ vector - reshaped_quadratic_form = compute_quadratic_form(reshaped_gramian, reshaped_vector) - - assert_close(reshaped_quadratic_form, quadratic_form) + assert_close(reshaped_gramian, target_gramian) @mark.parametrize( ["shape", "source", "destination"], [ - ([50, 2, 2, 3], [0, 2], [1, 0]), - ([60, 3, 2, 5], [1], [2]), - ([30, 6, 7], [0, 1], [1, 0]), - ([4, 3, 1], [0, 1], [1, 0]), - ([4, 1, 1], [1], [0]), - ([1, 1, 1], [], []), - ([4, 1], [0], [0]), - ([4], [], []), - ([1, 1], [], []), + ([], [], []), + ([1], [0], [0]), ([1], [], []), + ([1, 1], [], []), + ([1, 1], [1], [0]), + ([6, 7], [1], [0]), + ([3, 1], [0, 1], [1, 0]), + ([1, 1, 1], [], []), + ([3, 2, 5], [], []), + ([1, 1, 1], [2], [0]), + ([3, 2, 5], [1], [2]), + ([2, 2, 3], [0, 2], [1, 0]), + ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_quadratic_form_invariance_to_movedim( - shape: list[int], source: list[int], destination: list[int] -): - """ - When moving dims on a Gramian, we expect it to represent the same quadratic form that now - applies to inputs with moved dims. So the mapping x -> x^T G x commutes with moving dims x, G - and then computing the quadratic form with those. - """ - - flat_dim = prod(shape[1:]) - iterations = 20 - - matrix = rand_([flat_dim, shape[0]]) - gramian = reshape_gramian(matrix @ matrix.T, shape[1:]) - moved_gramian = movedim_gramian(gramian, source, destination) +def test_gramian_equivariance_movedim(shape: list[int], source: list[int], destination: list[int]): + original_matrix = randn_(shape + [2]) + target_matrix = original_matrix.movedim(source, destination) - for _ in range(iterations): - vector = rand_(shape[1:]) - moved_vector = vector.movedim(source, destination) + original_gramian = _compute_gramian(original_matrix) + target_gramian = _compute_gramian(target_matrix) - quadratic_form = compute_quadratic_form(gramian, vector) - moved_quadratic_form = compute_quadratic_form(moved_gramian, moved_vector) + moveddim_gramian = movedim_gramian(original_gramian, source, destination) - assert_close(moved_quadratic_form, quadratic_form) + assert_close(moveddim_gramian, target_gramian) From 639d3c38fdf1f6f1067fdc4b85b561a0748ad39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:39:28 +0200 Subject: [PATCH 059/114] Update tests/unit/autogram/test_gramian_utils.py --- tests/unit/autogram/test_gramian_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 8d8648cf..b0faa902 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -12,7 +12,7 @@ def _compute_gramian(matrix: Tensor) -> Tensor: Contracts the last dimension of matrix to make it into a Gramian. """ - indices = list(range(len(matrix.shape))) + indices = list(range(matrix.ndim)) transposed_matrix = matrix.movedim(indices, indices[::-1]) return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) From f014e05b574eb8c476f68cc8e978ec8475dbd642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 15:35:15 +0200 Subject: [PATCH 060/114] Improve test names and docstrings in test_gramian_utils.py --- tests/unit/autogram/test_gramian_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index b0faa902..31f8c213 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -8,9 +8,7 @@ def _compute_gramian(matrix: Tensor) -> Tensor: - """ - Contracts the last dimension of matrix to make it into a Gramian. - """ + """Contracts the last dimension of matrix to make it into a Gramian.""" indices = list(range(matrix.ndim)) transposed_matrix = matrix.movedim(indices, indices[::-1]) @@ -36,7 +34,9 @@ def _compute_gramian(matrix: Tensor) -> Tensor: ([6, 7, 9], [6, 7, 9]), ], ) -def test_gramian_equivariance_reshape(original_shape: list[int], target_shape: list[int]): +def test_reshape_gramian(original_shape: list[int], target_shape: list[int]): + """Tests that reshape_gramian is such that _compute_gramian is equivariant to a reshape.""" + original_matrix = randn_(original_shape + [2]) target_matrix = original_matrix.reshape(target_shape + [2]) @@ -66,7 +66,9 @@ def test_gramian_equivariance_reshape(original_shape: list[int], target_shape: l ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_gramian_equivariance_movedim(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_gramian(shape: list[int], source: list[int], destination: list[int]): + """Tests that movedim_gramian is such that _compute_gramian is equivariant to a movedim.""" + original_matrix = randn_(shape + [2]) target_matrix = original_matrix.movedim(source, destination) From ebe8a294266fd7f9bb1d59cd8d6756d2abe9626f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 16:55:27 +0200 Subject: [PATCH 061/114] Rename batched_dim to batch_dim --- docs/source/examples/iwmtl.rst | 2 +- src/torchjd/autogram/_engine.py | 44 ++++++++++----------- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 6 +-- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 24 +++++------ 6 files changed, 40 insertions(+), 40 deletions(-) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index ed13c62e..c4665b4b 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -31,7 +31,7 @@ The following example shows how to do that. optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batched_dim=0) + engine = Engine(shared_module.modules(), batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index ccc60dad..e6cefa4b 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -58,10 +58,10 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. - :param batched_dim: If the modules work with batches and process each batch element - independently, then many intermediary jacobians are sparse (block-diagonal), which allows - for a substancial memory optimization by backpropagating a squashed Jacobian instead. This - parameter indicates the batch dimension, if any. Defaults to None. + :param batch_dim: If the modules work with batches and process each batch element independently, + then many intermediary jacobians are sparse (block-diagonal), which allows for a substancial + memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates + the batch dimension, if any. Defaults to None. .. admonition:: Example @@ -84,7 +84,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), batched_dim=0) + >>> engine = Engine(model.modules(), batch_dim=0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] @@ -135,13 +135,13 @@ class Engine: def __init__( self, modules: Iterable[nn.Module], - batched_dim: int | None = None, + batch_dim: int | None = None, ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() - self._batched_dim = batched_dim + self._batch_dim = batch_dim self._module_hook_manager = ModuleHookManager( - self._target_edges, self._gramian_accumulator, batched_dim is not None + self._target_edges, self._gramian_accumulator, batch_dim is not None ) self._hook_modules(modules) @@ -156,10 +156,10 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None: self._module_hook_manager.hook_module(module) def _check_module_is_compatible(self, module: nn.Module) -> None: - if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): + if self._batch_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " - f"engine when `batched_dim` is not `None`. The incompatible module types are " + f"engine when `batch_dim` is not `None`. The incompatible module types are " f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to " f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), " f"but if you really can't and performance not a priority, you may also just set" @@ -203,20 +203,20 @@ def compute_gramian(self, output: Tensor) -> Tensor: Gramian depends on the shape of this output, as explained in the note above. """ - if self._batched_dim is not None: + if self._batch_dim is not None: # move batched dim to the end - ordered_output = output.movedim(self._batched_dim, -1) + ordered_output = output.movedim(self._batch_dim, -1) ordered_shape = list(ordered_output.shape) batch_size = ordered_shape[-1] - has_non_batched_dim = len(ordered_shape) > 1 + has_non_batch_dim = len(ordered_shape) > 1 target_shape = [batch_size] else: ordered_output = output ordered_shape = list(ordered_output.shape) - has_non_batched_dim = len(ordered_shape) > 0 + has_non_batch_dim = len(ordered_shape) > 0 target_shape = [] - if has_non_batched_dim: + if has_non_batch_dim: target_shape = [-1] + target_shape reshaped_output = ordered_output.reshape(target_shape) @@ -229,7 +229,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: self._module_hook_manager.gramian_accumulation_phase = True try: - square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) + square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim) finally: # Reset everything that has a state, even if the previous call raised an exception self._module_hook_manager.gramian_accumulation_phase = False @@ -238,14 +238,14 @@ def compute_gramian(self, output: Tensor) -> Tensor: unordered_gramian = reshape_gramian(square_gramian, ordered_shape) - if self._batched_dim is not None: - gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) + if self._batch_dim is not None: + gramian = movedim_gramian(unordered_gramian, [-1], [self._batch_dim]) else: gramian = unordered_gramian return gramian - def _compute_square_gramian(self, output: Tensor, has_non_batched_dim: bool) -> Tensor: + def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> Tensor: leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: @@ -256,14 +256,14 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: retain_graph=True, ) - if has_non_batched_dim: + if has_non_batch_dim: # There is one non-batched dimension, it is the first one - non_batched_dim_len = output.shape[0] + non_batch_dim_len = output.shape[0] jac_output_shape = [output.shape[0]] + list(output.shape) # Need to batch `grad_output` over the first dimension jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) - for i in range(non_batched_dim_len): + for i in range(non_batch_dim_len): jac_output[i, i, ...] = 1 _ = vmap(differentiation)(jac_output) diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index e0e3117f..815ef315 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules(), batched_dim=0) + engine = Engine(model.modules(), batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index c97feaf6..adc9b2d4 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), batched_dim=0) + engine = Engine(model.modules(), batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -327,7 +327,7 @@ def test_iwmtl(): optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batched_dim=0) + engine = Engine(shared_module.modules(), batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task @@ -371,7 +371,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), batched_dim=0) + engine = Engine(model[2:].modules(), batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index cd99bcd6..a41b2905 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -119,7 +119,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), batched_dim=0) + engine = Engine(model.modules(), batch_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6a06d4d9..814b58e2 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -177,7 +177,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str]): autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - engine = Engine(gramian_modules, batched_dim=0) + engine = Engine(gramian_modules, batch_dim=0) output = model(input) losses = loss_fn(output) @@ -262,7 +262,7 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), batched_dim=0) + _ = Engine(model.modules(), batch_dim=0) @mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) @@ -277,14 +277,14 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp is_batched = batch_size is not None if is_batched: - batched_dim = 0 + batch_dim = 0 input_dim = [batch_size, shape[0]] else: - batched_dim = None + batch_dim = None input_dim = [shape[0]] model = Linear(shape[0], shape[1]).to(device=DEVICE) - engine = Engine([model], batched_dim=batched_dim) + engine = Engine([model], batch_dim=batch_dim) input = randn_(input_dim) output = model(input) @@ -436,7 +436,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: @mark.parametrize( - ["shape", "batched_dim"], + ["shape", "batch_dim"], [ ([2, 5, 3, 2], 2), ([3, 2, 5], 1), @@ -450,25 +450,25 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([4, 3, 1], 2), ], ) -def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): +def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): """ Tests that for a vector with some batched dimensions, the gramian is the same if we use the - appropriate `batched_dims` or if we don't use any. + appropriate `batch_dims` or if we don't use any. """ - non_batched_shape = [shape[i] for i in range(len(shape)) if i != batched_dim] + non_batched_shape = [shape[i] for i in range(len(shape)) if i != batch_dim] input_size = prod(non_batched_shape) - batch_size = shape[batched_dim] + batch_size = shape[batch_dim] output_size = input_size model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batched_dim=batched_dim) + engine1 = Engine([model], batch_dim=batch_dim) engine2 = Engine([model]) input = randn_([batch_size, input_size]) output = model(input) output = output.reshape([batch_size] + non_batched_shape) - output = output.movedim(0, batched_dim) + output = output.movedim(0, batch_dim) gramian1 = engine1.compute_gramian(output) gramian2 = engine2.compute_gramian(output) From 4112216140a98c5a74fbbfde37d2acfbf7f9abe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 15 Sep 2025 17:16:56 +0200 Subject: [PATCH 062/114] Fix typo --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 814b58e2..f6117aa1 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -453,7 +453,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): """ Tests that for a vector with some batched dimensions, the gramian is the same if we use the - appropriate `batch_dims` or if we don't use any. + appropriate `batch_dim` or if we don't use any. """ non_batched_shape = [shape[i] for i in range(len(shape)) if i != batch_dim] From 676f93e3b6bb2b6dfac70c6a03d1f951e7b1b8e1 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 22 Sep 2025 09:48:20 +0200 Subject: [PATCH 063/114] Resolve typing of tuple in setup_context. --- src/torchjd/autogram/_module_hook_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 8b915449..eb08be2c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -165,10 +165,12 @@ def forward( ) -> tuple[Tensor, ...]: return tuple([x.detach() for x in xs]) + # For Python version > 3.10, the type of `inputs` should become + # tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, - inputs: tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module], + inputs: tuple, _, ): ctx.gramian_accumulation_phase = inputs[0] From 0d21d60be865be0b1e98667c5da3f93d91f5259c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 22 Sep 2025 20:32:57 +0200 Subject: [PATCH 064/114] Add WithBatchNorm and improve test_incompatible_modules --- tests/unit/autogram/test_engine.py | 15 ++++++++++++--- tests/utils/architectures.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index c3081dee..ceb5cdb9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -44,6 +44,7 @@ SomeUnusedOutput, SomeUnusedParam, SqueezeNet, + WithBatchNorm, WithBuffered, WithModuleTrackingRunningStats, WithNoTensorOutput, @@ -255,14 +256,22 @@ def test_autograd_while_modules_are_hooked( assert engine._gramian_accumulator.gramian is None -@mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats]) -def test_incompatible_modules(architecture: type[nn.Module]): +@mark.parametrize( + ["architecture", "batch_dim"], + [ + (WithModuleTrackingRunningStats, 0), + (WithModuleTrackingRunningStats, None), + (WithRNN, 0), + (WithBatchNorm, 0), + ], +) +def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | None): """Tests that the engine cannot be constructed with incompatible modules.""" model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), batch_dim=0) + _ = Engine(model.modules(), batch_dim=batch_dim) @mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 6a3d4c97..66d52323 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -687,6 +687,20 @@ def forward(self, input: Tensor) -> Tensor: return self.instance_norm(input) +class WithBatchNorm(ShapedModule): + """Simple model containing a BatchNorm layer.""" + + INPUT_SHAPES = (3, 6, 6) + OUTPUT_SHAPES = (3, 6, 6) + + def __init__(self): + super().__init__() + self.batch_norm = nn.BatchNorm2d(3, affine=True, track_running_stats=False) + + def forward(self, input: Tensor) -> Tensor: + return self.batch_norm(input) + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the From 269247d24f0b52298b2e7e7350495c13c77fc227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 22 Sep 2025 20:34:41 +0200 Subject: [PATCH 065/114] Add batch_dim parameter to tests when possible --- tests/unit/autogram/test_engine.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index ceb5cdb9..6f400525 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -110,7 +110,8 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_compute_gramian(architecture: type[ShapedModule], batch_size: int): +@mark.parametrize("batch_dim", [0, None]) +def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batch_dim: int | None): """Tests that the autograd and the autogram engines compute the same gramian.""" input_shapes = architecture.INPUT_SHAPES @@ -121,7 +122,7 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int): torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), batch_dim=batch_dim) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) @@ -148,7 +149,8 @@ def _non_empty_subsets(elements: set) -> list[set]: @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) -def test_compute_partial_gramian(gramian_module_names: set[str]): +@mark.parametrize("batch_dim", [0, None]) +def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None): """ Tests that the autograd and the autogram engines compute the same gramian when only a subset of the model parameters is specified. @@ -178,7 +180,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str]): autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - engine = Engine(gramian_modules, batch_dim=0) + engine = Engine(gramian_modules, batch_dim=batch_dim) output = model(input) losses = loss_fn(output) @@ -188,7 +190,10 @@ def test_compute_partial_gramian(gramian_module_names: set[str]): @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: int): +@mark.parametrize("batch_dim", [0, None]) +def test_iwrm_steps_with_autogram( + architecture: type[ShapedModule], batch_size: int, batch_dim: int | None +): """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 @@ -200,7 +205,7 @@ def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: model = architecture().to(device=DEVICE) - engine = Engine(model.modules()) + engine = Engine(model.modules(), batch_dim=batch_dim) optimizer = SGD(model.parameters(), lr=1e-7) for i in range(n_iter): @@ -216,8 +221,9 @@ def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("compute_gramian", [False, True]) +@mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - architecture: type[ShapedModule], batch_size: int, compute_gramian: bool + architecture: type[ShapedModule], batch_size: int, compute_gramian: bool, batch_dim: int | None ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd @@ -238,7 +244,7 @@ def test_autograd_while_modules_are_hooked( autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} # Hook modules and optionally compute the Gramian - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), batch_dim=batch_dim) if compute_gramian: torch.manual_seed(0) # Fix randomness for random models output = model_autogram(input) @@ -377,7 +383,8 @@ def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_outp [1], ], ) -def test_reshape_equivariance(shape: list[int]): +@mark.parametrize("batch_dim", [0, None]) +def test_reshape_equivariance(shape: list[int], batch_dim: int | None): """ Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape the `output` to some `shape`, then the result is the same as reshaping the Gramian to the @@ -472,7 +479,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): model = Linear(input_size, output_size).to(device=DEVICE) engine1 = Engine([model], batch_dim=batch_dim) - engine2 = Engine([model]) + engine2 = Engine([model], batch_dim=None) input = randn_([batch_size, input_size]) output = model(input) From 437143721247b6dc3d33e547a037f74cb4269b77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 23 Sep 2025 19:08:25 +0200 Subject: [PATCH 066/114] Import Callable from collections.abc --- tests/utils/forward_backwards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 28611409..42ea9412 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch from torch import Tensor, nn, vmap From a06cd7af1d627049c6accb0483327ab375eeb694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 23 Sep 2025 19:20:56 +0200 Subject: [PATCH 067/114] Change loss_fn to not handle reduction itself --- tests/unit/autogram/test_engine.py | 11 +++--- tests/utils/forward_backwards.py | 55 ++++++++++++++++++------------ 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6f400525..945c43b5 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -58,6 +58,7 @@ autogram_forward_backward, compute_gramian_with_autograd, make_mse_loss_fn, + reduce_to_vector, ) from utils.tensors import make_tensors, ones_, randn_, zeros_ @@ -130,12 +131,12 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc torch.random.manual_seed(0) # Fix randomness for random models output = model_autograd(inputs) - losses = loss_fn(output) + losses = reduce_to_vector(loss_fn(output)) autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters())) torch.random.manual_seed(0) # Fix randomness for random models output = model_autogram(inputs) - losses = loss_fn(output) + losses = reduce_to_vector(loss_fn(output)) autogram_gramian = engine.compute_gramian(losses) assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) @@ -170,7 +171,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int model = architecture().to(device=DEVICE) output = model(input) - losses = loss_fn(output) + losses = reduce_to_vector(loss_fn(output)) gramian_modules = [model.get_submodule(name) for name in gramian_module_names] gramian_params = [] @@ -183,7 +184,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int engine = Engine(gramian_modules, batch_dim=batch_dim) output = model(input) - losses = loss_fn(output) + losses = reduce_to_vector(loss_fn(output)) gramian = engine.compute_gramian(losses) assert_close(gramian, autograd_gramian) @@ -248,7 +249,7 @@ def test_autograd_while_modules_are_hooked( if compute_gramian: torch.manual_seed(0) # Fix randomness for random models output = model_autogram(input) - losses = loss_fn(output) + losses = reduce_to_vector(loss_fn(output)) _ = engine.compute_gramian(losses) # Verify that even with the hooked modules, autograd works normally when not using the engine. diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 42ea9412..bea15bfb 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -13,7 +13,7 @@ def autograd_forward_backward( model: nn.Module, inputs: PyTree, - loss_fn: Callable[[PyTree], Tensor], + loss_fn: Callable[[PyTree], list[Tensor]], ) -> None: losses = _forward_pass(model, inputs, loss_fn) losses.sum().backward() @@ -22,7 +22,7 @@ def autograd_forward_backward( def autojac_forward_backward( model: nn.Module, inputs: PyTree, - loss_fn: Callable[[PyTree], Tensor], + loss_fn: Callable[[PyTree], list[Tensor]], aggregator: Aggregator, ) -> None: losses = _forward_pass(model, inputs, loss_fn) @@ -33,7 +33,7 @@ def autograd_gramian_forward_backward( model: nn.Module, inputs: PyTree, params: list[nn.Parameter], - loss_fn: Callable[[PyTree], Tensor], + loss_fn: Callable[[PyTree], list[Tensor]], weighting: Weighting, ) -> None: losses = _forward_pass(model, inputs, loss_fn) @@ -46,45 +46,56 @@ def autogram_forward_backward( engine: Engine, weighting: Weighting, inputs: PyTree, - loss_fn: Callable[[PyTree], Tensor], + loss_fn: Callable[[PyTree], list[Tensor]], ) -> None: losses = _forward_pass(model, inputs, loss_fn) gramian = engine.compute_gramian(losses) losses.backward(weighting(gramian)) -def _forward_pass(model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], Tensor]) -> PyTree: +def _forward_pass( + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] +) -> PyTree: output = model(inputs) assert tree_map(lambda t: t.shape[1:], output) == model.OUTPUT_SHAPES - losses = loss_fn(output) + loss_tensors = loss_fn(output) + losses = reduce_to_vector(loss_tensors) return losses -def make_mse_loss_fn(targets: PyTree) -> Callable[[PyTree], Tensor]: - def mse_loss_fn(outputs: PyTree) -> Tensor: +def make_mse_loss_fn(targets: PyTree) -> Callable[[PyTree], list[Tensor]]: + def mse_loss_fn(outputs: PyTree) -> list[Tensor]: flat_outputs, _ = tree_flatten(outputs) flat_targets, _ = tree_flatten(targets) - # For each (output_i, target_i) pair, compute the MSE at each coordinate and store it in - # a matrix of shape [batch_size, dim_i], where dim_i is the number of elements of - # output_i and target_i. Concatenate them along dim=1 to obtain a matrix of MSEs of - # shape [batch_size, dim], where dim is the total number of elements of the outputs. - # Then, reduce this into a vector of losses of size [batch_size], by applying the mean - # along dim=1. - losses = torch.concatenate( - [ - reshape_raw_losses(mse_loss(output, target, reduction="none")) - for output, target in zip(flat_outputs, flat_targets) - ], - dim=1, - ).mean(dim=1) - return losses + loss_tensors = [ + mse_loss(output, target, reduction="none") + for output, target in zip(flat_outputs, flat_targets) + ] + + return loss_tensors return mse_loss_fn +def reduce_to_first_tensor(loss_tensors: list[Tensor]) -> Tensor: + return loss_tensors[0] + + +def reduce_to_matrix(loss_tensors: list[Tensor]) -> Tensor: + return torch.concat([reshape_raw_losses(t) for t in loss_tensors], dim=1) + + +def reduce_to_vector(loss_tensors: list[Tensor]) -> Tensor: + return reduce_to_matrix(loss_tensors).mean(dim=1) + + +def reduce_to_scalar(loss_tensors: list[Tensor]) -> Tensor: + return reduce_to_matrix(loss_tensors).mean() + + def reshape_raw_losses(raw_losses: Tensor) -> Tensor: assert raw_losses.ndim > 0 From ad0fb1d43954324b609388596b9aabbc22615ed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 23 Sep 2025 19:22:18 +0200 Subject: [PATCH 068/114] Add test_compute_gramian_various_output_shapes --- tests/unit/autogram/test_engine.py | 74 +++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 945c43b5..c69844b4 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,10 +1,11 @@ +from collections.abc import Callable from itertools import combinations from math import prod import pytest import torch from pytest import mark, param -from torch import nn +from torch import Tensor, nn from torch.nn import Linear from torch.optim import SGD from torch.testing import assert_close @@ -58,6 +59,9 @@ autogram_forward_backward, compute_gramian_with_autograd, make_mse_loss_fn, + reduce_to_first_tensor, + reduce_to_matrix, + reduce_to_scalar, reduce_to_vector, ) from utils.tensors import make_tensors, ones_, randn_, zeros_ @@ -142,6 +146,74 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) +@mark.parametrize("batch_size", [1, 3, 16]) +@mark.parametrize( + ["reduction", "movedim_source", "movedim_destination", "batch_dim"], + [ + # 0D + (reduce_to_scalar, [], [], None), # () + # 1D + (reduce_to_vector, [], [], 0), # (batch_size,) + (reduce_to_vector, [], [], None), # (batch_size,) + # 2D + (reduce_to_matrix, [], [], 0), # (batch_size, d1 * d2) + (reduce_to_matrix, [], [], None), # (batch_size, d1 * d2) + (reduce_to_matrix, [0], [1], 1), # (d1 * d2, batch_size) + (reduce_to_matrix, [0], [1], None), # (d1 * d2, batch_size) + # 3D + (reduce_to_first_tensor, [], [], 0), # (batch_size, d1, d2) + (reduce_to_first_tensor, [], [], None), # (batch_size, d1, d2) + (reduce_to_first_tensor, [0], [1], 1), # (d1, batch_size, d2) + (reduce_to_first_tensor, [0], [1], None), # (d1, batch_size, d2) + (reduce_to_first_tensor, [0], [2], 2), # (d2, d1, batch_size) + (reduce_to_first_tensor, [0], [2], None), # (d2, d1, batch_size) + ], +) +def test_compute_gramian_various_output_shapes( + batch_size: int | None, + reduction: Callable[[list[Tensor]], Tensor], + batch_dim: int | None, + movedim_source: list[int], + movedim_destination: list[int], +): + """ + Tests that the autograd and the autogram engines compute the same gramian when the output can + have various different shapes, and can be batched in any of its dimensions. + """ + + architecture = Ndim2Output + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model_autograd = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_autogram = architecture().to(device=DEVICE) + + engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_autograd(inputs) + losses = reduction(loss_fn(output)) + reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination) + # Go back to a vector so that compute_gramian_with_autograd works + loss_vector = reshaped_losses.reshape([-1]) + autograd_gramian = compute_gramian_with_autograd(loss_vector, list(model_autograd.parameters())) + expected_gramian = reshape_gramian(autograd_gramian, list(reshaped_losses.shape)) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_autogram(inputs) + losses = reduction(loss_fn(output)) + reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination) + autogram_gramian = engine.compute_gramian(reshaped_losses) + + assert_close(autogram_gramian, expected_gramian, rtol=1e-4, atol=1e-5) + + def _non_empty_subsets(elements: set) -> list[set]: """ Generates the list of subsets of the given set, excluding the empty set. From a4ca5f81384317c2bda3f8fc6e9d3086c9902e8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 23 Sep 2025 19:25:20 +0200 Subject: [PATCH 069/114] Move compute_gramian to test utils --- tests/unit/autogram/test_gramian_utils.py | 23 +++++++---------------- tests/utils/forward_backwards.py | 8 ++++++++ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 31f8c213..fde35252 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -1,20 +1,11 @@ -import torch from pytest import mark -from torch import Tensor from torch.testing import assert_close +from utils.forward_backwards import compute_gramian from utils.tensors import randn_ from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian -def _compute_gramian(matrix: Tensor) -> Tensor: - """Contracts the last dimension of matrix to make it into a Gramian.""" - - indices = list(range(matrix.ndim)) - transposed_matrix = matrix.movedim(indices, indices[::-1]) - return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) - - @mark.parametrize( ["original_shape", "target_shape"], [ @@ -35,13 +26,13 @@ def _compute_gramian(matrix: Tensor) -> Tensor: ], ) def test_reshape_gramian(original_shape: list[int], target_shape: list[int]): - """Tests that reshape_gramian is such that _compute_gramian is equivariant to a reshape.""" + """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" original_matrix = randn_(original_shape + [2]) target_matrix = original_matrix.reshape(target_shape + [2]) - original_gramian = _compute_gramian(original_matrix) - target_gramian = _compute_gramian(target_matrix) + original_gramian = compute_gramian(original_matrix) + target_gramian = compute_gramian(target_matrix) reshaped_gramian = reshape_gramian(original_gramian, target_shape) @@ -67,13 +58,13 @@ def test_reshape_gramian(original_shape: list[int], target_shape: list[int]): ], ) def test_movedim_gramian(shape: list[int], source: list[int], destination: list[int]): - """Tests that movedim_gramian is such that _compute_gramian is equivariant to a movedim.""" + """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" original_matrix = randn_(shape + [2]) target_matrix = original_matrix.movedim(source, destination) - original_gramian = _compute_gramian(original_matrix) - target_gramian = _compute_gramian(target_matrix) + original_gramian = compute_gramian(original_matrix) + target_gramian = compute_gramian(target_matrix) moveddim_gramian = movedim_gramian(original_gramian, source, destination) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index bea15bfb..1eb07092 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -130,3 +130,11 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]: gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) return gramian + + +def compute_gramian(matrix: Tensor) -> Tensor: + """Contracts the last dimension of matrix to make it into a Gramian.""" + + indices = list(range(matrix.ndim)) + transposed_matrix = matrix.movedim(indices, indices[::-1]) + return torch.tensordot(matrix, transposed_matrix, dims=([-1], [0])) From dfefb24e01b259b7a3554e5ef1cb19f117022b0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 23 Sep 2025 19:25:57 +0200 Subject: [PATCH 070/114] Revamp test_gramian_is_correct into test_compute_gramian_manual --- tests/unit/autogram/test_engine.py | 87 ++++++------------------------ 1 file changed, 15 insertions(+), 72 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index c69844b4..7c56168a 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -57,6 +57,7 @@ from utils.forward_backwards import ( autograd_forward_backward, autogram_forward_backward, + compute_gramian, compute_gramian_with_autograd, make_mse_loss_fn, reduce_to_first_tensor, @@ -353,88 +354,30 @@ def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | No _ = Engine(model.modules(), batch_dim=batch_dim) -@mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) -@mark.parametrize("batch_size", [None, 3, 16, 32]) -@mark.parametrize("reduce_output", [True, False]) -def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): +def test_compute_gramian_manual(): """ Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected Gramian. """ - is_batched = batch_size is not None + in_dims = 18 + out_dims = 25 - if is_batched: - batch_dim = 0 - input_dim = [batch_size, shape[0]] - else: - batch_dim = None - input_dim = [shape[0]] - - model = Linear(shape[0], shape[1]).to(device=DEVICE) - engine = Engine([model], batch_dim=batch_dim) + torch.manual_seed(0) + model = Linear(in_dims, out_dims).to(device=DEVICE) + engine = Engine(model.modules(), batch_dim=None) - input = randn_(input_dim) + input = randn_(in_dims) output = model(input) - if reduce_output: - output = torch.sum(output, dim=-1) - - assert output.ndim == int(not reduce_output) + int(is_batched) - gramian = engine.compute_gramian(output) - # compute the expected gramian - output_shape = list(output.shape) - initial_jacobian = torch.diag(ones_(output.numel())).reshape(output_shape + output_shape) - - if reduce_output: - initial_jacobian = initial_jacobian.unsqueeze(-1).repeat( - ([1] * initial_jacobian.ndim) + [shape[1]] - ) - if not is_batched: - initial_jacobian = initial_jacobian.unsqueeze(-2) - input = input.unsqueeze(0) - - assert initial_jacobian.shape[-2] == (1 if batch_size is None else batch_size) - assert initial_jacobian.shape[-1] == shape[1] - assert initial_jacobian.shape[:-2] == output.shape - - assert input.shape[0] == (1 if batch_size is None else batch_size) - assert input.shape[1] == shape[0] - - # If k is the batch_size (1 if None) and n the input size and m the output size, then - # - input has shape `[k, n]` - # - initial_jacobian has shape `output.shape + `[k, m]` - - # The partial (batched) jacobian of outputs w.r.t. weights is of shape `[k, m, m, n]`, whe - # multiplied (along 2 dims) by initial_jacobian this yields the jacobian of the weights of shape - # `output.shape + [m, n]`. The partial jacobian itself is block diagonal with diagonal defined - # by `partial_weight_jacobian[i, j, j] = input[i]` (other elements are 0). - - partial_weight_jacobian = zeros_([input.shape[0], shape[1], shape[1], shape[0]]) - for j in range(shape[1]): - partial_weight_jacobian[:, j, j, :] = input - weight_jacobian = torch.tensordot( - initial_jacobian, partial_weight_jacobian, dims=([-2, -1], [0, 1]) - ) - weight_gramian = torch.tensordot(weight_jacobian, weight_jacobian, dims=([-2, -1], [-2, -1])) - if weight_gramian.ndim == 4: - weight_gramian = weight_gramian.movedim((-2), (-1)) - - # The partial (batched) jacobian of outputs w.r.t. bias is of shape `[k, m, m]`, when multiplied - # (along 2 dims) by initial_jacobian this yields the jacobian of the bias of shape - # `output.shape + [m]`. The partial jacobian itself is block diagonal with diagonal defined by - # `partial_bias_jacobian[i, j, j] = 1` (other elements are 0). - partial_bias_jacobian = zeros_([input.shape[0], shape[1], shape[1]]) - for j in range(shape[1]): - partial_bias_jacobian[:, j, j] = 1.0 - bias_jacobian = torch.tensordot( - initial_jacobian, partial_bias_jacobian, dims=([-2, -1], [0, 1]) - ) - bias_gramian = torch.tensordot(bias_jacobian, bias_jacobian, dims=([-1], [-1])) - if bias_gramian.ndim == 4: - bias_gramian = bias_gramian.movedim(-2, -1) - + # Compute the expected gramian + weight_jacobian = zeros_([out_dims, model.weight.numel()]) + for j in range(out_dims): + weight_jacobian[j, j * in_dims : (j + 1) * in_dims] = input + weight_gramian = compute_gramian(weight_jacobian) + bias_jacobian = torch.diag(ones_(out_dims)) + bias_gramian = compute_gramian(bias_jacobian) expected_gramian = weight_gramian + bias_gramian assert_close(gramian, expected_gramian) From 3c8a217764628f3fc17893eca6d035a772550733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 01:27:36 +0200 Subject: [PATCH 071/114] Add link to IWMTL in main doc page --- docs/source/index.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 8bf079b1..cbcafd9f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,8 +38,8 @@ per-task losses has to be minimized. To start using TorchJD for multi-task learn Another more interesting application is to consider separately the loss of each element in the batch. This is what we define as :doc:`Instance-Wise Risk Minimization ` (IWRM). -For IWRM, in many cases, there exists an algorithm that is both equivalent to Jacobian descent, and -much more efficient. This algorithm, called Gramian-based Jacobian descent, consists in computing +There exists an algorithm that is in many cases equivalent to Jacobian descent, and much more +efficient. This algorithm, called Gramian-based Jacobian descent, consists in computing the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full Jacobian in memory), weighting the losses using the information of the Gramian, and then computing the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to @@ -48,6 +48,11 @@ Algorithm 3 of documentation and usage example of this algorithm is provided in :doc:`autogram.Engine `. +The primary usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently +for :doc:`IWRM `. It can also be used when considering one loss per element of the +batch and per task, in the context of multi-task learning. We call this :doc:`Instance-Wise Risk +Multi-Task Learning ` (IWMTL). + TorchJD is open-source, under MIT License. The source code is available on `GitHub `_. From 805292fee2ea900831274898c65727f583ebac2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 01:38:32 +0200 Subject: [PATCH 072/114] Improve explanation about generalized weightings --- src/torchjd/aggregation/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 3a83e960..19dc060f 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -43,8 +43,9 @@ a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest :class:`GeneralizedWeighting` is :class:`Flattening`: it simply "flattens" the -generalized Gramian into a square matrix, applies a normal weighting to it to obtain a vector of -weights, and returns the reshaped tensor of weights. +generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``), +applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of +weights. >>> from torch import ones >>> from torchjd.aggregation import Flattening, UPGradWeighting From c2c0988c66117151e7c478bd2b693307558f8029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 01:38:51 +0200 Subject: [PATCH 073/114] Add note in iwmtl --- docs/source/examples/iwmtl.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index c4665b4b..923da7b4 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -55,3 +55,11 @@ The following example shows how to do that. # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + +.. note:: + In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a + 4D tensor rather than a matrix, and a + :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as + :class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of + weights from it. More information about ``GeneralizedWeighting`` can be found in the + :doc:`../../docs/aggregation/index` page. From 0e5d0c3d707a1551d10f3df25593b8ed332bdead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 01:41:06 +0200 Subject: [PATCH 074/114] Emphasize generalized jacobian --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index c660cc26..16fb950e 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -183,7 +183,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: This function doesn't require ``output`` to be a vector. For example, if ``output`` is a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of - parameters in the model. This is what we call a generalized Jacobian. The + parameters in the model. This is what we call a `generalized Jacobian`. The corresponding Gramian :math:`G = J J^\top` will be of shape :math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number of dimensions of the returned generalized Gramian will always be twice that of the From c1985ed57485bba073bfff1b97699ab03fa2bbde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 01:58:28 +0200 Subject: [PATCH 075/114] Improve explanation of batch_dim in Engine --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 16fb950e..6a487dae 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -61,7 +61,7 @@ class Engine: :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary jacobians are sparse (block-diagonal), which allows for a substancial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates - the batch dimension, if any. Defaults to None. + the batch dimension of the output tensor, if any. Defaults to None. .. admonition:: Example From f3d8e0e89f19d47bc7d292ec85c64e22bdfd9dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:26:14 +0200 Subject: [PATCH 076/114] Rename compute_gramian to use_engine when it's a boolean --- tests/unit/autogram/test_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7c56168a..a32832c9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -294,10 +294,10 @@ def test_iwrm_steps_with_autogram( @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -@mark.parametrize("compute_gramian", [False, True]) +@mark.parametrize("use_engine", [False, True]) @mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - architecture: type[ShapedModule], batch_size: int, compute_gramian: bool, batch_dim: int | None + architecture: type[ShapedModule], batch_size: int, use_engine: bool, batch_dim: int | None ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd @@ -319,7 +319,7 @@ def test_autograd_while_modules_are_hooked( # Hook modules and optionally compute the Gramian engine = Engine(model_autogram.modules(), batch_dim=batch_dim) - if compute_gramian: + if use_engine: torch.manual_seed(0) # Fix randomness for random models output = model_autogram(input) losses = reduce_to_vector(loss_fn(output)) From c3ef149d689fa2c64ecedabe82063784bfca5c50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:30:52 +0200 Subject: [PATCH 077/114] Extract _assert_gramian_is_equivalent_to_autograd from test_compute_gramian, and add test_compute_gramian_with_weird_modules * Now it's much cleaner: xfails are isolated in the second test, and tests are only expected to fail when batch_dim=0 for those modules. * This means that we no longer have xfail tests that pass. * The conclusion is that BatchNorm works when batch_dim=None, and so do Random modules and modules with side effects. This is good news. --- tests/unit/autogram/test_engine.py | 44 +++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index a32832c9..239c9eb0 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -96,7 +96,6 @@ (SomeFrozenParam, 32), (MultiOutputWithFrozenBranch, 32), (WithSomeFrozenModule, 32), - param(WithSideEffect, 32, marks=mark.xfail), (SomeUnusedOutput, 32), (Ndim0Output, 32), (Ndim1Output, 32), @@ -105,21 +104,18 @@ (Ndim4Output, 32), (FreeParam, 32), (NoFreeParam, 32), - param(Randomness, 32, marks=mark.xfail), - param(Cifar10Model, 16, marks=[mark.slow]), - param(AlexNet, 2, marks=[mark.slow]), - param(InstanceNormResNet18, 4, marks=[mark.slow]), - param(GroupNormMobileNetV3Small, 3, marks=[mark.slow]), - param(SqueezeNet, 8, marks=[mark.slow]), - param(InstanceNormMobileNetV2, 2, marks=[mark.slow]), + param(Cifar10Model, 16, marks=mark.slow), + param(AlexNet, 2, marks=mark.slow), + param(InstanceNormResNet18, 4, marks=mark.slow), + param(GroupNormMobileNetV3Small, 3, marks=mark.slow), + param(SqueezeNet, 8, marks=mark.slow), + param(InstanceNormMobileNetV2, 2, marks=mark.slow), ] -@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) -@mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batch_dim: int | None): - """Tests that the autograd and the autogram engines compute the same gramian.""" - +def _assert_gramian_is_equivalent_to_autograd( + architecture: type[ShapedModule], batch_size: int, batch_dim: int | None +): input_shapes = architecture.INPUT_SHAPES output_shapes = architecture.OUTPUT_SHAPES @@ -147,6 +143,28 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize("batch_dim", [0, None]) +def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batch_dim: int | None): + """Tests that the autograd and the autogram engines compute the same gramian.""" + + _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + + +@mark.parametrize("architecture", [WithBatchNorm, WithSideEffect, Randomness]) +@mark.parametrize("batch_size", [1, 3, 32]) +@mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) +def test_compute_gramian_with_weird_modules( + architecture: type[ShapedModule], batch_size: int, batch_dim: int | None +): + """ + Tests that compute_gramian works even with some problematic modules when batch_dim is None. It + is expected to fail on those when the engine uses the batched optimization (when batch_dim=0). + """ + + _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) + + @mark.parametrize("batch_size", [1, 3, 16]) @mark.parametrize( ["reduction", "movedim_source", "movedim_destination", "batch_dim"], From 54dcbe009e1849b75f28edcb9cf577413dbecfd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:41:04 +0200 Subject: [PATCH 078/114] Stop making track_running_stats=True incompatible with autogram when batch_dim is None. --- src/torchjd/autogram/_engine.py | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 6a487dae..7228b714 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -10,7 +10,7 @@ from ._gramian_utils import movedim_gramian, reshape_gramian from ._module_hook_manager import ModuleHookManager -_INCOMPATIBLE_MODULE_TYPES = ( +_MODULES_INCOMPATIBLE_WITH_BATCHED = ( nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, @@ -156,23 +156,25 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None: self._module_hook_manager.hook_module(module) def _check_module_is_compatible(self, module: nn.Module) -> None: - if self._batch_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): - raise ValueError( - f"Found a module of type {type(module)}, which is incompatible with the autogram " - f"engine when `batch_dim` is not `None`. The incompatible module types are " - f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to " - f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), " - f"but if you really can't and performance not a priority, you may also just set" - f"`batch_dim=None` when creating the engine." - ) - - if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: - raise ValueError( - f"Found a module of type {type(module)}, with `track_running_stats=True`, which is " - "incompatible with the autogram engine due to performing in-place operations on " - "tensors and having side-effects during the forward pass. Try setting " - "`track_running_stats` to `False`." - ) + if self._batch_dim is not None: + if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED): + raise ValueError( + f"Found a module of type {type(module)}, which is incompatible with the " + f"autogram engine when `batch_dim` is not `None`. The incompatible module types" + f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The " + f"recommended fix is to replace incompatible layers by something else (e.g. " + f"BatchNorm by InstanceNorm). If you really can't and performance not a " + f"priority, you may also just set `batch_dim=None` when creating the engine." + ) + if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: + raise ValueError( + f"Found a module of type {type(module)}, with `track_running_stats=True`, which" + f" is incompatible with the autogram engine when `batch_dim` is not `None`, due" + f" to performing in-place operations on tensors and having side-effects during " + f"the forward pass. Try setting `track_running_stats` to `False`. If you really" + f" can't and performance not a priority, you may also just set `batch_dim=None`" + f" when creating the engine." + ) def compute_gramian(self, output: Tensor) -> Tensor: r""" From 7bbf3a8d784ca59a3acbb130a392305db21b1818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:41:32 +0200 Subject: [PATCH 079/114] Test WithModuleTrackingRunningStats as a module that should work when batch_dim is None --- tests/unit/autogram/test_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 239c9eb0..868ef62b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -151,7 +151,9 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc _assert_gramian_is_equivalent_to_autograd(architecture, batch_size, batch_dim) -@mark.parametrize("architecture", [WithBatchNorm, WithSideEffect, Randomness]) +@mark.parametrize( + "architecture", [WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats] +) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) def test_compute_gramian_with_weird_modules( From d627bcd697f3b7bbe737a372002ecf9e612caa1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:47:03 +0200 Subject: [PATCH 080/114] Improve explanation of the limitations of the engine --- src/torchjd/autogram/_engine.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 7228b714..f0f3f931 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -97,7 +97,8 @@ class Engine: >>> optimizer.step() .. warning:: - To use this engine, the modules should respect a few conditions: + When providing an non-None ``batch_dim``, all provided modules must respect a few + conditions: * They should treat the elements of the batch independently. Most common layers respect this, but for example `BatchNorm @@ -118,10 +119,6 @@ class Engine: trainable parameters. It is, however, perfectly fine for random modules to have child modules that have trainable parameters, so if you have a random module with some direct parameters, a simple fix is to wrap these parameters into a child module. - * For maximum efficiency, they should ideally not contain both direct trainable parameters - and child modules, especially if those direct trainable parameters are used before the - child modules. You can always wrap those direct trainable parameters into another child - module to avoid the slow-down. If you're building your own architecture, respecting those criterions should be quite easy. However, if you're using an existing architecture, you may have to modify it to make it @@ -130,6 +127,15 @@ class Engine: `GroupNorm `_ or `InstanceNorm2d `_ layers. + + The alternative is to use ``batch_dim=None``, but it's not recommended since it will + increase computation time, often by a lot. + + .. note:: + For maximum efficiency, modules should ideally not contain both direct trainable + parameters and child modules, especially if those direct trainable parameters are used + before the child modules. You can always wrap those direct trainable parameters into + another child module to avoid the slow-down. """ def __init__( From de521eb5a5ef733343b38e126002a7f807f9f230 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 02:47:53 +0200 Subject: [PATCH 081/114] Fix test_incompatible_modules --- tests/unit/autogram/test_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 868ef62b..a753dadb 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -360,7 +360,6 @@ def test_autograd_while_modules_are_hooked( ["architecture", "batch_dim"], [ (WithModuleTrackingRunningStats, 0), - (WithModuleTrackingRunningStats, None), (WithRNN, 0), (WithBatchNorm, 0), ], From 493e40b288cbfeacc5554f5f91c71e6d3bb53920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:46:43 +0200 Subject: [PATCH 082/114] Update src/torchjd/autogram/_gramian_utils.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_gramian_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 10929741..fe7a26e6 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -32,9 +32,9 @@ def _revert_last_dims(gramian: Tensor) -> Tensor: def movedim_gramian(gramian: Tensor, half_source: list[int], half_destination: list[int]) -> Tensor: """ - Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a - Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the - dimensions and on the second half of the dimensions reversed. + Moves the dimensions of a Gramian from some source dimensions to destination dimensions. This + must be done simultaneously on the first half of the dimensions and on the second half of the + dimensions reversed. :param gramian: Gramian to reshape. Can be a generalized Gramian. :param half_source: Source dimensions, that should be in the range [-gramian.ndim//2, From 3d1049b0914ae5a1c1d5de73c938a21431419b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:48:18 +0200 Subject: [PATCH 083/114] Update docs/source/examples/iwmtl.rst Co-authored-by: Pierre Quinton --- docs/source/examples/iwmtl.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 923da7b4..ce075b06 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -2,7 +2,7 @@ Instance-Wise Multi-Task Learning (IWMTL) ========================================= When training a model with multiple tasks, the gradients of the individual tasks are likely to -sometimes conflict. This is particularly true when looking at the individual (per-sample) gradients. +conflict. This is particularly true when looking at the individual (per-sample) gradients. The :doc:`autogram engine <../docs/autogram/engine>` can be used to efficiently compute the Gramian of the Jacobian of the matrix of per-sample and per-task losses. Weights can then be extracted from this Gramian to reweight the gradients and resolve conflict entirely. From be7e6ac02eba5a11ef4bcd48e8a03a7309ad9f89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:50:08 +0200 Subject: [PATCH 084/114] Update docs/source/index.rst Co-authored-by: Pierre Quinton --- docs/source/index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index cbcafd9f..d9e07e4d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,8 +38,8 @@ per-task losses has to be minimized. To start using TorchJD for multi-task learn Another more interesting application is to consider separately the loss of each element in the batch. This is what we define as :doc:`Instance-Wise Risk Minimization ` (IWRM). -There exists an algorithm that is in many cases equivalent to Jacobian descent, and much more -efficient. This algorithm, called Gramian-based Jacobian descent, consists in computing +The Gramian-based Jacobian descent algorithm provides a very efficient alternative way of +performing Jacobian descent. It consists in computing the Gramian of the Jacobian iteratively during the backward pass (without ever storing the full Jacobian in memory), weighting the losses using the information of the Gramian, and then computing the gradient of the obtained weighted loss. The iterative computation of the Gramian corresponds to From 629ea6007e31b422de47386b8f018f3ddbef1a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:52:14 +0200 Subject: [PATCH 085/114] Update src/torchjd/autogram/_engine.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index f0f3f931..35f43632 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -272,7 +272,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # Need to batch `grad_output` over the first dimension jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) for i in range(non_batch_dim_len): - jac_output[i, i, ...] = 1 + jac_output[i, i, ...] = 1.0 _ = vmap(differentiation)(jac_output) else: From 54005d34c9cd7be6f02eae6cf2e5ccf93e81be8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:52:28 +0200 Subject: [PATCH 086/114] Update src/torchjd/autogram/_engine.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 35f43632..781b3133 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -269,7 +269,6 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: non_batch_dim_len = output.shape[0] jac_output_shape = [output.shape[0]] + list(output.shape) - # Need to batch `grad_output` over the first dimension jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) for i in range(non_batch_dim_len): jac_output[i, i, ...] = 1.0 From e831674707acf28e649eea0d6344cd1e7cda29b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:55:50 +0200 Subject: [PATCH 087/114] Update src/torchjd/autogram/_gramian_utils.py Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_gramian_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index fe7a26e6..631ea192 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -3,9 +3,8 @@ def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: """ - Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first - half of the target dimensions must be done from the left, while the reshape of the second half - must be done from the right. + Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions must + be done from the left, while the reshape of the second half must be done from the right. :param gramian: Gramian to reshape. Can be a generalized Gramian. :param half_shape: First half of the target shape, the shape of the output is therefore From 94f9c9892f78f5d9fddac15a3b61aa08ed212fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:56:38 +0200 Subject: [PATCH 088/114] Update docs/source/index.rst Co-authored-by: Pierre Quinton --- docs/source/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index d9e07e4d..942edfb9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -48,7 +48,7 @@ Algorithm 3 of documentation and usage example of this algorithm is provided in :doc:`autogram.Engine `. -The primary usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently +The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently for :doc:`IWRM `. It can also be used when considering one loss per element of the batch and per task, in the context of multi-task learning. We call this :doc:`Instance-Wise Risk Multi-Task Learning ` (IWMTL). From 587bd0e5370a99ef0b7fd8b3268a9d5936e7dbf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 13:56:11 +0200 Subject: [PATCH 089/114] Improve formulation in Engine docstring --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 781b3133..0d9be9dd 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -129,7 +129,7 @@ class Engine: `_ layers. The alternative is to use ``batch_dim=None``, but it's not recommended since it will - increase computation time, often by a lot. + increase memory usage by a lot and thus typically slow down computation. .. note:: For maximum efficiency, modules should ideally not contain both direct trainable From 149f842b3cad2a10e3bf8964a6d6315f64c3fd7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 13:57:57 +0200 Subject: [PATCH 090/114] Move param description before the note --- src/torchjd/autogram/_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 0d9be9dd..28f61910 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -187,6 +187,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of all ``modules``. + :param output: The tensor of arbitrary shape to differentiate. The shape of the returned + Gramian depends on the shape of this output. + .. note:: This function doesn't require ``output`` to be a vector. For example, if ``output`` is a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the @@ -206,9 +209,6 @@ def compute_gramian(self, output: Tensor) -> Tensor: Multi-Task Learning (IWMTL) <../../examples/iwmtl>`, as each sample in the batch has one loss per task). - etc. - - :param output: The tensor of arbitrary shape to differentiate. The shape of the returned - Gramian depends on the shape of this output, as explained in the note above. """ if self._batch_dim is not None: From 6fb284efcafe168a4e6e10a3f16cbfa3748c356e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 24 Sep 2025 14:04:11 +0200 Subject: [PATCH 091/114] Use a single loop over module.named_parameters instead of 3 * making a dict out of the iterator counts as 1 already. --- src/torchjd/autogram/_vjp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 212021a1..967ab753 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -27,9 +27,14 @@ class VJP(ABC): def __init__(self, module: nn.Module): self.module = module - named_parameters = dict(module.named_parameters(recurse=False)) - self.trainable_params = {k: v for k, v in named_parameters.items() if v.requires_grad} - self.frozen_params = {k: v for k, v in named_parameters.items() if not v.requires_grad} + self.trainable_params = dict[str, Parameter]() + self.frozen_params = dict[str, Parameter]() + + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + self.trainable_params[name] = param + else: + self.frozen_params[name] = param @abstractmethod def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: From 7ab3ae899587cea579f11eda4f02a36d298d875e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 00:08:31 +0200 Subject: [PATCH 092/114] Update changelog.md --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de73934b..033f8f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ changes that do not affect the user. - Added the `autogram` package, with the `autogram.Engine`. This is an implementation of Algorithm 3 from [Jacobian Descent for Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232), - optimized for batched computations, as in IWRM. + optimized for batched computations, as in IWRM. Generalized Gramians can also be obtained by using + the autogram engine on a tensor of losses of arbitrary shape. - For all `Aggregator`s based on the weighting of the Gramian of the Jacobian, made their `Weighting` class public. It can be used directly on a Gramian (computed via the `autogram.Engine`) to extract some weights. The list of new public classes is: @@ -29,8 +30,11 @@ changes that do not affect the user. - `PCGradWeighting` - `RandomWeighting` - `SumWeighting` +- Added `GeneralizedWeighting` (base class) and `Flattening` (implementation) to extract tensors of + weights from generalized Gramians. - Added usage example for IWRM with autogram. - Added usage example for IWRM with partial autogram. +- Added usage example for IWMTL with autogram. ### Changed From 95beff5e556b912a9b0561fa40ad9767a4fda79f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 00:48:13 +0200 Subject: [PATCH 093/114] Revamp documentation of Engine --- src/torchjd/autogram/_engine.py | 92 +++++++++++++++++++-------------- tests/doc/test_autogram.py | 4 +- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 28f61910..73e84f76 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -45,56 +45,72 @@ class Engine: """ - Used for computing the Gramian of the Jacobian of some vector with respect to the direct - parameters of all provided modules. - - After this object is constructed, the outputs of the provided modules will have an extended - computation graph that allows to compute efficiently the Gramian of the Jacobian of the - per-sample losses with respect to the model parameters. - - This Gramian can then be used to extract weights from this Gramian using the provided - ``weighting`` which in turn can be backpropagated for a normal backward pass. This is the - reverse Gramian accumulation algorithm. + Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct + parameters of all provided modules. It is based on Algorithm 3 of `Jacobian Descent For + Multi-Objective Optimization `_ but goes even further: + + * It works for any computation graph (not just sequential models). + * It is optimized for batched computations (as long as ``batch_dim`` is specified). + * It supports any shape of tensor to differentiate (not just a vector of losses). For more + details about this, look at :meth:`Engine.compute_gramian`. + + As explained in Section 6 of `Jacobian Descent For Multi-Objective Optimization + `_, most :class:`Aggregators + ` combine the rows of the Jacobian using some + weights that depend only on the Gramian of the Jacobian. Because of that, the typical usage of + the autogram engine is to directly compute this Gramian, extract weights from it (with a + :class:`~torchjd.aggregation._weighting_bases.Weighting`), and use those weights to + backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using + :func:`torchjd.autojac.backward`. :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary jacobians are sparse (block-diagonal), which allows for a substancial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates - the batch dimension of the output tensor, if any. Defaults to None. + the batch dimension of the output tensor, if any. Defaults to ``None``. .. admonition:: Example Train a model using Gramian-based Jacobian descent. - >>> import torch - >>> from torch.nn import Linear, MSELoss, ReLU, Sequential - >>> from torch.optim import SGD - >>> - >>> from torchjd.aggregation import UPGradWeighting - >>> from torchjd.autogram import Engine - >>> - >>> # Generate data (8 batches of 16 examples of dim 5) for the sake of the example - >>> inputs = torch.randn(8, 16, 5) - >>> targets = torch.randn(8, 16) - >>> - >>> model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) - >>> optimizer = SGD(model.parameters()) - >>> - >>> criterion = MSELoss(reduction="none") - >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules(), batch_dim=0) - >>> - >>> for input, target in zip(inputs, targets): - >>> output = model(input).squeeze(dim=1) # shape: [16] - >>> losses = criterion(output, target) # shape: [16] - >>> - >>> optimizer.zero_grad() - >>> gramian = engine.compute_gramian(losses) # shape: [16, 16] - >>> weights = weighting(gramian) # shape: [16] - >>> losses.backward(weights) - >>> optimizer.step() + .. code-block:: python + :emphasize-lines: 5-6, 15-16, 18-19, 26-28 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import UPGradWeighting + from torchjd.autogram import Engine + + # Generate data (8 batches of 16 examples of dim 5) for the sake of the example + inputs = torch.randn(8, 16, 5) + targets = torch.randn(8, 16) + + model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) + optimizer = SGD(model.parameters()) + + criterion = MSELoss(reduction="none") # Important to use reduction="none" + weighting = UPGradWeighting() + + # Create the engine before the backward pass, and only once. + engine = Engine(model.modules(), batch_dim=0) + + for input, target in zip(inputs, targets): + output = model(input).squeeze(dim=1) # shape: [16] + losses = criterion(output, target) # shape: [16] + + optimizer.zero_grad() + gramian = engine.compute_gramian(losses) # shape: [16, 16] + weights = weighting(gramian) # shape: [16] + losses.backward(weights) + optimizer.step() + + This is equivalent to just calling ``torchjd.autojac.backward(losses)``. However, since the + Jacobian never has to be entirely in memory, it is often much more memory-efficient, and + thus typically faster, to use the Gramian-based approach. .. warning:: When providing an non-None ``batch_dim``, all provided modules must respect a few diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 815ef315..06bc0589 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -16,8 +16,10 @@ def test_engine(): model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) - criterion = MSELoss(reduction="none") + criterion = MSELoss(reduction="none") # Important to use reduction="none" weighting = UPGradWeighting() + + # Create the engine before the backward pass, and only once. engine = Engine(model.modules(), batch_dim=0) for input, target in zip(inputs, targets): From f9c387966b7a25a9a687d4c42eedee6b648b2473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 01:36:16 +0200 Subject: [PATCH 094/114] Move test_iwmtl --- tests/doc/test_rst.py | 92 +++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index adc9b2d4..db4df51e 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -72,6 +72,52 @@ def test_basic_usage(): optimizer.step() +def test_iwmtl(): + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import Flattening, UPGradWeighting + from torchjd.autogram import Engine + + shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_module = Linear(3, 1) + task2_module = Linear(3, 1) + params = [ + *shared_module.parameters(), + *task1_module.parameters(), + *task2_module.parameters(), + ] + + optimizer = SGD(params, lr=0.1) + mse = MSELoss(reduction="none") + weighting = Flattening(UPGradWeighting()) + engine = Engine(shared_module.modules(), batch_dim=0) + + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task + task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task + + for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + features = shared_module(input) # shape: [16, 3] + out1 = task1_module(features).squeeze(1) # shape: [16] + out2 = task2_module(features).squeeze(1) # shape: [16] + + # Compute the matrix of losses: one loss per element of the batch and per task + losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] + + # Compute the gramian (inner products between pairs of gradients of the losses) + gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] + + # Obtain the weights that lead to no conflict between reweighted gradients + weights = weighting(gramian) # shape: [16, 2] + + optimizer.zero_grad() + # Do the standard backward pass, but weighted using the obtained weights + losses.backward(weights) + optimizer.step() + + def test_iwrm(): def test_autograd(): import torch @@ -307,52 +353,6 @@ def test_mtl(): optimizer.step() -def test_iwmtl(): - import torch - from torch.nn import Linear, MSELoss, ReLU, Sequential - from torch.optim import SGD - - from torchjd.aggregation import Flattening, UPGradWeighting - from torchjd.autogram import Engine - - shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) - task1_module = Linear(3, 1) - task2_module = Linear(3, 1) - params = [ - *shared_module.parameters(), - *task1_module.parameters(), - *task2_module.parameters(), - ] - - optimizer = SGD(params, lr=0.1) - mse = MSELoss(reduction="none") - weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batch_dim=0) - - inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 - task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task - task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task - - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): - features = shared_module(input) # shape: [16, 3] - out1 = task1_module(features).squeeze(1) # shape: [16] - out2 = task2_module(features).squeeze(1) # shape: [16] - - # Compute the matrix of losses: one loss per element of the batch and per task - losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] - - # Compute the gramian (inner products between pairs of gradients of the losses) - gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] - - # Obtain the weights that lead to no conflict between reweighted gradients - weights = weighting(gramian) # shape: [16, 2] - - optimizer.zero_grad() - # Do the standard backward pass, but weighted using the obtained weights - losses.backward(weights) - optimizer.step() - - def test_partial_jd(): import torch from torch.nn import Linear, MSELoss, ReLU, Sequential From a0493da45205785fb96dc707e04834236c4de190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 01:50:03 +0200 Subject: [PATCH 095/114] Fix iwrm and partial_jd usage examples --- docs/source/examples/iwrm.rst | 2 +- docs/source/examples/partial_jd.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index 965b46b9..d25e0165 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index a2c93839..c8d9c781 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules()) + engine = Engine(model[2:].modules(), batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) From fcea26d5584297c674a97f08be9271fe958acdbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 01:54:28 +0200 Subject: [PATCH 096/114] Be always specific about the value of batch_dim --- tests/unit/autogram/test_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index a753dadb..9a000fe0 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -430,8 +430,8 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None): output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model]) - engine2 = Engine([model]) + engine1 = Engine([model], batch_dim=None) + engine2 = Engine([model], batch_dim=None) input = randn_([input_size]) output = model(input) @@ -470,8 +470,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model]) - engine2 = Engine([model]) + engine1 = Engine([model], batch_dim=None) + engine2 = Engine([model], batch_dim=None) input = randn_([input_size]) output = model(input).reshape(shape[1:]) From 5eaa33cd2d8f7674e911afe8ae3c80710f00db2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 01:55:09 +0200 Subject: [PATCH 097/114] Make batch_dim default to 0 --- src/torchjd/autogram/_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 73e84f76..7bd67953 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -68,7 +68,7 @@ class Engine: :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary jacobians are sparse (block-diagonal), which allows for a substancial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates - the batch dimension of the output tensor, if any. Defaults to ``None``. + the batch dimension of the output tensor, if any. Defaults to 0. .. admonition:: Example @@ -157,7 +157,7 @@ class Engine: def __init__( self, modules: Iterable[nn.Module], - batch_dim: int | None = None, + batch_dim: int | None = 0, ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() From 1c284a98e47e986e8afd49ae7daf8e4b013f57eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 14:47:47 +0200 Subject: [PATCH 098/114] Fix typos --- src/torchjd/autogram/_engine.py | 2 +- src/torchjd/autogram/_gramian_utils.py | 4 ++-- src/torchjd/autogram/_module_hook_manager.py | 2 +- src/torchjd/autogram/_vjp.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 7bd67953..04da6226 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -66,7 +66,7 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. :param batch_dim: If the modules work with batches and process each batch element independently, - then many intermediary jacobians are sparse (block-diagonal), which allows for a substancial + then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates the batch dimension of the output tensor, if any. Defaults to 0. diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 631ea192..956f42bd 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -3,8 +3,8 @@ def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: """ - Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions must - be done from the left, while the reshape of the second half must be done from the right. + Reshapes a Gramian to a provided shape. The reshape of the first half of the target dimensions + must be done from the left, while the reshape of the second half must be done from the right. :param gramian: Gramian to reshape. Can be a generalized Gramian. :param half_shape: First half of the target shape, the shape of the output is therefore diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index eb08be2c..72a3f066 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -225,7 +225,7 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: self.gramian_accumulator.track_parameter_paths(requires_grad_params) # We only care about running the JacobianAccumulator node, so we need one of its child - # edges (the edges of the original ouputs of the model) as target. For memory + # edges (the edges of the original outputs of the model) as target. For memory # efficiency, we select the smallest one (that requires grad). inf = float("inf") preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 967ab753..bb07e275 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -48,7 +48,7 @@ def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: class FunctionalVJP(VJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the - func api. The __call__ function takes both the inputs and the cotangents that can be vmaped + func api. The __call__ function takes both the inputs and the cotangents that can be vmapped jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using this method is that it computes the forward phase. From 2b9ef727e1f00f0681eb469e2fcba163a21a4592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 15:10:56 +0200 Subject: [PATCH 099/114] Remove default value of batch_dim --- src/torchjd/autogram/_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 04da6226..bdfc5d93 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -68,7 +68,7 @@ class Engine: :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates - the batch dimension of the output tensor, if any. Defaults to 0. + the batch dimension of the output tensor, if any. .. admonition:: Example @@ -157,7 +157,7 @@ class Engine: def __init__( self, modules: Iterable[nn.Module], - batch_dim: int | None = 0, + batch_dim: int | None, ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() From b0c92cc8cdd80a92b5c865195d67030477f1eee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 15:43:47 +0200 Subject: [PATCH 100/114] Improve documentation of Flattening --- src/torchjd/aggregation/_flattening.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 558d428c..74da4a97 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -8,11 +8,16 @@ class Flattening(GeneralizedWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the Gramian, - extracting a vector of weights from it using a + :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized + Gramian into a square matrix, extracting a vector of weights from it using a :class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of weights. + For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten + it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector + of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape + ``[2, 3]``. + :param weighting: The weighting to apply to the Gramian matrix. """ From fa0b9e1c0a171f1744e4a22ecf23f34dfcaf605e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 16:01:51 +0200 Subject: [PATCH 101/114] Improve formulation in main documentation page * This implies that there are other usages for the autogram engine, yet to be discovered. --- docs/source/index.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 942edfb9..6da0f4d5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,9 +49,9 @@ documentation and usage example of this algorithm is provided in :doc:`autogram.Engine `. The original usage of the autogram engine is to compute the Gramian of the Jacobian very efficiently -for :doc:`IWRM `. It can also be used when considering one loss per element of the -batch and per task, in the context of multi-task learning. We call this :doc:`Instance-Wise Risk -Multi-Task Learning ` (IWMTL). +for :doc:`IWRM `. Another direct application is when considering one loss per element +of the batch and per task, in the context of multi-task learning. We call this +:doc:`Instance-Wise Risk Multi-Task Learning ` (IWMTL). TorchJD is open-source, under MIT License. The source code is available on `GitHub `_. From 7c3e9ab30f3ce5e9bb479538413decfb11c46bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 16:11:10 +0200 Subject: [PATCH 102/114] Fix mistake in Engine docstring --- src/torchjd/autogram/_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index bdfc5d93..84001aa5 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -108,9 +108,9 @@ class Engine: losses.backward(weights) optimizer.step() - This is equivalent to just calling ``torchjd.autojac.backward(losses)``. However, since the - Jacobian never has to be entirely in memory, it is often much more memory-efficient, and - thus typically faster, to use the Gramian-based approach. + This is equivalent to just calling ``torchjd.autojac.backward(losses, UPGrad())``. However, + since the Jacobian never has to be entirely in memory, it is often much more + memory-efficient, and thus typically faster, to use the Gramian-based approach. .. warning:: When providing an non-None ``batch_dim``, all provided modules must respect a few From 2f92a1dc66a8707d0ed459f9625e69b40958c130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 16:16:38 +0200 Subject: [PATCH 103/114] Fix typo in error message --- src/torchjd/autogram/_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 84001aa5..488966f9 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -185,7 +185,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f"autogram engine when `batch_dim` is not `None`. The incompatible module types" f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The " f"recommended fix is to replace incompatible layers by something else (e.g. " - f"BatchNorm by InstanceNorm). If you really can't and performance not a " + f"BatchNorm by InstanceNorm). If you really can't and performance is not a " f"priority, you may also just set `batch_dim=None` when creating the engine." ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: @@ -194,8 +194,8 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f" is incompatible with the autogram engine when `batch_dim` is not `None`, due" f" to performing in-place operations on tensors and having side-effects during " f"the forward pass. Try setting `track_running_stats` to `False`. If you really" - f" can't and performance not a priority, you may also just set `batch_dim=None`" - f" when creating the engine." + f" can't and performance is not a priority, you may also just set " + f"`batch_dim=None` when creating the engine." ) def compute_gramian(self, output: Tensor) -> Tensor: From 89e9b4359fdff7618deba331ef14f21b5728e73f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 16:26:46 +0200 Subject: [PATCH 104/114] Minor style improvement --- src/torchjd/autogram/_gramian_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index 956f42bd..3bbfd062 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -24,7 +24,7 @@ def reshape_gramian(gramian: Tensor, half_shape: list[int]) -> Tensor: def _revert_last_dims(gramian: Tensor) -> Tensor: """Inverts the order of the last half of the dimensions of the input generalized Gramian.""" - half_ndim = len(gramian.shape) // 2 + half_ndim = gramian.ndim // 2 last_dims = [half_ndim + i for i in range(half_ndim)] return gramian.movedim(last_dims, last_dims[::-1]) From 65cf31781a0d12306c83d9d7f1bf4b8709ca0a3e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 25 Sep 2025 17:05:50 +0200 Subject: [PATCH 105/114] Fix inspection issues: Grammar --- src/torchjd/autogram/_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 488966f9..6a3fbe03 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -113,7 +113,7 @@ class Engine: memory-efficient, and thus typically faster, to use the Gramian-based approach. .. warning:: - When providing an non-None ``batch_dim``, all provided modules must respect a few + When providing a non-None ``batch_dim``, all provided modules must respect a few conditions: * They should treat the elements of the batch independently. Most common layers respect @@ -129,14 +129,14 @@ class Engine: open an issue if you need extra focus on this). * They should not perform in-place operations on tensors (for instance you should not use ``track_running_stats=True`` in normalization layers). - * They should not have side-effects during the forward pass (since their forward pass will - be called twice, the side-effects could be different from what's expected). + * They should not have side effects during the forward pass (since their forward pass will + be called twice, the side effects could be different from what's expected). * If they have some randomness during the forward pass, they should not have direct trainable parameters. It is, however, perfectly fine for random modules to have child modules that have trainable parameters, so if you have a random module with some direct parameters, a simple fix is to wrap these parameters into a child module. - If you're building your own architecture, respecting those criterions should be quite easy. + If you're building your own architecture, respecting those criteria should be quite easy. However, if you're using an existing architecture, you may have to modify it to make it compatible with the autogram engine. For instance, you may want to replace `BatchNorm2d `_ layers by From 52688ea607437509febdcf554fee27cde306d89a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 18:28:21 +0200 Subject: [PATCH 106/114] Add comment for VJPType --- src/torchjd/autogram/_vjp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index bb07e275..01ad313c 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -15,6 +15,7 @@ # still support older versions of PyTorch where pytree is protected). +# This includes vmapped VJPs, which are not of type VJP. VJPType = Callable[[PyTree, PyTree], dict[str, Tensor]] From 2ce6a5464f84ac8033b9f5cf0bc0b0d2516adb63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 18:40:44 +0200 Subject: [PATCH 107/114] Improve VJP docstrings --- src/torchjd/autogram/_vjp.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 01ad313c..6f775c06 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -40,9 +40,8 @@ def __init__(self, module: nn.Module): @abstractmethod def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: """ - VJP function that takes cotangents and inputs and returns dictionary of names of - parameters (as given by `module.named_parameters.keys()`) to gradients of the parameters - for the given cotangents at the given inputs. + Computes and returns the dictionary of parameter names to their gradients for the given + grad_outputs (cotangents) and at the given inputs. """ @@ -51,7 +50,7 @@ class FunctionalVJP(VJP): Represents a VJP function for a module's forward pass with respect to its parameters using the func api. The __call__ function takes both the inputs and the cotangents that can be vmapped jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using - this method is that it computes the forward phase. + this method is that it makes an extra forward pass. :params module: The module to differentiate. """ @@ -100,8 +99,8 @@ class AutogradVJP(VJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the - inputs. The main advantage of using this method is that it doesn't require computing the forward - phase. + inputs. The main advantage of using this method is that it doesn't require making an extra + forward pass. """ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): From 6e96f1417d2c40ec986545b3767b7e18b3666621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 18:54:13 +0200 Subject: [PATCH 108/114] Use one less loop when making the outputs and grad_outputs for outputs that require grad in AutogradVJP --- src/torchjd/autogram/_vjp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 6f775c06..17311664 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -111,10 +111,19 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: flat_grad_outputs = tree_flatten(grad_outputs)[0] + + # Only differentiate outputs that require grad. We only need their grad_outputs. + outputs_ = list[Tensor]() + grad_outputs_ = list[Tensor]() + for output, grad_output, requires_grad in zip(self.outputs, flat_grad_outputs, self.mask): + if requires_grad: + outputs_.append(output) + grad_outputs_.append(grad_output) + grads = torch.autograd.grad( - [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], + outputs_, self.flat_trainable_params, - [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], + grad_outputs_, retain_graph=True, allow_unused=True, materialize_grads=True, From d83d93721cbeb7c7edbfd6af3bc57c311a755aa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 19:34:35 +0200 Subject: [PATCH 109/114] Only store outputs that require grad in AutogradVJP --- src/torchjd/autogram/_vjp.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 17311664..b05a7408 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -105,23 +105,25 @@ class AutogradVJP(VJP): def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): super().__init__(module) - self.outputs = outputs - self.mask = [output.requires_grad for output in self.outputs] + + self.outputs_that_require_grad = list[Tensor]() + self.mask = list[bool]() + for output in outputs: + requires_grad = output.requires_grad + if requires_grad: + self.outputs_that_require_grad.append(output) + self.mask.append(requires_grad) + self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: flat_grad_outputs = tree_flatten(grad_outputs)[0] - # Only differentiate outputs that require grad. We only need their grad_outputs. - outputs_ = list[Tensor]() - grad_outputs_ = list[Tensor]() - for output, grad_output, requires_grad in zip(self.outputs, flat_grad_outputs, self.mask): - if requires_grad: - outputs_.append(output) - grad_outputs_.append(grad_output) + # Only keep the grad_outputs corresponding to outputs that require grad. + grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg] grads = torch.autograd.grad( - outputs_, + self.outputs_that_require_grad, self.flat_trainable_params, grad_outputs_, retain_graph=True, From 2b71a0bdcbf93549b6c216c616f56939c1dcff3e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 25 Sep 2025 20:02:52 +0200 Subject: [PATCH 110/114] check if any output has require_grad in Hook.__call__ --- src/torchjd/autogram/_module_hook_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 72a3f066..032f8a05 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -216,9 +216,9 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: flat_outputs, output_spec = tree_flatten(output) - if not any(isinstance(t, Tensor) for t in flat_outputs): - # This can happen only if a module returns no Tensor, for instance some niche usage - # such as a module that prints something. + if not any(isinstance(t, Tensor) for t in flat_outputs if t.requires_grad): + # This can happen only if a module returns no Tensor with a graph, for instance some + # niche usage such as a module that prints something. return output requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] From ab28a09abcd3b651e72af6165a6b165cb6fee401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 20:15:55 +0200 Subject: [PATCH 111/114] Add SomeFrozenParamAndUnusedTrainableParam edge case (failing tests) --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 9a000fe0..50e9c688 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -42,6 +42,7 @@ SingleInputPyTreeOutput, SIPOBranched, SomeFrozenParam, + SomeFrozenParamAndUnusedTrainableParam, SomeUnusedOutput, SomeUnusedParam, SqueezeNet, @@ -96,6 +97,7 @@ (SomeFrozenParam, 32), (MultiOutputWithFrozenBranch, 32), (WithSomeFrozenModule, 32), + (SomeFrozenParamAndUnusedTrainableParam, 32), (SomeUnusedOutput, 32), (Ndim0Output, 32), (Ndim1Output, 32), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 66d52323..a5bc905c 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -498,6 +498,24 @@ def forward(self, input: Tensor): return self.all_frozen(input) + self.non_frozen(input**2 / 5.0) +class SomeFrozenParamAndUnusedTrainableParam(ShapedModule): + """ + Module that has a frozen param (requires_grad=False) and a non-frozen param (requires_grad= + True), but the non-frozen param is also unused. + """ + + INPUT_SHAPES = (50,) + OUTPUT_SHAPES = (10,) + + def __init__(self): + super().__init__() + self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) + self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) + + def forward(self, input: Tensor): + return input @ self.frozen_param + + class MultiOutputWithFrozenBranch(ShapedModule): """ Module that has two outputs: one comes from a frozen parameter, so it will only require grad From bec9906e3ae9669bd7b2dbfcc49e68e44e209b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 20:35:04 +0200 Subject: [PATCH 112/114] Fix condition in hook --- src/torchjd/autogram/_module_hook_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 032f8a05..a88ad983 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -216,9 +216,9 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: flat_outputs, output_spec = tree_flatten(output) - if not any(isinstance(t, Tensor) for t in flat_outputs if t.requires_grad): - # This can happen only if a module returns no Tensor with a graph, for instance some - # niche usage such as a module that prints something. + if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): + # This can happen only if a module has a trainable param but outputs no tensor that + # require grad return output requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] From 595e0d70a6ce34078c6b52ff818e31143b05a8a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 20:35:17 +0200 Subject: [PATCH 113/114] Fix testing archi --- tests/unit/autogram/test_engine.py | 4 ++-- tests/utils/architectures.py | 30 +++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 50e9c688..cfe012e7 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -36,13 +36,13 @@ PyTreeInputPyTreeOutput, PyTreeInputSingleOutput, Randomness, + RequiresGradOfSchrodinger, ShapedModule, SimpleBranched, SimpleParamReuse, SingleInputPyTreeOutput, SIPOBranched, SomeFrozenParam, - SomeFrozenParamAndUnusedTrainableParam, SomeUnusedOutput, SomeUnusedParam, SqueezeNet, @@ -97,7 +97,7 @@ (SomeFrozenParam, 32), (MultiOutputWithFrozenBranch, 32), (WithSomeFrozenModule, 32), - (SomeFrozenParamAndUnusedTrainableParam, 32), + (RequiresGradOfSchrodinger, 32), (SomeUnusedOutput, 32), (Ndim0Output, 32), (Ndim1Output, 32), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index a5bc905c..c77ba2a1 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -498,22 +498,38 @@ def forward(self, input: Tensor): return self.all_frozen(input) + self.non_frozen(input**2 / 5.0) -class SomeFrozenParamAndUnusedTrainableParam(ShapedModule): +class RequiresGradOfSchrodinger(ShapedModule): """ - Module that has a frozen param (requires_grad=False) and a non-frozen param (requires_grad= - True), but the non-frozen param is also unused. + Wtf? """ INPUT_SHAPES = (50,) - OUTPUT_SHAPES = (10,) + OUTPUT_SHAPES = (3,) + + class SomeFrozenParamAndUnusedTrainableParam(ShapedModule): + """ + Module that has a frozen param (requires_grad=False) and a non-frozen param (requires_grad= + True), but the non-frozen param is also unused. + """ + + INPUT_SHAPES = (50,) + OUTPUT_SHAPES = (10,) + + def __init__(self): + super().__init__() + self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) + self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) + + def forward(self, input: Tensor): + return input @ self.frozen_param def __init__(self): super().__init__() - self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) - self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) + self.weird_module = self.SomeFrozenParamAndUnusedTrainableParam() + self.normal_module = nn.Linear(10, 3) def forward(self, input: Tensor): - return input @ self.frozen_param + return self.normal_module(self.weird_module(input)) class MultiOutputWithFrozenBranch(ShapedModule): From 4aee476c9ed8d9309f9695fb9dd01b3d52bd9207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 25 Sep 2025 20:39:21 +0200 Subject: [PATCH 114/114] Fix docstring of RequiresGradOfSchrodinger --- tests/utils/architectures.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index c77ba2a1..d72782f1 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -500,7 +500,9 @@ def forward(self, input: Tensor): class RequiresGradOfSchrodinger(ShapedModule): """ - Wtf? + Model that contains a module whose output will not require grad despite containing a param that + requires grad (so it will be hooked). The final output of the model will require grad, though, + because another normal module is used on the output of the first module. """ INPUT_SHAPES = (50,)