diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 7fc4b80c..082ace69 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -53,7 +53,7 @@ def hook_module(self, module: nn.Module, gramian_computer: GramianComputer) -> N """ Add a module hook used to insert Jacobian accumulation nodes into the backward graph. - The hook injects a JacobianAccumulator function into the computation graph after the module, + The hook injects a AutogramNode function into the computation graph after the module, enabling Gramian computation. """ @@ -125,14 +125,14 @@ def __call__( self.gramian_computer.track_forward_call() - # We only care about running the JacobianAccumulator node, so we need one of its child + # We only care about running the AutogramNode, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory # efficiency, we select the smallest one (that requires grad). preference = torch.tensor([t.numel() for t in rg_outputs]) index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - autograd_fn_rg_outputs = JacobianAccumulator.apply( + autograd_fn_rg_outputs = AutogramNode.apply( self.gramian_accumulation_phase, self.gramian_computer, args, @@ -147,13 +147,10 @@ def __call__( return tree_unflatten(flat_outputs, output_spec) -class JacobianAccumulator(torch.autograd.Function): +class AutogramNode(torch.autograd.Function): """ - Autograd function that accumulates Jacobian Gramians during the first backward pass. - - Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian - of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a - toggle mechanism to activate only during the Gramian accumulation phase. + Autograd function that is identity on forward and that launches the computation and accumulation + of the gramian on backward. """ generate_vmap_rule = True