Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
586e17a
Make `GramianAccumulator` track paths to Modules rathre than parameters.
PierreQuinton Oct 12, 2025
7e901e2
Improve docstrings
ValerianRey Oct 12, 2025
41d3b0b
Make InterModuleParamReuse xfail
ValerianRey Oct 12, 2025
72375bb
Merge branch 'main' into gramian-accumulator-handles-modules
ValerianRey Oct 13, 2025
07be8d6
Fix GramianAccumulator tests.
PierreQuinton Oct 13, 2025
a2cec0b
Make `_make_path_jacobians` return a list rather than a dict.
PierreQuinton Oct 13, 2025
f8856ba
Make `accumulate_path_jacobian` take a `list[Tensor]` of jacobians
PierreQuinton Oct 13, 2025
a9d752e
Rename `VJP` into `JacobianComputer` as it does not compute vector ja…
PierreQuinton Oct 13, 2025
89cdc86
functional_call can take a list of dict.
PierreQuinton Oct 13, 2025
3cc2cdf
Fix typing
PierreQuinton Oct 13, 2025
6138a72
Fix typing yet again
PierreQuinton Oct 13, 2025
1a1426c
Makes `JacobianComputer` return a Jacobian Matrix. Adapt `ComputeModu…
PierreQuinton Oct 13, 2025
6750ffa
Improve docstring.
PierreQuinton Oct 13, 2025
22fd4e7
Apparently, if a in_dim corresponding to a PyTree is set to `None`, i…
PierreQuinton Oct 13, 2025
13248cb
Fix precision of tests (for cuda testing)
ValerianRey Oct 13, 2025
b2b10d0
Remove ModuleJacobianComputer
ValerianRey Oct 13, 2025
5641382
Make JacobianComputers take rg_outputs at call
ValerianRey Oct 13, 2025
ea12581
Move JacobianComputer construction to Engine
ValerianRey Oct 13, 2025
ab28096
Add GramianComputer and subclasses
ValerianRey Oct 13, 2025
47551fe
Use GramianComputer instead of JacobianComputer
ValerianRey Oct 13, 2025
9e102e7
Simplify JacobianAccumulator
ValerianRey Oct 13, 2025
be6e41d
Move vmap handling to JacobianComputer
ValerianRey Oct 13, 2025
8967ce1
Improve docstrings
ValerianRey Oct 13, 2025
b791485
Use _to_gramian in JacobianBasedGramianComputerWithCrossTerms
ValerianRey Oct 13, 2025
dd81b94
Remove JacobianBasedGramianComputerWithoutCrossTerms
ValerianRey Oct 13, 2025
f3c3961
Simplify how in_dims are computed in FunctionalJacobianComputer
ValerianRey Oct 13, 2025
944d6ac
Merge branch 'main' into gramian-accumulator-handles-modules
ValerianRey Oct 13, 2025
cd2a8d8
Fix mypy error
PierreQuinton Oct 14, 2025
0fcd870
Use ___ for variable name
ValerianRey Oct 14, 2025
ae356e4
Make _make_gramian_computer protected
ValerianRey Oct 14, 2025
84a6dcf
Add comment
ValerianRey Oct 14, 2025
22a463c
Improve type consistency
ValerianRey Oct 14, 2025
3e6641b
Simplify JacobianComputer docstrings
ValerianRey Oct 14, 2025
08baaf1
Fix comment
ValerianRey Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._jacobian_computer import (
AutogradJacobianComputer,
FunctionalJacobianComputer,
JacobianComputer,
)
from ._module_hook_manager import ModuleHookManager

_MODULES_INCOMPATIBLE_WITH_BATCHED = (
Expand Down Expand Up @@ -179,21 +185,32 @@ def __init__(
self._gramian_accumulator = GramianAccumulator()
self._target_edges = EdgeRegistry()
self._batch_dim = batch_dim
self._module_hook_manager = ModuleHookManager(
self._target_edges, self._gramian_accumulator, batch_dim is not None
)
self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator)
self._gramian_computers = dict[nn.Module, GramianComputer]()

for module in modules:
self._hook_module_recursively(module)

def _hook_module_recursively(self, module: nn.Module) -> None:
if any(p.requires_grad for p in module.parameters(recurse=False)):
self._check_module_is_compatible(module)
self._module_hook_manager.hook_module(module)
gramian_computer = self._make_gramian_computer(module)
self._gramian_computers[module] = gramian_computer
self._module_hook_manager.hook_module(module, gramian_computer)
else:
for child in module.children():
self._hook_module_recursively(child)

def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
jacobian_computer: JacobianComputer
if self._batch_dim is not None:
jacobian_computer = FunctionalJacobianComputer(module)
else:
jacobian_computer = AutogradJacobianComputer(module)
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)

return gramian_computer

def _check_module_is_compatible(self, module: nn.Module) -> None:
if self._batch_dim is not None:
if isinstance(module, _MODULES_INCOMPATIBLE_WITH_BATCHED):
Expand Down Expand Up @@ -276,6 +293,8 @@ def compute_gramian(self, output: Tensor) -> Tensor:
self._module_hook_manager.gramian_accumulation_phase.value = False
self._gramian_accumulator.reset()
self._target_edges.reset()
for gramian_computer in self._gramian_computers.values():
gramian_computer.reset()

unordered_gramian = reshape_gramian(square_gramian, ordered_shape)

Expand Down
54 changes: 3 additions & 51 deletions src/torchjd/autogram/_gramian_accumulator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from collections import Counter
from collections.abc import Iterable
from typing import Optional

import torch
from torch import Tensor


Expand All @@ -17,60 +14,15 @@ class GramianAccumulator:

def __init__(self) -> None:
self._gramian: Optional[Tensor] = None
self._summed_jacobians = dict[Tensor, Tensor]()
self._path_counter = Counter[Tensor]()

def reset(self) -> None:
self._gramian = None
self._summed_jacobians = {}
self._path_counter = Counter()

def track_parameter_paths(self, parameters: Iterable[Tensor]) -> None:
"""
Register parameters and count their paths in the computational graph.

:param parameters: Parameter tensors to track. Duplicates increase path count.
"""
self._path_counter.update(parameters)

def accumulate_path_jacobians(self, path_jacobians: dict[Tensor, Tensor]) -> None:
"""
Add path Jacobians for multiple parameters.

:param path_jacobians: Dictionary mapping parameters to Jacobian tensors of a single path.
"""
for parameter, jacobian in path_jacobians.items():
self._accumulate_path_jacobian(parameter, jacobian)

def _accumulate_path_jacobian(self, parameter: Tensor, jacobian: Tensor) -> None:
"""
Add path Jacobian for a parameter. In case the full Jacobian is computed, accumulate its
Gramian.

:param parameter: The parameter.
:param jacobian: path Jacobian with respect to the parameter.
"""
if parameter in self._summed_jacobians:
self._summed_jacobians[parameter] += jacobian
else:
self._summed_jacobians[parameter] = jacobian
self._path_counter.subtract([parameter])
if self._path_counter[parameter] == 0:
self._accumulate_gramian(parameter)
del self._path_counter[parameter]
del self._summed_jacobians[parameter]

def _accumulate_gramian(self, parameter: Tensor) -> None:
"""
Compute the Gramian of the full Jacobian and accumulate it.

:param parameter: Parameter whose full Jacobian is available.
"""
full_jacobian_matrix = torch.flatten(self._summed_jacobians[parameter], start_dim=1)
def accumulate_gramian(self, gramian: Tensor) -> None:
if self._gramian is not None:
self._gramian.addmm_(full_jacobian_matrix, full_jacobian_matrix.T)
self._gramian.add_(gramian)
else:
self._gramian = torch.mm(full_jacobian_matrix, full_jacobian_matrix.T)
self._gramian = gramian

@property
def gramian(self) -> Optional[Tensor]:
Expand Down
78 changes: 78 additions & 0 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC, abstractmethod
from typing import Optional

from torch import Tensor
from torch.utils._pytree import PyTree

from torchjd.autogram._jacobian_computer import JacobianComputer


class GramianComputer(ABC):
@abstractmethod
def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

def track_forward_call(self) -> None:
"""Track that the module's forward was called. Necessary in some implementations."""

def reset(self):
"""Reset state if any. Necessary in some implementations."""


class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer):
self.jacobian_computer = jacobian_computer

@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T


class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
"""
Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
the gramian.
"""

def __init__(self, jacobian_computer: JacobianComputer):
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Optional[Tensor] = None

def reset(self) -> None:
self.remaining_counter = 0
self.summed_jacobian = None

def track_forward_call(self) -> None:
self.remaining_counter += 1

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)

if self.summed_jacobian is None:
self.summed_jacobian = jacobian_matrix
else:
self.summed_jacobian += jacobian_matrix

self.remaining_counter -= 1

if self.remaining_counter == 0:
gramian = self._to_gramian(self.summed_jacobian)
del self.summed_jacobian
return gramian
else:
return None
187 changes: 187 additions & 0 deletions src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import cast

import torch
from torch import Tensor, nn
from torch.nn import Parameter
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only

# Note about import from protected _pytree module:
# PyTorch maintainers plan to make pytree public (see
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
# It should also come with better speed, because the current implementation is slow, according to
# https://github.com/pytorch/pytorch/issues/65761#issue-1010116111.
# When pytree becomes public, this import will have to be changed with a conditional import (to
# still support older versions of PyTorch where pytree is protected).


class JacobianComputer(ABC):
"""
Abstract class to computes Jacobians for a module's forward pass with respect to its parameters.

:params module: The module to differentiate.
"""

def __init__(self, module: nn.Module):
self.module = module

self.rg_params = dict[str, Parameter]()
self.frozen_params = dict[str, Parameter]()

for name, param in module.named_parameters(recurse=True):
if param.requires_grad:
self.rg_params[name] = param
else:
self.frozen_params[name] = param

def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
# This makes __call__ vmappable.
return ComputeModuleJacobians.apply(
self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs
)

@abstractmethod
def _compute_jacobian(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
"""
Computes and returns the Jacobian. The output must be a matrix (2D Tensor).
"""


class FunctionalJacobianComputer(JacobianComputer):
"""
JacobianComputer using the functional differentiation API. This requires to use vmap, so it's
not compatible with every module, and it requires to have an extra forward pass to create the
vjp function.
"""

def _compute_jacobian(
self,
_: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
grad_outputs_in_dims = (0,) * len(grad_outputs)
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims)
vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)

return vmapped_vjp(grad_outputs, args, kwargs)

def _call_on_one_instance(
self,
grad_outputs_j: tuple[Tensor, ...],
args_j: tuple[PyTree, ...],
kwargs_j: dict[str, PyTree],
) -> Tensor:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
# nn.Flatten) do not work equivalently if they're provided with a batch or with
# an element of a batch. We thus always provide them with batches, just of a
# different size.
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
grad_outputs_j_ = tuple(x.unsqueeze(0) for x in grad_outputs_j)

def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]:
all_state = [
cast(dict[str, Tensor], rg_params),
dict(self.module.named_buffers()),
cast(dict[str, Tensor], self.frozen_params),
]
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
flat_outputs = tree_flatten(output)[0]
rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad)
return rg_outputs

vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
gradients = vjp_func(grad_outputs_j_)[0]
gradient = torch.cat([t.reshape(-1) for t in gradients.values()])
return gradient


class AutogradJacobianComputer(JacobianComputer):
"""
JacobianComputer using the autograd engine. The main advantage of using this method is that it
doesn't require making an extra forward pass.
"""

def _compute_jacobian(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
_: tuple[PyTree, ...],
__: dict[str, PyTree],
) -> Tensor:
flat_rg_params, ___ = tree_flatten(self.rg_params)
grads = torch.autograd.grad(
rg_outputs,
flat_rg_params,
grad_outputs,
retain_graph=True,
allow_unused=True,
materialize_grads=True,
)
flattened_grads = torch.cat([g.reshape(-1) for g in grads])
jacobian = flattened_grads.unsqueeze(0)
return jacobian


class ComputeModuleJacobians(torch.autograd.Function):
@staticmethod
def forward(
compute_jacobian_fn: Callable[
[tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Tensor
],
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
# There is no non-batched dimension
jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs)
return jacobian

@staticmethod
def vmap(
_,
in_dims: tuple[None, None, tuple[int, ...], None, None],
compute_jacobian_fn: Callable,
rg_outputs: tuple[Tensor, ...],
jac_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> tuple[Tensor, None]:
# There is a non-batched dimension
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(
rg_outputs,
jac_outputs,
args,
kwargs,
)
shape = generalized_jacobian.shape
jacobian = generalized_jacobian.reshape([shape[0] * shape[1], -1])
return jacobian, None

@staticmethod
def setup_context(*_) -> None:
pass
Loading
Loading