diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 7baabdb7..b0451e16 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -11,6 +11,7 @@ from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine from torchjd.autojac import backward +from torchjd.autojac._jac_to_grad import jac_to_grad def autograd_forward_backward( @@ -29,7 +30,8 @@ def autojac_forward_backward( aggregator: Aggregator, ) -> None: losses = forward_pass(model, inputs, loss_fn, reduce_to_vector) - backward(losses, aggregator=aggregator) + backward(losses) + jac_to_grad(model.parameters(), aggregator) def autograd_gramian_forward_backward(