From dc09a24843592d78eb46a3f851b0fbc1b547f857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 16:01:44 +0200 Subject: [PATCH 01/13] Fix type hints --- src/torchjd/autogram/_module_hook_manager.py | 18 ++++++++++-------- src/torchjd/autogram/_vjp.py | 10 +++++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 3cc4cf6b..df137161 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree: if self.gramian_accumulation_phase: return outputs @@ -161,7 +161,7 @@ class JacobianAccumulator(torch.autograd.Function): def forward( gramian_accumulation_phase: BoolRef, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *rg_tensors: Tensor, @@ -169,7 +169,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -183,7 +183,9 @@ def setup_context( ctx.module = inputs[4] @staticmethod - def backward(ctx, *grad_outputs: Tensor): + def backward(ctx, *grad_outputs: Tensor) -> tuple: + # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]] + if not ctx.gramian_accumulation_phase: return None, None, None, None, None, *grad_outputs @@ -203,7 +205,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, @@ -216,9 +218,9 @@ def forward( @staticmethod def vmap( _, - in_dims: PyTree, + in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]] vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, @@ -244,5 +246,5 @@ def _make_path_jacobians( return path_jacobians @staticmethod - def setup_context(*_): + def setup_context(*_) -> None: pass diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 4e6c4985..909531c5 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,9 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -56,11 +58,13 @@ def __init__(self, module: nn.Module): super().__init__(module) self.vmapped_vjp = torch.vmap(self._call_on_one_instance) - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance( - self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree + self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...] ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. From 16032c3cea6306dfc1b7b63eeabf746860e76f06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 16:06:23 +0200 Subject: [PATCH 02/13] Remove xfail on WithModuleWithStringArg --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 74084f3a..6818d3c3 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -109,6 +109,7 @@ (Ndim3Output, 32), (Ndim4Output, 32), (WithDropout, 32), + (WithModuleWithStringArg, 32), (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), @@ -167,7 +168,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc Randomness, WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), - WithModuleWithStringArg, ], ) @mark.parametrize("batch_size", [1, 3, 32]) From 49e5d5f0e81c174b80b5833837417ce94ed892e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:17:00 +0200 Subject: [PATCH 03/13] Add in_dims param to FunctionalVJP and compute it in Hook --- src/torchjd/autogram/_module_hook_manager.py | 10 +++++++++- src/torchjd/autogram/_vjp.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index df137161..6623743e 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -129,7 +129,15 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs) + rg_output_in_dims = (0,) * len(rg_outputs) + arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + in_dims = (rg_output_in_dims, arg_in_dims) + + vjp = ( + FunctionalVJP(module, in_dims) + if self.has_batch_dim + else AutogradVJP(module, rg_outputs) + ) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 909531c5..181900d6 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -54,9 +54,9 @@ class FunctionalVJP(ModuleVJP): every module, and it requires to have an extra forward pass to create the vjp function. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): super().__init__(module) - self.vmapped_vjp = torch.vmap(self._call_on_one_instance) + self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) def __call__( self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] From 889e87fce635df286938480147fd0878baf20600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:41:41 +0200 Subject: [PATCH 04/13] Add WithModuleWithHybridPyTreeArg --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6818d3c3..fb37863c 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,7 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithHybridPyTreeArg, WithModuleWithStringArg, WithModuleWithStringOutput, WithNoTensorOutput, @@ -110,6 +111,7 @@ (Ndim4Output, 32), (WithDropout, 32), (WithModuleWithStringArg, 32), + (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 073420d2..3eb4e8cf 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -797,6 +797,53 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class WithModuleWithHybridPyTreeArg(ShapedModule): + """ + Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor + leaves. + """ + + INPUT_SHAPES = (18,) + OUTPUT_SHAPES = (3,) + + class WithHybridPyTreeArg(nn.Module): + def __init__(self): + super().__init__() + self.m0 = nn.Parameter(torch.randn(3, 3)) + self.m1 = nn.Parameter(torch.randn(4, 3)) + self.m2 = nn.Parameter(torch.randn(5, 3)) + self.m3 = nn.Parameter(torch.randn(6, 3)) + + def forward(self, input: PyTree) -> Tensor: + t0 = input["one"][0][0] + t1 = input["one"][0][1] + t2 = input["one"][1] + t3 = input["two"] + + c0 = input["one"][0][3] + c1 = input["one"][0][4][0] + c2 = input["one"][2] + c3 = input["three"] + + return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 + + def __init__(self): + super().__init__() + self.with_string_arg = self.WithHybridPyTreeArg() + + def forward(self, input: Tensor) -> Tensor: + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] + + tree = { + "zero": "unused", + "one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"], + "two": t3, + "three": 0.5, + } + + return self.with_string_arg(tree) + + class WithModuleWithStringOutput(ShapedModule): """Model containing a module that has a string output.""" From e201fb77c755c2100a1a418c91fe8bca91376000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 19:51:39 +0200 Subject: [PATCH 05/13] Improve WithModuleWithHybridPyTreeArg * With this extra linear module, we now also check that the gradients wrt the args that require grad are correctly backpropagated. --- tests/utils/architectures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 3eb4e8cf..ef4933f5 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -803,7 +803,7 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): leaves. """ - INPUT_SHAPES = (18,) + INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) class WithHybridPyTreeArg(nn.Module): @@ -829,9 +829,12 @@ def forward(self, input: PyTree) -> Tensor: def __init__(self): super().__init__() + self.linear = nn.Linear(10, 18) self.with_string_arg = self.WithHybridPyTreeArg() def forward(self, input: Tensor) -> Tensor: + input = self.linear(input) + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] tree = { From 15ae846db72bb366a1c7a7ec68de40de65be8558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:05:11 +0200 Subject: [PATCH 06/13] Relax warning about inputs and outputs of Engine --- src/torchjd/autogram/_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index be7104f0..90cd2582 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -120,9 +120,8 @@ class Engine: this, but for example `BatchNorm `_ does not (it computes some average and standard deviation over the elements of the batch). - * Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of - tensors, or any nesting of those structures), but each of these tensors must be batched on - its first dimension. `Transformers + * Their inputs and outputs can be anything, but each input tensor and each output tensor + must be batched on its first dimension. `Transformers `_ and `RNNs `_ are thus not supported yet. This is only an implementation issue, so it should be fixed soon (please From b876e7602e22caa256937ae43c1c9a805b7bc529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:27:07 +0200 Subject: [PATCH 07/13] Fix type hint in AutogradVJP.__call__ --- src/torchjd/autogram/_vjp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 181900d6..442b7fe9 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -107,7 +107,9 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): self.rg_outputs = rg_outputs self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...] + ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, self.flat_trainable_params, From a5dfca0b02449de6bc0e0cdcb22cee6f34f42802 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:37:01 +0200 Subject: [PATCH 08/13] Add kwargs support --- src/torchjd/autogram/_module_hook_manager.py | 39 +++++++++++++------- src/torchjd/autogram/_vjp.py | 16 +++++--- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 6623743e..708a568d 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -65,7 +65,7 @@ def hook_module(self, module: nn.Module) -> None: self._gramian_accumulator, self._has_batch_dim, ) - self._handles.append(module.register_forward_hook(hook)) + self._handles.append(module.register_forward_hook(hook, with_kwargs=True)) @staticmethod def remove_hooks(handles: list[TorchRemovableHandle]) -> None: @@ -101,7 +101,13 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree: + def __call__( + self, + module: nn.Module, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + outputs: PyTree, + ) -> PyTree: if self.gramian_accumulation_phase: return outputs @@ -131,7 +137,8 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) rg_output_in_dims = (0,) * len(rg_outputs) arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - in_dims = (rg_output_in_dims, arg_in_dims) + kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) + in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) vjp = ( FunctionalVJP(module, in_dims) @@ -143,6 +150,7 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) self.gramian_accumulation_phase, vjp, args, + kwargs, self.gramian_accumulator, module, *rg_outputs, @@ -170,6 +178,7 @@ def forward( gramian_accumulation_phase: BoolRef, vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *rg_tensors: Tensor, @@ -177,7 +186,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -187,25 +196,27 @@ def setup_context( ctx.gramian_accumulation_phase = inputs[0] ctx.vjp = inputs[1] ctx.args = inputs[2] - ctx.gramian_accumulator = inputs[3] - ctx.module = inputs[4] + ctx.kwargs = inputs[3] + ctx.gramian_accumulator = inputs[4] + ctx.module = inputs[5] @staticmethod def backward(ctx, *grad_outputs: Tensor) -> tuple: - # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]] + # For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]] if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, *grad_outputs + return None, None, None, None, None, None, *grad_outputs AccumulateJacobian.apply( ctx.vjp, ctx.args, + ctx.kwargs, ctx.gramian_accumulator, ctx.module, *grad_outputs, ) - return None, None, None, None, None, *grad_outputs + return None, None, None, None, None, None, *grad_outputs class AccumulateJacobian(torch.autograd.Function): @@ -214,29 +225,31 @@ class AccumulateJacobian(torch.autograd.Function): def forward( vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, ) -> None: # There is no non-batched dimension - generalized_jacobians = vjp(grad_outputs, args) + generalized_jacobians = vjp(grad_outputs, args, kwargs) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod def vmap( _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]] + in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]] vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, ) -> tuple[None, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], tree_map(lambda _: None, args)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 442b7fe9..3a94543c 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -20,7 +20,7 @@ class VJP(ABC): @abstractmethod def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given @@ -59,12 +59,15 @@ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] ) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, args) + return self.vmapped_vjp(grad_outputs, args, kwargs) def _call_on_one_instance( - self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...] + self, + grad_outputs_j: tuple[Tensor, ...], + args_j: tuple[PyTree, ...], + kwargs_j: dict[str, PyTree], ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. @@ -72,6 +75,7 @@ def _call_on_one_instance( # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) + kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: @@ -80,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor **dict(self.module.named_buffers()), **self.frozen_params, } - output = torch.func.functional_call(self.module, all_state, args_j) + output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] return rg_outputs @@ -108,7 +112,7 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__( - self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, From b089449b7bb6a4ccbdc40cb7649305730477d22c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 20:37:30 +0200 Subject: [PATCH 09/13] Add WithModuleWithHybridKwargs test --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index fb37863c..4de7f2c9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,7 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithHybridKwargs, WithModuleWithHybridPyTreeArg, WithModuleWithStringArg, WithModuleWithStringOutput, @@ -113,6 +114,7 @@ (WithModuleWithStringArg, 32), (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), + (WithModuleWithHybridKwargs, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index ef4933f5..e6c7aaa6 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -847,6 +847,31 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg(tree) +class WithModuleWithHybridKwargs(ShapedModule): + """Model containing a module that has a string keyword argument.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + class WithStringArg(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, s: str, input: Tensor) -> Tensor: + if s == "two": + return input @ self.matrix * 2.0 + else: + return input @ self.matrix + + def __init__(self): + super().__init__() + self.with_string_arg = self.WithStringArg() + + def forward(self, input: Tensor) -> Tensor: + return self.with_string_arg(s="two", input=input) + + class WithModuleWithStringOutput(ShapedModule): """Model containing a module that has a string output.""" From c93b6b063a738ddf1f2a70415e2d246570437f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:13:49 +0200 Subject: [PATCH 10/13] Improve docstring of WithModuleWithHybridKwargs --- tests/utils/architectures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index e6c7aaa6..06d454be 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -848,7 +848,7 @@ def forward(self, input: Tensor) -> Tensor: class WithModuleWithHybridKwargs(ShapedModule): - """Model containing a module that has a string keyword argument.""" + """Model calling its submodule's forward with a string and a tensor as keyword arguments.""" INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) From 11441af2be38ce02a339def15c26da0b0e91a1f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:19:39 +0200 Subject: [PATCH 11/13] Move WithStringArg and WithHybridPyTreeArg out of their wrapping models --- tests/unit/autogram/test_engine.py | 4 +- tests/utils/architectures.py | 85 +++++++++++++----------------- 2 files changed, 40 insertions(+), 49 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 4de7f2c9..23650b80 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -50,8 +50,8 @@ WithBatchNorm, WithBuffered, WithDropout, + WithKwargs, WithModuleTrackingRunningStats, - WithModuleWithHybridKwargs, WithModuleWithHybridPyTreeArg, WithModuleWithStringArg, WithModuleWithStringOutput, @@ -114,7 +114,7 @@ (WithModuleWithStringArg, 32), (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), - (WithModuleWithHybridKwargs, 32), + (WithKwargs, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 06d454be..f17845c4 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,31 +772,54 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class _WithStringArg(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, s: str, input: Tensor) -> Tensor: + if s == "two": + return input @ self.matrix * 2.0 + else: + return input @ self.matrix + + class WithModuleWithStringArg(ShapedModule): """Model containing a module that has a string argument.""" INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - class WithStringArg(nn.Module): - def __init__(self): - super().__init__() - self.matrix = nn.Parameter(torch.randn(2, 3)) - - def forward(self, s: str, input: Tensor) -> Tensor: - if s == "two": - return input @ self.matrix * 2.0 - else: - return input @ self.matrix - def __init__(self): super().__init__() - self.with_string_arg = self.WithStringArg() + self.with_string_arg = _WithStringArg() def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class _WithHybridPyTreeArg(nn.Module): + def __init__(self): + super().__init__() + self.m0 = nn.Parameter(torch.randn(3, 3)) + self.m1 = nn.Parameter(torch.randn(4, 3)) + self.m2 = nn.Parameter(torch.randn(5, 3)) + self.m3 = nn.Parameter(torch.randn(6, 3)) + + def forward(self, input: PyTree) -> Tensor: + t0 = input["one"][0][0] + t1 = input["one"][0][1] + t2 = input["one"][1] + t3 = input["two"] + + c0 = input["one"][0][3] + c1 = input["one"][0][4][0] + c2 = input["one"][2] + c3 = input["three"] + + return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 + + class WithModuleWithHybridPyTreeArg(ShapedModule): """ Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor @@ -806,31 +829,10 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - class WithHybridPyTreeArg(nn.Module): - def __init__(self): - super().__init__() - self.m0 = nn.Parameter(torch.randn(3, 3)) - self.m1 = nn.Parameter(torch.randn(4, 3)) - self.m2 = nn.Parameter(torch.randn(5, 3)) - self.m3 = nn.Parameter(torch.randn(6, 3)) - - def forward(self, input: PyTree) -> Tensor: - t0 = input["one"][0][0] - t1 = input["one"][0][1] - t2 = input["one"][1] - t3 = input["two"] - - c0 = input["one"][0][3] - c1 = input["one"][0][4][0] - c2 = input["one"][2] - c3 = input["three"] - - return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 - def __init__(self): super().__init__() self.linear = nn.Linear(10, 18) - self.with_string_arg = self.WithHybridPyTreeArg() + self.with_string_arg = _WithHybridPyTreeArg() def forward(self, input: Tensor) -> Tensor: input = self.linear(input) @@ -847,26 +849,15 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg(tree) -class WithModuleWithHybridKwargs(ShapedModule): +class WithKwargs(ShapedModule): """Model calling its submodule's forward with a string and a tensor as keyword arguments.""" INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - class WithStringArg(nn.Module): - def __init__(self): - super().__init__() - self.matrix = nn.Parameter(torch.randn(2, 3)) - - def forward(self, s: str, input: Tensor) -> Tensor: - if s == "two": - return input @ self.matrix * 2.0 - else: - return input @ self.matrix - def __init__(self): super().__init__() - self.with_string_arg = self.WithStringArg() + self.with_string_arg = _WithStringArg() def forward(self, input: Tensor) -> Tensor: return self.with_string_arg(s="two", input=input) From 309356ffc77492527cfff68f124d95db53c52e3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:22:56 +0200 Subject: [PATCH 12/13] Add WithModuleWithHybridPyTreeKwarg, move some stuff --- tests/unit/autogram/test_engine.py | 6 +++-- tests/utils/architectures.py | 39 ++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 23650b80..7c88ae9d 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -50,10 +50,11 @@ WithBatchNorm, WithBuffered, WithDropout, - WithKwargs, WithModuleTrackingRunningStats, WithModuleWithHybridPyTreeArg, + WithModuleWithHybridPyTreeKwarg, WithModuleWithStringArg, + WithModuleWithStringKwarg, WithModuleWithStringOutput, WithNoTensorOutput, WithRNN, @@ -114,7 +115,8 @@ (WithModuleWithStringArg, 32), (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), - (WithKwargs, 32), + (WithModuleWithStringKwarg, 32), + (WithModuleWithHybridPyTreeKwarg, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index f17845c4..6166670b 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -798,6 +798,20 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class WithModuleWithStringKwarg(ShapedModule): + """Model calling its submodule's forward with a string and a tensor as keyword arguments.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) + + def __init__(self): + super().__init__() + self.with_string_arg = _WithStringArg() + + def forward(self, input: Tensor) -> Tensor: + return self.with_string_arg(s="two", input=input) + + class _WithHybridPyTreeArg(nn.Module): def __init__(self): super().__init__() @@ -849,18 +863,33 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg(tree) -class WithKwargs(ShapedModule): - """Model calling its submodule's forward with a string and a tensor as keyword arguments.""" +class WithModuleWithHybridPyTreeKwarg(ShapedModule): + """ + Model calling its submodule's forward with a PyTree keyword argument containing a mix of tensors + and non-tensor values. + """ - INPUT_SHAPES = (2,) + INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) def __init__(self): super().__init__() - self.with_string_arg = _WithStringArg() + self.linear = nn.Linear(10, 18) + self.with_string_arg = _WithHybridPyTreeArg() def forward(self, input: Tensor) -> Tensor: - return self.with_string_arg(s="two", input=input) + input = self.linear(input) + + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] + + tree = { + "zero": "unused", + "one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"], + "two": t3, + "three": 0.5, + } + + return self.with_string_arg(input=tree) class WithModuleWithStringOutput(ShapedModule): From 2ec986abb65f6a2222d77f598076f4f3e7f7203a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:24:01 +0200 Subject: [PATCH 13/13] (unrelated) fix typo --- tests/utils/architectures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 6166670b..3dc126ad 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1112,11 +1112,11 @@ def forward(self, input: Tensor) -> Tensor: # Other torchvision.models were not added for the following reasons: -# - VGG16: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. +# - VGG16: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. # - DenseNet: no way to easily replace the BatchNorms (no norm_layer param) # - InceptionV3: no way to easily replace the BatchNorms (no norm_layer param) # - GoogleNet: no way to easily replace the BatchNorms (no norm_layer param) # - ShuffleNetV2: no way to easily replace the BatchNorms (no norm_layer param) -# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. -# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. +# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. +# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. # - MNASNet: no way to easily replace the BatchNorms (no norm_layer param)