diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 5aabcdc9..96610e38 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -169,11 +169,20 @@ def compute_gramian(self, output: Tensor) -> Tensor: """ reshaped_output = output.reshape([-1]) - return self._compute_square_gramian(reshaped_output) - def _compute_square_gramian(self, output: Tensor) -> Tensor: self._module_hook_manager.gramian_accumulation_phase = True + try: + square_gramian = self._compute_square_gramian(reshaped_output) + finally: + # Reset everything that has a state, even if the previous call raised an exception + self._module_hook_manager.gramian_accumulation_phase = False + self._gramian_accumulator.reset() + self._target_edges.reset() + + return square_gramian + + def _compute_square_gramian(self, output: Tensor) -> Tensor: leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: @@ -190,9 +199,4 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # have failed. So gramian is necessarily a valid Tensor here. gramian = cast(Tensor, self._gramian_accumulator.gramian) - # Reset everything that has a state - self._module_hook_manager.gramian_accumulation_phase = False - self._gramian_accumulator.reset() - self._target_edges.reset() - return gramian