diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f5e1028d..d4597f4c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: tuple[PyTree, ...], output: PyTree) -> PyTree: if self.gramian_accumulation_phase: return output @@ -154,7 +154,7 @@ def forward( gramian_accumulation_phase: BoolRef, output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *xs: Tensor, @@ -199,7 +199,7 @@ class AccumulateJacobian(torch.autograd.Function): def forward( output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *flat_grad_outputs: Tensor, @@ -213,10 +213,10 @@ def forward( @staticmethod def vmap( _, - in_dims: PyTree, + in_dims: tuple[PyTree, ...], output_spec: TreeSpec, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *flat_jac_outputs: Tensor, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 08088b1c..d6a5ef5d 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,7 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: tuple[PyTree, ...]) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,7 +56,7 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]: + def __call__(self, grad_outputs: PyTree, args: tuple[PyTree, ...]) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]: