Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor, nn
from torch.autograd.graph import get_gradient_edge
from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten
from torch.utils.hooks import RemovableHandle as TorchRemovableHandle

from ._edge_registry import EdgeRegistry
Expand Down Expand Up @@ -127,7 +127,6 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:

autograd_fn_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
output_spec,
vjp,
args,
self.gramian_accumulator,
Expand All @@ -152,7 +151,6 @@ class JacobianAccumulator(torch.autograd.Function):
@staticmethod
def forward(
gramian_accumulation_phase: BoolRef,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
Expand All @@ -162,50 +160,46 @@ def forward(
return tuple(x.detach() for x in xs)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
@staticmethod
def setup_context(
ctx,
inputs: tuple,
_,
):
ctx.gramian_accumulation_phase = inputs[0]
ctx.output_spec = inputs[1]
ctx.vjp = inputs[2]
ctx.args = inputs[3]
ctx.gramian_accumulator = inputs[4]
ctx.module = inputs[5]
ctx.vjp = inputs[1]
ctx.args = inputs[2]
ctx.gramian_accumulator = inputs[3]
ctx.module = inputs[4]

@staticmethod
def backward(ctx, *flat_grad_outputs: Tensor):
def backward(ctx, *grad_outputs: Tensor):
if not ctx.gramian_accumulation_phase:
return None, None, None, None, None, None, *flat_grad_outputs
return None, None, None, None, None, *grad_outputs

AccumulateJacobian.apply(
ctx.output_spec,
ctx.vjp,
ctx.args,
ctx.gramian_accumulator,
ctx.module,
*flat_grad_outputs,
*grad_outputs,
)

return None, None, None, None, None, None, *flat_grad_outputs
return None, None, None, None, None, *grad_outputs


class AccumulateJacobian(torch.autograd.Function):

@staticmethod
def forward(
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_grad_outputs: Tensor,
*grad_outputs: Tensor,
) -> None:
# There is no non-batched dimension
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
generalized_jacobians = vjp(grad_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
Expand All @@ -214,17 +208,15 @@ def forward(
def vmap(
_,
in_dims: PyTree,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_jac_outputs: Tensor,
*jac_outputs: Tensor,
) -> tuple[None, None]:
# There is a non-batched dimension
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
# We do not vmap over the args for the non-batched dimension
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
in_dims = (in_dims[4:], tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
Expand Down
24 changes: 13 additions & 11 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class VJP(ABC):
"""Represents an abstract VJP function."""

@abstractmethod
def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
grad_outputs (cotangents) and at the given inputs.
Expand Down Expand Up @@ -56,32 +56,35 @@ def __init__(self, module: nn.Module):
super().__init__(module)
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)

def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, args)

def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]:
def _call_on_one_instance(
self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree
) -> dict[str, Tensor]:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
# nn.Flatten) do not work equivalently if they're provided with a batch or with
# an element of a batch. We thus always provide them with batches, just of a
# different size.
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j)
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]

def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor:
def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
all_state = {
**trainable_params,
**dict(self.module.named_buffers()),
**self.frozen_params,
}
return torch.func.functional_call(self.module, all_state, args_j)
output = torch.func.functional_call(self.module, all_state, args_j)
return tree_flatten(output)[0]

vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
return vjp_func(grad_outputs_j)[0]
return vjp_func(grad_outputs_j_)[0]


class AutogradVJP(ModuleVJP):
Expand All @@ -105,11 +108,10 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]):

self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)

def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]:
flat_grad_outputs = tree_flatten(grad_outputs)[0]
def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]:

# Only keep the grad_outputs corresponding to outputs that require grad.
grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg]
grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg]

grads = torch.autograd.grad(
self.outputs_that_require_grad,
Expand Down
Loading