From 932848528cb9ff3e3dc1c54c285ea9fa7aa3eb75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 18:14:40 +0200 Subject: [PATCH 1/2] Rename primals to trainable_params in functional_model_call --- src/torchjd/autogram/_vjp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 602728fe..060c3040 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -68,9 +68,9 @@ def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dic inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j) grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: + def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: all_state = { - **primals, + **trainable_params, **dict(self.module.named_buffers()), **self.frozen_params, } From a39ad40e73eb0ad83afbe181c04dfbd5c002262c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 29 Sep 2025 18:16:03 +0200 Subject: [PATCH 2/2] Rename inputs to args --- src/torchjd/autogram/_vjp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 060c3040..08088b1c 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,16 +56,16 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, inputs) + def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + return self.vmapped_vjp(grad_outputs, args) - def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with # an element of a batch. We thus always provide them with batches, just of a # different size. - inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j) + args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j) def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: @@ -74,7 +74,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor: **dict(self.module.named_buffers()), **self.frozen_params, } - return torch.func.functional_call(self.module, all_state, inputs_j) + return torch.func.functional_call(self.module, all_state, args_j) vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]