Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e757e31
Add non-batched support
PierreQuinton Aug 27, 2025
8f6e5ee
Update src/torchjd/autogram/_vjp.py
PierreQuinton Sep 3, 2025
d507daa
Update tests/unit/autogram/test_engine.py
PierreQuinton Sep 3, 2025
d60328e
Fix batched_dim param in some Engine usages
ValerianRey Sep 4, 2025
4103e4d
Make batched_dim parameter name explicit when creating Engine
ValerianRey Sep 4, 2025
4fadac6
Rename batched_dims to batched_dim in test_gramian_is_correct
ValerianRey Sep 4, 2025
5c98ee8
Fix parameter description of batched_dim
ValerianRey Sep 4, 2025
e5e6e9d
Improve error message in _check_module_is_compatible
ValerianRey Sep 4, 2025
c6c6a3f
Remove parameter description of removed parameter grad_output in comp…
ValerianRey Sep 4, 2025
af3373b
Rename flat_gramian to square_gramian
ValerianRey Sep 4, 2025
aeb8f3b
Remove redundant cast to dict in AutogradVJP
ValerianRey Sep 4, 2025
d5b8685
Rename info to _ in AccumulateJacobian.vmap
ValerianRey Sep 4, 2025
b789e35
Type-hint in_dims as PyTree in AccumulateJacobian.vmap
ValerianRey Sep 4, 2025
e949dff
Rename tree_spec to output_spec in ModuleHookManager
ValerianRey Sep 4, 2025
16f6aa9
Rename self.tree_spec to self.param_spec in AutogradVJP
ValerianRey Sep 4, 2025
ab134a8
Add example in comment of reshape_gramian
ValerianRey Sep 4, 2025
8d66d57
Improve variable names in compute_quadratic_form
ValerianRey Sep 4, 2025
4bdc400
Revamp documentation of compute_gramian
ValerianRey Sep 4, 2025
33216c1
Update src/torchjd/autogram/_engine.py
ValerianRey Sep 5, 2025
c090dfe
Add ... indexing in jac_output for code clarity
ValerianRey Sep 5, 2025
3466a3c
Fix formatting of docstring
ValerianRey Sep 5, 2025
d2bab7a
Add more parametrizations to test_reshape_equivariance
ValerianRey Sep 5, 2025
79e0609
Improve parametrization of test_movedim_equivariance
ValerianRey Sep 5, 2025
c40f44c
Improve parametrization of test_batched_non_batched_equivalence
ValerianRey Sep 5, 2025
2fe3b51
Add comment in compute_gramian
ValerianRey Sep 5, 2025
d82a8f6
Improve clarity of reshape_gramian
ValerianRey Sep 5, 2025
c846ed7
Improve clarity of movedim_gramian
ValerianRey Sep 5, 2025
c2fd0a0
Revert removal of _handles in ModuleHookManager
ValerianRey Sep 5, 2025
1cbff18
Add more edge cases to test_quadratic_form_invariance_to_reshape
ValerianRey Sep 5, 2025
6938560
Add more edge cases to test_quadratic_form_invariance_to_movedim
ValerianRey Sep 5, 2025
1b56558
Factorize code into _make_path_jacobians and use for-loop
ValerianRey Sep 5, 2025
1d78e15
Fix model not being moved to cuda in new tests
ValerianRey Sep 5, 2025
91d739b
Make test_equivalence_autojac_autogram also work with non-batched engine
ValerianRey Sep 6, 2025
92b7b1a
Make separation between trainable and frozen params in VJP and rename…
ValerianRey Sep 6, 2025
9b92cf7
Add allow_unused=True and materialize_grads=True in call to autograd.…
ValerianRey Sep 6, 2025
7950bb4
Stop trying to differentiate outputs that dont require grad in Autogr…
ValerianRey Sep 6, 2025
4dc3f7c
Fix variable name
ValerianRey Sep 6, 2025
23cb0a8
Rename jacobians to generalized_jacobians
ValerianRey Sep 6, 2025
d58a08b
Merge branch 'main' into vgp-v3-rebased
ValerianRey Sep 6, 2025
1169b85
Replace torch.movedim by tensor.movedim
ValerianRey Sep 6, 2025
59c957c
Add batch_size variable in compute_gramian
ValerianRey Sep 6, 2025
dbfbbc6
Merge branch 'main' into vgp-v3-rebased
ValerianRey Sep 9, 2025
9a01a12
Add GeneralizedWeighting
ValerianRey Sep 9, 2025
31812a1
Add FakeGeneralizedWeighting in tests
ValerianRey Sep 9, 2025
ee16f9b
reshape_gramian can now take generalized gramians as inputs, its shap…
PierreQuinton Sep 9, 2025
0a25555
Implement `HierarchicalWeighting` Weighting, needs testing.
PierreQuinton Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Abstract base classes
:undoc-members:
:exclude-members: forward

