diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f5e1028d..09bd7ca8 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,7 +4,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry @@ -127,7 +127,6 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: autograd_fn_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, - output_spec, vjp, args, self.gramian_accumulator, @@ -152,7 +151,6 @@ class JacobianAccumulator(torch.autograd.Function): @staticmethod def forward( gramian_accumulation_phase: BoolRef, - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, @@ -162,7 +160,7 @@ def forward( return tuple(x.detach() for x in xs) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, TreeSpec, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -170,42 +168,38 @@ def setup_context( _, ): ctx.gramian_accumulation_phase = inputs[0] - ctx.output_spec = inputs[1] - ctx.vjp = inputs[2] - ctx.args = inputs[3] - ctx.gramian_accumulator = inputs[4] - ctx.module = inputs[5] + ctx.vjp = inputs[1] + ctx.args = inputs[2] + ctx.gramian_accumulator = inputs[3] + ctx.module = inputs[4] @staticmethod - def backward(ctx, *flat_grad_outputs: Tensor): + def backward(ctx, *grad_outputs: Tensor): if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *grad_outputs AccumulateJacobian.apply( - ctx.output_spec, ctx.vjp, ctx.args, ctx.gramian_accumulator, ctx.module, - *flat_grad_outputs, + *grad_outputs, ) - return None, None, None, None, None, None, *flat_grad_outputs + return None, None, None, None, None, *grad_outputs class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, - *flat_grad_outputs: Tensor, + *grad_outputs: Tensor, ) -> None: # There is no non-batched dimension - grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) generalized_jacobians = vjp(grad_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) @@ -214,17 +208,15 @@ def forward( def vmap( _, in_dims: PyTree, - output_spec: TreeSpec, vjp: VJP, args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, - *flat_jac_outputs: Tensor, + *jac_outputs: Tensor, ) -> tuple[None, None]: # There is a non-batched dimension - jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) # We do not vmap over the args for the non-batched dimension - in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) + in_dims = (in_dims[4:], tree_map(lambda _: None, args)) generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 08088b1c..1dad99c7 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,32 +56,35 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) - def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: + def _call_on_one_instance( + self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree + ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) - grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) + grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] - def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: + def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: all_state = { **trainable_params, **dict(self.module.named_buffers()), **self.frozen_params, } - return torch.func.functional_call(self.module, all_state, args_j) + output = torch.func.functional_call(self.module, all_state, args_j) + return tree_flatten(output)[0] - vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] + vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1] # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. - return vjp_func(grad_outputs_j)[0] + return vjp_func(grad_outputs_j_)[0] class AutogradVJP(ModuleVJP): @@ -105,11 +108,10 @@ def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: - flat_grad_outputs = tree_flatten(grad_outputs)[0] + def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: # Only keep the grad_outputs corresponding to outputs that require grad. - grad_outputs_ = [grad_output for grad_output, rg in zip(flat_grad_outputs, self.mask) if rg] + grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg] grads = torch.autograd.grad( self.outputs_that_require_grad,