Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._vjp import AutogradVJP, FunctionalVJP, VJPType
from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped

# Note about import from protected _pytree module:
# PyTorch maintainers plan to make pytree public (see
Expand Down Expand Up @@ -93,7 +93,7 @@ class AccumulateJacobian(torch.autograd.Function):
@staticmethod
def forward(
output_spec: TreeSpec,
vjp: VJPType,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
Expand All @@ -110,7 +110,7 @@ def vmap(
_,
in_dims: PyTree,
output_spec: TreeSpec,
vjp: VJPType,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
Expand Down Expand Up @@ -157,7 +157,7 @@ class JacobianAccumulator(torch.autograd.Function):
def forward(
gramian_accumulation_phase: BoolRef,
output_spec: TreeSpec,
vjp: VJPType,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
Expand All @@ -166,7 +166,7 @@ def forward(
return tuple(x.detach() for x in xs)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
# tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
@staticmethod
def setup_context(
ctx,
Expand Down Expand Up @@ -232,8 +232,9 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))

vjp: VJP
if self.has_batch_dim:
vjp = torch.vmap(FunctionalVJP(module))
vjp = Vmapped(FunctionalVJP(module))
else:
vjp = AutogradVJP(module, flat_outputs)

Expand Down
31 changes: 21 additions & 10 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
# still support older versions of PyTorch where pytree is protected).


# This includes vmapped VJPs, which are not of type VJP.
VJPType = Callable[[PyTree, PyTree], dict[str, Tensor]]
class VJP(ABC):
"""Represents an abstract VJP function."""

@abstractmethod
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
grad_outputs (cotangents) and at the given inputs.
"""

class VJP(ABC):

class ModuleVJP(VJP, ABC):
"""
Represents an abstract VJP function for a module's forward pass with respect to its parameters.

Expand All @@ -37,15 +44,19 @@ def __init__(self, module: nn.Module):
else:
self.frozen_params[name] = param

@abstractmethod

class Vmapped(VJP):
"""VJP wrapper that applies the wrapped VJP, vmapped on the first dimension."""

def __init__(self, vjp: VJP):
super().__init__()
self.vmapped_vjp = torch.vmap(vjp)

def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
grad_outputs (cotangents) and at the given inputs.
"""
return self.vmapped_vjp(grad_outputs, inputs)


class FunctionalVJP(VJP):
class FunctionalVJP(ModuleVJP):
"""
Represents a VJP function for a module's forward pass with respect to its parameters using the
func api. The __call__ function takes both the inputs and the cotangents that can be vmapped
Expand Down Expand Up @@ -95,7 +106,7 @@ def functional_model_call(primals: dict[str, Parameter]) -> Tensor:
return torch.func.vjp(functional_model_call, self.trainable_params)[1]


class AutogradVJP(VJP):
class AutogradVJP(ModuleVJP):
"""
Represents a VJP function for a module's forward pass with respect to its parameters using the
autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the
Expand Down
Loading