.. autoclass:: torchjd.aggregation.GeneralizedWeighting
:members:
:undoc-members:
:exclude-members: forward


.. toctree::
:hidden:
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from ._utils.check_dependencies import (
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
)
from ._weighting_bases import Weighting
from ._weighting_bases import GeneralizedWeighting, Weighting

try:
from ._cagrad import CAGrad, CAGradWeighting
Expand Down
25 changes: 25 additions & 0 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,28 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu

def forward(self, stat: _T) -> Tensor:
return self.weighting(self.fn(stat))


class GeneralizedWeighting(nn.Module, ABC):
r"""
Abstract base class for all weightings that operate on generalized Gramians. It has the role of
extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a
generalized Gramian of dimension
:math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`.
"""

def __init__(self):
super().__init__()

@abstractmethod
def forward(self, generalized_gramian: Tensor) -> Tensor:
"""Computes the vector of weights from the input generalized Gramian."""

# Override to make type hints and documentation more specific
def __call__(self, generalized_gramian: Tensor) -> Tensor:
"""
Computes the tensor of weights from the input generalized Gramian and applies all registered
hooks.
"""

return super().__call__(generalized_gramian)
105 changes: 105 additions & 0 deletions src/torchjd/aggregation/hierachical_weighting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
from torch import Tensor

from ..autogram._gramian_utils import reshape_gramian
from ._weighting_bases import GeneralizedWeighting, Weighting


class HierarchicalWeighting(GeneralizedWeighting):
"""
Hierarchically reduces a generalized Gramian using a sequence of weighting functions.

Applies multiple weightings in sequence to a generalized Gramian ``G`` of shape
``[n₁, ..., nₖ, nₖ, ..., n₁]``. It first applies the initial weighting to the innermost diagonal
Gramians, contracts those dimensions to form a smaller generalized Gramian, and repeats the
process with subsequent weightings. The final returned weights are chosen so that contracting
the original Gramian directly with these weights produces the same quadratic form as applying
the reductions step by step.

:param weightings: A list of weighting callables, one for each hierarchical reduction step.
"""

def __init__(self, weightings: list[Weighting]):
super().__init__()
self.weightings = weightings
self.n_dims = len(weightings)

def forward(self, generalized_gramian: Tensor) -> Tensor:

assert len(self.weightings) * 2 == len(generalized_gramian.shape) # temporary

weighting = self.weightings[0]
dim_size = generalized_gramian.shape[0]
reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size])
weights = _compute_weights(weighting, reshaped_gramian)
generalized_gramian = _contract_gramian(reshaped_gramian, weights)

for i in range(self.n_dim):
weighting = self.weightings[i]
dim_size = generalized_gramian.shape[i]
reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size])
temp_weights = _compute_weights(weighting, reshaped_gramian)
generalized_gramian = _contract_gramian(reshaped_gramian, temp_weights)
weights = _scale_weights(weights, temp_weights)

return weights


def _compute_weights(weighting: Weighting, generalized_gramian: Tensor) -> Tensor:
"""
Apply a weighting to each diagonal Gramian in a generalized Gramian.

For a generalized Gramian ``G`` of shape ``[m, n, n, m]``, this extracts each diagonal Gramian
``G[j, :, :, j]`` of shape ``[n, n]`` for ``j`` in ``[m]`` and applies the provided weighting.
The resulting weights are stacked into a tensor of shape ``[m, n]``.

:param weighting: Callable that maps a Gramian of shape ``[n, n]`` to weights of shape ``[n]``.
:param generalized_gramian: Tensor of shape ``[m, n, n, m]`` containing the generalized Gramian.
:returns: Tensor of shape ``[m, n]`` containing the computed weights for each diagonal Gramian.
"""

weights = torch.zeros(
generalized_gramian[:2], device=generalized_gramian.device, dtype=generalized_gramian.dtype
)
for i in range(generalized_gramian.shape[0]):
weights[i] = weighting(generalized_gramian[i, :, :, i])
return weights


