From ee742de2a8d9b5efe566fb76d7f0601ac598e150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 17:55:18 +0200 Subject: [PATCH 1/2] Merge Vmapped to FunctionalVJP --- src/torchjd/autogram/_module_hook_manager.py | 8 +-- src/torchjd/autogram/_vjp.py | 52 ++++++-------------- 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 6e0f138b..803027bd 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,7 +9,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import VJP, AutogradVJP, FunctionalVJP, Vmapped +from ._vjp import VJP, AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -232,11 +232,7 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(flat_outputs[index])) - vjp: VJP - if self.has_batch_dim: - vjp = Vmapped(FunctionalVJP(module)) - else: - vjp = AutogradVJP(module, flat_outputs) + vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) autograd_fn_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 6520e969..3745b102 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Sequence import torch from torch import Tensor, nn @@ -45,31 +45,21 @@ def __init__(self, module: nn.Module): self.frozen_params[name] = param -class Vmapped(VJP): - """VJP wrapper that applies the wrapped VJP, vmapped on the first dimension.""" - - def __init__(self, vjp: VJP): - super().__init__() - self.vmapped_vjp = torch.vmap(vjp) - - def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, inputs) - - class FunctionalVJP(ModuleVJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the - func api. The __call__ function takes both the inputs and the cotangents that can be vmapped - jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using - this method is that it makes an extra forward pass. - - :params module: The module to differentiate. + functional differentiation API. This requires to use vmap, so it's not compatible with + everything, and it requires to have an extra forward pass to create the vjp function. """ def __init__(self, module: nn.Module): super().__init__(module) + self.vmapped_vjp = torch.vmap(self._call_on_one_instance) + + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + return self.vmapped_vjp(grad_outputs, inputs) - def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with @@ -78,32 +68,20 @@ def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j) grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) - # _vjp_from_module returns a function that computes the vjp w.r.t. to the - # primals (tuple), here the functional has a single primal which is - # dict(module.named_parameters()). We therefore take the 0'th element to obtain - # the dict of gradients w.r.t. the module's named_parameters. - return self._vjp_from_module(inputs_j)(grad_outputs_j)[0] - - def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: - """ - Create a VJP function for a module's forward pass with respect to its parameters. - - Returns a function that computes vector-Jacobian products for the module's parameters given - fixed inputs. Only parameters with requires_grad=True are included in the differentiation. - - :param inputs: Fixed inputs to the module for the VJP computation. - :returns: VJP function that takes cotangents and returns parameter gradients. - """ - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: all_state = { **primals, **dict(self.module.named_buffers()), **self.frozen_params, } - return torch.func.functional_call(self.module, all_state, inputs) + return torch.func.functional_call(self.module, all_state, inputs_j) + + vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] - return torch.func.vjp(functional_model_call, self.trainable_params)[1] + # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the + # functional has a single primal which is dict(module.named_parameters()). We therefore take + # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. + return vjp_func(grad_outputs_j)[0] class AutogradVJP(ModuleVJP): From b11bfec96c19cad1be4a1cea598c6b541c994e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 18:03:03 +0200 Subject: [PATCH 2/2] Fix typo --- src/torchjd/autogram/_vjp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 3745b102..602728fe 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -49,7 +49,7 @@ class FunctionalVJP(ModuleVJP): """ Represents a VJP function for a module's forward pass with respect to its parameters using the functional differentiation API. This requires to use vmap, so it's not compatible with - everything, and it requires to have an extra forward pass to create the vjp function. + every module, and it requires to have an extra forward pass to create the vjp function. """ def __init__(self, module: nn.Module):