diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 602728fe..08088b1c 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,25 +56,25 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, inputs) + def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + return self.vmapped_vjp(grad_outputs, args) - def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with # an element of a batch. We thus always provide them with batches, just of a # different size. - inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j) + args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: + def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: all_state = { - **primals, + **trainable_params, **dict(self.module.named_buffers()), **self.frozen_params, } - return torch.func.functional_call(self.module, all_state, inputs_j) + return torch.func.functional_call(self.module, all_state, args_j) vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]