Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor, nn
from torch.autograd.graph import get_gradient_edge
from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten
from torch.utils.hooks import RemovableHandle as TorchRemovableHandle

from ._edge_registry import EdgeRegistry
Expand Down Expand Up @@ -127,7 +127,6 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:

autograd_fn_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
output_spec,
vjp,
args,
self.gramian_accumulator,
Expand All @@ -152,7 +151,6 @@ class JacobianAccumulator(torch.autograd.Function):
@staticmethod
def forward(
gramian_accumulation_phase: BoolRef,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
Expand All @@ -162,70 +160,66 @@ def forward(
return tuple(x.detach() for x in xs)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
@staticmethod
def setup_context(
ctx,
inputs: tuple,
_,
):
ctx.gramian_accumulation_phase = inputs[0]
ctx.output_spec = inputs[1]
ctx.vjp = inputs[2]
ctx.args = inputs[3]
ctx.gramian_accumulator = inputs[4]
ctx.module = inputs[5]
ctx.vjp = inputs[1]
ctx.args = inputs[2]
ctx.gramian_accumulator = inputs[3]
ctx.module = inputs[4]

# For Python version > 3.10, the return type should become
# tuple[None, None, None, None, None, *tuple[Tensor | None, ...]]
@staticmethod
def backward(ctx, *flat_grad_outputs: Tensor):
def backward(ctx, *flat_grad_outputs: Tensor | None) -> tuple:
if not ctx.gramian_accumulation_phase:
return None, None, None, None, None, None, *flat_grad_outputs
return None, None, None, None, None, *flat_grad_outputs

AccumulateJacobian.apply(
ctx.output_spec,
ctx.vjp,
ctx.args,
ctx.gramian_accumulator,
ctx.module,
*flat_grad_outputs,
)

return None, None, None, None, None, None, *flat_grad_outputs
return None, None, None, None, None, *flat_grad_outputs


class AccumulateJacobian(torch.autograd.Function):

@staticmethod
def forward(
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_grad_outputs: Tensor,
*flat_grad_outputs: Tensor | None,
) -> None:
# There is no non-batched dimension
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
generalized_jacobians = vjp(grad_outputs, args)
generalized_jacobians = vjp(flat_grad_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)

@staticmethod
def vmap(
_,
in_dims: PyTree,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_jac_outputs: Tensor,
*flat_jac_outputs: Tensor | None,
) -> tuple[None, None]:
# There is a non-batched dimension
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
# We do not vmap over the args for the non-batched dimension
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
in_dims = (in_dims[4:], tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(flat_jac_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
return None, None
Expand Down
38 changes: 25 additions & 13 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import cast

import torch
from torch import Tensor, nn
Expand All @@ -19,7 +20,9 @@ class VJP(ABC):
"""Represents an abstract VJP function."""

@abstractmethod
def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
def __call__(
self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree
) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
grad_outputs (cotangents) and at the given inputs.
Expand Down Expand Up @@ -56,32 +59,37 @@ def __init__(self, module: nn.Module):
super().__init__(module)
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)

def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, args)
def __call__(
self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree
) -> dict[str, Tensor]:
return self.vmapped_vjp(flat_grad_outputs, args)

def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]:
def _call_on_one_instance(
self, flat_grad_outputs_j: tuple[Tensor | None, ...], args_j: PyTree
) -> dict[str, Tensor]:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
# nn.Flatten) do not work equivalently if they're provided with a batch or with
# an element of a batch. We thus always provide them with batches, just of a
# different size.
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j)
flat_grad_outputs_j_ = [x.unsqueeze(0) for x in flat_grad_outputs_j]

def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor:
def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
all_state = {
**trainable_params,
**dict(self.module.named_buffers()),
**self.frozen_params,
}
return torch.func.functional_call(self.module, all_state, args_j)
output = torch.func.functional_call(self.module, all_state, args_j)
return tree_flatten(output)[0]

vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
return vjp_func(grad_outputs_j)[0]
return vjp_func(flat_grad_outputs_j_)[0]


class AutogradVJP(ModuleVJP):
Expand All @@ -105,16 +113,20 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]):

self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)

def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]:
flat_grad_outputs = tree_flatten(grad_outputs)[0]
def __call__(
self, flat_grad_outputs: tuple[Tensor | None, ...], _: PyTree
) -> dict[str, Tensor]:

# Only keep the grad_outputs corresponding to outputs that require grad.
grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg]
flat_grad_outputs_ = [
grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg
]
casted_flat_grad_outputs = cast(list[Tensor], flat_grad_outputs_)

grads = torch.autograd.grad(
self.outputs_that_require_grad,
self.flat_trainable_params,
grad_outputs_,
casted_flat_grad_outputs,
retain_graph=True,
allow_unused=True,
materialize_grads=True,
Expand Down
Loading