Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def hook_module(self, module: nn.Module, gramian_computer: GramianComputer) -> N
"""
Add a module hook used to insert Jacobian accumulation nodes into the backward graph.

The hook injects a JacobianAccumulator function into the computation graph after the module,
The hook injects a AutogramNode function into the computation graph after the module,
enabling Gramian computation.
"""

Expand Down Expand Up @@ -125,14 +125,14 @@ def __call__(

self.gramian_computer.track_forward_call()

# We only care about running the JacobianAccumulator node, so we need one of its child
# We only care about running the AutogramNode, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
preference = torch.tensor([t.numel() for t in rg_outputs])
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(rg_outputs[index]))

autograd_fn_rg_outputs = JacobianAccumulator.apply(
autograd_fn_rg_outputs = AutogramNode.apply(
self.gramian_accumulation_phase,
self.gramian_computer,
args,
Expand All @@ -147,13 +147,10 @@ def __call__(
return tree_unflatten(flat_outputs, output_spec)


class JacobianAccumulator(torch.autograd.Function):
class AutogramNode(torch.autograd.Function):
"""
Autograd function that accumulates Jacobian Gramians during the first backward pass.

Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
toggle mechanism to activate only during the Gramian accumulation phase.
Autograd function that is identity on forward and that launches the computation and accumulation
of the gramian on backward.
"""

generate_vmap_rule = True
Expand Down
Loading