diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 273131d2..6e0f138b 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,7 +9,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import AutogradVJP, FunctionalVJP, VJPType +from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -93,7 +93,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( output_spec: TreeSpec, - vjp: VJPType, + vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, @@ -110,7 +110,7 @@ def vmap( _, in_dims: PyTree, output_spec: TreeSpec, - vjp: VJPType, + vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, @@ -157,7 +157,7 @@ class JacobianAccumulator(torch.autograd.Function): def forward( gramian_accumulation_phase: BoolRef, output_spec: TreeSpec, - vjp: VJPType, + vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, @@ -166,7 +166,7 @@ def forward( return tuple(x.detach() for x in xs) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, TreeSpec, VJPType, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -232,8 +232,9 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(flat_outputs[index])) + vjp: VJP if self.has_batch_dim: - vjp = torch.vmap(FunctionalVJP(module)) + vjp = Vmapped(FunctionalVJP(module)) else: vjp = AutogradVJP(module, flat_outputs) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index b05a7408..6520e969 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -15,11 +15,18 @@ # still support older versions of PyTorch where pytree is protected). -# This includes vmapped VJPs, which are not of type VJP. -VJPType = Callable[[PyTree, PyTree], dict[str, Tensor]] +class VJP(ABC): + """Represents an abstract VJP function.""" + @abstractmethod + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + """ + Computes and returns the dictionary of parameter names to their gradients for the given + grad_outputs (cotangents) and at the given inputs. + """ -class VJP(ABC): + +class ModuleVJP(VJP, ABC): """ Represents an abstract VJP function for a module's forward pass with respect to its parameters. @@ -37,15 +44,19 @@ def __init__(self, module: nn.Module): else: self.frozen_params[name] = param - @abstractmethod + +class Vmapped(VJP): + """VJP wrapper that applies the wrapped VJP, vmapped on the first dimension.""" + + def __init__(self, vjp: VJP): + super().__init__() + self.vmapped_vjp = torch.vmap(vjp) + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: - """ - Computes and returns the dictionary of parameter names to their gradients for the given - grad_outputs (cotangents) and at the given inputs. - """ + return self.vmapped_vjp(grad_outputs, inputs) -class FunctionalVJP(VJP): +class FunctionalVJP(ModuleVJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the func api. The __call__ function takes both the inputs and the cotangents that can be vmapped @@ -95,7 +106,7 @@ def functional_model_call(primals: dict[str, Parameter]) -> Tensor: return torch.func.vjp(functional_model_call, self.trainable_params)[1] -class AutogradVJP(VJP): +class AutogradVJP(ModuleVJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the