From 1d92c6e292a3564f1da40cfe279e65f5ed5a7080 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 12 Oct 2025 09:09:42 +0200 Subject: [PATCH 1/4] Make `AccumulateJacobian` return the jacobian w.r.t. the parameters of the module rather than providing them to the GramianAccumulator directly. Rename the class to `ComputeModuleJacobians` to better reflect its role. --- src/torchjd/autogram/_module_hook_manager.py | 23 +++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 058dd188..66a99ee3 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -206,33 +206,32 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: if not ctx.gramian_accumulation_phase: return None, None, None, None, None, None, *grad_outputs - AccumulateJacobian.apply( + path_jacobians = ComputeModuleJacobians.apply( ctx.vjp, ctx.args, ctx.kwargs, - ctx.gramian_accumulator, ctx.module, *grad_outputs, ) + ctx.gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None, None, None, None, None, *grad_outputs -class AccumulateJacobian(torch.autograd.Function): +class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( vjp: VJP, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, - ) -> None: + ) -> dict[Tensor, Tensor]: # There is no non-batched dimension generalized_jacobians = vjp(grad_outputs, args, kwargs) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) + path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) + return path_jacobians @staticmethod def vmap( @@ -241,17 +240,15 @@ def vmap( vjp: VJP, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, - ) -> tuple[None, None]: + ) -> tuple[dict[Tensor, Tensor], dict[Tensor, None]]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) + in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) - return None, None + path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) + return path_jacobians, tree_map(lambda _: None, path_jacobians) @staticmethod def _make_path_jacobians( From c5c1d6029a90530a916f2a1bc446e1965fc962c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Oct 2025 18:26:56 +0200 Subject: [PATCH 2/4] Fix type hint of in_dims --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 66a99ee3..0e21319d 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -236,7 +236,7 @@ def forward( @staticmethod def vmap( _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]] + in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]] vjp: VJP, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], From 782a480ca0abb4b6aab81bf6a396192dcdeda062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Oct 2025 18:27:08 +0200 Subject: [PATCH 3/4] Create variable for out_dims in vmap --- src/torchjd/autogram/_module_hook_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0e21319d..f1895aa0 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -248,7 +248,8 @@ def vmap( in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) - return path_jacobians, tree_map(lambda _: None, path_jacobians) + out_dims = tree_map(lambda _: None, path_jacobians) + return path_jacobians, out_dims @staticmethod def _make_path_jacobians( From 63e6e426b4981961e9d9ad8a71e6f53dbff4e331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 00:20:36 +0200 Subject: [PATCH 4/4] Change returned out_dims to None --- src/torchjd/autogram/_module_hook_manager.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f1895aa0..e39d1b25 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -242,14 +242,13 @@ def vmap( kwargs: dict[str, PyTree], module: nn.Module, *jac_outputs: Tensor, - ) -> tuple[dict[Tensor, Tensor], dict[Tensor, None]]: + ) -> tuple[dict[Tensor, Tensor], None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) - out_dims = tree_map(lambda _: None, path_jacobians) - return path_jacobians, out_dims + return path_jacobians, None @staticmethod def _make_path_jacobians(