def _contract_gramian(generalized_gramian: Tensor, weights: Tensor) -> Tensor:
r"""
Compute a partial quadratic form by contracting a generalized Gramian with weight vectors on
both sides.

Given a generalized Gramian ``G`` of shape ``[m, n, n, m]`` and weights ``w`` of shape
``[m, n]``, this function computes a Gramian ``G'`` of shape ``[m, m]`` where

.. math::

G'[i, j] = \sum_{k, l=1}^n w[i, k] G[i, k, l, j] w[j, l].

This can be viewed as forming a quadratic form with respect to the two innermost dimensions of
``G``.

:param generalized_gramian: Tensor of shape ``[m, n, n, m]`` representing the generalized
Gramian.
:param weights: Tensor of shape ``[m, n]`` containing weight vectors to contract with the
Gramian.
:returns: Tensor of shape ``[m, m]`` containing the contracted Gramian, i.e. the partial
quadratic form.
"""
left_product = torch.einsum("ij,ijkl->ikl", weights, generalized_gramian)
return torch.einsum("ij,ijl->il", weights, left_product)


def _scale_weights(weights: Tensor, scalings: Tensor) -> Tensor:
"""
Scale a tensor along its leading dimensions by broadcasting scaling factors.

:param weights: Tensor of shape [n₁, ..., nₖ, nₖ₊₁, ..., nₚ].
:param scalings: Tensor of shape [n₁, ..., nₖ] providing scaling factors for the leading
dimensions of ``weights``.
:returns: Tensor of the same shape as ``weights``, where each slice
``weights[i₁, ..., iₖ, :, ..., :]`` is multiplied by ``scalings[i₁, ..., iₖ]``.
"""
return weights * scalings[(...,) + (None,) * (weights.ndim - scalings.ndim)]
112 changes: 95 additions & 17 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import cast

import torch
from torch import Tensor, nn
from torch import Tensor, nn, vmap
from torch.autograd.graph import get_gradient_edge

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._gramian_utils import movedim_gramian, reshape_gramian
from ._module_hook_manager import ModuleHookManager

_INCOMPATIBLE_MODULE_TYPES = (
Expand Down Expand Up @@ -57,6 +58,10 @@ class Engine:

:param modules: A collection of modules whose direct (non-recursive) parameters will contribute
to the Gramian of the Jacobian.
:param batched_dim: If the modules work with batches and process each batch element
independently, then many intermediary jacobians are sparse (block-diagonal), which allows
for a substancial memory optimization by backpropagating a squashed Jacobian instead. This
parameter indicates the batch dimension, if any. Defaults to None.

.. admonition::
Example
Expand All @@ -79,7 +84,7 @@ class Engine:
>>>
>>> criterion = MSELoss(reduction="none")
>>> weighting = UPGradWeighting()
>>> engine = Engine(model.modules())
>>> engine = Engine(model.modules(), batched_dim=0)
>>>
>>> for input, target in zip(inputs, targets):
>>> output = model(input).squeeze(dim=1) # shape: [16]
Expand Down Expand Up @@ -127,10 +132,17 @@ class Engine:
<https://docs.pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html>`_ layers.
"""

def __init__(self, modules: Iterable[nn.Module]):
def __init__(
self,
modules: Iterable[nn.Module],
batched_dim: int | None = None,
):
self._gramian_accumulator = GramianAccumulator()
self._target_edges = EdgeRegistry()
self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator)
self._batched_dim = batched_dim
self._module_hook_manager = ModuleHookManager(
self._target_edges, self._gramian_accumulator, batched_dim is not None
)

self._hook_modules(modules)

Expand All @@ -143,13 +155,15 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None:
self._check_module_is_compatible(module)
self._module_hook_manager.hook_module(module)

@staticmethod
def _check_module_is_compatible(module: nn.Module) -> None:
if isinstance(module, _INCOMPATIBLE_MODULE_TYPES):
def _check_module_is_compatible(self, module: nn.Module) -> None:
if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES):
raise ValueError(
f"Found a module of type {type(module)}, which is incompatible with the autogram "
f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their"
" subclasses)."
f"engine when `batched_dim` is not `None`. The incompatible module types are "
f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to "
f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), "
f"but if you really can't and performance not a priority, you may also just set"
f"`batch_dim=None` when creating the engine."
)

if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats:
Expand All @@ -161,17 +175,68 @@ def _check_module_is_compatible(module: nn.Module) -> None:
)

def compute_gramian(self, output: Tensor) -> Tensor:
"""
Compute the Gramian of the Jacobian of ``output`` with respect the direct parameters of all
``modules``.
r"""
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
all ``modules``.

.. note::
This function doesn't require ``output`` to be a vector. For example, if ``output`` is
a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the
parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of
parameters in the model. This is what we call a generalized Jacobian. The
corresponding Gramian :math:`G = J J^\top` will be of shape
:math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number
of dimensions of the returned generalized Gramian will always be twice that of the
``output``.

A few examples:
- 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the
squared norm of the gradient of ``output``).
- 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian
descent).
- 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and
multi-task learning, as each sample in the batch has one loss per task).
- etc.

:param output: The vector to differentiate. Must be a 1-D tensor.
:param output: The tensor of arbitrary shape to differentiate. The shape of the returned
Gramian depends on the shape of this output, as explained in the note above.
"""

reshaped_output = output.reshape([-1])
return self._compute_square_gramian(reshaped_output)
if self._batched_dim is not None:
# move batched dim to the end
ordered_output = output.movedim(self._batched_dim, -1)
ordered_shape = list(ordered_output.shape)
batch_size = ordered_shape[-1]
has_non_batched_dim = len(ordered_shape) > 1
target_shape = [batch_size]
else:
ordered_output = output
ordered_shape = list(ordered_output.shape)
has_non_batched_dim = len(ordered_shape) > 0
target_shape = []

if has_non_batched_dim:
target_shape = [-1] + target_shape

reshaped_output = ordered_output.reshape(target_shape)
# There are four different cases for the shape of reshaped_output:
# - Not batched and not non-batched: scalar of shape []
# - Batched only: vector of shape [batch_size]
# - Non-batched only: vector of shape [dim]
# - Batched and non-batched: matrix of shape [dim, batch_size]

square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim)

