From 4d5c09d1d3d8347ea73be77ee0d8c23ee96d637e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 17 Sep 2025 17:48:26 +0200 Subject: [PATCH 1/5] Refactor ModuleHookManager * Use a dedicated Hook class for readability and improved debugging * Use BoolRef so that we can have a pointer to the boolean value in the Hook without having a pointer to the ModuleHookManager in the Hook (which is a reference cycle). * Use the module given as input to Hook.__call__ rather than the module provided to hook_module. This avoids having to store a reference to the module in its hook (which is a reference cycle). --- src/torchjd/autogram/_engine.py | 4 +- src/torchjd/autogram/_module_hook_manager.py | 73 ++++++++++++-------- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 96610e38..b4c47715 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -170,13 +170,13 @@ def compute_gramian(self, output: Tensor) -> Tensor: reshaped_output = output.reshape([-1]) - self._module_hook_manager.gramian_accumulation_phase = True + self._module_hook_manager.gramian_accumulation_phase.value = True try: square_gramian = self._compute_square_gramian(reshaped_output) finally: # Reset everything that has a state, even if the previous call raised an exception - self._module_hook_manager.gramian_accumulation_phase = False + self._module_hook_manager.gramian_accumulation_phase.value = False self._gramian_accumulator.reset() self._target_edges.reset() diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 07381184..f6ffb984 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry @@ -19,6 +19,16 @@ # still support older versions of PyTorch where pytree is protected). +class BoolRef: + """Class wrapping a boolean value, acting as a reference to this boolean value.""" + + def __init__(self, value: bool): + self.value = value + + def __bool__(self) -> bool: + return self.value + + class ModuleHookManager: """ Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation. @@ -35,7 +45,7 @@ def __init__( ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator - self.gramian_accumulation_phase = False + self.gramian_accumulation_phase = BoolRef(False) self._handles: list[TorchRemovableHandle] = [] def hook_module(self, module: nn.Module) -> None: @@ -46,40 +56,43 @@ def hook_module(self, module: nn.Module) -> None: enabling Gramian computation. """ - def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: - if self.gramian_accumulation_phase: - return output + hook = Hook(self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator) + self._handles.append(module.register_forward_hook(hook)) - flat_outputs, tree_spec = tree_flatten(output) - if not any(isinstance(t, Tensor) for t in flat_outputs): - # This can happen only if a module returns no Tensor, for instance some niche usage - # such as a module that prints something. - return output +class Hook: + def __init__( + self, + gramian_accumulation_phase: BoolRef, + target_edges: EdgeRegistry, + gramian_accumulator: GramianAccumulator, + ): + self.gramian_accumulation_phase = gramian_accumulation_phase + self.target_edges = target_edges + self.gramian_accumulator = gramian_accumulator - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] - self._gramian_accumulator.track_parameter_paths(requires_grad_params) + def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + if self.gramian_accumulation_phase: + return output - # We only care about running the JacobianAccumulator node, so we need one of its child - # edges (the edges of the original ouputs of the model) as target. For memory - # efficiency, we select the smallest one (that requires grad). - inf = float("inf") - preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) - index = cast(int, preference.argmin().item()) - self._target_edges.register(get_gradient_edge(flat_outputs[index])) + flat_outputs, tree_spec = tree_flatten(output) - return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) + if not any(isinstance(t, Tensor) for t in flat_outputs): + # This can happen only if a module returns no Tensor, for instance some niche usage + # such as a module that prints something. + return output - handle = module.register_forward_hook(module_hook) - self._handles.append(handle) + requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] + self.gramian_accumulator.track_parameter_paths(requires_grad_params) + + # We only care about running the JacobianAccumulator node, so we need one of its child + # edges (the edges of the original ouputs of the model) as target. For memory + # efficiency, we select the smallest one (that requires grad). + inf = float("inf") + preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) + index = cast(int, preference.argmin().item()) + self.target_edges.register(get_gradient_edge(flat_outputs[index])) - def _apply_jacobian_accumulator( - self, - module: nn.Module, - args: PyTree, - tree_spec: TreeSpec, - flat_outputs: list[Tensor], - ) -> PyTree: vjp = torch.vmap(get_functional_vjp(module)) class AccumulateJacobian(torch.autograd.Function): @@ -88,7 +101,7 @@ class AccumulateJacobian(torch.autograd.Function): def forward(*flat_grad_outputs: Tensor) -> None: grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) jacobians = vjp(grad_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( + self.gramian_accumulator.accumulate_path_jacobians( { module.get_parameter(param_name): jacobian for param_name, jacobian in jacobians.items() From fa69e9c54a33ec6691eefeb06a9fd26631391d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Sep 2025 16:48:25 +0200 Subject: [PATCH 2/5] Refactor autograd Functions * Instead of using non-local variables, make them take these variables as input to forward * This seems more standard practice, and fixes a reference cycle issue --- src/torchjd/autogram/_module_hook_manager.py | 123 ++++++++++++------- 1 file changed, 76 insertions(+), 47 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f6ffb984..0f8accad 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten +from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry @@ -60,6 +60,69 @@ def hook_module(self, module: nn.Module) -> None: self._handles.append(module.register_forward_hook(hook)) +class AccumulateJacobian(torch.autograd.Function): + + @staticmethod + def forward( + ctx, tree_spec: TreeSpec, vjp, args, gramian_accumulator, module, *flat_grad_outputs: Tensor + ) -> None: + grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) + jacobians = vjp(grad_outputs, args) + gramian_accumulator.accumulate_path_jacobians( + { + module.get_parameter(param_name): jacobian + for param_name, jacobian in jacobians.items() + } + ) + + +class JacobianAccumulator(torch.autograd.Function): + """ + Autograd function that accumulates Jacobian Gramians during the first backward pass. + + Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian + of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a + toggle mechanism to activate only during the Gramian accumulation phase. + """ + + generate_vmap_rule = True + + @staticmethod + def forward( + ctx, + gramian_accumulation_phase: BoolRef, + tree_spec, + vjp, + args, + gramian_accumulator, + module, + *xs: Tensor, + ) -> tuple[Tensor, ...]: + ctx.gramian_accumulation_phase = gramian_accumulation_phase + ctx.tree_spec = tree_spec + ctx.vjp = vjp + ctx.args = args + ctx.gramian_accumulator = gramian_accumulator + ctx.module = module + return tuple([x.detach() for x in xs]) + + @staticmethod + def backward(ctx, *flat_grad_outputs: Tensor): + if not ctx.gramian_accumulation_phase: + return None, None, None, None, None, None, *flat_grad_outputs + + AccumulateJacobian.apply( + ctx.tree_spec, + ctx.vjp, + ctx.args, + ctx.gramian_accumulator, + ctx.module, + *flat_grad_outputs, + ) + + return None, None, None, None, None, None, *flat_grad_outputs + + class Hook: def __init__( self, @@ -95,49 +158,15 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: vjp = torch.vmap(get_functional_vjp(module)) - class AccumulateJacobian(torch.autograd.Function): - - @staticmethod - def forward(*flat_grad_outputs: Tensor) -> None: - grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) - jacobians = vjp(grad_outputs, args) - self.gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian - for param_name, jacobian in jacobians.items() - } - ) - - @staticmethod - def setup_context(*_): - pass - - class JacobianAccumulator(torch.autograd.Function): - """ - Autograd function that accumulates Jacobian Gramians during the first backward pass. - - Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian - of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a - toggle mechanism to activate only during the Gramian accumulation phase. - """ - - generate_vmap_rule = True - - @staticmethod - def forward(*xs: Tensor) -> tuple[Tensor, ...]: - return tuple([x.detach() for x in xs]) - - @staticmethod - def setup_context(*_): - pass - - @staticmethod - def backward(ctx, *flat_grad_outputs: Tensor): - if not self.gramian_accumulation_phase: - return flat_grad_outputs - - AccumulateJacobian.apply(*flat_grad_outputs) - - return flat_grad_outputs - - return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec) + return tree_unflatten( + JacobianAccumulator.apply( + self.gramian_accumulation_phase, + tree_spec, + vjp, + args, + self.gramian_accumulator, + module, + *flat_outputs, + ), + tree_spec, + ) From 95e95bdfe46ba8e2b7d24885a99eefeb8480258b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Sep 2025 17:11:16 +0200 Subject: [PATCH 3/5] Add finalizer in ModuleHookManager to unhook * This solves the last reference cycle issue --- src/torchjd/autogram/_module_hook_manager.py | 21 ++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0f8accad..62efdbcd 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,3 +1,4 @@ +import weakref from typing import cast import torch @@ -48,6 +49,16 @@ def __init__( self.gramian_accumulation_phase = BoolRef(False) self._handles: list[TorchRemovableHandle] = [] + # When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks + # alive. In fact, keeping the hooks alive would also keep the target edges alive, which + # would keep the graph or part of the graph alive. Since the graph contains nodes that store + # the module in their context, which themselves reference their hooks, the hooks will be + # caught in a reference cycle and will not be freed by the garbage collector. It is thus + # important to remove the hooks whenever we're sure we won't need them anymore. + # We could have used a __del__ method here, with the same effects, but weakref.finalize + # seems to be a better practice (and it only works if the function to call is static). + self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles) + def hook_module(self, module: nn.Module) -> None: """ Add a module hook used to insert Jacobian accumulation nodes into the backward graph. @@ -59,6 +70,16 @@ def hook_module(self, module: nn.Module) -> None: hook = Hook(self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator) self._handles.append(module.register_forward_hook(hook)) + @staticmethod + def remove_hooks(handles: list[TorchRemovableHandle]) -> None: + """ + Remove all registered hooks. This method is deliberately static so that it can be called by + weakref.finalize. + """ + + for handle in handles: + handle.remove() + class AccumulateJacobian(torch.autograd.Function): From 39952b6ea38f4e10faad584237e7ebbb3f2a3835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 18 Sep 2025 17:20:32 +0200 Subject: [PATCH 4/5] Remove manual garbage collection * Remove garbage_collect marker * Remove garbage_collect_if_marked fixture This is not needed anymore since the garbage collector's job is much easier now --- tests/unit/autogram/test_engine.py | 12 ++++++------ tests/unit/conftest.py | 17 ----------------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7ccfc3e8..b1fe5e3b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -96,12 +96,12 @@ (FreeParam, 32), (NoFreeParam, 32), param(Randomness, 32, marks=mark.xfail), - param(Cifar10Model, 16, marks=[mark.slow, mark.garbage_collect]), - param(AlexNet, 2, marks=[mark.slow, mark.garbage_collect]), - param(InstanceNormResNet18, 4, marks=[mark.slow, mark.garbage_collect]), - param(GroupNormMobileNetV3Small, 3, marks=[mark.slow, mark.garbage_collect]), - param(SqueezeNet, 8, marks=[mark.slow, mark.garbage_collect]), - param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]), + param(Cifar10Model, 16, marks=[mark.slow]), + param(AlexNet, 2, marks=[mark.slow]), + param(InstanceNormResNet18, 4, marks=[mark.slow]), + param(GroupNormMobileNetV3Small, 3, marks=[mark.slow]), + param(SqueezeNet, 8, marks=[mark.slow]), + param(InstanceNormMobileNetV2, 2, marks=[mark.slow]), ] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6a58931b..91bca4f9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,3 @@ -import gc import os import random as rand @@ -32,28 +31,12 @@ def fix_randomness() -> None: torch.use_deterministic_algorithms(True) -@fixture(autouse=True) -def garbage_collect_if_marked(request): - """ - Since garbage collection takes some time, we only do it when needed (when the test or the - parametrization of the test is marked with mark.garbage_collect). This is currently useful for - freeing CUDA memory after a lot has been allocated. - """ - - yield - if request.node.get_closest_marker("garbage_collect"): - if DEVICE.type == "cuda": - torch.cuda.empty_cache() - gc.collect() - - def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") - config.addinivalue_line("markers", "garbage_collect: do garbage collection after test") def pytest_collection_modifyitems(config, items): From 71c11de511c455d4a477200c4174911dc6f77df9 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 19 Sep 2025 10:00:18 +0200 Subject: [PATCH 5/5] Type the forward methods of both `autograd.Function`s --- src/torchjd/autogram/_module_hook_manager.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 62efdbcd..c6f42130 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,4 +1,5 @@ import weakref +from collections.abc import Callable from typing import cast import torch @@ -85,7 +86,13 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( - ctx, tree_spec: TreeSpec, vjp, args, gramian_accumulator, module, *flat_grad_outputs: Tensor + ctx, + tree_spec: TreeSpec, + vjp: Callable[[PyTree, PyTree], dict[str, Tensor]], + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, + *flat_grad_outputs: Tensor, ) -> None: grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) jacobians = vjp(grad_outputs, args) @@ -112,11 +119,11 @@ class JacobianAccumulator(torch.autograd.Function): def forward( ctx, gramian_accumulation_phase: BoolRef, - tree_spec, - vjp, - args, - gramian_accumulator, - module, + tree_spec: TreeSpec, + vjp: Callable[[PyTree, PyTree], dict[str, Tensor]], + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, *xs: Tensor, ) -> tuple[Tensor, ...]: ctx.gramian_accumulation_phase = gramian_accumulation_phase