diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 058dd188..e39d1b25 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -206,52 +206,49 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple: if not ctx.gramian_accumulation_phase: return None, None, None, None, None, None, *grad_outputs - AccumulateJacobian.apply( + path_jacobians = ComputeModuleJacobians.apply( ctx.vjp, ctx.args, ctx.kwargs, - ctx.gramian_accumulator, ctx.module, *grad_outputs, ) + ctx.gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None, None, None, None, None, *grad_outputs -class AccumulateJacobian(torch.autograd.Function): +class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( vjp: VJP, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, - ) -> None: + ) -> dict[Tensor, Tensor]: # There is no non-batched dimension generalized_jacobians = vjp(grad_outputs, args, kwargs) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) + path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) + return path_jacobians @staticmethod def vmap( _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]] + in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]] vjp: VJP, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, - ) -> tuple[None, None]: + ) -> tuple[dict[Tensor, Tensor], None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) + in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) - return None, None + path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians) + return path_jacobians, None @staticmethod def _make_path_jacobians(