From bcfaf042b2bf1db56180e88d3ab1c5a9317d1a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 15:35:04 +0200 Subject: [PATCH 1/3] fix(autogram): Work with rg outputs only --- src/torchjd/autogram/_module_hook_manager.py | 33 +++++++++++++------- src/torchjd/autogram/_vjp.py | 27 ++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 09bd7ca8..6a86372a 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,16 +101,23 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree: if self.gramian_accumulation_phase: - return output + return outputs - flat_outputs, output_spec = tree_flatten(output) + flat_outputs, output_spec = tree_flatten(outputs) - if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): + rg_outputs = list[Tensor]() + rg_output_indices = list[int]() + for idx, output in enumerate(flat_outputs): + if isinstance(output, Tensor) and output.requires_grad: + rg_outputs.append(output) + rg_output_indices.append(idx) + + if len(rg_outputs) == 0: # This can happen only if a module has a trainable param but outputs no tensor that # require grad - return output + return outputs requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] self.gramian_accumulator.track_parameter_paths(requires_grad_params) @@ -118,23 +125,25 @@ def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: # We only care about running the JacobianAccumulator node, so we need one of its child # edges (the edges of the original outputs of the model) as target. For memory # efficiency, we select the smallest one (that requires grad). - inf = float("inf") - preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) + preference = torch.tensor([t.numel() for t in rg_outputs]) index = cast(int, preference.argmin().item()) - self.target_edges.register(get_gradient_edge(flat_outputs[index])) + self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) + vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs) - autograd_fn_outputs = JacobianAccumulator.apply( + autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, vjp, args, self.gramian_accumulator, module, - *flat_outputs, + *rg_outputs, ) - return tree_unflatten(autograd_fn_outputs, output_spec) + for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs): + flat_outputs[idx] = output + + return tree_unflatten(flat_outputs, output_spec) class JacobianAccumulator(torch.autograd.Function): diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 1dad99c7..4e6c4985 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -70,16 +70,18 @@ def _call_on_one_instance( args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] - def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: + def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: all_state = { **trainable_params, **dict(self.module.named_buffers()), **self.frozen_params, } output = torch.func.functional_call(self.module, all_state, args_j) - return tree_flatten(output)[0] + flat_outputs = tree_flatten(output)[0] + rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] + return rg_outputs - vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1] + vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1] # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the # functional has a single primal which is dict(module.named_parameters()). We therefore take @@ -95,28 +97,17 @@ class AutogradVJP(ModuleVJP): forward pass. """ - def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): + def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): super().__init__(module) - self.outputs_that_require_grad = list[Tensor]() - self.mask = list[bool]() - for output in outputs: - requires_grad = output.requires_grad - if requires_grad: - self.outputs_that_require_grad.append(output) - self.mask.append(requires_grad) - + self.rg_outputs = rg_outputs self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: - - # Only keep the grad_outputs corresponding to outputs that require grad. - grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg] - grads = torch.autograd.grad( - self.outputs_that_require_grad, + self.rg_outputs, self.flat_trainable_params, - grad_outputs_, + grad_outputs, retain_graph=True, allow_unused=True, materialize_grads=True, From 629dcae3859ab2543d3a178e445c3d96a7633858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 15:35:17 +0200 Subject: [PATCH 2/3] Stop marking WithModuleWithStringOutput as xfail --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index bd3922d5..74084f3a 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -109,6 +109,7 @@ (Ndim3Output, 32), (Ndim4Output, 32), (WithDropout, 32), + (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), @@ -167,7 +168,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), WithModuleWithStringArg, - param(WithModuleWithStringOutput, marks=mark.xfail), ], ) @mark.parametrize("batch_size", [1, 3, 32]) From 286c7946f845ce73399627863379eab9a052d73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 1 Oct 2025 15:39:58 +0200 Subject: [PATCH 3/3] Rename --- src/torchjd/autogram/_module_hook_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 6a86372a..3cc4cf6b 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -164,9 +164,9 @@ def forward( args: PyTree, gramian_accumulator: GramianAccumulator, module: nn.Module, - *xs: Tensor, + *rg_tensors: Tensor, ) -> tuple[Tensor, ...]: - return tuple(x.detach() for x in xs) + return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]