From 0fe81726f77de899a3326dfcbd601bddf827ccd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:49:02 +0200 Subject: [PATCH 01/15] Add WithTransformer test --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index d4cac1ec..415a1b16 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -61,6 +61,7 @@ WithRNN, WithSideEffect, WithSomeFrozenModule, + WithTransformer, ) from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( @@ -118,6 +119,7 @@ (WithModuleWithStringOutput, 32), (WithModuleWithStringKwarg, 32), (WithModuleWithHybridPyTreeKwarg, 32), + (WithTransformer, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index c537c2d7..c4f02ac6 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -931,6 +931,29 @@ def forward(self, input: Tensor) -> Tensor: return output +class WithTransformer(ShapedModule): + """Module containing a single Transformer.""" + + INPUT_SHAPES = ((10, 8), (20, 8)) + OUTPUT_SHAPES = (20, 8) + + def __init__(self): + super().__init__() + self.transformer = nn.Transformer( + d_model=8, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=32, + batch_first=True, + dropout=0.0, + ) + + def forward(self, input: tuple[Tensor, Tensor]) -> Tensor: + src, tgt = input + return self.transformer(src, tgt) + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the From 89b1ac0824e4b37b197f8404b99ca4da613809f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:59:30 +0200 Subject: [PATCH 02/15] Add test_batched_non_batched_equivalence_2 --- tests/unit/autogram/test_engine.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 415a1b16..16efdaeb 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -567,3 +567,33 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): + """ + Same as test_batched_non_batched_equivalence but on real architectures, and thus only between + batch_size=0 and batch_size=None. + """ + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model = architecture().to(device=DEVICE) + + engine_0 = Engine(model.modules(), batch_dim=0) + engine_none = Engine(model.modules(), batch_dim=None) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model(inputs) + losses = reduce_to_vector(loss_fn(output)) + + gramian_0 = engine_0.compute_gramian(losses) + gramian_none = engine_none.compute_gramian(losses) + + assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5) From a259716e12766b4b2460c270af34fd2c097a56c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 21:56:53 +0200 Subject: [PATCH 03/15] Fix test_batched_non_batched_equivalence_2 to not have two engines on the same model --- tests/unit/autogram/test_engine.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 16efdaeb..715ebcd3 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -580,20 +580,26 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz output_shapes = architecture.OUTPUT_SHAPES torch.manual_seed(0) - model = architecture().to(device=DEVICE) + model_0 = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_none = architecture().to(device=DEVICE) - engine_0 = Engine(model.modules(), batch_dim=0) - engine_none = Engine(model.modules(), batch_dim=None) + engine_0 = Engine(model_0.modules(), batch_dim=0) + engine_none = Engine(model_none.modules(), batch_dim=None) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) torch.random.manual_seed(0) # Fix randomness for random models - output = model(inputs) - losses = reduce_to_vector(loss_fn(output)) + output = model_0(inputs) + losses_0 = reduce_to_vector(loss_fn(output)) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_none(inputs) + losses_none = reduce_to_vector(loss_fn(output)) - gramian_0 = engine_0.compute_gramian(losses) - gramian_none = engine_none.compute_gramian(losses) + gramian_0 = engine_0.compute_gramian(losses_0) + gramian_none = engine_none.compute_gramian(losses_none) assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5) From 1d83347269d910c5c19c49643a1273cabf306c46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 3 Oct 2025 21:05:55 +0200 Subject: [PATCH 04/15] Add quick fix to handle transformers --- src/torchjd/autogram/_module_hook_manager.py | 15 ++++++++++++++- src/torchjd/autogram/_vjp.py | 9 +++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 9f4285ee..6504032c 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,6 +4,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge +from torch.nn import MultiheadAttention from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle @@ -126,6 +127,14 @@ def __call__( return outputs requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] + + # Quickfix to handle Transformers. + if isinstance(module, MultiheadAttention): + if module.out_proj.weight is not None and module.out_proj.weight.requires_grad: + requires_grad_params.append(module.out_proj.weight) + if module.out_proj.bias is not None and module.out_proj.bias.requires_grad: + requires_grad_params.append(module.out_proj.bias) + self.gramian_accumulator.track_parameter_paths(requires_grad_params) # We only care about running the JacobianAccumulator node, so we need one of its child @@ -260,7 +269,11 @@ def _make_path_jacobians( ) -> dict[Tensor, Tensor]: path_jacobians: dict[Tensor, Tensor] = {} for param_name, generalized_jacobian in generalized_jacobians.items(): - key = module.get_parameter(param_name) + # Quickfix to handle Transformers. + if isinstance(module, MultiheadAttention) and param_name.startswith("out_proj."): + key = module.out_proj.get_parameter(param_name[9:]) + else: + key = module.get_parameter(param_name) jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) path_jacobians[key] = jacobian return path_jacobians diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 3a94543c..e5bbf531 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -46,6 +46,15 @@ def __init__(self, module: nn.Module): else: self.frozen_params[name] = param + # Quickfix to handle Transformers. + if isinstance(module, nn.MultiheadAttention): + submodule = module.out_proj + for name, param in submodule.named_parameters(recurse=False): + if param.requires_grad: + self.trainable_params["out_proj." + name] = param + else: + self.frozen_params["out_proj." + name] = param + class FunctionalVJP(ModuleVJP): """ From 48e41fe91c02af2d9d9f9d8ece76452a7635f37a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 15:43:43 +0200 Subject: [PATCH 05/15] Filter user warning from WithTransformer test --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 715ebcd3..9f64621f 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -119,7 +119,7 @@ (WithModuleWithStringOutput, 32), (WithModuleWithStringKwarg, 32), (WithModuleWithHybridPyTreeKwarg, 32), - (WithTransformer, 32), + param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), From ea821ed0f2ad673b1a147a4e8c3e7f2c1a438835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 16:32:09 +0200 Subject: [PATCH 06/15] Remove special case in _make_path_jacobians * It's actually not needed. We can either do mha.out_proj.get_parameter("weight") or mha.get_parameter("out_proj.weight"). The latter is much easier to handle since "out_proj.weight" is already the param name that we store. --- src/torchjd/autogram/_module_hook_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 6504032c..036e0fbb 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -269,11 +269,7 @@ def _make_path_jacobians( ) -> dict[Tensor, Tensor]: path_jacobians: dict[Tensor, Tensor] = {} for param_name, generalized_jacobian in generalized_jacobians.items(): - # Quickfix to handle Transformers. - if isinstance(module, MultiheadAttention) and param_name.startswith("out_proj."): - key = module.out_proj.get_parameter(param_name[9:]) - else: - key = module.get_parameter(param_name) + key = module.get_parameter(param_name) jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) path_jacobians[key] = jacobian return path_jacobians From 11cdcdb9ffd7bc5f97cd85e09b577246575786d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 16:34:34 +0200 Subject: [PATCH 07/15] Extract code to cleanly handle indirectly_used params --- src/torchjd/autogram/_module_hook_manager.py | 33 +++++++++++++------- src/torchjd/autogram/_vjp.py | 30 +++++++++++------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 036e0fbb..8d41837b 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,7 +4,6 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.nn import MultiheadAttention from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle @@ -126,16 +125,8 @@ def __call__( # require grad return outputs - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] - - # Quickfix to handle Transformers. - if isinstance(module, MultiheadAttention): - if module.out_proj.weight is not None and module.out_proj.weight.requires_grad: - requires_grad_params.append(module.out_proj.weight) - if module.out_proj.bias is not None and module.out_proj.bias.requires_grad: - requires_grad_params.append(module.out_proj.bias) - - self.gramian_accumulator.track_parameter_paths(requires_grad_params) + rg_params = _get_direct_rg_params(module) + _get_indirectly_used_rg_params(module) + self.gramian_accumulator.track_parameter_paths(rg_params) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory @@ -170,6 +161,26 @@ def __call__( return tree_unflatten(flat_outputs, output_spec) +def _get_direct_rg_params(module: nn.Module) -> list[nn.Parameter]: + return [p for p in module.parameters(recurse=False) if p.requires_grad] + + +def _get_indirectly_used_rg_params(module: nn.Module) -> list[nn.Parameter]: + """ + Get the parameters that are used by module but that are not its direct params. This is a fairly + unusual setup that has to be handled on a case-by-case basis. + """ + + # MHA uses its out_proj child params itself. Note that we also check that the MHA still has + # an out_proj attribute because it might change in the future (which will remove the + # necessity of custom code for MHA entirely). See the status of + # https://github.com/pytorch/pytorch/pull/126568 + if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): + return _get_direct_rg_params(module.out_proj) + else: + return [] + + class JacobianAccumulator(torch.autograd.Function): """ Autograd function that accumulates Jacobian Gramians during the first backward pass. diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index e5bbf531..78f17189 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -40,20 +40,28 @@ def __init__(self, module: nn.Module): self.trainable_params = dict[str, Parameter]() self.frozen_params = dict[str, Parameter]() + self._save_direct_params(module) + self._save_indirectly_used_params(module) + + def _save_direct_params(self, module: nn.Module, prefix: str = "") -> None: for name, param in module.named_parameters(recurse=False): if param.requires_grad: - self.trainable_params[name] = param + self.trainable_params[prefix + name] = param else: - self.frozen_params[name] = param - - # Quickfix to handle Transformers. - if isinstance(module, nn.MultiheadAttention): - submodule = module.out_proj - for name, param in submodule.named_parameters(recurse=False): - if param.requires_grad: - self.trainable_params["out_proj." + name] = param - else: - self.frozen_params["out_proj." + name] = param + self.frozen_params[prefix + name] = param + + def _save_indirectly_used_params(self, module: nn.Module) -> None: + """ + Save the parameters that are used by module but that are not its direct params. This is a + fairly unusual setup that has to be handled on a case-by-case basis. + """ + + # MHA uses its out_proj child params itself. Note that we also check that the MHA still has + # an out_proj attribute because it might change in the future (which will remove the + # necessity of custom code for MHA entirely). See the status of + # https://github.com/pytorch/pytorch/pull/126568 + if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): + self._save_direct_params(module.out_proj, "out_proj.") class FunctionalVJP(ModuleVJP): From 6531f00635f1e2da944ca6875da955daae50b9d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 17:00:40 +0200 Subject: [PATCH 08/15] Add WithMultiHeadAttention --- tests/unit/autogram/test_engine.py | 2 ++ tests/utils/architectures.py | 25 ++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 9f64621f..0a364746 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -57,6 +57,7 @@ WithModuleWithStringArg, WithModuleWithStringKwarg, WithModuleWithStringOutput, + WithMultiHeadAttention, WithNoTensorOutput, WithRNN, WithSideEffect, @@ -119,6 +120,7 @@ (WithModuleWithStringOutput, 32), (WithModuleWithStringKwarg, 32), (WithModuleWithHybridPyTreeKwarg, 32), + (WithMultiHeadAttention, 32), param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")), (FreeParam, 32), (NoFreeParam, 32), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index c4f02ac6..22a2b58b 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -931,8 +931,31 @@ def forward(self, input: Tensor) -> Tensor: return output +class WithMultiHeadAttention(ShapedModule): + """Module containing a MultiheadAttention layer.""" + + INPUT_SHAPES = ((20, 8), (10, 9), (10, 11)) + OUTPUT_SHAPES = (20, 8) + + def __init__(self): + super().__init__() + self.mha = nn.MultiheadAttention( + embed_dim=8, + num_heads=2, + dropout=0.0, + batch_first=True, + kdim=9, + vdim=11, + ) + + def forward(self, input: tuple[Tensor, Tensor, Tensor]) -> Tensor: + query, key, value = input + attn_output, _ = self.mha(query, key, value) + return attn_output + + class WithTransformer(ShapedModule): - """Module containing a single Transformer.""" + """Module containing a Transformer.""" INPUT_SHAPES = ((10, 8), (20, 8)) OUTPUT_SHAPES = (20, 8) From 8c175d9e741e61beb70a2099875505e12b4f065b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 17:03:47 +0200 Subject: [PATCH 09/15] Adapt warning about transformers --- src/torchjd/autogram/_engine.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 90cd2582..de6ba99f 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -113,19 +113,16 @@ class Engine: memory-efficient, and thus typically faster, to use the Gramian-based approach. .. warning:: - When providing a non-None ``batch_dim``, all provided modules must respect a few - conditions: + When providing a non-None ``batch_dim``, all provided modules must respect a few conditions: * They should treat the elements of the batch independently. Most common layers respect this, but for example `BatchNorm `_ does not (it computes some average and standard deviation over the elements of the batch). * Their inputs and outputs can be anything, but each input tensor and each output tensor - must be batched on its first dimension. `Transformers - `_ and `RNNs + must be batched on its first dimension. `RNNs `_ are thus not - supported yet. This is only an implementation issue, so it should be fixed soon (please - open an issue if you need extra focus on this). + supported yet. * They should not perform in-place operations on tensors (for instance you should not use ``track_running_stats=True`` in normalization layers). * They should not have side effects during the forward pass (since their forward pass will @@ -133,7 +130,9 @@ class Engine: * If they have some randomness during the forward pass, they should not have direct trainable parameters. It is, however, perfectly fine for random modules to have child modules that have trainable parameters, so if you have a random module with some direct - parameters, a simple fix is to wrap these parameters into a child module. + parameters, a simple fix is to wrap these parameters into a child module. For this reason, + `Transformers `_ + should be used with `dropout=0.0`. If you're building your own architecture, respecting those criteria should be quite easy. However, if you're using an existing architecture, you may have to modify it to make it From 82b1d391223c6470f4fdaef48a51935f1a3b6cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 9 Oct 2025 17:26:48 +0200 Subject: [PATCH 10/15] Improve warnings in autogram engine --- src/torchjd/autogram/_engine.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index de6ba99f..a60bd4c4 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -120,19 +120,29 @@ class Engine: `_ does not (it computes some average and standard deviation over the elements of the batch). * Their inputs and outputs can be anything, but each input tensor and each output tensor - must be batched on its first dimension. `RNNs - `_ are thus not - supported yet. + must be batched on its first dimension. When available (e.g. in `Transformers + `_, + `MultiheadAttention + `_, + etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs + `_ not supported yet + because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``. * They should not perform in-place operations on tensors (for instance you should not use ``track_running_stats=True`` in normalization layers). * They should not have side effects during the forward pass (since their forward pass will be called twice, the side effects could be different from what's expected). * If they have some randomness during the forward pass, they should not have direct - trainable parameters. It is, however, perfectly fine for random modules to have child - modules that have trainable parameters, so if you have a random module with some direct - parameters, a simple fix is to wrap these parameters into a child module. For this reason, - `Transformers `_ - should be used with `dropout=0.0`. + trainable parameters. For this reason, + `Transformers + `_, which use a + dropout function (rather than a `Dropout + `_ layer) in a + module with some trainable parameters, has to be used with + ``dropout=0.0``. Note that a `Dropout + `_ layers are + entirely supported and should be preferred. It is also perfectly fine for random modules + to have child modules that have trainable parameters, so if you have a random module with + some direct parameters, a simple fix is to wrap these parameters into a child module. If you're building your own architecture, respecting those criteria should be quite easy. However, if you're using an existing architecture, you may have to modify it to make it From b105dfe711bcbda29a12088f9b4ceb43576076c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 10 Oct 2025 00:52:12 +0200 Subject: [PATCH 11/15] Add WithTransformerLarge & speed test --- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 ++ tests/unit/autogram/test_engine.py | 6 ++++++ tests/utils/architectures.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index a41b2905..ca6b9ba7 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -13,6 +13,7 @@ NoFreeParam, ShapedModule, SqueezeNet, + WithTransformerLarge, ) from utils.forward_backwards import ( autograd_forward_backward, @@ -27,6 +28,7 @@ from torchjd.autogram import Engine PARAMETRIZATIONS = [ + (WithTransformerLarge, 8), (FreeParam, 64), (NoFreeParam, 64), (Cifar10Model, 64), diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 0a364746..2582d513 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -63,6 +63,7 @@ WithSideEffect, WithSomeFrozenModule, WithTransformer, + WithTransformerLarge, ) from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( @@ -130,6 +131,11 @@ param(GroupNormMobileNetV3Small, 3, marks=mark.slow), param(SqueezeNet, 8, marks=mark.slow), param(InstanceNormMobileNetV2, 2, marks=mark.slow), + param( + WithTransformerLarge, + 8, + marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")], + ), ] diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 22a2b58b..d5760d03 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -977,6 +977,24 @@ def forward(self, input: tuple[Tensor, Tensor]) -> Tensor: return self.transformer(src, tgt) +class WithTransformerLarge(ShapedModule): + """Module containing a large Transformer.""" + + INPUT_SHAPES = ((10, 512), (20, 512)) + OUTPUT_SHAPES = (20, 512) + + def __init__(self): + super().__init__() + self.transformer = nn.Transformer( + batch_first=True, + dropout=0.0, + ) + + def forward(self, input: tuple[Tensor, Tensor]) -> Tensor: + src, tgt = input + return self.transformer(src, tgt) + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the From 047bd0fe8d665d0528329cc67abab29b1f6df57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 10 Oct 2025 15:18:27 +0200 Subject: [PATCH 12/15] Factorize code to extract used params --- src/torchjd/autogram/_module_hook_manager.py | 25 ++--------- src/torchjd/autogram/_module_utils.py | 47 ++++++++++++++++++++ src/torchjd/autogram/_vjp.py | 28 ++---------- 3 files changed, 53 insertions(+), 47 deletions(-) create mode 100644 src/torchjd/autogram/_module_utils.py diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 8d41837b..e21297e3 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,6 +9,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._module_utils import get_used_params from ._vjp import VJP, AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: @@ -125,8 +126,8 @@ def __call__( # require grad return outputs - rg_params = _get_direct_rg_params(module) + _get_indirectly_used_rg_params(module) - self.gramian_accumulator.track_parameter_paths(rg_params) + rg_params, _ = get_used_params(module) + self.gramian_accumulator.track_parameter_paths(list(rg_params.values())) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory @@ -161,26 +162,6 @@ def __call__( return tree_unflatten(flat_outputs, output_spec) -def _get_direct_rg_params(module: nn.Module) -> list[nn.Parameter]: - return [p for p in module.parameters(recurse=False) if p.requires_grad] - - -def _get_indirectly_used_rg_params(module: nn.Module) -> list[nn.Parameter]: - """ - Get the parameters that are used by module but that are not its direct params. This is a fairly - unusual setup that has to be handled on a case-by-case basis. - """ - - # MHA uses its out_proj child params itself. Note that we also check that the MHA still has - # an out_proj attribute because it might change in the future (which will remove the - # necessity of custom code for MHA entirely). See the status of - # https://github.com/pytorch/pytorch/pull/126568 - if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): - return _get_direct_rg_params(module.out_proj) - else: - return [] - - class JacobianAccumulator(torch.autograd.Function): """ Autograd function that accumulates Jacobian Gramians during the first backward pass. diff --git a/src/torchjd/autogram/_module_utils.py b/src/torchjd/autogram/_module_utils.py new file mode 100644 index 00000000..c2e5c0df --- /dev/null +++ b/src/torchjd/autogram/_module_utils.py @@ -0,0 +1,47 @@ +from torch import nn + + +def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + """ + Gets all parameters that a module uses. In reality, we return all direct params (which may + include some unused params) and all the indirectly used params that we know about (we may be + missing some in weird modules). + + Returns the tuple containing the params that require grad and the params that don't require + grad. + """ + + direct_rg_params, direct_frozen_params = _get_direct_params(module) + indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module) + rg_params = direct_rg_params | indirect_rg_params + frozen_params = direct_frozen_params | indirect_frozen_params + + return rg_params, frozen_params + + +def _get_direct_params( + module: nn.Module, prefix: str = "" +) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + rg_params = dict[str, nn.Parameter]() + frozen_params = dict[str, nn.Parameter]() + + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + rg_params[prefix + name] = param + else: + frozen_params[prefix + name] = param + + return rg_params, frozen_params + + +def _get_indirectly_used_params( + module: nn.Module, +) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + # MHA uses its out_proj child params itself. Note that we also check that the MHA still has + # an out_proj attribute because it might change in the future (which will remove the + # necessity of custom code for MHA entirely). See the status of + # https://github.com/pytorch/pytorch/pull/126568 + if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): + return _get_direct_params(module.out_proj, prefix="out_proj.") + + return {}, {} diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 78f17189..a0a1422b 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -6,6 +6,8 @@ from torch.nn import Parameter from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten +from torchjd.autogram._module_utils import get_used_params + # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see # https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). @@ -37,31 +39,7 @@ class ModuleVJP(VJP, ABC): def __init__(self, module: nn.Module): self.module = module - self.trainable_params = dict[str, Parameter]() - self.frozen_params = dict[str, Parameter]() - - self._save_direct_params(module) - self._save_indirectly_used_params(module) - - def _save_direct_params(self, module: nn.Module, prefix: str = "") -> None: - for name, param in module.named_parameters(recurse=False): - if param.requires_grad: - self.trainable_params[prefix + name] = param - else: - self.frozen_params[prefix + name] = param - - def _save_indirectly_used_params(self, module: nn.Module) -> None: - """ - Save the parameters that are used by module but that are not its direct params. This is a - fairly unusual setup that has to be handled on a case-by-case basis. - """ - - # MHA uses its out_proj child params itself. Note that we also check that the MHA still has - # an out_proj attribute because it might change in the future (which will remove the - # necessity of custom code for MHA entirely). See the status of - # https://github.com/pytorch/pytorch/pull/126568 - if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): - self._save_direct_params(module.out_proj, "out_proj.") + self.trainable_params, self.frozen_params = get_used_params(module) class FunctionalVJP(ModuleVJP): From 1e1ff02d32940580b38370431554650d2f377cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 10 Oct 2025 15:44:08 +0200 Subject: [PATCH 13/15] Rename trainable to rg --- src/torchjd/autogram/_vjp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index a0a1422b..acf79c3b 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -39,7 +39,7 @@ class ModuleVJP(VJP, ABC): def __init__(self, module: nn.Module): self.module = module - self.trainable_params, self.frozen_params = get_used_params(module) + self.rg_params, self.frozen_params = get_used_params(module) class FunctionalVJP(ModuleVJP): @@ -73,9 +73,9 @@ def _call_on_one_instance( kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] - def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: + def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: all_state = { - **trainable_params, + **rg_params, **dict(self.module.named_buffers()), **self.frozen_params, } @@ -84,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] return rg_outputs - vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] + vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take @@ -104,14 +104,14 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): super().__init__(module) self.rg_outputs = rg_outputs - self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) + self.flat_rg_params, self.param_spec = tree_flatten(self.rg_params) def __call__( self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, - self.flat_trainable_params, + self.flat_rg_params, grad_outputs, retain_graph=True, allow_unused=True, From 9268babc97e16c3e5223c403ebaca1cc8f3e9b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 10 Oct 2025 15:50:40 +0200 Subject: [PATCH 14/15] Add extra explanation to test_batched_non_batched_equivalence_2 --- tests/unit/autogram/test_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 2582d513..b7f43de9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -582,6 +582,9 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz """ Same as test_batched_non_batched_equivalence but on real architectures, and thus only between batch_size=0 and batch_size=None. + + If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be + that the get_used_params does not work for some module of the architecture. """ input_shapes = architecture.INPUT_SHAPES From 9497ed9eb329db22ec0e97002982dae44eaedb43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:46:13 +0200 Subject: [PATCH 15/15] Remove useless cast to list Co-authored-by: Pierre Quinton --- src/torchjd/autogram/_module_hook_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e21297e3..0ab25bf0 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -127,7 +127,7 @@ def __call__( return outputs rg_params, _ = get_used_params(module) - self.gramian_accumulator.track_parameter_paths(list(rg_params.values())) + self.gramian_accumulator.track_parameter_paths(rg_params.values()) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory