diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index ce075b06..4c1c7a4c 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -31,7 +31,7 @@ The following example shows how to do that. optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batch_dim=0) + engine = Engine(shared_module, batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index d25e0165..a326f582 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index c8d9c781..c86a653a 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), batch_dim=0) + engine = Engine(model[2:], batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 9cab68f1..48888562 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -1,4 +1,3 @@ -from collections.abc import Iterable from typing import cast import torch @@ -58,8 +57,9 @@ class Engine: backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using :func:`torchjd.autojac.backward`. - :param modules: A collection of modules whose direct (non-recursive) parameters will contribute - to the Gramian of the Jacobian. + :param modules: The modules whose parameters will contribute to the Gramian of the Jacobian. + Several modules can be provided, but it's important that none of them is a child module of + another of them. :param batch_dim: If the modules work with batches and process each batch element independently, then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates @@ -91,7 +91,7 @@ class Engine: weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] @@ -173,7 +173,7 @@ class Engine: def __init__( self, - modules: Iterable[nn.Module], + *modules: nn.Module, batch_dim: int | None, ): self._gramian_accumulator = GramianAccumulator() @@ -183,16 +183,16 @@ def __init__( self._target_edges, self._gramian_accumulator, batch_dim is not None ) - self._hook_modules(modules) + for module in modules: + self._hook_module_recursively(module) - def _hook_modules(self, modules: Iterable[nn.Module]) -> None: - _modules = set(modules) - - # Add module forward hooks to compute jacobians - for module in _modules: - if any(p.requires_grad for p in module.parameters(recurse=False)): - self._check_module_is_compatible(module) - self._module_hook_manager.hook_module(module) + def _hook_module_recursively(self, module: nn.Module) -> None: + if any(p.requires_grad for p in module.parameters(recurse=False)): + self._check_module_is_compatible(module) + self._module_hook_manager.hook_module(module) + else: + for child in module.children(): + self._hook_module_recursively(child) def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batch_dim is not None: diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 0ab25bf0..058dd188 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -9,7 +9,6 @@ from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._module_utils import get_used_params from ._vjp import VJP, AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: @@ -126,8 +125,8 @@ def __call__( # require grad return outputs - rg_params, _ = get_used_params(module) - self.gramian_accumulator.track_parameter_paths(rg_params.values()) + rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad] + self.gramian_accumulator.track_parameter_paths(rg_params) # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory diff --git a/src/torchjd/autogram/_module_utils.py b/src/torchjd/autogram/_module_utils.py deleted file mode 100644 index c2e5c0df..00000000 --- a/src/torchjd/autogram/_module_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from torch import nn - - -def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - """ - Gets all parameters that a module uses. In reality, we return all direct params (which may - include some unused params) and all the indirectly used params that we know about (we may be - missing some in weird modules). - - Returns the tuple containing the params that require grad and the params that don't require - grad. - """ - - direct_rg_params, direct_frozen_params = _get_direct_params(module) - indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module) - rg_params = direct_rg_params | indirect_rg_params - frozen_params = direct_frozen_params | indirect_frozen_params - - return rg_params, frozen_params - - -def _get_direct_params( - module: nn.Module, prefix: str = "" -) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - rg_params = dict[str, nn.Parameter]() - frozen_params = dict[str, nn.Parameter]() - - for name, param in module.named_parameters(recurse=False): - if param.requires_grad: - rg_params[prefix + name] = param - else: - frozen_params[prefix + name] = param - - return rg_params, frozen_params - - -def _get_indirectly_used_params( - module: nn.Module, -) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]: - # MHA uses its out_proj child params itself. Note that we also check that the MHA still has - # an out_proj attribute because it might change in the future (which will remove the - # necessity of custom code for MHA entirely). See the status of - # https://github.com/pytorch/pytorch/pull/126568 - if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"): - return _get_direct_params(module.out_proj, prefix="out_proj.") - - return {}, {} diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index acf79c3b..86df495b 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -6,8 +6,6 @@ from torch.nn import Parameter from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten -from torchjd.autogram._module_utils import get_used_params - # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see # https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400). @@ -39,7 +37,15 @@ class ModuleVJP(VJP, ABC): def __init__(self, module: nn.Module): self.module = module - self.rg_params, self.frozen_params = get_used_params(module) + + self.rg_params = dict[str, Parameter]() + self.frozen_params = dict[str, Parameter]() + + for name, param in module.named_parameters(recurse=True): + if param.requires_grad: + self.rg_params[name] = param + else: + self.frozen_params[name] = param class FunctionalVJP(ModuleVJP): diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 06bc0589..64ce48f7 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -20,7 +20,7 @@ def test_engine(): weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 871d1c1c..53d92ed2 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -94,7 +94,7 @@ def test_iwmtl(): optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") weighting = Flattening(UPGradWeighting()) - engine = Engine(shared_module.modules(), batch_dim=0) + engine = Engine(shared_module, batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task @@ -184,7 +184,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -374,7 +374,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules(), batch_dim=0) + engine = Engine(model[2:], batch_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index ca6b9ba7..5188ba86 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -121,7 +121,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules(), batch_dim=0) + engine = Engine(model, batch_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b7f43de9..31f63559 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -150,7 +150,7 @@ def _assert_gramian_is_equivalent_to_autograd( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) @@ -261,7 +261,7 @@ def test_compute_gramian_various_output_shapes( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) @@ -324,7 +324,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True) torch.manual_seed(0) - engine = Engine(gramian_modules, batch_dim=batch_dim) + engine = Engine(*gramian_modules, batch_dim=batch_dim) output = model(input) losses = reduce_to_vector(loss_fn(output)) @@ -349,7 +349,7 @@ def test_iwrm_steps_with_autogram( model = architecture().to(device=DEVICE) - engine = Engine(model.modules(), batch_dim=batch_dim) + engine = Engine(model, batch_dim=batch_dim) optimizer = SGD(model.parameters(), lr=1e-7) for i in range(n_iter): @@ -388,7 +388,7 @@ def test_autograd_while_modules_are_hooked( autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} # Hook modules and optionally compute the Gramian - engine = Engine(model_autogram.modules(), batch_dim=batch_dim) + engine = Engine(model_autogram, batch_dim=batch_dim) if use_engine: torch.manual_seed(0) # Fix randomness for random models output = model_autogram(input) @@ -420,7 +420,7 @@ def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | No model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules(), batch_dim=batch_dim) + _ = Engine(model, batch_dim=batch_dim) def test_compute_gramian_manual(): @@ -434,7 +434,7 @@ def test_compute_gramian_manual(): torch.manual_seed(0) model = Linear(in_dims, out_dims).to(device=DEVICE) - engine = Engine(model.modules(), batch_dim=None) + engine = Engine(model, batch_dim=None) input = randn_(in_dims) output = model(input) @@ -480,8 +480,8 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None): output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=None) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=None) + engine2 = Engine(model, batch_dim=None) input = randn_([input_size]) output = model(input) @@ -520,8 +520,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: output_size = prod(shape[1:]) model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=None) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=None) + engine2 = Engine(model, batch_dim=None) input = randn_([input_size]) output = model(input).reshape(shape[1:]) @@ -563,8 +563,8 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): output_size = input_size model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine([model], batch_dim=batch_dim) - engine2 = Engine([model], batch_dim=None) + engine1 = Engine(model, batch_dim=batch_dim) + engine2 = Engine(model, batch_dim=None) input = randn_([batch_size, input_size]) output = model(input) @@ -595,8 +595,8 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz torch.manual_seed(0) model_none = architecture().to(device=DEVICE) - engine_0 = Engine(model_0.modules(), batch_dim=0) - engine_none = Engine(model_none.modules(), batch_dim=None) + engine_0 = Engine(model_0, batch_dim=0) + engine_none = Engine(model_none, batch_dim=None) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes)