diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 90cd2582..a60bd4c4 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -113,27 +113,36 @@ class Engine: memory-efficient, and thus typically faster, to use the Gramian-based approach. .. warning:: - When providing a non-None ``batch_dim``, all provided modules must respect a few - conditions: + When providing a non-None ``batch_dim``, all provided modules must respect a few conditions: * They should treat the elements of the batch independently. Most common layers respect this, but for example `BatchNorm `_ does not (it computes some average and standard deviation over the elements of the batch). * Their inputs and outputs can be anything, but each input tensor and each output tensor - must be batched on its first dimension. `Transformers - `_ and `RNNs - `_ are thus not - supported yet. This is only an implementation issue, so it should be fixed soon (please - open an issue if you need extra focus on this). + must be batched on its first dimension. When available (e.g. in `Transformers + `_, + `MultiheadAttention + `_, + etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs + `_ not supported yet + because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``. * They should not perform in-place operations on tensors (for instance you should not use ``track_running_stats=True`` in normalization layers). * They should not have side effects during the forward pass (since their forward pass will be called twice, the side effects could be different from what's expected). * If they have some randomness during the forward pass, they should not have direct - trainable parameters. It is, however, perfectly fine for random modules to have child - modules that have trainable parameters, so if you have a random module with some direct - parameters, a simple fix is to wrap these parameters into a child module. + trainable parameters. For this reason, + `Transformers + `_, which use a + dropout function (rather than a `Dropout + `_ layer) in a + module with some trainable parameters, has to be used with + ``dropout=0.0``. Note that a `Dropout + `_ layers are + entirely supported and should be preferred. It is also perfectly fine for random modules + to have child modules that have trainable parameters, so if you have a random module with + some direct parameters, a simple fix is to wrap these parameters into a child module. If you're building your own architecture, respecting those criteria should be quite easy. However, if you're using an existing architecture, you may have to modify it to make it diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 9f4285ee..0ab25bf0 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,6 +9,7 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._module_utils import get_used_params from ._vjp import VJP, AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: @@ -125,8 +126,8 @@ def __call__( # require grad return outputs - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] - self.gramian_accumulator.track_parameter_paths(requires_grad_params) + rg_params, _ = get_used_params(module) + self.gramian_accumulator.track_parameter_paths(rg_params.values()) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory diff --git a/src/torchjd/autogram/_module_utils.py b/src/torchjd/autogram/_module_utils.py new file mode 100644 index 00000000..c2e5c0df --- /dev/null +++ b/src/torchjd/autogram/_module_utils.py @@ -0,0 +1,47 @@ +from torch import nn + + +def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + """ + Gets all parameters that a module uses. In reality, we return all direct params (which may + include some unused params) and all the indirectly used params that we know about (we may be + missing some in weird modules). + + Returns the tuple containing the params that require grad and the params that don't require + grad. + """ + + direct_rg_params, direct_frozen_params = _get_direct_params(module) + indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module) + rg_params = direct_rg_params | indirect_rg_params + frozen_params = direct_frozen_params | indirect_frozen_params + + return rg_params, frozen_params + + +def _get_direct_params( + module: nn.Module, prefix: str = "" +) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + rg_params = dict[str, nn.Parameter]() + frozen_params = dict[str, nn.Parameter]() + + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + rg_params[prefix + name] = param + else: + frozen_params[prefix + name] = param + + return rg_params, frozen_params + + +def _get_indirectly_used_params( + module: nn.Module, +) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: + # MHA uses its out_proj child params itself. Note that we also check that the MHA still has + # an out_proj attribute because it might change in the future (which will remove the + # necessity of custom code for MHA entirely). See the status of + # https://github.com/pytorch/pytorch/pull/126568 + if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): + return _get_direct_params(module.out_proj, prefix="out_proj.") + + return {}, {} diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 3a94543c..acf79c3b 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -6,6 +6,8 @@ from torch.nn import Parameter from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten +from torchjd.autogram._module_utils import get_used_params + # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see # https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). @@ -37,14 +39,7 @@ class ModuleVJP(VJP, ABC): def __init__(self, module: nn.Module): self.module = module - self.trainable_params = dict[str, Parameter]() - self.frozen_params = dict[str, Parameter]() - - for name, param in module.named_parameters(recurse=False): - if param.requires_grad: - self.trainable_params[name] = param - else: - self.frozen_params[name] = param + self.rg_params, self.frozen_params = get_used_params(module) class FunctionalVJP(ModuleVJP): @@ -78,9 +73,9 @@ def _call_on_one_instance( kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] - def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: + def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]: all_state = { - **trainable_params, + **rg_params, **dict(self.module.named_buffers()), **self.frozen_params, } @@ -89,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] return rg_outputs - vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] + vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take @@ -109,14 +104,14 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): super().__init__(module) self.rg_outputs = rg_outputs - self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) + self.flat_rg_params, self.param_spec = tree_flatten(self.rg_params) def __call__( self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, - self.flat_trainable_params, + self.flat_rg_params, grad_outputs, retain_graph=True, allow_unused=True, diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index a41b2905..ca6b9ba7 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -13,6 +13,7 @@ NoFreeParam, ShapedModule, SqueezeNet, + WithTransformerLarge, ) from utils.forward_backwards import ( autograd_forward_backward, @@ -27,6 +28,7 @@ from torchjd.autogram import Engine PARAMETRIZATIONS = [ + (WithTransformerLarge, 8), (FreeParam, 64), (NoFreeParam, 64), (Cifar10Model, 64), diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index d4cac1ec..b7f43de9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -57,10 +57,13 @@ WithModuleWithStringArg, WithModuleWithStringKwarg, WithModuleWithStringOutput, + WithMultiHeadAttention, WithNoTensorOutput, WithRNN, WithSideEffect, WithSomeFrozenModule, + WithTransformer, + WithTransformerLarge, ) from utils.dict_assertions import assert_tensor_dicts_are_close from utils.forward_backwards import ( @@ -118,6 +121,8 @@ (WithModuleWithStringOutput, 32), (WithModuleWithStringKwarg, 32), (WithModuleWithHybridPyTreeKwarg, 32), + (WithMultiHeadAttention, 32), + param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), @@ -126,6 +131,11 @@ param(GroupNormMobileNetV3Small, 3, marks=mark.slow), param(SqueezeNet, 8, marks=mark.slow), param(InstanceNormMobileNetV2, 2, marks=mark.slow), + param( + WithTransformerLarge, + 8, + marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")], + ), ] @@ -565,3 +575,42 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): + """ + Same as test_batched_non_batched_equivalence but on real architectures, and thus only between + batch_size=0 and batch_size=None. + + If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be + that the get_used_params does not work for some module of the architecture. + """ + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model_0 = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_none = architecture().to(device=DEVICE) + + engine_0 = Engine(model_0.modules(), batch_dim=0) + engine_none = Engine(model_none.modules(), batch_dim=None) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_0(inputs) + losses_0 = reduce_to_vector(loss_fn(output)) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_none(inputs) + losses_none = reduce_to_vector(loss_fn(output)) + + gramian_0 = engine_0.compute_gramian(losses_0) + gramian_none = engine_none.compute_gramian(losses_none) + + assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index c537c2d7..d5760d03 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -931,6 +931,70 @@ def forward(self, input: Tensor) -> Tensor: return output +class WithMultiHeadAttention(ShapedModule): + """Module containing a MultiheadAttention layer.""" + + INPUT_SHAPES = ((20, 8), (10, 9), (10, 11)) + OUTPUT_SHAPES = (20, 8) + + def __init__(self): + super().__init__() + self.mha = nn.MultiheadAttention( + embed_dim=8, + num_heads=2, + dropout=0.0, + batch_first=True, + kdim=9, + vdim=11, + ) + + def forward(self, input: tuple[Tensor, Tensor, Tensor]) -> Tensor: + query, key, value = input + attn_output, _ = self.mha(query, key, value) + return attn_output + + +class WithTransformer(ShapedModule): + """Module containing a Transformer.""" + + INPUT_SHAPES = ((10, 8), (20, 8)) + OUTPUT_SHAPES = (20, 8) + + def __init__(self): + super().__init__() + self.transformer = nn.Transformer( + d_model=8, + nhead=2, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=32, + batch_first=True, + dropout=0.0, + ) + + def forward(self, input: tuple[Tensor, Tensor]) -> Tensor: + src, tgt = input + return self.transformer(src, tgt) + + +class WithTransformerLarge(ShapedModule): + """Module containing a large Transformer.""" + + INPUT_SHAPES = ((10, 512), (20, 512)) + OUTPUT_SHAPES = (20, 512) + + def __init__(self): + super().__init__() + self.transformer = nn.Transformer( + batch_first=True, + dropout=0.0, + ) + + def forward(self, input: tuple[Tensor, Tensor]) -> Tensor: + src, tgt = input + return self.transformer(src, tgt) + + class FreeParam(ShapedModule): """ Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the