def _compute_square_gramian(self, output: Tensor) -> Tensor:
unordered_gramian = reshape_gramian(square_gramian, ordered_shape)

if self._batched_dim is not None:
gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim])
else:
gramian = unordered_gramian

return gramian

def _compute_square_gramian(self, output: Tensor, has_non_batched_dim: bool) -> Tensor:
self._module_hook_manager.gramian_accumulation_phase = True

leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)}))
Expand All @@ -184,7 +249,20 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
retain_graph=True,
)

_ = differentiation(torch.ones_like(output))
if has_non_batched_dim:
# There is one non-batched dimension, it is the first one
non_batched_dim_len = output.shape[0]
jac_output_shape = [output.shape[0]] + list(output.shape)

# Need to batch `grad_output` over the first dimension
jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype)
for i in range(non_batched_dim_len):
jac_output[i, i, ...] = 1

_ = vmap(differentiation)(jac_output)
else:
grad_output = torch.ones_like(output)
_ = differentiation(grad_output)

# If the gramian were None, then leaf_targets would be empty, so autograd.grad would
# have failed. So gramian is necessarily a valid Tensor here.
Expand Down
69 changes: 69 additions & 0 deletions src/torchjd/autogram/_gramian_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from math import prod

from torch import Tensor


def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor:
"""
Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first
half of the target dimensions must be done from the left, while the reshape of the second half
must be done from the right.
:param gramian: Gramian to reshape
:param shape: First half of the target shape, the shape of the output is therefore
`shape + shape[::-1]`.
"""

# Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]:
# - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2]
# - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3]
# - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4]

automatic_dimensions = [i for i in range(len(shape)) if shape[i] == -1]
if len(automatic_dimensions) == 1:
index = automatic_dimensions[0]
current_shape = gramian.shape[: len(gramian.shape) // 2]
numel = prod(current_shape)
specified_numel = -prod(shape) # shape[index] == -1, this is the product of all other dims
shape[index] = numel // specified_numel

unordered_intput_gramian = _revert_last_dims(gramian)
unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape)
reordered_output_gramian = _revert_last_dims(unordered_output_gramian)
return reordered_output_gramian


def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor:
"""
Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a
Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the
dimensions and on the second half of the dimensions reversed.
:param gramian: Gramian to reshape.
:param source: Source dimensions, that should be in the range
[-gramian.ndim//2, gramian.ndim//2[. Its elements should be unique.
:param destination: Destination dimensions, that should be in the range
[-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `source`, and its
elements should be unique.
"""

# Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `source` of [-2, 2] and destination of [0, 1]:
# - `source_` will be [1, 2] and `destination_` will be [0, 1]
# - `mirrored_source` will be [1, 2, 4, 3] and `mirrored_destination` will be [0, 1, 5, 4]
# - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3]

# Map everything to the range [0, gramian.ndim//2[
length = gramian.ndim // 2
source_ = [i if 0 <= i else i + length for i in source]
destination_ = [i if 0 <= i else i + length for i in destination]

# Mirror the source and destination and use the result to move the dimensions of the gramian
last_dim = gramian.ndim - 1
mirrored_source = source_ + [last_dim - i for i in source_]
mirrored_destination = destination_ + [last_dim - i for i in destination_]
moved_gramian = gramian.movedim(mirrored_source, mirrored_destination)
return moved_gramian


def _revert_last_dims(generalized_gramian: Tensor) -> Tensor:
input_ndim = len(generalized_gramian.shape) // 2
last_dims = [input_ndim + i for i in range(input_ndim)]
return generalized_gramian.movedim(last_dims, last_dims[::-1])
Loading
Loading