From dc09a24843592d78eb46a3f851b0fbc1b547f857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 16:01:44 +0200 Subject: [PATCH 1/9] Fix type hints --- src/torchjd/autogram/_module_hook_manager.py | 18 ++++++++++-------- src/torchjd/autogram/_vjp.py | 10 +++++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 3cc4cf6b..df137161 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree: if self.gramian_accumulation_phase: return outputs @@ -161,7 +161,7 @@ class JacobianAccumulator(torch.autograd.Function): def forward( gramian_accumulation_phase: BoolRef, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *rg_tensors: Tensor, @@ -169,7 +169,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -183,7 +183,9 @@ def setup_context( ctx.module = inputs[4] @staticmethod - def backward(ctx, *grad_outputs: Tensor): + def backward(ctx, *grad_outputs: Tensor) -> tuple: + # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]] + if not ctx.gramian_accumulation_phase: return None, None, None, None, None, *grad_outputs @@ -203,7 +205,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, @@ -216,9 +218,9 @@ def forward( @staticmethod def vmap( _, - in_dims: PyTree, + in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]] vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, @@ -244,5 +246,5 @@ def _make_path_jacobians( return path_jacobians @staticmethod - def setup_context(*_): + def setup_context(*_) -> None: pass diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 4e6c4985..909531c5 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,9 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,11 +58,13 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance( - self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree + self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...] ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. From 16032c3cea6306dfc1b7b63eeabf746860e76f06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 16:06:23 +0200 Subject: [PATCH 2/9] Remove xfail on WithModuleWithStringArg --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 74084f3a..6818d3c3 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -109,6 +109,7 @@ (Ndim3Output, 32), (Ndim4Output, 32), (WithDropout, 32), + (WithModuleWithStringArg, 32), (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), @@ -167,7 +168,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc Randomness, WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), - WithModuleWithStringArg, ], ) @mark.parametrize("batch_size", [1, 3, 32]) From 49e5d5f0e81c174b80b5833837417ce94ed892e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:17:00 +0200 Subject: [PATCH 3/9] Add in_dims param to FunctionalVJP and compute it in Hook --- src/torchjd/autogram/_module_hook_manager.py | 10 +++++++++- src/torchjd/autogram/_vjp.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index df137161..6623743e 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -129,7 +129,15 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs) + rg_output_in_dims = (0,) * len(rg_outputs) + arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + in_dims = (rg_output_in_dims, arg_in_dims) + + vjp = ( + FunctionalVJP(module, in_dims) + if self.has_batch_dim + else AutogradVJP(module, rg_outputs) + ) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 909531c5..181900d6 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -54,9 +54,9 @@ class FunctionalVJP(ModuleVJP): every module, and it requires to have an extra forward pass to create the vjp function. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): super().__init__(module) - self.vmapped_vjp = torch.vmap(self._call_on_one_instance) + self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) def __call__( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] From 889e87fce635df286938480147fd0878baf20600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:41:41 +0200 Subject: [PATCH 4/9] Add WithModuleWithHybridPyTreeArg --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6818d3c3..fb37863c 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,7 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithHybridPyTreeArg, WithModuleWithStringArg, WithModuleWithStringOutput, WithNoTensorOutput, @@ -110,6 +111,7 @@ (Ndim4Output, 32), (WithDropout, 32), (WithModuleWithStringArg, 32), + (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 073420d2..3eb4e8cf 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -797,6 +797,53 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class WithModuleWithHybridPyTreeArg(ShapedModule): + """ + Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor + leaves. + """ + + INPUT_SHAPES = (18,) + OUTPUT_SHAPES = (3,) + + class WithHybridPyTreeArg(nn.Module): + def __init__(self): + super().__init__() + self.m0 = nn.Parameter(torch.randn(3, 3)) + self.m1 = nn.Parameter(torch.randn(4, 3)) + self.m2 = nn.Parameter(torch.randn(5, 3)) + self.m3 = nn.Parameter(torch.randn(6, 3)) + + def forward(self, input: PyTree) -> Tensor: + t0 = input["one"][0][0] + t1 = input["one"][0][1] + t2 = input["one"][1] + t3 = input["two"] + + c0 = input["one"][0][3] + c1 = input["one"][0][4][0] + c2 = input["one"][2] + c3 = input["three"] + + return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 + + def __init__(self): + super().__init__() + self.with_string_arg = self.WithHybridPyTreeArg() + + def forward(self, input: Tensor) -> Tensor: + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] + + tree = { + "zero": "unused", + "one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"], + "two": t3, + "three": 0.5, + } + + return self.with_string_arg(tree) + + class WithModuleWithStringOutput(ShapedModule): """Model containing a module that has a string output.""" From e201fb77c755c2100a1a418c91fe8bca91376000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:51:39 +0200 Subject: [PATCH 5/9] Improve WithModuleWithHybridPyTreeArg * With this extra linear module, we now also check that the gradients wrt the args that require grad are correctly backpropagated. --- tests/utils/architectures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 3eb4e8cf..ef4933f5 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -803,7 +803,7 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): leaves. """ - INPUT_SHAPES = (18,) + INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) class WithHybridPyTreeArg(nn.Module): @@ -829,9 +829,12 @@ def forward(self, input: PyTree) -> Tensor: def __init__(self): super().__init__() + self.linear = nn.Linear(10, 18) self.with_string_arg = self.WithHybridPyTreeArg() def forward(self, input: Tensor) -> Tensor: + input = self.linear(input) + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] tree = { From 15ae846db72bb366a1c7a7ec68de40de65be8558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:05:11 +0200 Subject: [PATCH 6/9] Relax warning about inputs and outputs of Engine --- src/torchjd/autogram/_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index be7104f0..90cd2582 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -120,9 +120,8 @@ class Engine: this, but for example `BatchNorm `_ does not (it computes some average and standard deviation over the elements of the batch). - * Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of - tensors, or any nesting of those structures), but each of these tensors must be batched on - its first dimension. `Transformers + * Their inputs and outputs can be anything, but each input tensor and each output tensor + must be batched on its first dimension. `Transformers `_ and `RNNs `_ are thus not supported yet. This is only an implementation issue, so it should be fixed soon (please From b876e7602e22caa256937ae43c1c9a805b7bc529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:27:07 +0200 Subject: [PATCH 7/9] Fix type hint in AutogradVJP.__call__ --- src/torchjd/autogram/_vjp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 181900d6..442b7fe9 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -107,7 +107,9 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): self.rg_outputs = rg_outputs self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...] + ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, self.flat_trainable_params, From 0d1e65d8e00e7df15a286de96362dd0c24ec6f50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 19:52:49 +0200 Subject: [PATCH 8/9] Only compute in_dims when needed --- src/torchjd/autogram/_module_hook_manager.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 6623743e..d67d4da1 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -129,15 +129,14 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - rg_output_in_dims = (0,) * len(rg_outputs) - arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - in_dims = (rg_output_in_dims, arg_in_dims) - - vjp = ( - FunctionalVJP(module, in_dims) - if self.has_batch_dim - else AutogradVJP(module, rg_outputs) - ) + vjp: VJP + if self.has_batch_dim: + rg_output_in_dims = (0,) * len(rg_outputs) + arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + in_dims = (rg_output_in_dims, arg_in_dims) + vjp = FunctionalVJP(module, in_dims) + else: + vjp = AutogradVJP(module, rg_outputs) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, From 465ef4e3055bdc07eb6ad2c79bde008cb63c016d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 19:53:32 +0200 Subject: [PATCH 9/9] Add plural in variable name --- src/torchjd/autogram/_module_hook_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index d67d4da1..d0b78a79 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -131,9 +131,9 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) vjp: VJP if self.has_batch_dim: - rg_output_in_dims = (0,) * len(rg_outputs) - arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - in_dims = (rg_output_in_dims, arg_in_dims) + rg_outputs_in_dims = (0,) * len(rg_outputs) + args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + in_dims = (rg_outputs_in_dims, args_in_dims) vjp = FunctionalVJP(module, in_dims) else: vjp = AutogradVJP(module, rg_outputs)