From 3a60aac22eb9751bf015eea171b7451bbc76561d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 30 Sep 2025 15:02:39 +0200 Subject: [PATCH 1/6] Make `VJP` take flat `grad_outputs`. This allows removing the parameter `output_spec` from both `autograd.Function` in `ModuleHookManager`. --- src/torchjd/autogram/_module_hook_manager.py | 34 +++++++++----------- src/torchjd/autogram/_vjp.py | 25 +++++++++----- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f5e1028d..f7774020 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry @@ -123,11 +123,14 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(flat_outputs[index])) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) + vjp = ( + FunctionalVJP(module, output_spec) + if self.has_batch_dim + else AutogradVJP(module, flat_outputs) + ) autograd_fn_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, - output_spec, vjp, args, self.gramian_accumulator, @@ -152,7 +155,6 @@ class JacobianAccumulator(torch.autograd.Function): @staticmethod def forward( gramian_accumulation_phase: BoolRef, - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, @@ -170,19 +172,17 @@ def setup_context( _, ): ctx.gramian_accumulation_phase = inputs[0] - ctx.output_spec = inputs[1] - ctx.vjp = inputs[2] - ctx.args = inputs[3] - ctx.gramian_accumulator = inputs[4] - ctx.module = inputs[5] + ctx.vjp = inputs[1] + ctx.args = inputs[2] + ctx.gramian_accumulator = inputs[3] + ctx.module = inputs[4] @staticmethod def backward(ctx, *flat_grad_outputs: Tensor): if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *flat_grad_outputs AccumulateJacobian.apply( - ctx.output_spec, ctx.vjp, ctx.args, ctx.gramian_accumulator, @@ -190,14 +190,13 @@ def backward(ctx, *flat_grad_outputs: Tensor): *flat_grad_outputs, ) - return None, None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *flat_grad_outputs class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, @@ -205,8 +204,7 @@ def forward( *flat_grad_outputs: Tensor, ) -> None: # There is no non-batched dimension - grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) - generalized_jacobians = vjp(grad_outputs, args) + generalized_jacobians = vjp(flat_grad_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) @@ -214,7 +212,6 @@ def forward( def vmap( _, in_dims: PyTree, - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, @@ -222,10 +219,9 @@ def vmap( *flat_jac_outputs: Tensor, ) -> tuple[None, None]: # There is a non-batched dimension - jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension - in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + in_dims = (in_dims[4:], tree_map(lambda _: None, args)) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(flat_jac_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 08088b1c..c9737a83 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten +from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map_only, tree_unflatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -19,7 +19,9 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__( + self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree + ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -52,20 +54,26 @@ class FunctionalVJP(ModuleVJP): every module, and it requires to have an extra forward pass to create the vjp function. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module, output_spec: TreeSpec): super().__init__(module) + self.output_spec = output_spec self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, args) + def __call__( + self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree + ) -> dict[str, Tensor]: + return self.vmapped_vjp(flat_grad_outputs, args) - def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: + def _call_on_one_instance( + self, flat_grad_outputs_j: tuple[Tensor | None, ...], args_j: PyTree + ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) + grad_outputs_j = tree_unflatten(flat_grad_outputs_j, self.output_spec) grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: @@ -105,8 +113,9 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: - flat_grad_outputs = tree_flatten(grad_outputs)[0] + def __call__( + self, flat_grad_outputs: tuple[Tensor | None, ...], _: PyTree + ) -> dict[str, Tensor]: # Only keep the grad_outputs corresponding to outputs that require grad. grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg] From e1ccb7119a4a5f5e0c6b4ee25f1c82938a5109a0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 09:34:16 +0200 Subject: [PATCH 2/6] Fix type hint for future versions. --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f7774020..62527475 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -164,7 +164,7 @@ def forward( return tuple(x.detach() for x in xs) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, From 959837007822ab80fcc8721cc0992f76d6b0a03c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 09:42:52 +0200 Subject: [PATCH 3/6] Flatten functional call of model so that output_spec is removed from FunctionalVJP. `output_spec` now only appears in the hook. --- src/torchjd/autogram/_module_hook_manager.py | 6 +----- src/torchjd/autogram/_vjp.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 62527475..7943ea25 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -123,11 +123,7 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(flat_outputs[index])) - vjp = ( - FunctionalVJP(module, output_spec) - if self.has_batch_dim - else AutogradVJP(module, flat_outputs) - ) + vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) autograd_fn_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index c9737a83..b857a472 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map_only, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -54,9 +54,8 @@ class FunctionalVJP(ModuleVJP): every module, and it requires to have an extra forward pass to create the vjp function. """ - def __init__(self, module: nn.Module, output_spec: TreeSpec): + def __init__(self, module: nn.Module): super().__init__(module) - self.output_spec = output_spec self.vmapped_vjp = torch.vmap(self._call_on_one_instance) def __call__( @@ -73,23 +72,25 @@ def _call_on_one_instance( # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) - grad_outputs_j = tree_unflatten(flat_grad_outputs_j, self.output_spec) - grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) + flat_grad_outputs_j = tree_map_only( + torch.Tensor, lambda x: x.unsqueeze(0), flat_grad_outputs_j + ) - def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: + def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: all_state = { **trainable_params, **dict(self.module.named_buffers()), **self.frozen_params, } - return torch.func.functional_call(self.module, all_state, args_j) + output = torch.func.functional_call(self.module, all_state, args_j) + return tree_flatten(output)[0] - vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] + vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1] # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - return vjp_func(grad_outputs_j)[0] + return vjp_func(flat_grad_outputs_j)[0] class AutogradVJP(ModuleVJP): From a07532dea4deacc76e205639f3aa82734843c1de Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 09:44:44 +0200 Subject: [PATCH 4/6] Change `tuple[Tensor | None, ...]` to `tuple[Tensor]`. This will be done in another PR. --- src/torchjd/autogram/_vjp.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index b857a472..42592bdc 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,9 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__( - self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree - ) -> dict[str, Tensor]: + def __call__(self, flat_grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -58,13 +56,11 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__( - self, flat_grad_outputs: tuple[Tensor | None, ...], args: PyTree - ) -> dict[str, Tensor]: + def __call__(self, flat_grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: return self.vmapped_vjp(flat_grad_outputs, args) def _call_on_one_instance( - self, flat_grad_outputs_j: tuple[Tensor | None, ...], args_j: PyTree + self, flat_grad_outputs_j: tuple[Tensor, ...], args_j: PyTree ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. @@ -114,9 +110,7 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__( - self, flat_grad_outputs: tuple[Tensor | None, ...], _: PyTree - ) -> dict[str, Tensor]: + def __call__(self, flat_grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: # Only keep the grad_outputs corresponding to outputs that require grad. grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg] From 89d5f8fc8166abc4ecd064639c5800dec7fceb3d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 09:56:14 +0200 Subject: [PATCH 5/6] Make `flat_grad_outputs_j` a list --- src/torchjd/autogram/_vjp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 42592bdc..6d852607 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -68,9 +68,7 @@ def _call_on_one_instance( # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) - flat_grad_outputs_j = tree_map_only( - torch.Tensor, lambda x: x.unsqueeze(0), flat_grad_outputs_j - ) + flat_grad_outputs_j_ = [x.unsqueeze(0) for x in flat_grad_outputs_j] def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: all_state = { @@ -86,7 +84,7 @@ def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[T # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - return vjp_func(flat_grad_outputs_j)[0] + return vjp_func(flat_grad_outputs_j_)[0] class AutogradVJP(ModuleVJP): From 4ea34d9fd7476134725cadca72b085c6ada61379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 12:12:58 +0200 Subject: [PATCH 6/6] Rename flat_grad_outputs to grad_outputs * They're always flat now --- src/torchjd/autogram/_module_hook_manager.py | 16 ++++++++-------- src/torchjd/autogram/_vjp.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 7943ea25..09bd7ca8 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -174,19 +174,19 @@ def setup_context( ctx.module = inputs[4] @staticmethod - def backward(ctx, *flat_grad_outputs: Tensor): + def backward(ctx, *grad_outputs: Tensor): if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *grad_outputs AccumulateJacobian.apply( ctx.vjp, ctx.args, ctx.gramian_accumulator, ctx.module, - *flat_grad_outputs, + *grad_outputs, ) - return None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *grad_outputs class AccumulateJacobian(torch.autograd.Function): @@ -197,10 +197,10 @@ def forward( args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, - *flat_grad_outputs: Tensor, + *grad_outputs: Tensor, ) -> None: # There is no non-batched dimension - generalized_jacobians = vjp(flat_grad_outputs, args) + generalized_jacobians = vjp(grad_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) @@ -212,12 +212,12 @@ def vmap( args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, - *flat_jac_outputs: Tensor, + *jac_outputs: Tensor, ) -> tuple[None, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension in_dims = (in_dims[4:], tree_map(lambda _: None, args)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(flat_jac_outputs, args) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 6d852607..1dad99c7 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, flat_grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,11 +56,11 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, flat_grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: - return self.vmapped_vjp(flat_grad_outputs, args) + def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance( - self, flat_grad_outputs_j: tuple[Tensor, ...], args_j: PyTree + self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. @@ -68,7 +68,7 @@ def _call_on_one_instance( # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) - flat_grad_outputs_j_ = [x.unsqueeze(0) for x in flat_grad_outputs_j] + grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: all_state = { @@ -84,7 +84,7 @@ def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[T # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - return vjp_func(flat_grad_outputs_j_)[0] + return vjp_func(grad_outputs_j_)[0] class AutogradVJP(ModuleVJP): @@ -108,10 +108,10 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, flat_grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: # Only keep the grad_outputs corresponding to outputs that require grad. - grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg] + grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg] grads = torch.autograd.grad( self.outputs_that_require_grad,