From d6c0b4986ec523ea003e1c440c8bd103d9493715 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 10:13:39 +0200 Subject: [PATCH 1/2] Type `args` as `tuple[PyTree, ...]` to highlight both the fact that it is a tuple and has recursive structure. --- src/torchjd/autogram/_module_hook_manager.py | 8 ++++---- src/torchjd/autogram/_vjp.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f5e1028d..2c571b74 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: tuple[PyTree, ...], output: PyTree) -> PyTree: if self.gramian_accumulation_phase: return output @@ -154,7 +154,7 @@ def forward( gramian_accumulation_phase: BoolRef, output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *xs: Tensor, @@ -199,7 +199,7 @@ class AccumulateJacobian(torch.autograd.Function): def forward( output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *flat_grad_outputs: Tensor, @@ -216,7 +216,7 @@ def vmap( in_dims: PyTree, output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *flat_jac_outputs: Tensor, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 08088b1c..d6a5ef5d 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: tuple[PyTree, ...]) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,7 +56,7 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: tuple[PyTree, ...]) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: From a0c593e6334b35fda115a61bdd77ad0a3e6ea546 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 1 Oct 2025 10:20:17 +0200 Subject: [PATCH 2/2] Type `in_dims` as `tuple[PyTree, ...]` to highlight both the fact that it is a tuple and has recursive structure. --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 2c571b74..d4597f4c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -213,7 +213,7 @@ def forward( @staticmethod def vmap( _, - in_dims: PyTree, + in_dims: tuple[PyTree, ...], output_spec: TreeSpec, vjp: VJP, args: tuple[PyTree, ...],