Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped
from ._vjp import VJP, AutogradVJP, FunctionalVJP

# Note about import from protected _pytree module:
# PyTorch maintainers plan to make pytree public (see
Expand Down Expand Up @@ -232,11 +232,7 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))

vjp: VJP
if self.has_batch_dim:
vjp = Vmapped(FunctionalVJP(module))
else:
vjp = AutogradVJP(module, flat_outputs)
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)

autograd_fn_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
Expand Down
52 changes: 15 additions & 37 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Sequence

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -45,31 +45,21 @@ def __init__(self, module: nn.Module):
self.frozen_params[name] = param


class Vmapped(VJP):
"""VJP wrapper that applies the wrapped VJP, vmapped on the first dimension."""

def __init__(self, vjp: VJP):
super().__init__()
self.vmapped_vjp = torch.vmap(vjp)

def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, inputs)


class FunctionalVJP(ModuleVJP):
"""
Represents a VJP function for a module's forward pass with respect to its parameters using the
func api. The __call__ function takes both the inputs and the cotangents that can be vmapped
jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using
this method is that it makes an extra forward pass.

:params module: The module to differentiate.
functional differentiation API. This requires to use vmap, so it's not compatible with
every module, and it requires to have an extra forward pass to create the vjp function.
"""

def __init__(self, module: nn.Module):
super().__init__(module)
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)

def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, inputs)

def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]:
def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
# nn.Flatten) do not work equivalently if they're provided with a batch or with
Expand All @@ -78,32 +68,20 @@ def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor
inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j)
grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j)

# _vjp_from_module returns a function that computes the vjp w.r.t. to the
# primals (tuple), here the functional has a single primal which is
# dict(module.named_parameters()). We therefore take the 0'th element to obtain
# the dict of gradients w.r.t. the module's named_parameters.
return self._vjp_from_module(inputs_j)(grad_outputs_j)[0]

def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]:
"""
Create a VJP function for a module's forward pass with respect to its parameters.

Returns a function that computes vector-Jacobian products for the module's parameters given
fixed inputs. Only parameters with requires_grad=True are included in the differentiation.

:param inputs: Fixed inputs to the module for the VJP computation.
:returns: VJP function that takes cotangents and returns parameter gradients.
"""

def functional_model_call(primals: dict[str, Parameter]) -> Tensor:
all_state = {
**primals,
**dict(self.module.named_buffers()),
**self.frozen_params,
}
return torch.func.functional_call(self.module, all_state, inputs)
return torch.func.functional_call(self.module, all_state, inputs_j)

vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]

return torch.func.vjp(functional_model_call, self.trainable_params)[1]
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
# the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
return vjp_func(grad_outputs_j)[0]


class AutogradVJP(ModuleVJP):
Expand Down
Loading