Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def compute_gramian(self, output: Tensor) -> Tensor:

reshaped_output = output.reshape([-1])

self._module_hook_manager.gramian_accumulation_phase = True
self._module_hook_manager.gramian_accumulation_phase.value = True

try:
square_gramian = self._compute_square_gramian(reshaped_output)
finally:
# Reset everything that has a state, even if the previous call raised an exception
self._module_hook_manager.gramian_accumulation_phase = False
self._module_hook_manager.gramian_accumulation_phase.value = False
self._gramian_accumulator.reset()
self._target_edges.reset()

Expand Down
204 changes: 137 additions & 67 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import weakref
from collections.abc import Callable
from typing import cast

import torch
Expand All @@ -19,6 +21,16 @@
# still support older versions of PyTorch where pytree is protected).


class BoolRef:
"""Class wrapping a boolean value, acting as a reference to this boolean value."""

def __init__(self, value: bool):
self.value = value

def __bool__(self) -> bool:
return self.value


class ModuleHookManager:
"""
Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
Expand All @@ -35,9 +47,19 @@ def __init__(
):
self._target_edges = target_edges
self._gramian_accumulator = gramian_accumulator
self.gramian_accumulation_phase = False
self.gramian_accumulation_phase = BoolRef(False)
self._handles: list[TorchRemovableHandle] = []

# When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks
# alive. In fact, keeping the hooks alive would also keep the target edges alive, which
# would keep the graph or part of the graph alive. Since the graph contains nodes that store
# the module in their context, which themselves reference their hooks, the hooks will be
# caught in a reference cycle and will not be freed by the garbage collector. It is thus
# important to remove the hooks whenever we're sure we won't need them anymore.
# We could have used a __del__ method here, with the same effects, but weakref.finalize
# seems to be a better practice (and it only works if the function to call is static).
self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles)

def hook_module(self, module: nn.Module) -> None:
"""
Add a module hook used to insert Jacobian accumulation nodes into the backward graph.
Expand All @@ -46,85 +68,133 @@ def hook_module(self, module: nn.Module) -> None:
enabling Gramian computation.
"""

def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return output

flat_outputs, tree_spec = tree_flatten(output)
hook = Hook(self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator)
self._handles.append(module.register_forward_hook(hook))

if not any(isinstance(t, Tensor) for t in flat_outputs):
# This can happen only if a module returns no Tensor, for instance some niche usage
# such as a module that prints something.
return output

requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self._gramian_accumulator.track_parameter_paths(requires_grad_params)
@staticmethod
def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
"""
Remove all registered hooks. This method is deliberately static so that it can be called by
weakref.finalize.
"""

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original ouputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
inf = float("inf")
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
index = cast(int, preference.argmin().item())
self._target_edges.register(get_gradient_edge(flat_outputs[index]))
for handle in handles:
handle.remove()

return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs)

handle = module.register_forward_hook(module_hook)
self._handles.append(handle)
class AccumulateJacobian(torch.autograd.Function):

def _apply_jacobian_accumulator(
self,
module: nn.Module,
args: PyTree,
@staticmethod
def forward(
ctx,
tree_spec: TreeSpec,
flat_outputs: list[Tensor],
) -> PyTree:
vjp = torch.vmap(get_functional_vjp(module))

class AccumulateJacobian(torch.autograd.Function):

@staticmethod
def forward(*flat_grad_outputs: Tensor) -> None:
grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec)
jacobians = vjp(grad_outputs, args)
self._gramian_accumulator.accumulate_path_jacobians(
{
module.get_parameter(param_name): jacobian
for param_name, jacobian in jacobians.items()
}
)
vjp: Callable[[PyTree, PyTree], dict[str, Tensor]],
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_grad_outputs: Tensor,
) -> None:
grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec)
jacobians = vjp(grad_outputs, args)
gramian_accumulator.accumulate_path_jacobians(
{
module.get_parameter(param_name): jacobian
for param_name, jacobian in jacobians.items()
}
)


class JacobianAccumulator(torch.autograd.Function):
"""
Autograd function that accumulates Jacobian Gramians during the first backward pass.

@staticmethod
def setup_context(*_):
pass
Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
toggle mechanism to activate only during the Gramian accumulation phase.
"""

class JacobianAccumulator(torch.autograd.Function):
"""
Autograd function that accumulates Jacobian Gramians during the first backward pass.
generate_vmap_rule = True

Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
toggle mechanism to activate only during the Gramian accumulation phase.
"""
@staticmethod
def forward(
ctx,
gramian_accumulation_phase: BoolRef,
tree_spec: TreeSpec,
vjp: Callable[[PyTree, PyTree], dict[str, Tensor]],
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*xs: Tensor,
) -> tuple[Tensor, ...]:
ctx.gramian_accumulation_phase = gramian_accumulation_phase
ctx.tree_spec = tree_spec
ctx.vjp = vjp
ctx.args = args
ctx.gramian_accumulator = gramian_accumulator
ctx.module = module
return tuple([x.detach() for x in xs])

@staticmethod
def backward(ctx, *flat_grad_outputs: Tensor):
if not ctx.gramian_accumulation_phase:
return None, None, None, None, None, None, *flat_grad_outputs

AccumulateJacobian.apply(
ctx.tree_spec,
ctx.vjp,
ctx.args,
ctx.gramian_accumulator,
ctx.module,
*flat_grad_outputs,
)

return None, None, None, None, None, None, *flat_grad_outputs


class Hook:
def __init__(
self,
gramian_accumulation_phase: BoolRef,
target_edges: EdgeRegistry,
gramian_accumulator: GramianAccumulator,
):
self.gramian_accumulation_phase = gramian_accumulation_phase
self.target_edges = target_edges
self.gramian_accumulator = gramian_accumulator

generate_vmap_rule = True
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return output

@staticmethod
def forward(*xs: Tensor) -> tuple[Tensor, ...]:
return tuple([x.detach() for x in xs])
flat_outputs, tree_spec = tree_flatten(output)

@staticmethod
def setup_context(*_):
pass
if not any(isinstance(t, Tensor) for t in flat_outputs):
# This can happen only if a module returns no Tensor, for instance some niche usage
# such as a module that prints something.
return output

@staticmethod
def backward(ctx, *flat_grad_outputs: Tensor):
if not self.gramian_accumulation_phase:
return flat_grad_outputs
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(requires_grad_params)

AccumulateJacobian.apply(*flat_grad_outputs)
# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original ouputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
inf = float("inf")
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))

return flat_grad_outputs
vjp = torch.vmap(get_functional_vjp(module))

return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec)
return tree_unflatten(
JacobianAccumulator.apply(
self.gramian_accumulation_phase,
tree_spec,
vjp,
args,
self.gramian_accumulator,
module,
*flat_outputs,
),
tree_spec,
)
12 changes: 6 additions & 6 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@
(FreeParam, 32),
(NoFreeParam, 32),
param(Randomness, 32, marks=mark.xfail),
param(Cifar10Model, 16, marks=[mark.slow, mark.garbage_collect]),
param(AlexNet, 2, marks=[mark.slow, mark.garbage_collect]),
param(InstanceNormResNet18, 4, marks=[mark.slow, mark.garbage_collect]),
param(GroupNormMobileNetV3Small, 3, marks=[mark.slow, mark.garbage_collect]),
param(SqueezeNet, 8, marks=[mark.slow, mark.garbage_collect]),
param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]),
param(Cifar10Model, 16, marks=[mark.slow]),
param(AlexNet, 2, marks=[mark.slow]),
param(InstanceNormResNet18, 4, marks=[mark.slow]),
param(GroupNormMobileNetV3Small, 3, marks=[mark.slow]),
param(SqueezeNet, 8, marks=[mark.slow]),
param(InstanceNormMobileNetV2, 2, marks=[mark.slow]),
]


Expand Down
17 changes: 0 additions & 17 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
import os
import random as rand

Expand Down Expand Up @@ -32,28 +31,12 @@ def fix_randomness() -> None:
torch.use_deterministic_algorithms(True)


@fixture(autouse=True)
def garbage_collect_if_marked(request):
"""
Since garbage collection takes some time, we only do it when needed (when the test or the
parametrization of the test is marked with mark.garbage_collect). This is currently useful for
freeing CUDA memory after a lot has been allocated.
"""

yield
if request.node.get_closest_marker("garbage_collect"):
if DEVICE.type == "cuda":
torch.cuda.empty_cache()
gc.collect()


def pytest_addoption(parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")


def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
config.addinivalue_line("markers", "garbage_collect: do garbage collection after test")


def pytest_collection_modifyitems(config, items):
Expand Down
Loading