From 586e17ac73c77704a81e7a2832e44e719779786e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 12 Oct 2025 14:19:27 +0200 Subject: [PATCH 01/32] Make `GramianAccumulator` track paths to Modules rathre than parameters. --- src/torchjd/autogram/_gramian_accumulator.py | 64 +++++++++----------- src/torchjd/autogram/_module_hook_manager.py | 5 +- 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 6d89fc18..db8ffb5c 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,9 +1,9 @@ from collections import Counter -from collections.abc import Iterable from typing import Optional import torch -from torch import Tensor +from torch import Tensor, nn +from torch.utils._pytree import PyTree, tree_flatten class GramianAccumulator: @@ -17,56 +17,50 @@ class GramianAccumulator: def __init__(self) -> None: self._gramian: Optional[Tensor] = None - self._summed_jacobians = dict[Tensor, Tensor]() - self._path_counter = Counter[Tensor]() + self._summed_jacobians = dict[nn.Module, list[Tensor]]() + self._path_counter = Counter[nn.Module]() def reset(self) -> None: self._gramian = None self._summed_jacobians = {} self._path_counter = Counter() - def track_parameter_paths(self, parameters: Iterable[Tensor]) -> None: + def track_module_paths(self, module: nn.Module) -> None: """ - Register parameters and count their paths in the computational graph. + Register module and count its paths in the computational graph. - :param parameters: Parameter tensors to track. Duplicates increase path count. + :param module: Module to track. Duplicates increase path count. """ - self._path_counter.update(parameters) + self._path_counter.update([module]) - def accumulate_path_jacobians(self, path_jacobians: dict[Tensor, Tensor]) -> None: + def accumulate_path_jacobians(self, module: nn.Module, jacobians: PyTree) -> None: """ - Add path Jacobians for multiple parameters. + Add Jacobians corresponding to a module. - :param path_jacobians: Dictionary mapping parameters to Jacobian tensors of a single path. + :param module: The module. + :param jacobians: Dictionary mapping parameters to Jacobian tensors of a single path. """ - for parameter, jacobian in path_jacobians.items(): - self._accumulate_path_jacobian(parameter, jacobian) - - def _accumulate_path_jacobian(self, parameter: Tensor, jacobian: Tensor) -> None: - """ - Add path Jacobian for a parameter. In case the full Jacobian is computed, accumulate its - Gramian. - - :param parameter: The parameter. - :param jacobian: path Jacobian with respect to the parameter. - """ - if parameter in self._summed_jacobians: - self._summed_jacobians[parameter] += jacobian + flat_jacobians = tree_flatten(jacobians)[0] + if module in self._summed_jacobians: + self._summed_jacobians[module] = [ + a + b for a, b in zip(self._summed_jacobians[module], flat_jacobians) + ] else: - self._summed_jacobians[parameter] = jacobian - self._path_counter.subtract([parameter]) - if self._path_counter[parameter] == 0: - self._accumulate_gramian(parameter) - del self._path_counter[parameter] - del self._summed_jacobians[parameter] - - def _accumulate_gramian(self, parameter: Tensor) -> None: + self._summed_jacobians[module] = flat_jacobians + self._path_counter.subtract([module]) + if self._path_counter[module] == 0: + for jacobian in self._summed_jacobians[module]: + self._accumulate_one_jacobian_in_gramian(jacobian) + del self._path_counter[module] + del self._summed_jacobians[module] + + def _accumulate_one_jacobian_in_gramian(self, jacobian: Tensor) -> None: """ - Compute the Gramian of the full Jacobian and accumulate it. + Compute the Gramian of a Jacobian and accumulate it. - :param parameter: Parameter whose full Jacobian is available. + :param jacobian: the Jacobian. """ - full_jacobian_matrix = torch.flatten(self._summed_jacobians[parameter], start_dim=1) + full_jacobian_matrix = torch.flatten(jacobian, start_dim=1) if self._gramian is not None: self._gramian.addmm_(full_jacobian_matrix, full_jacobian_matrix.T) else: diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e39d1b25..96f75244 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -125,8 +125,7 @@ def __call__( # require grad return outputs - rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad] - self.gramian_accumulator.track_parameter_paths(rg_params) + self.gramian_accumulator.track_module_paths(module) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory @@ -213,7 +212,7 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: ctx.module, *grad_outputs, ) - ctx.gramian_accumulator.accumulate_path_jacobians(path_jacobians) + ctx.gramian_accumulator.accumulate_path_jacobians(ctx.module, path_jacobians) return None, None, None, None, None, None, *grad_outputs From 7e901e2ce72e3792a70025c71ed98a7e896bd1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 00:29:48 +0200 Subject: [PATCH 02/32] Improve docstrings --- src/torchjd/autogram/_gramian_accumulator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index db8ffb5c..8dc5c27b 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -26,16 +26,16 @@ def reset(self) -> None: self._path_counter = Counter() def track_module_paths(self, module: nn.Module) -> None: - """ - Register module and count its paths in the computational graph. + """Increment the usage count of the provided module. - :param module: Module to track. Duplicates increase path count. + :param module: The module. """ + self._path_counter.update([module]) def accumulate_path_jacobians(self, module: nn.Module, jacobians: PyTree) -> None: """ - Add Jacobians corresponding to a module. + Add the Jacobians corresponding to all usages of a module. :param module: The module. :param jacobians: Dictionary mapping parameters to Jacobian tensors of a single path. From 41d3b0b4d46f0a6ade461ca985e92f3d7382b992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 00:33:32 +0200 Subject: [PATCH 03/32] Make InterModuleParamReuse xfail --- tests/unit/autogram/test_engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 31f63559..ac3b3ab7 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -102,7 +102,6 @@ (WithNoTensorOutput, 32), (WithBuffered, 32), (SimpleParamReuse, 32), - (InterModuleParamReuse, 32), (ModuleReuse, 32), (SomeUnusedParam, 32), (SomeFrozenParam, 32), @@ -202,7 +201,12 @@ def test_compute_gramian_with_weird_modules( @mark.xfail @mark.parametrize( - "architecture", [ModelUsingSubmoduleParamsDirectly, ModelAlsoUsingSubmoduleParamsDirectly] + "architecture", + [ + ModelUsingSubmoduleParamsDirectly, + ModelAlsoUsingSubmoduleParamsDirectly, + InterModuleParamReuse, + ], ) @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) From 07be8d64d4f6d622f18ef8036e8fbf73f7c68b64 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 08:48:50 +0200 Subject: [PATCH 04/32] Fix GramianAccumulator tests. --- .../unit/autogram/test_gramian_accumulator.py | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/unit/autogram/test_gramian_accumulator.py b/tests/unit/autogram/test_gramian_accumulator.py index 373f1b7c..160e8553 100644 --- a/tests/unit/autogram/test_gramian_accumulator.py +++ b/tests/unit/autogram/test_gramian_accumulator.py @@ -1,10 +1,15 @@ from pytest import mark +from torch import nn from torch.testing import assert_close from utils.tensors import randn_, zeros_ from torchjd.autogram._gramian_accumulator import GramianAccumulator +class FakeModule(nn.Module): + pass + + @mark.parametrize( ["shapes", "number_of_jacobians"], [ @@ -16,9 +21,10 @@ def test_adding_jacobians_one_by_one(shapes: list[list[int]], number_of_jacobian batch_size = 10 gramian_accumulator = GramianAccumulator() - keys = [randn_(shape) for shape in shapes] + keys = [FakeModule() for _ in shapes] for key, n in zip(keys, number_of_jacobians): - gramian_accumulator.track_parameter_paths([key] * n) + for _ in range(n): + gramian_accumulator.track_module_paths(key) expected_gramian = zeros_([batch_size, batch_size]) @@ -27,7 +33,7 @@ def test_adding_jacobians_one_by_one(shapes: list[list[int]], number_of_jacobian cumulated_jacobian = zeros_(batched_shape) for i in range(n): jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobians({key: jacobian}) + gramian_accumulator.accumulate_path_jacobians(key, [jacobian]) cumulated_jacobian += jacobian jacobian_matrix = cumulated_jacobian.reshape([batch_size, -1]) expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) @@ -48,21 +54,29 @@ def test_adding_jacobians_lots_by_lots(shapes: list[list[int]]): batch_size = 10 gramian_accumulator = GramianAccumulator() - keys = [randn_(shape) for shape in shapes] - for i in range(number_of_jacobians): - gramian_accumulator.track_parameter_paths(keys) + keys = [FakeModule() for _ in shapes] + for key in keys: + for i in range(number_of_jacobians): + gramian_accumulator.track_module_paths(key) expected_gramian = zeros_([batch_size, batch_size]) - cumulated_jacobians = {key: zeros_([batch_size] + shape) for key, shape in zip(keys, shapes)} + cumulated_jacobians = { + key: [zeros_([batch_size] + shape)] * number_of_jacobians + for key, shape in zip(keys, shapes) + } for i in range(number_of_jacobians): - jacobians = {key: randn_([batch_size] + shape) for key, shape in zip(keys, shapes)} - gramian_accumulator.accumulate_path_jacobians(jacobians) - for key, jacobian in jacobians.items(): - cumulated_jacobians[key] += jacobian + jacobian_dict = { + key: [randn_([batch_size] + shape) for _ in range(number_of_jacobians)] + for key, shape in zip(keys, shapes) + } + for key, jacobians in jacobian_dict.items(): + gramian_accumulator.accumulate_path_jacobians(key, jacobian_dict[key]) + cumulated_jacobians[key] = [a + b for a, b in zip(jacobians, cumulated_jacobians[key])] for cumulated_jacobian in cumulated_jacobians.values(): - jacobian_matrix = cumulated_jacobian.reshape([batch_size, -1]) - expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) + for jacobian in cumulated_jacobian: + jacobian_matrix = jacobian.reshape([batch_size, -1]) + expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) gramian = gramian_accumulator.gramian assert_close(gramian, expected_gramian) @@ -84,14 +98,15 @@ def test_internal_dicts_are_cleaned(shapes: list[list[int]], number_of_jacobians batch_size = 10 gramian_accumulator = GramianAccumulator() - keys = [randn_(shape) for shape in shapes] + keys = [FakeModule() for shape in shapes] for key, n in zip(keys, number_of_jacobians): - gramian_accumulator.track_parameter_paths([key] * n) + for _ in range(n): + gramian_accumulator.track_module_paths(key) for key, shape, n in zip(keys, shapes, number_of_jacobians): batched_shape = [batch_size] + shape for i in range(n): jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobians({key: jacobian}) + gramian_accumulator.accumulate_path_jacobians(key, [jacobian]) assert key not in gramian_accumulator._summed_jacobians.keys() assert key not in gramian_accumulator._path_counter.keys() From a2cec0b405097dd01778281a7ad0c6efde715160 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 08:59:54 +0200 Subject: [PATCH 05/32] Make `_make_path_jacobians` return a list rather than a dict. --- src/torchjd/autogram/_module_hook_manager.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 96f75244..5604c01e 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -226,7 +226,7 @@ def forward( kwargs: dict[str, PyTree], module: nn.Module, *grad_outputs: Tensor, - ) -> dict[Tensor, Tensor]: + ) -> list[Tensor]: # There is no non-batched dimension generalized_jacobians = vjp(grad_outputs, args, kwargs) path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) @@ -241,7 +241,7 @@ def vmap( kwargs: dict[str, PyTree], module: nn.Module, *jac_outputs: Tensor, - ) -> tuple[dict[Tensor, Tensor], None]: + ) -> tuple[list[Tensor], None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) @@ -253,12 +253,12 @@ def vmap( def _make_path_jacobians( module: nn.Module, generalized_jacobians: dict[str, Tensor], - ) -> dict[Tensor, Tensor]: - path_jacobians: dict[Tensor, Tensor] = {} + ) -> list[Tensor]: + path_jacobians: list[Tensor] = [] for param_name, generalized_jacobian in generalized_jacobians.items(): key = module.get_parameter(param_name) jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) - path_jacobians[key] = jacobian + path_jacobians.append(jacobian) return path_jacobians @staticmethod From f8856bafa94ea80090b1afd8dab93ff81588f6cb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 09:18:11 +0200 Subject: [PATCH 06/32] Make `accumulate_path_jacobian` take a `list[Tensor]` of jacobians --- src/torchjd/autogram/_gramian_accumulator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 8dc5c27b..ad842901 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -3,7 +3,6 @@ import torch from torch import Tensor, nn -from torch.utils._pytree import PyTree, tree_flatten class GramianAccumulator: @@ -33,20 +32,19 @@ def track_module_paths(self, module: nn.Module) -> None: self._path_counter.update([module]) - def accumulate_path_jacobians(self, module: nn.Module, jacobians: PyTree) -> None: + def accumulate_path_jacobians(self, module: nn.Module, jacobians: list[Tensor]) -> None: """ Add the Jacobians corresponding to all usages of a module. :param module: The module. - :param jacobians: Dictionary mapping parameters to Jacobian tensors of a single path. + :param jacobians: List of Jacobian tensors of a single path. """ - flat_jacobians = tree_flatten(jacobians)[0] if module in self._summed_jacobians: self._summed_jacobians[module] = [ - a + b for a, b in zip(self._summed_jacobians[module], flat_jacobians) + a + b for a, b in zip(self._summed_jacobians[module], jacobians) ] else: - self._summed_jacobians[module] = flat_jacobians + self._summed_jacobians[module] = jacobians self._path_counter.subtract([module]) if self._path_counter[module] == 0: for jacobian in self._summed_jacobians[module]: From a9d752ee1b1d0c26876f5d46ed6c2c46f5eec325 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 10:52:40 +0200 Subject: [PATCH 07/32] Rename `VJP` into `JacobianComputer` as it does not compute vector jacobian product but rather generalized Jacobian. --- .../{_vjp.py => _jacobian_computer.py} | 8 ++++---- src/torchjd/autogram/_module_hook_manager.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) rename src/torchjd/autogram/{_vjp.py => _jacobian_computer.py} (96%) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_jacobian_computer.py similarity index 96% rename from src/torchjd/autogram/_vjp.py rename to src/torchjd/autogram/_jacobian_computer.py index 86df495b..5b2884ee 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -15,7 +15,7 @@ # still support older versions of PyTorch where pytree is protected). -class VJP(ABC): +class JacobianComputer(ABC): """Represents an abstract VJP function.""" @abstractmethod @@ -28,7 +28,7 @@ def __call__( """ -class ModuleVJP(VJP, ABC): +class ModuleJacobianComputer(JacobianComputer, ABC): """ Represents an abstract VJP function for a module's forward pass with respect to its parameters. @@ -48,7 +48,7 @@ def __init__(self, module: nn.Module): self.frozen_params[name] = param -class FunctionalVJP(ModuleVJP): +class FunctionalJacobianComputer(ModuleJacobianComputer): """ Represents a VJP function for a module's forward pass with respect to its parameters using the functional differentiation API. This requires to use vmap, so it's not compatible with @@ -98,7 +98,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: return vjp_func(grad_outputs_j_)[0] -class AutogradVJP(ModuleVJP): +class AutogradJacobianComputer(ModuleJacobianComputer): """ Represents a VJP function for a module's forward pass with respect to its parameters using the autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 5604c01e..d6fdf2e5 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,7 +9,11 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import VJP, AutogradVJP, FunctionalVJP +from ._jacobian_computer import ( + AutogradJacobianComputer, + FunctionalJacobianComputer, + JacobianComputer, +) # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -134,15 +138,15 @@ def __call__( index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp: VJP + vjp: JacobianComputer if self.has_batch_dim: rg_output_in_dims = (0,) * len(rg_outputs) arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) - vjp = FunctionalVJP(module, in_dims) + vjp = FunctionalJacobianComputer(module, in_dims) else: - vjp = AutogradVJP(module, rg_outputs) + vjp = AutogradJacobianComputer(module, rg_outputs) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, @@ -174,7 +178,7 @@ class JacobianAccumulator(torch.autograd.Function): @staticmethod def forward( gramian_accumulation_phase: BoolRef, - vjp: VJP, + vjp: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, @@ -221,7 +225,7 @@ class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( - vjp: VJP, + vjp: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], module: nn.Module, @@ -236,7 +240,7 @@ def forward( def vmap( _, in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]] - vjp: VJP, + vjp: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], module: nn.Module, From 89cdc86c25217e134f8440aec1cf30773d8456a2 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 11:06:56 +0200 Subject: [PATCH 08/32] functional_call can take a list of dict. --- src/torchjd/autogram/_jacobian_computer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 5b2884ee..249fd702 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -80,11 +80,11 @@ def _call_on_one_instance( grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: - all_state = { - **rg_params, - **dict(self.module.named_buffers()), - **self.frozen_params, - } + all_state = [ + rg_params, + dict(self.module.named_buffers()), + self.frozen_params, + ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] From 3cc2cdf13deb82a4d8726930edc3787c5476b3c3 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 11:07:39 +0200 Subject: [PATCH 09/32] Fix typing --- src/torchjd/autogram/_jacobian_computer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 249fd702..20250510 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -82,7 +82,7 @@ def _call_on_one_instance( def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: all_state = [ rg_params, - dict(self.module.named_buffers()), + dict[str, Tensor](self.module.named_buffers()), self.frozen_params, ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) From 6138a72e00e4428a2ed3f0d5a53e355e60c65443 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 11:14:26 +0200 Subject: [PATCH 10/32] Fix typing yet again --- src/torchjd/autogram/_jacobian_computer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 20250510..e8b76068 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import cast import torch from torch import Tensor, nn @@ -81,9 +82,9 @@ def _call_on_one_instance( def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: all_state = [ - rg_params, - dict[str, Tensor](self.module.named_buffers()), - self.frozen_params, + cast(dict[str, Tensor], rg_params), + dict(self.module.named_buffers()), + cast(dict[str, Tensor], self.frozen_params), ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] From 1a1426ccd9dba1f54b40483f5b3fda67de17ec94 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 16:44:03 +0200 Subject: [PATCH 11/32] Makes `JacobianComputer` return a Jacobian Matrix. Adapt `ComputeModuleJacobians`, `JacobianAccumulator` and `GramianAccumulator` accordingly. --- src/torchjd/autogram/_gramian_accumulator.py | 24 +++---- src/torchjd/autogram/_jacobian_computer.py | 43 ++++++----- src/torchjd/autogram/_module_hook_manager.py | 54 ++++++-------- .../unit/autogram/test_gramian_accumulator.py | 72 +++++-------------- 4 files changed, 72 insertions(+), 121 deletions(-) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index ad842901..5e3ef2d6 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -16,7 +16,7 @@ class GramianAccumulator: def __init__(self) -> None: self._gramian: Optional[Tensor] = None - self._summed_jacobians = dict[nn.Module, list[Tensor]]() + self._summed_jacobians = dict[nn.Module, Tensor]() self._path_counter = Counter[nn.Module]() def reset(self) -> None: @@ -32,37 +32,33 @@ def track_module_paths(self, module: nn.Module) -> None: self._path_counter.update([module]) - def accumulate_path_jacobians(self, module: nn.Module, jacobians: list[Tensor]) -> None: + def accumulate_path_jacobian(self, module: nn.Module, jacobian_matrix: Tensor) -> None: """ Add the Jacobians corresponding to all usages of a module. :param module: The module. - :param jacobians: List of Jacobian tensors of a single path. + :param jacobian_matrix: Jacobian tensors of a single path. """ if module in self._summed_jacobians: - self._summed_jacobians[module] = [ - a + b for a, b in zip(self._summed_jacobians[module], jacobians) - ] + self._summed_jacobians[module] += jacobian_matrix else: - self._summed_jacobians[module] = jacobians + self._summed_jacobians[module] = jacobian_matrix self._path_counter.subtract([module]) if self._path_counter[module] == 0: - for jacobian in self._summed_jacobians[module]: - self._accumulate_one_jacobian_in_gramian(jacobian) + self._accumulate_one_jacobian_in_gramian(self._summed_jacobians[module]) del self._path_counter[module] del self._summed_jacobians[module] - def _accumulate_one_jacobian_in_gramian(self, jacobian: Tensor) -> None: + def _accumulate_one_jacobian_in_gramian(self, jacobian_matrix: Tensor) -> None: """ Compute the Gramian of a Jacobian and accumulate it. - :param jacobian: the Jacobian. + :param jacobian_matrix: the Jacobian. """ - full_jacobian_matrix = torch.flatten(jacobian, start_dim=1) if self._gramian is not None: - self._gramian.addmm_(full_jacobian_matrix, full_jacobian_matrix.T) + self._gramian.addmm_(jacobian_matrix, jacobian_matrix.T) else: - self._gramian = torch.mm(full_jacobian_matrix, full_jacobian_matrix.T) + self._gramian = torch.mm(jacobian_matrix, jacobian_matrix.T) @property def gramian(self) -> Optional[Tensor]: diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index e8b76068..0d32b1b5 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_map_only # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -17,21 +17,21 @@ class JacobianComputer(ABC): - """Represents an abstract VJP function.""" + """Represents an abstract function that computes Jacobians.""" @abstractmethod def __call__( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] - ) -> dict[str, Tensor]: + ) -> Tensor: """ - Computes and returns the dictionary of parameter names to their gradients for the given - grad_outputs (cotangents) and at the given inputs. + Computes and returns the Jacobian. The output must be a matrix (2D Tensor). """ class ModuleJacobianComputer(JacobianComputer, ABC): """ - Represents an abstract VJP function for a module's forward pass with respect to its parameters. + Represents an abstract function that computes Jacobians for a module's forward pass with respect + to its parameters. :params module: The module to differentiate. """ @@ -51,9 +51,10 @@ def __init__(self, module: nn.Module): class FunctionalJacobianComputer(ModuleJacobianComputer): """ - Represents a VJP function for a module's forward pass with respect to its parameters using the - functional differentiation API. This requires to use vmap, so it's not compatible with - every module, and it requires to have an extra forward pass to create the vjp function. + Represents a function that computes Jacobians for a module's forward pass with respect to its + parameters using the functional differentiation API. This requires to use vmap, so it's not + compatible with every module, and it requires to have an extra forward pass to create the vjp + function. """ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): @@ -62,7 +63,7 @@ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): def __call__( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] - ) -> dict[str, Tensor]: + ) -> Tensor: return self.vmapped_vjp(grad_outputs, args, kwargs) def _call_on_one_instance( @@ -70,7 +71,7 @@ def _call_on_one_instance( grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...], kwargs_j: dict[str, PyTree], - ) -> dict[str, Tensor]: + ) -> Tensor: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with @@ -96,26 +97,28 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - return vjp_func(grad_outputs_j_)[0] + gradients = vjp_func(grad_outputs_j_)[0] + gradient = torch.cat([t.reshape(-1) for t in gradients.values()]) + return gradient class AutogradJacobianComputer(ModuleJacobianComputer): """ - Represents a VJP function for a module's forward pass with respect to its parameters using the - autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the - inputs. The main advantage of using this method is that it doesn't require making an extra - forward pass. + Represents a function that computes Jacobians for a module's forward pass with respect to its + parameters using the autograd engine. The __call__ function takes both the inputs and the + cotangents but ignores the inputs. The main advantage of using this method is that it doesn't + require making an extra forward pass. """ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): super().__init__(module) self.rg_outputs = rg_outputs - self.flat_rg_params, self.param_spec = tree_flatten(self.rg_params) + self.flat_rg_params, _ = tree_flatten(self.rg_params) def __call__( self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] - ) -> dict[str, Tensor]: + ) -> Tensor: grads = torch.autograd.grad( self.rg_outputs, self.flat_rg_params, @@ -124,4 +127,6 @@ def __call__( allow_unused=True, materialize_grads=True, ) - return tree_unflatten(grads, self.param_spec) + flattened_grads = torch.cat([g.reshape(-1) for g in grads]) + jacobian = flattened_grads.unsqueeze(0) + return jacobian diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index d6fdf2e5..23576de9 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -138,19 +138,19 @@ def __call__( index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp: JacobianComputer + jacobian_computer: JacobianComputer if self.has_batch_dim: rg_output_in_dims = (0,) * len(rg_outputs) arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) - vjp = FunctionalJacobianComputer(module, in_dims) + jacobian_computer = FunctionalJacobianComputer(module, in_dims) else: - vjp = AutogradJacobianComputer(module, rg_outputs) + jacobian_computer = AutogradJacobianComputer(module, rg_outputs) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, - vjp, + jacobian_computer, args, kwargs, self.gramian_accumulator, @@ -178,7 +178,7 @@ class JacobianAccumulator(torch.autograd.Function): @staticmethod def forward( gramian_accumulation_phase: BoolRef, - vjp: JacobianComputer, + jacobian_computer: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, @@ -188,7 +188,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, JacobianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -196,7 +196,7 @@ def setup_context( _, ): ctx.gramian_accumulation_phase = inputs[0] - ctx.vjp = inputs[1] + ctx.jacobian_computer = inputs[1] ctx.args = inputs[2] ctx.kwargs = inputs[3] ctx.gramian_accumulator = inputs[4] @@ -209,14 +209,14 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: if not ctx.gramian_accumulation_phase: return None, None, None, None, None, None, *grad_outputs - path_jacobians = ComputeModuleJacobians.apply( - ctx.vjp, + path_jacobian = ComputeModuleJacobians.apply( + ctx.jacobian_computer, ctx.args, ctx.kwargs, ctx.module, *grad_outputs, ) - ctx.gramian_accumulator.accumulate_path_jacobians(ctx.module, path_jacobians) + ctx.gramian_accumulator.accumulate_path_jacobian(ctx.module, path_jacobian) return None, None, None, None, None, None, *grad_outputs @@ -225,45 +225,35 @@ class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( - vjp: JacobianComputer, + jacobian_computer: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], module: nn.Module, *grad_outputs: Tensor, - ) -> list[Tensor]: + ) -> Tensor: # There is no non-batched dimension - generalized_jacobians = vjp(grad_outputs, args, kwargs) - path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) - return path_jacobians + jacobian = jacobian_computer(grad_outputs, args, kwargs) + return jacobian @staticmethod def vmap( _, in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]] - vjp: JacobianComputer, + jacobian_computer: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], module: nn.Module, *jac_outputs: Tensor, - ) -> tuple[list[Tensor], None]: + ) -> tuple[Tensor, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) - path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) - return path_jacobians, None - - @staticmethod - def _make_path_jacobians( - module: nn.Module, - generalized_jacobians: dict[str, Tensor], - ) -> list[Tensor]: - path_jacobians: list[Tensor] = [] - for param_name, generalized_jacobian in generalized_jacobians.items(): - key = module.get_parameter(param_name) - jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) - path_jacobians.append(jacobian) - return path_jacobians + generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( + jac_outputs, args, kwargs + ) + shape = generalized_jacobian.shape + jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) + return jacobian, None @staticmethod def setup_context(*_) -> None: diff --git a/tests/unit/autogram/test_gramian_accumulator.py b/tests/unit/autogram/test_gramian_accumulator.py index 160e8553..ef2f2072 100644 --- a/tests/unit/autogram/test_gramian_accumulator.py +++ b/tests/unit/autogram/test_gramian_accumulator.py @@ -11,29 +11,29 @@ class FakeModule(nn.Module): @mark.parametrize( - ["shapes", "number_of_jacobians"], + ["sizes", "number_of_jacobians"], [ - ([[3, 4, 5], [7, 5]], [3, 7]), - ([[3], [7, 5, 8], [2, 3]], [0, 7, 1]), + ([4, 7], [3, 7]), + ([3, 8, 4], [0, 7, 1]), ], ) -def test_adding_jacobians_one_by_one(shapes: list[list[int]], number_of_jacobians: list[int]): +def test_adding_jacobians_one_by_one(sizes: list[int], number_of_jacobians: list[int]): batch_size = 10 gramian_accumulator = GramianAccumulator() - keys = [FakeModule() for _ in shapes] + keys = [FakeModule() for _ in sizes] for key, n in zip(keys, number_of_jacobians): for _ in range(n): gramian_accumulator.track_module_paths(key) expected_gramian = zeros_([batch_size, batch_size]) - for key, shape, n in zip(keys, shapes, number_of_jacobians): - batched_shape = [batch_size] + shape + for key, size, n in zip(keys, sizes, number_of_jacobians): + batched_shape = [batch_size, size] cumulated_jacobian = zeros_(batched_shape) for i in range(n): jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobians(key, [jacobian]) + gramian_accumulator.accumulate_path_jacobian(key, jacobian) cumulated_jacobian += jacobian jacobian_matrix = cumulated_jacobian.reshape([batch_size, -1]) expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) @@ -42,71 +42,31 @@ def test_adding_jacobians_one_by_one(shapes: list[list[int]], number_of_jacobian assert_close(gramian, expected_gramian, rtol=5e-06, atol=2e-05) -@mark.parametrize( - "shapes", - [ - [[3, 4, 5], [7, 5]], - [[3], [7, 5, 8], [2, 3]], - ], -) -def test_adding_jacobians_lots_by_lots(shapes: list[list[int]]): - number_of_jacobians = 4 - batch_size = 10 - gramian_accumulator = GramianAccumulator() - - keys = [FakeModule() for _ in shapes] - for key in keys: - for i in range(number_of_jacobians): - gramian_accumulator.track_module_paths(key) - - expected_gramian = zeros_([batch_size, batch_size]) - - cumulated_jacobians = { - key: [zeros_([batch_size] + shape)] * number_of_jacobians - for key, shape in zip(keys, shapes) - } - for i in range(number_of_jacobians): - jacobian_dict = { - key: [randn_([batch_size] + shape) for _ in range(number_of_jacobians)] - for key, shape in zip(keys, shapes) - } - for key, jacobians in jacobian_dict.items(): - gramian_accumulator.accumulate_path_jacobians(key, jacobian_dict[key]) - cumulated_jacobians[key] = [a + b for a, b in zip(jacobians, cumulated_jacobians[key])] - for cumulated_jacobian in cumulated_jacobians.values(): - for jacobian in cumulated_jacobian: - jacobian_matrix = jacobian.reshape([batch_size, -1]) - expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) - - gramian = gramian_accumulator.gramian - assert_close(gramian, expected_gramian) - - def test_returns_none_if_no_jacobian_were_provided(): gramian_accumulator = GramianAccumulator() assert gramian_accumulator.gramian is None @mark.parametrize( - ["shapes", "number_of_jacobians"], + ["sizes", "number_of_jacobians"], [ - ([[3, 4, 5], [7, 5]], [3, 7]), - ([[3], [7, 5, 8], [2, 3]], [0, 7, 1]), + ([5, 7], [3, 7]), + ([3, 8, 4], [0, 7, 1]), ], ) -def test_internal_dicts_are_cleaned(shapes: list[list[int]], number_of_jacobians: list[int]): +def test_internal_dicts_are_cleaned(sizes: list[int], number_of_jacobians: list[int]): batch_size = 10 gramian_accumulator = GramianAccumulator() - keys = [FakeModule() for shape in shapes] + keys = [FakeModule() for shape in sizes] for key, n in zip(keys, number_of_jacobians): for _ in range(n): gramian_accumulator.track_module_paths(key) - for key, shape, n in zip(keys, shapes, number_of_jacobians): - batched_shape = [batch_size] + shape + for key, size, n in zip(keys, sizes, number_of_jacobians): + batched_shape = [batch_size, size] for i in range(n): jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobians(key, [jacobian]) + gramian_accumulator.accumulate_path_jacobian(key, jacobian) assert key not in gramian_accumulator._summed_jacobians.keys() assert key not in gramian_accumulator._path_counter.keys() From 6750ffafa088085cdb378bf55b98c7e094974069 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 16:47:19 +0200 Subject: [PATCH 12/32] Improve docstring. --- src/torchjd/autogram/_gramian_accumulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 5e3ef2d6..755b53ad 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -34,7 +34,7 @@ def track_module_paths(self, module: nn.Module) -> None: def accumulate_path_jacobian(self, module: nn.Module, jacobian_matrix: Tensor) -> None: """ - Add the Jacobians corresponding to all usages of a module. + Add the Jacobian corresponding to a call to a module. :param module: The module. :param jacobian_matrix: Jacobian tensors of a single path. From 22fd4e73031e3de246651d91420fbe7b9fc9ed18 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 13 Oct 2025 16:49:13 +0200 Subject: [PATCH 13/32] Apparently, if a in_dim corresponding to a PyTree is set to `None`, it is considered to be `None` for all the PyTree. --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 23576de9..d5bca59d 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -247,7 +247,7 @@ def vmap( ) -> tuple[Tensor, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) + in_dims = (in_dims[4:], None, None) generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( jac_outputs, args, kwargs ) From 13248cb7df746cef8f687aed356c97ed6b6667cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 17:17:57 +0200 Subject: [PATCH 14/32] Fix precision of tests (for cuda testing) --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index c16ffa1f..1319476b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -165,7 +165,7 @@ def _assert_gramian_is_equivalent_to_autograd( losses = reduce_to_vector(loss_fn(output)) autogram_gramian = engine.compute_gramian(losses) - assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5) + assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=3e-5) @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) From b2b10d0ce12d2df0b2067d5f2df5e838645f269f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 17:28:32 +0200 Subject: [PATCH 15/32] Remove ModuleJacobianComputer --- src/torchjd/autogram/_jacobian_computer.py | 27 +++++++++------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 0d32b1b5..4f0bc822 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -17,21 +17,8 @@ class JacobianComputer(ABC): - """Represents an abstract function that computes Jacobians.""" - - @abstractmethod - def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] - ) -> Tensor: - """ - Computes and returns the Jacobian. The output must be a matrix (2D Tensor). - """ - - -class ModuleJacobianComputer(JacobianComputer, ABC): """ - Represents an abstract function that computes Jacobians for a module's forward pass with respect - to its parameters. + Abstract class to computes Jacobians for a module's forward pass with respect to its parameters. :params module: The module to differentiate. """ @@ -48,8 +35,16 @@ def __init__(self, module: nn.Module): else: self.frozen_params[name] = param + @abstractmethod + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] + ) -> Tensor: + """ + Computes and returns the Jacobian. The output must be a matrix (2D Tensor). + """ + -class FunctionalJacobianComputer(ModuleJacobianComputer): +class FunctionalJacobianComputer(JacobianComputer): """ Represents a function that computes Jacobians for a module's forward pass with respect to its parameters using the functional differentiation API. This requires to use vmap, so it's not @@ -102,7 +97,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: return gradient -class AutogradJacobianComputer(ModuleJacobianComputer): +class AutogradJacobianComputer(JacobianComputer): """ Represents a function that computes Jacobians for a module's forward pass with respect to its parameters using the autograd engine. The __call__ function takes both the inputs and the From 5641382b3795051743d4bf1cea808ac14f4910dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 18:44:12 +0200 Subject: [PATCH 16/32] Make JacobianComputers take rg_outputs at call * This makes us able to instantiate them outside of the hook, and thus to potentially give them a state. --- src/torchjd/autogram/_jacobian_computer.py | 43 ++++++++++++-------- src/torchjd/autogram/_module_hook_manager.py | 23 ++++++----- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 4f0bc822..aeb909fb 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_flatten, tree_map_only +from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -37,7 +37,11 @@ def __init__(self, module: nn.Module): @abstractmethod def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], ) -> Tensor: """ Computes and returns the Jacobian. The output must be a matrix (2D Tensor). @@ -52,14 +56,20 @@ class FunctionalJacobianComputer(JacobianComputer): function. """ - def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): - super().__init__(module) - self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) - def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], ) -> Tensor: - return self.vmapped_vjp(grad_outputs, args, kwargs) + rg_output_in_dims = (0,) * len(rg_outputs) + arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) + in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) + vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) + + return vmapped_vjp(grad_outputs, args, kwargs) def _call_on_one_instance( self, @@ -105,18 +115,17 @@ class AutogradJacobianComputer(JacobianComputer): require making an extra forward pass. """ - def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): - super().__init__(module) - - self.rg_outputs = rg_outputs - self.flat_rg_params, _ = tree_flatten(self.rg_params) - def __call__( - self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] + self, + grad_outputs: tuple[Tensor, ...], + _: tuple[PyTree, ...], + __: dict[str, PyTree], + rg_outputs: Sequence[Tensor], ) -> Tensor: + flat_rg_params, _ = tree_flatten(self.rg_params) grads = torch.autograd.grad( - self.rg_outputs, - self.flat_rg_params, + rg_outputs, + flat_rg_params, grad_outputs, retain_graph=True, allow_unused=True, diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index d5bca59d..ff4c8f10 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,10 +1,11 @@ import weakref +from collections.abc import Sequence from typing import cast import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry @@ -140,13 +141,9 @@ def __call__( jacobian_computer: JacobianComputer if self.has_batch_dim: - rg_output_in_dims = (0,) * len(rg_outputs) - arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) - in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) - jacobian_computer = FunctionalJacobianComputer(module, in_dims) + jacobian_computer = FunctionalJacobianComputer(module) else: - jacobian_computer = AutogradJacobianComputer(module, rg_outputs) + jacobian_computer = AutogradJacobianComputer(module) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, @@ -201,6 +198,7 @@ def setup_context( ctx.kwargs = inputs[3] ctx.gramian_accumulator = inputs[4] ctx.module = inputs[5] + ctx.rg_outputs = inputs[6:] @staticmethod def backward(ctx, *grad_outputs: Tensor) -> tuple: @@ -213,6 +211,7 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: ctx.jacobian_computer, ctx.args, ctx.kwargs, + ctx.rg_outputs, ctx.module, *grad_outputs, ) @@ -228,28 +227,30 @@ def forward( jacobian_computer: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], module: nn.Module, *grad_outputs: Tensor, ) -> Tensor: # There is no non-batched dimension - jacobian = jacobian_computer(grad_outputs, args, kwargs) + jacobian = jacobian_computer(grad_outputs, args, kwargs, rg_outputs) return jacobian @staticmethod def vmap( _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]] + in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], None, *tuple[int | None, ...]] jacobian_computer: JacobianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], module: nn.Module, *jac_outputs: Tensor, ) -> tuple[Tensor, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], None, None) + in_dims = (in_dims[5:], None, None, None) generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( - jac_outputs, args, kwargs + jac_outputs, args, kwargs, rg_outputs ) shape = generalized_jacobian.shape jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) From ea12581c9610e8062f8e6314836973ece64cd50f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 18:52:27 +0200 Subject: [PATCH 17/32] Move JacobianComputer construction to Engine * This will allow us to keep a reference to them in the engine and to reset them as needed when we add a state for them --- src/torchjd/autogram/_engine.py | 21 +++++++++++++---- src/torchjd/autogram/_module_hook_manager.py | 24 +++++--------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 48888562..dbd84cf6 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -7,6 +7,11 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_utils import movedim_gramian, reshape_gramian +from ._jacobian_computer import ( + AutogradJacobianComputer, + FunctionalJacobianComputer, + JacobianComputer, +) from ._module_hook_manager import ModuleHookManager _MODULES_INCOMPATIBLE_WITH_BATCHED = ( @@ -179,9 +184,7 @@ def __init__( self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() self._batch_dim = batch_dim - self._module_hook_manager = ModuleHookManager( - self._target_edges, self._gramian_accumulator, batch_dim is not None - ) + self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) for module in modules: self._hook_module_recursively(module) @@ -189,11 +192,21 @@ def __init__( def _hook_module_recursively(self, module: nn.Module) -> None: if any(p.requires_grad for p in module.parameters(recurse=False)): self._check_module_is_compatible(module) - self._module_hook_manager.hook_module(module) + jacobian_computer = self.make_jacobian_computer(module) + self._module_hook_manager.hook_module(module, jacobian_computer) else: for child in module.children(): self._hook_module_recursively(child) + def make_jacobian_computer(self, module: nn.Module) -> JacobianComputer: + jacobian_computer: JacobianComputer + if self._batch_dim is not None: + jacobian_computer = FunctionalJacobianComputer(module) + else: + jacobian_computer = AutogradJacobianComputer(module) + + return jacobian_computer + def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batch_dim is not None: if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED): diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index ff4c8f10..a5fb48c4 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -10,11 +10,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._jacobian_computer import ( - AutogradJacobianComputer, - FunctionalJacobianComputer, - JacobianComputer, -) +from ._jacobian_computer import JacobianComputer # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -38,11 +34,9 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - has_batch_dim: bool, ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator - self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = BoolRef(False) self._handles: list[TorchRemovableHandle] = [] @@ -56,7 +50,7 @@ def __init__( # seems to be a better practice (and it only works if the function to call is static). self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles) - def hook_module(self, module: nn.Module) -> None: + def hook_module(self, module: nn.Module, jacobian_computer: JacobianComputer) -> None: """ Add a module hook used to insert Jacobian accumulation nodes into the backward graph. @@ -68,7 +62,7 @@ def hook_module(self, module: nn.Module) -> None: self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator, - self._has_batch_dim, + jacobian_computer, ) self._handles.append(module.register_forward_hook(hook, with_kwargs=True)) @@ -99,12 +93,12 @@ def __init__( gramian_accumulation_phase: BoolRef, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - has_batch_dim: bool, + jacobian_computer: JacobianComputer, ): self.gramian_accumulation_phase = gramian_accumulation_phase self.target_edges = target_edges self.gramian_accumulator = gramian_accumulator - self.has_batch_dim = has_batch_dim + self.jacobian_computer = jacobian_computer def __call__( self, @@ -139,15 +133,9 @@ def __call__( index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - jacobian_computer: JacobianComputer - if self.has_batch_dim: - jacobian_computer = FunctionalJacobianComputer(module) - else: - jacobian_computer = AutogradJacobianComputer(module) - autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, - jacobian_computer, + self.jacobian_computer, args, kwargs, self.gramian_accumulator, From ab280966880eb14eaf657a6bc0b9a7afcf305d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:06:44 +0200 Subject: [PATCH 18/32] Add GramianComputer and subclasses --- src/torchjd/autogram/_gramian_computer.py | 96 +++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/torchjd/autogram/_gramian_computer.py diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py new file mode 100644 index 00000000..f64225db --- /dev/null +++ b/src/torchjd/autogram/_gramian_computer.py @@ -0,0 +1,96 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Optional + +from torch import Tensor +from torch.utils._pytree import PyTree + +from torchjd.autogram._jacobian_computer import JacobianComputer + + +class GramianComputer(ABC): + @abstractmethod + def __call__( + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + ) -> Optional[Tensor]: + """Compute what we can for a module and optionally return the gramian if it's ready.""" + + def track_forward_call(self) -> None: + """Track that the module's forward was called. Necessary in some implementations.""" + + def reset(self): + """Reset state if any. Necessary in some implementations.""" + + +class JacobianBasedGramianComputer(GramianComputer, ABC): + def __init__(self, jacobian_computer): + self.jacobian_computer = jacobian_computer + + @staticmethod + def _to_gramian(jacobian: Tensor) -> Tensor: + return jacobian @ jacobian.T + + +class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): + """ + Stateful GramianComputer that waits for all usages to be counted before returning the gramian. + """ + + def __call__( + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + ) -> Tensor: + """Compute what we can for a module and optionally return the gramian if it's ready.""" + + jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) + return self._to_gramian(jacobian_matrix) + + +class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): + """ + Stateful GramianComputer that waits for all usages to be counted before returning the gramian. + """ + + def __init__(self, jacobian_computer: JacobianComputer): + super().__init__(jacobian_computer) + self.remaining_counter = 0 + self.summed_jacobian = None + + def reset(self) -> None: + self.remaining_counter = 0 + self.summed_jacobian = None + + def track_forward_call(self) -> None: + self.remaining_counter += 1 + + def __call__( + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + ) -> Optional[Tensor]: + """Compute what we can for a module and optionally return the gramian if it's ready.""" + + jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) + + if self.summed_jacobian is None: + self.summed_jacobian = jacobian_matrix + else: + self.summed_jacobian += jacobian_matrix + + self.remaining_counter -= 1 + + if self.remaining_counter == 0: + gramian = self.summed_jacobian @ self.summed_jacobian.T + del self.summed_jacobian + return gramian + else: + return None From 47551fe83b0594e98a98b8e20500e300eaae4c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:33:00 +0200 Subject: [PATCH 19/32] Use GramianComputer instead of JacobianComputer --- src/torchjd/autogram/_engine.py | 14 ++- src/torchjd/autogram/_gramian_accumulator.py | 44 +--------- src/torchjd/autogram/_gramian_computer.py | 56 +++++++++++- src/torchjd/autogram/_module_hook_manager.py | 85 ++++--------------- .../unit/autogram/test_gramian_accumulator.py | 72 ---------------- 5 files changed, 84 insertions(+), 187 deletions(-) delete mode 100644 tests/unit/autogram/test_gramian_accumulator.py diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index dbd84cf6..c9657dac 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -6,6 +6,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms from ._gramian_utils import movedim_gramian, reshape_gramian from ._jacobian_computer import ( AutogradJacobianComputer, @@ -185,6 +186,7 @@ def __init__( self._target_edges = EdgeRegistry() self._batch_dim = batch_dim self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) + self._gramian_computers = dict[nn.Module, GramianComputer]() for module in modules: self._hook_module_recursively(module) @@ -192,20 +194,22 @@ def __init__( def _hook_module_recursively(self, module: nn.Module) -> None: if any(p.requires_grad for p in module.parameters(recurse=False)): self._check_module_is_compatible(module) - jacobian_computer = self.make_jacobian_computer(module) - self._module_hook_manager.hook_module(module, jacobian_computer) + gramian_computer = self.make_gramian_computer(module) + self._gramian_computers[module] = gramian_computer + self._module_hook_manager.hook_module(module, gramian_computer) else: for child in module.children(): self._hook_module_recursively(child) - def make_jacobian_computer(self, module: nn.Module) -> JacobianComputer: + def make_gramian_computer(self, module: nn.Module) -> GramianComputer: jacobian_computer: JacobianComputer if self._batch_dim is not None: jacobian_computer = FunctionalJacobianComputer(module) else: jacobian_computer = AutogradJacobianComputer(module) + gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) - return jacobian_computer + return gramian_computer def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batch_dim is not None: @@ -289,6 +293,8 @@ def compute_gramian(self, output: Tensor) -> Tensor: self._module_hook_manager.gramian_accumulation_phase.value = False self._gramian_accumulator.reset() self._target_edges.reset() + for gramian_computer in self._gramian_computers.values(): + gramian_computer.reset() unordered_gramian = reshape_gramian(square_gramian, ordered_shape) diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 755b53ad..2c9405bf 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,8 +1,6 @@ -from collections import Counter from typing import Optional -import torch -from torch import Tensor, nn +from torch import Tensor class GramianAccumulator: @@ -16,49 +14,15 @@ class GramianAccumulator: def __init__(self) -> None: self._gramian: Optional[Tensor] = None - self._summed_jacobians = dict[nn.Module, Tensor]() - self._path_counter = Counter[nn.Module]() def reset(self) -> None: self._gramian = None - self._summed_jacobians = {} - self._path_counter = Counter() - def track_module_paths(self, module: nn.Module) -> None: - """Increment the usage count of the provided module. - - :param module: The module. - """ - - self._path_counter.update([module]) - - def accumulate_path_jacobian(self, module: nn.Module, jacobian_matrix: Tensor) -> None: - """ - Add the Jacobian corresponding to a call to a module. - - :param module: The module. - :param jacobian_matrix: Jacobian tensors of a single path. - """ - if module in self._summed_jacobians: - self._summed_jacobians[module] += jacobian_matrix - else: - self._summed_jacobians[module] = jacobian_matrix - self._path_counter.subtract([module]) - if self._path_counter[module] == 0: - self._accumulate_one_jacobian_in_gramian(self._summed_jacobians[module]) - del self._path_counter[module] - del self._summed_jacobians[module] - - def _accumulate_one_jacobian_in_gramian(self, jacobian_matrix: Tensor) -> None: - """ - Compute the Gramian of a Jacobian and accumulate it. - - :param jacobian_matrix: the Jacobian. - """ + def accumulate_gramian(self, gramian: Tensor) -> None: if self._gramian is not None: - self._gramian.addmm_(jacobian_matrix, jacobian_matrix.T) + self._gramian.add_(gramian) else: - self._gramian = torch.mm(jacobian_matrix, jacobian_matrix.T) + self._gramian = gramian @property def gramian(self) -> Optional[Tensor]: diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index f64225db..d4666cd9 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from typing import Optional +import torch from torch import Tensor from torch.utils._pytree import PyTree @@ -30,11 +31,61 @@ class JacobianBasedGramianComputer(GramianComputer, ABC): def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer + def compute_jacobian( + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + ) -> Tensor: + return ComputeModuleJacobians.apply( + self.jacobian_computer, args, kwargs, rg_outputs, *grad_outputs + ) + @staticmethod def _to_gramian(jacobian: Tensor) -> Tensor: return jacobian @ jacobian.T +class ComputeModuleJacobians(torch.autograd.Function): + @staticmethod + def forward( + jacobian_computer: JacobianComputer, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + *grad_outputs: Tensor, + ) -> Tensor: + # There is no non-batched dimension + jacobian = jacobian_computer(grad_outputs, args, kwargs, rg_outputs) + return jacobian + + @staticmethod + def vmap( + _, + in_dims: tuple, + # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], *tuple[int | None, ...]] + jacobian_computer: JacobianComputer, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + *jac_outputs: Tensor, + ) -> tuple[Tensor, None]: + # There is a non-batched dimension + # We do not vmap over the args for the non-batched dimension + in_dims = (in_dims[4:], None, None, None) + generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( + jac_outputs, args, kwargs, rg_outputs + ) + shape = generalized_jacobian.shape + jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) + return jacobian, None + + @staticmethod + def setup_context(*_) -> None: + pass + + class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): """ Stateful GramianComputer that waits for all usages to be counted before returning the gramian. @@ -49,8 +100,7 @@ def __call__( ) -> Tensor: """Compute what we can for a module and optionally return the gramian if it's ready.""" - jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) - return self._to_gramian(jacobian_matrix) + return self._to_gramian(self.compute_jacobian(grad_outputs, args, kwargs, rg_outputs)) class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): @@ -79,7 +129,7 @@ def __call__( ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" - jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) + jacobian_matrix = self.compute_jacobian(grad_outputs, args, kwargs, rg_outputs) if self.summed_jacobian is None: self.summed_jacobian = jacobian_matrix diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index a5fb48c4..29285646 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,5 +1,4 @@ import weakref -from collections.abc import Sequence from typing import cast import torch @@ -10,7 +9,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._jacobian_computer import JacobianComputer +from ._gramian_computer import GramianComputer # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -50,7 +49,7 @@ def __init__( # seems to be a better practice (and it only works if the function to call is static). self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles) - def hook_module(self, module: nn.Module, jacobian_computer: JacobianComputer) -> None: + def hook_module(self, module: nn.Module, gramian_computer: GramianComputer) -> None: """ Add a module hook used to insert Jacobian accumulation nodes into the backward graph. @@ -62,7 +61,7 @@ def hook_module(self, module: nn.Module, jacobian_computer: JacobianComputer) -> self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator, - jacobian_computer, + gramian_computer, ) self._handles.append(module.register_forward_hook(hook, with_kwargs=True)) @@ -93,12 +92,12 @@ def __init__( gramian_accumulation_phase: BoolRef, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - jacobian_computer: JacobianComputer, + gramian_computer: GramianComputer, ): self.gramian_accumulation_phase = gramian_accumulation_phase self.target_edges = target_edges self.gramian_accumulator = gramian_accumulator - self.jacobian_computer = jacobian_computer + self.gramian_computer = gramian_computer def __call__( self, @@ -124,7 +123,7 @@ def __call__( # require grad return outputs - self.gramian_accumulator.track_module_paths(module) + self.gramian_computer.track_forward_call() # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory @@ -135,11 +134,10 @@ def __call__( autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, - self.jacobian_computer, + self.gramian_computer, args, kwargs, self.gramian_accumulator, - module, *rg_outputs, ) @@ -163,17 +161,16 @@ class JacobianAccumulator(torch.autograd.Function): @staticmethod def forward( gramian_accumulation_phase: BoolRef, - jacobian_computer: JacobianComputer, + gramian_computer: GramianComputer, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, - module: nn.Module, *rg_tensors: Tensor, ) -> tuple[Tensor, ...]: return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, JacobianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -181,69 +178,21 @@ def setup_context( _, ): ctx.gramian_accumulation_phase = inputs[0] - ctx.jacobian_computer = inputs[1] + ctx.gramian_computer = inputs[1] ctx.args = inputs[2] ctx.kwargs = inputs[3] ctx.gramian_accumulator = inputs[4] - ctx.module = inputs[5] - ctx.rg_outputs = inputs[6:] + ctx.rg_outputs = inputs[5:] @staticmethod def backward(ctx, *grad_outputs: Tensor) -> tuple: - # For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]] + # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, None, *grad_outputs - - path_jacobian = ComputeModuleJacobians.apply( - ctx.jacobian_computer, - ctx.args, - ctx.kwargs, - ctx.rg_outputs, - ctx.module, - *grad_outputs, - ) - ctx.gramian_accumulator.accumulate_path_jacobian(ctx.module, path_jacobian) - - return None, None, None, None, None, None, *grad_outputs - - -class ComputeModuleJacobians(torch.autograd.Function): - - @staticmethod - def forward( - jacobian_computer: JacobianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - module: nn.Module, - *grad_outputs: Tensor, - ) -> Tensor: - # There is no non-batched dimension - jacobian = jacobian_computer(grad_outputs, args, kwargs, rg_outputs) - return jacobian + return None, None, None, None, None, *grad_outputs - @staticmethod - def vmap( - _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], None, *tuple[int | None, ...]] - jacobian_computer: JacobianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - module: nn.Module, - *jac_outputs: Tensor, - ) -> tuple[Tensor, None]: - # There is a non-batched dimension - # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[5:], None, None, None) - generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( - jac_outputs, args, kwargs, rg_outputs - ) - shape = generalized_jacobian.shape - jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) - return jacobian, None + optional_gramian = ctx.gramian_computer(grad_outputs, ctx.args, ctx.kwargs, ctx.rg_outputs) + if optional_gramian is not None: + ctx.gramian_accumulator.accumulate_gramian(optional_gramian) - @staticmethod - def setup_context(*_) -> None: - pass + return None, None, None, None, None, *grad_outputs diff --git a/tests/unit/autogram/test_gramian_accumulator.py b/tests/unit/autogram/test_gramian_accumulator.py deleted file mode 100644 index ef2f2072..00000000 --- a/tests/unit/autogram/test_gramian_accumulator.py +++ /dev/null @@ -1,72 +0,0 @@ -from pytest import mark -from torch import nn -from torch.testing import assert_close -from utils.tensors import randn_, zeros_ - -from torchjd.autogram._gramian_accumulator import GramianAccumulator - - -class FakeModule(nn.Module): - pass - - -@mark.parametrize( - ["sizes", "number_of_jacobians"], - [ - ([4, 7], [3, 7]), - ([3, 8, 4], [0, 7, 1]), - ], -) -def test_adding_jacobians_one_by_one(sizes: list[int], number_of_jacobians: list[int]): - batch_size = 10 - gramian_accumulator = GramianAccumulator() - - keys = [FakeModule() for _ in sizes] - for key, n in zip(keys, number_of_jacobians): - for _ in range(n): - gramian_accumulator.track_module_paths(key) - - expected_gramian = zeros_([batch_size, batch_size]) - - for key, size, n in zip(keys, sizes, number_of_jacobians): - batched_shape = [batch_size, size] - cumulated_jacobian = zeros_(batched_shape) - for i in range(n): - jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobian(key, jacobian) - cumulated_jacobian += jacobian - jacobian_matrix = cumulated_jacobian.reshape([batch_size, -1]) - expected_gramian.addmm_(jacobian_matrix, jacobian_matrix.T) - - gramian = gramian_accumulator.gramian - assert_close(gramian, expected_gramian, rtol=5e-06, atol=2e-05) - - -def test_returns_none_if_no_jacobian_were_provided(): - gramian_accumulator = GramianAccumulator() - assert gramian_accumulator.gramian is None - - -@mark.parametrize( - ["sizes", "number_of_jacobians"], - [ - ([5, 7], [3, 7]), - ([3, 8, 4], [0, 7, 1]), - ], -) -def test_internal_dicts_are_cleaned(sizes: list[int], number_of_jacobians: list[int]): - batch_size = 10 - gramian_accumulator = GramianAccumulator() - - keys = [FakeModule() for shape in sizes] - for key, n in zip(keys, number_of_jacobians): - for _ in range(n): - gramian_accumulator.track_module_paths(key) - - for key, size, n in zip(keys, sizes, number_of_jacobians): - batched_shape = [batch_size, size] - for i in range(n): - jacobian = randn_(batched_shape) - gramian_accumulator.accumulate_path_jacobian(key, jacobian) - assert key not in gramian_accumulator._summed_jacobians.keys() - assert key not in gramian_accumulator._path_counter.keys() From 9e102e75332e64c0ad1c068dac23ca00088889f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:40:45 +0200 Subject: [PATCH 20/32] Simplify JacobianAccumulator --- src/torchjd/autogram/_module_hook_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 29285646..f6ff622f 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -188,11 +188,11 @@ def setup_context( def backward(ctx, *grad_outputs: Tensor) -> tuple: # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] - if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, *grad_outputs - - optional_gramian = ctx.gramian_computer(grad_outputs, ctx.args, ctx.kwargs, ctx.rg_outputs) - if optional_gramian is not None: - ctx.gramian_accumulator.accumulate_gramian(optional_gramian) + if ctx.gramian_accumulation_phase: + optional_gramian = ctx.gramian_computer( + grad_outputs, ctx.args, ctx.kwargs, ctx.rg_outputs + ) + if optional_gramian is not None: + ctx.gramian_accumulator.accumulate_gramian(optional_gramian) return None, None, None, None, None, *grad_outputs From be6e41dc22453408b6854870463288467270328a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:47:41 +0200 Subject: [PATCH 21/32] Move vmap handling to JacobianComputer --- src/torchjd/autogram/_gramian_computer.py | 55 +------------------- src/torchjd/autogram/_jacobian_computer.py | 58 ++++++++++++++++++++-- 2 files changed, 56 insertions(+), 57 deletions(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index d4666cd9..21c98e19 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from typing import Optional -import torch from torch import Tensor from torch.utils._pytree import PyTree @@ -31,61 +30,11 @@ class JacobianBasedGramianComputer(GramianComputer, ABC): def __init__(self, jacobian_computer): self.jacobian_computer = jacobian_computer - def compute_jacobian( - self, - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - ) -> Tensor: - return ComputeModuleJacobians.apply( - self.jacobian_computer, args, kwargs, rg_outputs, *grad_outputs - ) - @staticmethod def _to_gramian(jacobian: Tensor) -> Tensor: return jacobian @ jacobian.T -class ComputeModuleJacobians(torch.autograd.Function): - @staticmethod - def forward( - jacobian_computer: JacobianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - *grad_outputs: Tensor, - ) -> Tensor: - # There is no non-batched dimension - jacobian = jacobian_computer(grad_outputs, args, kwargs, rg_outputs) - return jacobian - - @staticmethod - def vmap( - _, - in_dims: tuple, - # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], *tuple[int | None, ...]] - jacobian_computer: JacobianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - *jac_outputs: Tensor, - ) -> tuple[Tensor, None]: - # There is a non-batched dimension - # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], None, None, None) - generalized_jacobian = torch.vmap(jacobian_computer, in_dims=in_dims)( - jac_outputs, args, kwargs, rg_outputs - ) - shape = generalized_jacobian.shape - jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) - return jacobian, None - - @staticmethod - def setup_context(*_) -> None: - pass - - class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): """ Stateful GramianComputer that waits for all usages to be counted before returning the gramian. @@ -100,7 +49,7 @@ def __call__( ) -> Tensor: """Compute what we can for a module and optionally return the gramian if it's ready.""" - return self._to_gramian(self.compute_jacobian(grad_outputs, args, kwargs, rg_outputs)) + return self._to_gramian(self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs)) class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): @@ -129,7 +78,7 @@ def __call__( ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" - jacobian_matrix = self.compute_jacobian(grad_outputs, args, kwargs, rg_outputs) + jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) if self.summed_jacobian is None: self.summed_jacobian = jacobian_matrix diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index aeb909fb..b17f0a2a 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import cast import torch @@ -35,13 +35,24 @@ def __init__(self, module: nn.Module): else: self.frozen_params[name] = param - @abstractmethod def __call__( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], rg_outputs: Sequence[Tensor], + ) -> Tensor: + return ComputeModuleJacobians.apply( + self._compute_jacobian, args, kwargs, rg_outputs, *grad_outputs + ) + + @abstractmethod + def _compute_jacobian( + self, + grad_outputs: tuple[Tensor, ...], + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], ) -> Tensor: """ Computes and returns the Jacobian. The output must be a matrix (2D Tensor). @@ -56,7 +67,7 @@ class FunctionalJacobianComputer(JacobianComputer): function. """ - def __call__( + def _compute_jacobian( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], @@ -115,7 +126,7 @@ class AutogradJacobianComputer(JacobianComputer): require making an extra forward pass. """ - def __call__( + def _compute_jacobian( self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], @@ -134,3 +145,42 @@ def __call__( flattened_grads = torch.cat([g.reshape(-1) for g in grads]) jacobian = flattened_grads.unsqueeze(0) return jacobian + + +class ComputeModuleJacobians(torch.autograd.Function): + @staticmethod + def forward( + compute_jacobian_fn: Callable, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + *grad_outputs: Tensor, + ) -> Tensor: + # There is no non-batched dimension + jacobian = compute_jacobian_fn(grad_outputs, args, kwargs, rg_outputs) + return jacobian + + @staticmethod + def vmap( + _, + in_dims: tuple, + # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], *tuple[int | None, ...]] + compute_jacobian_fn: Callable, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + rg_outputs: Sequence[Tensor], + *jac_outputs: Tensor, + ) -> tuple[Tensor, None]: + # There is a non-batched dimension + # We do not vmap over the args for the non-batched dimension + in_dims = (in_dims[4:], None, None, None) + generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims)( + jac_outputs, args, kwargs, rg_outputs + ) + shape = generalized_jacobian.shape + jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) + return jacobian, None + + @staticmethod + def setup_context(*_) -> None: + pass From 8967ce144af07a2addc82e98593ac0e10444753c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:49:16 +0200 Subject: [PATCH 22/32] Improve docstrings --- src/torchjd/autogram/_gramian_computer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 21c98e19..581c067e 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -37,7 +37,10 @@ def _to_gramian(jacobian: Tensor) -> Tensor: class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): """ - Stateful GramianComputer that waits for all usages to be counted before returning the gramian. + Stateless JacobianBasedGramianComputer that always returns the gramian when computing it. + + This has the effect of ignoring potential conflict from between the different usages of a same + module. """ def __call__( @@ -54,7 +57,8 @@ def __call__( class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): """ - Stateful GramianComputer that waits for all usages to be counted before returning the gramian. + Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning + the gramian. """ def __init__(self, jacobian_computer: JacobianComputer): From b791485586427c24614aff5a570534660df3deac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:50:29 +0200 Subject: [PATCH 23/32] Use _to_gramian in JacobianBasedGramianComputerWithCrossTerms --- src/torchjd/autogram/_gramian_computer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 581c067e..25718ba4 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -92,7 +92,7 @@ def __call__( self.remaining_counter -= 1 if self.remaining_counter == 0: - gramian = self.summed_jacobian @ self.summed_jacobian.T + gramian = self._to_gramian(self.summed_jacobian) del self.summed_jacobian return gramian else: From dd81b9434de749131b8206ffb0a6652b61948767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 19:50:52 +0200 Subject: [PATCH 24/32] Remove JacobianBasedGramianComputerWithoutCrossTerms --- src/torchjd/autogram/_gramian_computer.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 25718ba4..3d7c8270 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -35,26 +35,6 @@ def _to_gramian(jacobian: Tensor) -> Tensor: return jacobian @ jacobian.T -class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer): - """ - Stateless JacobianBasedGramianComputer that always returns the gramian when computing it. - - This has the effect of ignoring potential conflict from between the different usages of a same - module. - """ - - def __call__( - self, - grad_outputs: tuple[Tensor, ...], - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - ) -> Tensor: - """Compute what we can for a module and optionally return the gramian if it's ready.""" - - return self._to_gramian(self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs)) - - class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): """ Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning From f3c3961804404a105cdc399860be28c936c75a72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 20:09:06 +0200 Subject: [PATCH 25/32] Simplify how in_dims are computed in FunctionalJacobianComputer --- src/torchjd/autogram/_jacobian_computer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index b17f0a2a..a585b9c8 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -72,12 +72,12 @@ def _compute_jacobian( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], + _: Sequence[Tensor], ) -> Tensor: - rg_output_in_dims = (0,) * len(rg_outputs) - arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + grad_outputs_in_dims = (0,) * len(grad_outputs) + args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) - in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) + in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) return vmapped_vjp(grad_outputs, args, kwargs) From cd2a8d8f45f3ac40ff142ea9384e1f61fa66fe66 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 14 Oct 2025 18:35:25 +0200 Subject: [PATCH 26/32] Fix mypy error --- src/torchjd/autogram/_gramian_computer.py | 2 +- src/torchjd/autogram/_jacobian_computer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 3d7c8270..e4b609f4 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -44,7 +44,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): def __init__(self, jacobian_computer: JacobianComputer): super().__init__(jacobian_computer) self.remaining_counter = 0 - self.summed_jacobian = None + self.summed_jacobian: Optional[Tensor] = None def reset(self) -> None: self.remaining_counter = 0 diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index a585b9c8..f074add5 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -133,7 +133,7 @@ def _compute_jacobian( __: dict[str, PyTree], rg_outputs: Sequence[Tensor], ) -> Tensor: - flat_rg_params, _ = tree_flatten(self.rg_params) + flat_rg_params = tree_flatten(self.rg_params)[0] grads = torch.autograd.grad( rg_outputs, flat_rg_params, From 0fcd8705af40e5d67ae2eba800546b516b221c25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 19:14:18 +0200 Subject: [PATCH 27/32] Use ___ for variable name --- src/torchjd/autogram/_jacobian_computer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index f074add5..4994951e 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -133,7 +133,7 @@ def _compute_jacobian( __: dict[str, PyTree], rg_outputs: Sequence[Tensor], ) -> Tensor: - flat_rg_params = tree_flatten(self.rg_params)[0] + flat_rg_params, ___ = tree_flatten(self.rg_params) grads = torch.autograd.grad( rg_outputs, flat_rg_params, From ae356e4865f7a3e6160e6d998de6fe608ecf461f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 19:17:22 +0200 Subject: [PATCH 28/32] Make _make_gramian_computer protected --- src/torchjd/autogram/_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index c9657dac..ecde7483 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -194,14 +194,14 @@ def __init__( def _hook_module_recursively(self, module: nn.Module) -> None: if any(p.requires_grad for p in module.parameters(recurse=False)): self._check_module_is_compatible(module) - gramian_computer = self.make_gramian_computer(module) + gramian_computer = self._make_gramian_computer(module) self._gramian_computers[module] = gramian_computer self._module_hook_manager.hook_module(module, gramian_computer) else: for child in module.children(): self._hook_module_recursively(child) - def make_gramian_computer(self, module: nn.Module) -> GramianComputer: + def _make_gramian_computer(self, module: nn.Module) -> GramianComputer: jacobian_computer: JacobianComputer if self._batch_dim is not None: jacobian_computer = FunctionalJacobianComputer(module) From 84a6dcf73f4cc01948a98e769172b86c804e7f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 19:30:22 +0200 Subject: [PATCH 29/32] Add comment --- src/torchjd/autogram/_jacobian_computer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 4994951e..5a66d424 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -42,6 +42,7 @@ def __call__( kwargs: dict[str, PyTree], rg_outputs: Sequence[Tensor], ) -> Tensor: + # This makes __call__ vmappable. return ComputeModuleJacobians.apply( self._compute_jacobian, args, kwargs, rg_outputs, *grad_outputs ) From 22a463cfc5ee7a2977559517c9a9a44dd515de28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 19:55:49 +0200 Subject: [PATCH 30/32] Improve type consistency --- src/torchjd/autogram/_gramian_computer.py | 7 ++- src/torchjd/autogram/_jacobian_computer.py | 45 +++++++++++--------- src/torchjd/autogram/_module_hook_manager.py | 5 ++- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index e4b609f4..2bc62f21 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence from typing import Optional from torch import Tensor @@ -12,10 +11,10 @@ class GramianComputer(ABC): @abstractmethod def __call__( self, + rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" @@ -55,14 +54,14 @@ def track_forward_call(self) -> None: def __call__( self, + rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], ) -> Optional[Tensor]: """Compute what we can for a module and optionally return the gramian if it's ready.""" - jacobian_matrix = self.jacobian_computer(grad_outputs, args, kwargs, rg_outputs) + jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) if self.summed_jacobian is None: self.summed_jacobian = jacobian_matrix diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 5a66d424..aee08ed6 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Callable from typing import cast import torch @@ -37,23 +37,23 @@ def __init__(self, module: nn.Module): def __call__( self, + rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], ) -> Tensor: # This makes __call__ vmappable. return ComputeModuleJacobians.apply( - self._compute_jacobian, args, kwargs, rg_outputs, *grad_outputs + self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs ) @abstractmethod def _compute_jacobian( self, + rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], ) -> Tensor: """ Computes and returns the Jacobian. The output must be a matrix (2D Tensor). @@ -70,10 +70,10 @@ class FunctionalJacobianComputer(JacobianComputer): def _compute_jacobian( self, + _: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - _: Sequence[Tensor], ) -> Tensor: grad_outputs_in_dims = (0,) * len(grad_outputs) args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) @@ -96,9 +96,9 @@ def _call_on_one_instance( # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) - grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] + grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j) - def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: + def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]: all_state = [ cast(dict[str, Tensor], rg_params), dict(self.module.named_buffers()), @@ -106,7 +106,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] - rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] + rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) return rg_outputs vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] @@ -129,10 +129,10 @@ class AutogradJacobianComputer(JacobianComputer): def _compute_jacobian( self, + rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree], - rg_outputs: Sequence[Tensor], ) -> Tensor: flat_rg_params, ___ = tree_flatten(self.rg_params) grads = torch.autograd.grad( @@ -151,32 +151,35 @@ def _compute_jacobian( class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( - compute_jacobian_fn: Callable, + compute_jacobian_fn: Callable[ + [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor + ], + rg_outputs: tuple[Tensor, ...], + grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - *grad_outputs: Tensor, ) -> Tensor: # There is no non-batched dimension - jacobian = compute_jacobian_fn(grad_outputs, args, kwargs, rg_outputs) + jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) return jacobian @staticmethod def vmap( _, - in_dims: tuple, - # tuple[None, tuple[PyTree, ...], dict[str, PyTree], Sequence[int], *tuple[int | None, ...]] + in_dims: tuple[None, None, tuple[int, ...], None, None], compute_jacobian_fn: Callable, + rg_outputs: tuple[Tensor, ...], + jac_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - rg_outputs: Sequence[Tensor], - *jac_outputs: Tensor, ) -> tuple[Tensor, None]: # There is a non-batched dimension - # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], None, None, None) - generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims)( - jac_outputs, args, kwargs, rg_outputs + # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension + generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])( + rg_outputs, + jac_outputs, + args, + kwargs, ) shape = generalized_jacobian.shape jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1]) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f6ff622f..2b72d5f1 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -190,7 +190,10 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: if ctx.gramian_accumulation_phase: optional_gramian = ctx.gramian_computer( - grad_outputs, ctx.args, ctx.kwargs, ctx.rg_outputs + ctx.rg_outputs, + grad_outputs, + ctx.args, + ctx.kwargs, ) if optional_gramian is not None: ctx.gramian_accumulator.accumulate_gramian(optional_gramian) From 3e6641b69329e25a028a1d9602d793123fdcdd01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 20:36:06 +0200 Subject: [PATCH 31/32] Simplify JacobianComputer docstrings --- src/torchjd/autogram/_jacobian_computer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index aee08ed6..26452f5d 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -62,10 +62,9 @@ def _compute_jacobian( class FunctionalJacobianComputer(JacobianComputer): """ - Represents a function that computes Jacobians for a module's forward pass with respect to its - parameters using the functional differentiation API. This requires to use vmap, so it's not - compatible with every module, and it requires to have an extra forward pass to create the vjp - function. + JacobianComputer using the functional differentiation API. This requires to use vmap, so it's + not compatible with every module, and it requires to have an extra forward pass to create the + vjp function. """ def _compute_jacobian( @@ -121,10 +120,8 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...] class AutogradJacobianComputer(JacobianComputer): """ - Represents a function that computes Jacobians for a module's forward pass with respect to its - parameters using the autograd engine. The __call__ function takes both the inputs and the - cotangents but ignores the inputs. The main advantage of using this method is that it doesn't - require making an extra forward pass. + JacobianComputer using the autograd engine. The main advantage of using this method is that it + doesn't require making an extra forward pass. """ def _compute_jacobian( From 08baaf1763c2c495d74dfa64a3c3bab9b9fb6a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 14 Oct 2025 20:39:31 +0200 Subject: [PATCH 32/32] Fix comment --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 2b72d5f1..7fc4b80c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -170,7 +170,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx,