From 2c3dacb9e48f70b84af4039bf5a5887461aa8aa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 11 Oct 2025 02:44:05 +0200 Subject: [PATCH] refactor(autogram): Make engine hook recursively * Make the Engine hook modules recursively. If a direct parameter exists, hook the module and do not hook its children. If not, try to hook the child modules. * This changes means that only the parentmost module with direct rg params gets hooked. Its used parameters are thus simply module.parameters(recurse=True) now. This is even the case in special cases where the parent uses the child parameters, so we don't need to have a special case for MHA anymore. * Remove _module_utils: it's now trivial to know with respect to which parameters to differentiate. * Update all usages to now create the Engine with Engine(model) instead of Engine(model.modules). For partial JD, users have to be more careful, as they should sometimes specify several modules, but these modules should be "disjoint" (i.e. no specified module should be a child of another specified module) * This mostly makes a difference on WithFreeParam. Before, we had 2 hooks (one for the parent, parameterized with the parent's param - aka the free param, and one for the child module, parameterized with the child's params). Now we simply have 1 hook for the parent, parameterized with the all parameters (i.e. parent.parameters(recurse=True)). This is probably faster (because we don't have to do 2 extra forwards and 2 extra backwards for the child, but just 1 now), but maybe a bit more memory consuming (because we have to store the Jacobian wrt the child's params and wrt the parent's free param at the same time). This case is quite niche though, and I still see it as an improvement. * Change Engine to take *modules: nn.Module instead of Iterable[nn.Module] (more convenient for the new usage, because we only specify one model 99% of the time). Update the docstring accordingly. --- docs/source/examples/iwmtl.rst | 2 +- docs/source/examples/iwrm.rst | 2 +- docs/source/examples/partial_jd.rst | 2 +- src/torchjd/autogram/_engine.py | 28 ++++++------ src/torchjd/autogram/_module_hook_manager.py | 5 +-- src/torchjd/autogram/_module_utils.py | 47 -------------------- src/torchjd/autogram/_vjp.py | 12 +++-- tests/doc/test_autogram.py | 2 +- tests/doc/test_rst.py | 6 +-- tests/speed/autogram/grad_vs_jac_vs_gram.py | 2 +- tests/unit/autogram/test_engine.py | 30 ++++++------- 11 files changed, 48 insertions(+), 90 deletions(-) delete mode 100644 src/torchjd/autogram/_module_utils.py diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index ce075b06..4c1c7a4c 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -31,7 +31,7 @@ The following example shows how to do that. optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batch_dim=0) + engine = Engine(shared_module, batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index d25e0165..a326f582 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index c8d9c781..c86a653a 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), batch_dim=0) + engine = Engine(model[2:], batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index a60bd4c4..48888f25 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -1,4 +1,3 @@ -from collections.abc import Iterable from typing import cast import torch @@ -63,8 +62,9 @@ class Engine: backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using :func:`torchjd.autojac.backward`. - :param modules: A collection of modules whose direct (non-recursive) parameters will contribute - to the Gramian of the Jacobian. + :param modules: The modules whose parameters will contribute to the Gramian of the Jacobian. + Several modules can be provided, but it's important that none of them is a child module of + another of them. :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates @@ -96,7 +96,7 @@ class Engine: weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] @@ -178,7 +178,7 @@ class Engine: def __init__( self, - modules: Iterable[nn.Module], + *modules: nn.Module, batch_dim: int | None, ): self._gramian_accumulator = GramianAccumulator() @@ -188,16 +188,16 @@ def __init__( self._target_edges, self._gramian_accumulator, batch_dim is not None ) - self._hook_modules(modules) + for module in modules: + self._hook_module_recursively(module) - def _hook_modules(self, modules: Iterable[nn.Module]) -> None: - _modules = set(modules) - - # Add module forward hooks to compute jacobians - for module in _modules: - if any(p.requires_grad for p in module.parameters(recurse=False)): - self._check_module_is_compatible(module) - self._module_hook_manager.hook_module(module) + def _hook_module_recursively(self, module: nn.Module) -> None: + if any(p.requires_grad for p in module.parameters(recurse=False)): + self._check_module_is_compatible(module) + self._module_hook_manager.hook_module(module) + else: + for child in module.children(): + self._hook_module_recursively(child) def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batch_dim is not None: diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0ab25bf0..058dd188 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,7 +9,6 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._module_utils import get_used_params from ._vjp import VJP, AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: @@ -126,8 +125,8 @@ def __call__( # require grad return outputs - rg_params, _ = get_used_params(module) - self.gramian_accumulator.track_parameter_paths(rg_params.values()) + rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad] + self.gramian_accumulator.track_parameter_paths(rg_params) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory diff --git a/src/torchjd/autogram/_module_utils.py b/src/torchjd/autogram/_module_utils.py deleted file mode 100644 index c2e5c0df..00000000 --- a/src/torchjd/autogram/_module_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from torch import nn - - -def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - """ - Gets all parameters that a module uses. In reality, we return all direct params (which may - include some unused params) and all the indirectly used params that we know about (we may be - missing some in weird modules). - - Returns the tuple containing the params that require grad and the params that don't require - grad. - """ - - direct_rg_params, direct_frozen_params = _get_direct_params(module) - indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module) - rg_params = direct_rg_params | indirect_rg_params - frozen_params = direct_frozen_params | indirect_frozen_params - - return rg_params, frozen_params - - -def _get_direct_params( - module: nn.Module, prefix: str = "" -) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - rg_params = dict[str, nn.Parameter]() - frozen_params = dict[str, nn.Parameter]() - - for name, param in module.named_parameters(recurse=False): - if param.requires_grad: - rg_params[prefix + name] = param - else: - frozen_params[prefix + name] = param - - return rg_params, frozen_params - - -def _get_indirectly_used_params( - module: nn.Module, -) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - # MHA uses its out_proj child params itself. Note that we also check that the MHA still has - # an out_proj attribute because it might change in the future (which will remove the - # necessity of custom code for MHA entirely). See the status of - # https://github.com/pytorch/pytorch/pull/126568 - if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): - return _get_direct_params(module.out_proj, prefix="out_proj.") - - return {}, {} diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index acf79c3b..86df495b 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -6,8 +6,6 @@ from torch.nn import Parameter from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten -from torchjd.autogram._module_utils import get_used_params - # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see # https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). @@ -39,7 +37,15 @@ class ModuleVJP(VJP, ABC): def __init__(self, module: nn.Module): self.module = module - self.rg_params, self.frozen_params = get_used_params(module) + + self.rg_params = dict[str, Parameter]() + self.frozen_params = dict[str, Parameter]() + + for name, param in module.named_parameters(recurse=True): + if param.requires_grad: + self.rg_params[name] = param + else: + self.frozen_params[name] = param class FunctionalVJP(ModuleVJP): diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 06bc0589..64ce48f7 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -20,7 +20,7 @@ def test_engine(): weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 871d1c1c..53d92ed2 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -94,7 +94,7 @@ def test_iwmtl(): optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batch_dim=0) + engine = Engine(shared_module, batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task @@ -184,7 +184,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -374,7 +374,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), batch_dim=0) + engine = Engine(model[2:], batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index ca6b9ba7..5188ba86 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -121,7 +121,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b7f43de9..31f63559 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -150,7 +150,7 @@ def _assert_gramian_is_equivalent_to_autograd( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) @@ -261,7 +261,7 @@ def test_compute_gramian_various_output_shapes( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) @@ -324,7 +324,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - engine = Engine(gramian_modules, batch_dim=batch_dim) + engine = Engine(*gramian_modules, batch_dim=batch_dim) output = model(input) losses = reduce_to_vector(loss_fn(output)) @@ -349,7 +349,7 @@ def test_iwrm_steps_with_autogram( model = architecture().to(device=DEVICE) - engine = Engine(model.modules(), batch_dim=batch_dim) + engine = Engine(model, batch_dim=batch_dim) optimizer = SGD(model.parameters(), lr=1e-7) for i in range(n_iter): @@ -388,7 +388,7 @@ def test_autograd_while_modules_are_hooked( autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} # Hook modules and optionally compute the Gramian - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) if use_engine: torch.manual_seed(0) # Fix randomness for random models output = model_autogram(input) @@ -420,7 +420,7 @@ def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | No model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), batch_dim=batch_dim) + _ = Engine(model, batch_dim=batch_dim) def test_compute_gramian_manual(): @@ -434,7 +434,7 @@ def test_compute_gramian_manual(): torch.manual_seed(0) model = Linear(in_dims, out_dims).to(device=DEVICE) - engine = Engine(model.modules(), batch_dim=None) + engine = Engine(model, batch_dim=None) input = randn_(in_dims) output = model(input) @@ -480,8 +480,8 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None): output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=None) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=None) + engine2 = Engine(model, batch_dim=None) input = randn_([input_size]) output = model(input) @@ -520,8 +520,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=None) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=None) + engine2 = Engine(model, batch_dim=None) input = randn_([input_size]) output = model(input).reshape(shape[1:]) @@ -563,8 +563,8 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): output_size = input_size model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=batch_dim) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=batch_dim) + engine2 = Engine(model, batch_dim=None) input = randn_([batch_size, input_size]) output = model(input) @@ -595,8 +595,8 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz torch.manual_seed(0) model_none = architecture().to(device=DEVICE) - engine_0 = Engine(model_0.modules(), batch_dim=0) - engine_none = Engine(model_none.modules(), batch_dim=None) + engine_0 = Engine(model_0, batch_dim=0) + engine_none = Engine(model_none, batch_dim=None) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes)