diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 96610e38..b4c47715 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -170,13 +170,13 @@ def compute_gramian(self, output: Tensor) -> Tensor: reshaped_output = output.reshape([-1]) - self._module_hook_manager.gramian_accumulation_phase = True + self._module_hook_manager.gramian_accumulation_phase.value = True try: square_gramian = self._compute_square_gramian(reshaped_output) finally: # Reset everything that has a state, even if the previous call raised an exception - self._module_hook_manager.gramian_accumulation_phase = False + self._module_hook_manager.gramian_accumulation_phase.value = False self._gramian_accumulator.reset() self._target_edges.reset() diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 07381184..c6f42130 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,3 +1,5 @@ +import weakref +from collections.abc import Callable from typing import cast import torch @@ -19,6 +21,16 @@ # still support older versions of PyTorch where pytree is protected). +class BoolRef: + """Class wrapping a boolean value, acting as a reference to this boolean value.""" + + def __init__(self, value: bool): + self.value = value + + def __bool__(self) -> bool: + return self.value + + class ModuleHookManager: """ Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation. @@ -35,9 +47,19 @@ def __init__( ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator - self.gramian_accumulation_phase = False + self.gramian_accumulation_phase = BoolRef(False) self._handles: list[TorchRemovableHandle] = [] + # When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks + # alive. In fact, keeping the hooks alive would also keep the target edges alive, which + # would keep the graph or part of the graph alive. Since the graph contains nodes that store + # the module in their context, which themselves reference their hooks, the hooks will be + # caught in a reference cycle and will not be freed by the garbage collector. It is thus + # important to remove the hooks whenever we're sure we won't need them anymore. + # We could have used a __del__ method here, with the same effects, but weakref.finalize + # seems to be a better practice (and it only works if the function to call is static). + self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles) + def hook_module(self, module: nn.Module) -> None: """ Add a module hook used to insert Jacobian accumulation nodes into the backward graph. @@ -46,85 +68,133 @@ def hook_module(self, module: nn.Module) -> None: enabling Gramian computation. """ - def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: - if self.gramian_accumulation_phase: - return output - - flat_outputs, tree_spec = tree_flatten(output) + hook = Hook(self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator) + self._handles.append(module.register_forward_hook(hook)) - if not any(isinstance(t, Tensor) for t in flat_outputs): - # This can happen only if a module returns no Tensor, for instance some niche usage - # such as a module that prints something. - return output - - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] - self._gramian_accumulator.track_parameter_paths(requires_grad_params) + @staticmethod + def remove_hooks(handles: list[TorchRemovableHandle]) -> None: + """ + Remove all registered hooks. This method is deliberately static so that it can be called by + weakref.finalize. + """ - # We only care about running the JacobianAccumulator node, so we need one of its child - # edges (the edges of the original ouputs of the model) as target. For memory - # efficiency, we select the smallest one (that requires grad). - inf = float("inf") - preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) - index = cast(int, preference.argmin().item()) - self._target_edges.register(get_gradient_edge(flat_outputs[index])) + for handle in handles: + handle.remove() - return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) - handle = module.register_forward_hook(module_hook) - self._handles.append(handle) +class AccumulateJacobian(torch.autograd.Function): - def _apply_jacobian_accumulator( - self, - module: nn.Module, - args: PyTree, + @staticmethod + def forward( + ctx, tree_spec: TreeSpec, - flat_outputs: list[Tensor], - ) -> PyTree: - vjp = torch.vmap(get_functional_vjp(module)) - - class AccumulateJacobian(torch.autograd.Function): - - @staticmethod - def forward(*flat_grad_outputs: Tensor) -> None: - grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) - jacobians = vjp(grad_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian - for param_name, jacobian in jacobians.items() - } - ) + vjp: Callable[[PyTree, PyTree], dict[str, Tensor]], + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, + *flat_grad_outputs: Tensor, + ) -> None: + grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) + jacobians = vjp(grad_outputs, args) + gramian_accumulator.accumulate_path_jacobians( + { + module.get_parameter(param_name): jacobian + for param_name, jacobian in jacobians.items() + } + ) + + +class JacobianAccumulator(torch.autograd.Function): + """ + Autograd function that accumulates Jacobian Gramians during the first backward pass. - @staticmethod - def setup_context(*_): - pass + Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian + of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a + toggle mechanism to activate only during the Gramian accumulation phase. + """ - class JacobianAccumulator(torch.autograd.Function): - """ - Autograd function that accumulates Jacobian Gramians during the first backward pass. + generate_vmap_rule = True - Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian - of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a - toggle mechanism to activate only during the Gramian accumulation phase. - """ + @staticmethod + def forward( + ctx, + gramian_accumulation_phase: BoolRef, + tree_spec: TreeSpec, + vjp: Callable[[PyTree, PyTree], dict[str, Tensor]], + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, + *xs: Tensor, + ) -> tuple[Tensor, ...]: + ctx.gramian_accumulation_phase = gramian_accumulation_phase + ctx.tree_spec = tree_spec + ctx.vjp = vjp + ctx.args = args + ctx.gramian_accumulator = gramian_accumulator + ctx.module = module + return tuple([x.detach() for x in xs]) + + @staticmethod + def backward(ctx, *flat_grad_outputs: Tensor): + if not ctx.gramian_accumulation_phase: + return None, None, None, None, None, None, *flat_grad_outputs + + AccumulateJacobian.apply( + ctx.tree_spec, + ctx.vjp, + ctx.args, + ctx.gramian_accumulator, + ctx.module, + *flat_grad_outputs, + ) + + return None, None, None, None, None, None, *flat_grad_outputs + + +class Hook: + def __init__( + self, + gramian_accumulation_phase: BoolRef, + target_edges: EdgeRegistry, + gramian_accumulator: GramianAccumulator, + ): + self.gramian_accumulation_phase = gramian_accumulation_phase + self.target_edges = target_edges + self.gramian_accumulator = gramian_accumulator - generate_vmap_rule = True + def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + if self.gramian_accumulation_phase: + return output - @staticmethod - def forward(*xs: Tensor) -> tuple[Tensor, ...]: - return tuple([x.detach() for x in xs]) + flat_outputs, tree_spec = tree_flatten(output) - @staticmethod - def setup_context(*_): - pass + if not any(isinstance(t, Tensor) for t in flat_outputs): + # This can happen only if a module returns no Tensor, for instance some niche usage + # such as a module that prints something. + return output - @staticmethod - def backward(ctx, *flat_grad_outputs: Tensor): - if not self.gramian_accumulation_phase: - return flat_grad_outputs + requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] + self.gramian_accumulator.track_parameter_paths(requires_grad_params) - AccumulateJacobian.apply(*flat_grad_outputs) + # We only care about running the JacobianAccumulator node, so we need one of its child + # edges (the edges of the original ouputs of the model) as target. For memory + # efficiency, we select the smallest one (that requires grad). + inf = float("inf") + preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) + index = cast(int, preference.argmin().item()) + self.target_edges.register(get_gradient_edge(flat_outputs[index])) - return flat_grad_outputs + vjp = torch.vmap(get_functional_vjp(module)) - return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec) + return tree_unflatten( + JacobianAccumulator.apply( + self.gramian_accumulation_phase, + tree_spec, + vjp, + args, + self.gramian_accumulator, + module, + *flat_outputs, + ), + tree_spec, + ) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7ccfc3e8..b1fe5e3b 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -96,12 +96,12 @@ (FreeParam, 32), (NoFreeParam, 32), param(Randomness, 32, marks=mark.xfail), - param(Cifar10Model, 16, marks=[mark.slow, mark.garbage_collect]), - param(AlexNet, 2, marks=[mark.slow, mark.garbage_collect]), - param(InstanceNormResNet18, 4, marks=[mark.slow, mark.garbage_collect]), - param(GroupNormMobileNetV3Small, 3, marks=[mark.slow, mark.garbage_collect]), - param(SqueezeNet, 8, marks=[mark.slow, mark.garbage_collect]), - param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]), + param(Cifar10Model, 16, marks=[mark.slow]), + param(AlexNet, 2, marks=[mark.slow]), + param(InstanceNormResNet18, 4, marks=[mark.slow]), + param(GroupNormMobileNetV3Small, 3, marks=[mark.slow]), + param(SqueezeNet, 8, marks=[mark.slow]), + param(InstanceNormMobileNetV2, 2, marks=[mark.slow]), ] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6a58931b..91bca4f9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,4 +1,3 @@ -import gc import os import random as rand @@ -32,28 +31,12 @@ def fix_randomness() -> None: torch.use_deterministic_algorithms(True) -@fixture(autouse=True) -def garbage_collect_if_marked(request): - """ - Since garbage collection takes some time, we only do it when needed (when the test or the - parametrization of the test is marked with mark.garbage_collect). This is currently useful for - freeing CUDA memory after a lot has been allocated. - """ - - yield - if request.node.get_closest_marker("garbage_collect"): - if DEVICE.type == "cuda": - torch.cuda.empty_cache() - gc.collect() - - def pytest_addoption(parser): parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") - config.addinivalue_line("markers", "garbage_collect: do garbage collection after test") def pytest_collection_modifyitems(config, items):