From 99b4260f5a5a29c325703c90a72b8eaed5f54839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 01:25:03 +0100 Subject: [PATCH 01/35] feat: Use .jac fields --- docs/source/examples/amp.rst | 8 +- docs/source/examples/basic_usage.rst | 21 +-- docs/source/examples/iwmtl.rst | 4 +- docs/source/examples/iwrm.rst | 17 +- .../source/examples/lightning_integration.rst | 8 +- docs/source/examples/monitoring.rst | 8 +- docs/source/examples/mtl.rst | 8 +- docs/source/examples/partial_jd.rst | 2 +- docs/source/examples/rnn.rst | 8 +- src/torchjd/autojac/_backward.py | 42 ++--- src/torchjd/autojac/_mtl_backward.py | 41 +++-- src/torchjd/autojac/_transform/__init__.py | 7 +- src/torchjd/autojac/_transform/_accumulate.py | 37 ++--- src/torchjd/autojac/_transform/_aggregate.py | 151 ------------------ src/torchjd/utils/__init__.py | 3 + src/torchjd/utils/_accumulation.py | 43 +++++ src/torchjd/utils/_jac_to_grad.py | 75 +++++++++ tests/doc/test_backward.py | 5 +- tests/doc/test_rst.py | 51 +++--- 19 files changed, 258 insertions(+), 281 deletions(-) delete mode 100644 src/torchjd/autojac/_transform/_aggregate.py create mode 100644 src/torchjd/utils/__init__.py create mode 100644 src/torchjd/utils/_accumulation.py create mode 100644 src/torchjd/utils/_jac_to_grad.py diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index 0e719bfd..df6b78ec 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler following example shows the resulting code for a multi-task learning use-case. .. code-block:: python - :emphasize-lines: 2, 17, 27, 34, 36-38 + :emphasize-lines: 2, 18, 28, 35-36, 38-39 import torch from torch.amp import GradScaler @@ -21,6 +21,7 @@ following example shows the resulting code for a multi-task learning use-case. from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -48,10 +49,11 @@ following example shows the resulting code for a multi-task learning use-case. loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - optimizer.zero_grad() - mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) + mtl_backward(losses=scaled_losses, features=features) + jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() + optimizer.zero_grad() .. hint:: Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 8fa4320b..0920ef1c 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -20,6 +20,7 @@ Import several classes from ``torch`` and ``torchjd``: from torchjd import autojac from torchjd.aggregation import UPGrad + from torchjd.utils import jac_to_grad Define the model and the optimizer, as usual: @@ -59,20 +60,16 @@ We can now compute the losses associated to each element of the batch. The last steps are similar to gradient descent-based optimization, but using the two losses. -Reset the ``.grad`` field of each model parameter: - -.. code-block:: python - - optimizer.zero_grad() - Perform the Jacobian descent backward pass: .. code-block:: python - autojac.backward([loss1, loss2], aggregator) + autojac.backward([loss1, loss2]) + jac_to_grad(model.parameters(), aggregator) -This will populate the ``.grad`` field of each model parameter with the corresponding aggregated -Jacobian matrix. +The first function will populate the ``.jac`` field of each model parameter with the corresponding +Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad`` +field of the parameters. It also resets the ``.jac`` fields to ``None`` to save some memory. Update each parameter based on its ``.grad`` field, using the ``optimizer``: @@ -81,3 +78,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``: optimizer.step() The model's parameters have been updated! + +As usual, you should now reset the ``.grad`` field of each model parameter: + +.. code-block:: python + + optimizer.zero_grad() diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 4c1c7a4c..8b2410f7 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely. The following example shows how to do that. .. code-block:: python - :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 + :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -51,10 +51,10 @@ The following example shows how to do that. # Obtain the weights that lead to no conflict between reweighted gradients weights = weighting(gramian) # shape: [16, 2] - optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + optimizer.zero_grad() .. note:: In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index a326f582..3a3cc5d5 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -50,6 +50,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac + X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -64,11 +65,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] loss = loss_fn(y_hat, y) # shape: [] (scalar) - optimizer.zero_grad() loss.backward() optimizer.step() + optimizer.zero_grad() In this baseline example, the update may negatively affect the loss of some elements of the batch. @@ -76,7 +77,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autojac .. code-block:: python - :emphasize-lines: 5-6, 12, 16, 21, 23 + :emphasize-lines: 5-7, 13, 17, 22-24 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -84,6 +85,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac from torchjd.aggregation import UPGrad from torchjd.autojac import backward + from torchjd.utils import jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -99,11 +101,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() - backward(losses, aggregator) - + backward(losses) + jac_to_grad(model.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() Here, we compute the Jacobian of the per-sample losses with respect to the model parameters and use it to update the model such that no loss from the batch is (locally) increased. @@ -111,7 +113,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autogram (recommended) .. code-block:: python - :emphasize-lines: 5-6, 12, 16-17, 21, 23-25 + :emphasize-lines: 5-6, 13, 17-18, 22-25 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -120,6 +122,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac from torchjd.aggregation import UPGradWeighting from torchjd.autogram import Engine + X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -134,11 +137,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() Here, the per-sample gradients are never fully stored in memory, leading to large improvements in memory usage and speed compared to autojac, in most practical cases. The diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 203f63b5..a010361c 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using <../docs/autojac/mtl_backward>` at each training iteration. .. code-block:: python - :emphasize-lines: 9-10, 18, 32 + :emphasize-lines: 9-11, 19, 32-33 import torch from lightning import LightningModule, Trainer @@ -23,6 +23,7 @@ The following code example demonstrates a basic multi-task learning setup using from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad class Model(LightningModule): def __init__(self): @@ -43,9 +44,10 @@ The following code example demonstrates a basic multi-task learning setup using loss2 = mse_loss(output2, target2) opt = self.optimizers() - opt.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() + opt.zero_grad() def configure_optimizers(self) -> OptimizerLRScheduler: optimizer = Adam(self.parameters(), lr=1e-3) diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 8ec675aa..6297e62c 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -15,7 +15,7 @@ Jacobian descent is doing something different than gradient descent. With they have a negative inner product). .. code-block:: python - :emphasize-lines: 9-11, 13-18, 33-34 + :emphasize-lines: 10-12, 14-19, 34-35 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -24,6 +24,7 @@ they have a negative inner product). from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" @@ -63,6 +64,7 @@ they have a negative inner product). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index d726ae3a..e739d8a2 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -19,7 +19,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. .. code-block:: python - :emphasize-lines: 5-6, 19, 33 + :emphasize-lines: 5-7, 20, 33-34 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -27,6 +27,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -52,9 +53,10 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() .. note:: In this example, the Jacobian is only with respect to the shared parameters. The task-specific diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index c86a653a..ad82205a 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) weights = weighting(gramian) losses.backward(weights) optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/rnn.rst b/docs/source/examples/rnn.rst index d9cb8b98..847b66d9 100644 --- a/docs/source/examples/rnn.rst +++ b/docs/source/examples/rnn.rst @@ -6,7 +6,7 @@ element of the output sequences. If the gradients of these losses are likely to descent can be leveraged to enhance optimization. .. code-block:: python - :emphasize-lines: 5-6, 10, 17, 20 + :emphasize-lines: 5-7, 11, 18, 20-21 import torch from torch.nn import RNN @@ -14,6 +14,7 @@ descent can be leveraged to enhance optimization. from torchjd.aggregation import UPGrad from torchjd.autojac import backward + from torchjd.utils import jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) @@ -26,9 +27,10 @@ descent can be leveraged to enhance optimization. output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - optimizer.zero_grad() - backward(losses, aggregator, parallel_chunk_size=1) + backward(losses, parallel_chunk_size=1) + jac_to_grad(rnn.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() .. note:: At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index ca3009bc..46ac2d48 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -2,28 +2,23 @@ from torch import Tensor -from torchjd.aggregation import Aggregator - -from ._transform import Accumulate, Aggregate, Diagonalize, Init, Jac, OrderedSet, Transform +from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors def backward( tensors: Sequence[Tensor] | Tensor, - aggregator: Aggregator, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: r""" - Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its - aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the - ``inputs``. - - :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian - matrices will have one row for each value of each of these tensors. - :param aggregator: Aggregator used to reduce the Jacobian into a vector. - :param inputs: The tensors with respect to which the Jacobian must be computed. These must have + Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and + accumulates them in the `.jac` fields of the `inputs`. + + :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will + have one row for each value of each of these tensors. + :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to @@ -41,7 +36,6 @@ def backward( >>> import torch >>> - >>> from torchjd.aggregation import UPGrad >>> from torchjd.autojac import backward >>> >>> param = torch.tensor([1., 2.], requires_grad=True) @@ -49,12 +43,13 @@ def backward( >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> - >>> backward([y1, y2], UPGrad()) + >>> backward([y1, y2]) >>> - >>> param.grad - tensor([0.5000, 2.5000]) + >>> param.jac + tensor([[-1., 1.], + [ 2., 4.]]) - The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of + The ``.jac`` field of ``param`` now contains the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. .. warning:: @@ -80,7 +75,6 @@ def backward( backward_transform = _create_transform( tensors=tensors_, - aggregator=aggregator, inputs=inputs_, retain_graph=retain_graph, parallel_chunk_size=parallel_chunk_size, @@ -91,12 +85,11 @@ def backward( def _create_transform( tensors: OrderedSet[Tensor], - aggregator: Aggregator, inputs: OrderedSet[Tensor], retain_graph: bool, parallel_chunk_size: int | None, ) -> Transform: - """Creates the Jacobian descent backward transform.""" + """Creates the backward transform.""" # Transform that creates gradient outputs containing only ones. init = Init(tensors) @@ -107,10 +100,7 @@ def _create_transform( # Transform that computes the required Jacobians. jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) - # Transform that aggregates the Jacobians. - aggregate = Aggregate(aggregator, inputs) - - # Transform that accumulates the result in the .grad field of the inputs. - accumulate = Accumulate() + # Transform that accumulates the result in the .jac field of the inputs. + accumulate = AccumulateJac() - return accumulate << aggregate << jac << diag << init + return accumulate << jac << diag << init diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 4bdac023..0d05447d 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -2,16 +2,23 @@ from torch import Tensor -from torchjd.aggregation import Aggregator - -from ._transform import Accumulate, Aggregate, Grad, Init, Jac, OrderedSet, Select, Stack, Transform +from ._transform import ( + AccumulateGrad, + AccumulateJac, + Grad, + Init, + Jac, + OrderedSet, + Select, + Stack, + Transform, +) from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors def mtl_backward( losses: Sequence[Tensor], features: Sequence[Tensor] | Tensor, - aggregator: Aggregator, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, retain_graph: bool = False, @@ -23,21 +30,18 @@ def mtl_backward( This function computes the gradient of each task-specific loss with respect to its task-specific parameters and accumulates it in their ``.grad`` fields. Then, it computes the Jacobian of all - losses with respect to the shared parameters, aggregates it and accumulates the result in their - ``.grad`` fields. + losses with respect to the shared parameters and accumulates it in their ``.jac`` fields. - :param losses: The task losses. The Jacobian matrix will have one row per loss. + :param losses: The task losses. The Jacobians will have one row per loss. :param features: The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty. - :param aggregator: Aggregator used to reduce the Jacobian into a vector. :param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags must be set to ``True``. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its loss, but that were not used to compute the ``features``. - :param shared_params: The parameters of the shared feature extractor. The Jacobian matrix will - have one column for each value in these tensors. Their ``requires_grad`` flags must be set - to ``True``. If not provided, defaults to the leaf tensors that are in the computation graph - of the ``features``. + :param shared_params: The parameters of the shared feature extractor. Their ``requires_grad`` + flags must be set to ``True``. If not provided, defaults to the leaf tensors that are in the + computation graph of the ``features``. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to ``False``. :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the @@ -95,7 +99,6 @@ def mtl_backward( backward_transform = _create_transform( losses=losses_, features=features_, - aggregator=aggregator, tasks_params=tasks_params_, shared_params=shared_params_, retain_graph=retain_graph, @@ -108,7 +111,6 @@ def mtl_backward( def _create_transform( losses: OrderedSet[Tensor], features: OrderedSet[Tensor], - aggregator: Aggregator, tasks_params: list[OrderedSet[Tensor]], shared_params: OrderedSet[Tensor], retain_graph: bool, @@ -140,13 +142,10 @@ def _create_transform( # Transform that computes the Jacobians of the losses w.r.t. the shared parameters. jac = Jac(features, shared_params, parallel_chunk_size, retain_graph) - # Transform that aggregates the Jacobians. - aggregate = Aggregate(aggregator, shared_params) - - # Transform that accumulates the result in the .grad field of the shared parameters. - accumulate = Accumulate() + # Transform that accumulates the result in the .jac field of the shared parameters. + accumulate = AccumulateJac() - return accumulate << aggregate << jac << stack + return accumulate << jac << stack def _create_task_transform( @@ -167,7 +166,7 @@ def _create_task_transform( # Transform that accumulates the gradients w.r.t. the task-specific parameters into their # .grad fields. - accumulate = Accumulate() << Select(task_params) + accumulate = AccumulateGrad() << Select(task_params) # Transform that backpropagates the gradients of the losses w.r.t. the features. backpropagate = Select(features) diff --git a/src/torchjd/autojac/_transform/__init__.py b/src/torchjd/autojac/_transform/__init__.py index 46be392d..10d1c512 100644 --- a/src/torchjd/autojac/_transform/__init__.py +++ b/src/torchjd/autojac/_transform/__init__.py @@ -1,5 +1,4 @@ -from ._accumulate import Accumulate -from ._aggregate import Aggregate +from ._accumulate import AccumulateGrad, AccumulateJac from ._base import Composition, Conjunction, RequirementError, Transform from ._diagonalize import Diagonalize from ._grad import Grad @@ -10,8 +9,8 @@ from ._stack import Stack __all__ = [ - "Accumulate", - "Aggregate", + "AccumulateGrad", + "AccumulateJac", "Composition", "Conjunction", "Diagonalize", diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 5a1ac89c..5aaa6bd1 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -1,44 +1,33 @@ from torch import Tensor +from torchjd.utils._accumulation import _accumulate_grads, _accumulate_jacs + from ._base import TensorDict, Transform -class Accumulate(Transform): +class AccumulateGrad(Transform): """ Transform from Gradients to {} that accumulates gradients with respect to keys into their ``grad`` field. """ def __call__(self, gradients: TensorDict) -> TensorDict: - for key in gradients.keys(): - _check_expects_grad(key) - if hasattr(key, "grad") and key.grad is not None: - key.grad += gradients[key] - else: - # We clone the value because we do not want subsequent accumulations to also affect - # this value (in case it is still used outside). We do not detach from the - # computation graph because the value can have grad_fn that we want to keep track of - # (in case it was obtained via create_graph=True and a differentiable aggregator). - key.grad = gradients[key].clone() - + _accumulate_grads(gradients.keys(), gradients.values()) return {} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: return set() -def _check_expects_grad(tensor: Tensor) -> None: - if not _expects_grad(tensor): - raise ValueError( - "Cannot populate the .grad field of a Tensor that does not satisfy:" - "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." - ) - - -def _expects_grad(tensor: Tensor) -> bool: +class AccumulateJac(Transform): """ - Determines whether a Tensor expects its .grad attribute to be populated. - See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. + Transform from Jacobians to {} that accumulates jacobians with respect to keys into their + ``jac`` field. """ - return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) + def __call__(self, jacobians: TensorDict) -> TensorDict: + _accumulate_jacs(jacobians.keys(), jacobians.values()) + return {} + + def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + return set() diff --git a/src/torchjd/autojac/_transform/_aggregate.py b/src/torchjd/autojac/_transform/_aggregate.py deleted file mode 100644 index 6f1b2cca..00000000 --- a/src/torchjd/autojac/_transform/_aggregate.py +++ /dev/null @@ -1,151 +0,0 @@ -from collections import OrderedDict -from collections.abc import Hashable -from typing import TypeVar - -import torch -from torch import Tensor - -from torchjd.aggregation import Aggregator - -from ._base import RequirementError, TensorDict, Transform -from ._ordered_set import OrderedSet - -_KeyType = TypeVar("_KeyType", bound=Hashable) -_ValueType = TypeVar("_ValueType") - - -class Aggregate(Transform): - """ - Transform aggregating Jacobians into Gradients. - - It does so by reshaping these Jacobians into matrices, concatenating them into a single matrix, - applying an aggregator to it, separating the result back into one gradient vector per key, and - finally reshaping those into gradients of the same shape as their corresponding keys. - - :param aggregator: The aggregator used to aggregate the concatenated jacobian matrix. - :param key_order: Order in which the different jacobian matrices must be concatenated. - """ - - def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]): - matrixify = _Matrixify() - aggregate_matrices = _AggregateMatrices(aggregator, key_order) - reshape = _Reshape() - - self._aggregator_str = str(aggregator) - self.transform = reshape << aggregate_matrices << matrixify - - def __call__(self, input: TensorDict) -> TensorDict: - return self.transform(input) - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return self.transform.check_keys(input_keys) - - -class _AggregateMatrices(Transform): - """ - Transform aggregating JacobiansMatrices into GradientsVectors. - - It does so by concatenating the matrices into a single matrix, applying an aggregator to it and - separating the result back into one gradient vector per key. - - :param aggregator: The aggregator used to aggregate the concatenated jacobian matrix. - :param key_order: Order in which the different jacobian matrices must be concatenated. - """ - - def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]): - self.key_order = key_order - self.aggregator = aggregator - - def __call__(self, jacobian_matrices: TensorDict) -> TensorDict: - """ - Concatenates the provided ``jacobian_matrices`` into a single matrix and aggregates it using - the ``aggregator``. Returns the dictionary mapping each key from ``jacobian_matrices`` to - the part of the obtained gradient vector, that corresponds to the jacobian matrix given for - that key. - - :param jacobian_matrices: The dictionary of jacobian matrices to aggregate. The first - dimension of each jacobian matrix should be the same. - """ - ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order) - return self._aggregate_group(ordered_matrices, self.aggregator) - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - if not set(self.key_order) == input_keys: - raise RequirementError( - f"The input_keys must match the key_order. Found input_keys {input_keys} and" - f"key_order {self.key_order}." - ) - return input_keys - - @staticmethod - def _select_ordered_subdict( - dictionary: dict[_KeyType, _ValueType], ordered_keys: OrderedSet[_KeyType] - ) -> OrderedDict[_KeyType, _ValueType]: - """ - Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``. - Returns an OrderedDict in the same order as the provided ``ordered_keys``. - """ - - return OrderedDict([(key, dictionary[key]) for key in ordered_keys]) - - @staticmethod - def _aggregate_group( - jacobian_matrices: OrderedDict[Tensor, Tensor], aggregator: Aggregator - ) -> TensorDict: - """ - Unites the jacobian matrices and aggregates them using an - :class:`~torchjd.aggregation._aggregator_bases.Aggregator`. Returns the obtained gradient - vectors. - """ - - if len(jacobian_matrices) == 0: - return {} - - united_jacobian_matrix = _AggregateMatrices._unite(jacobian_matrices) - united_gradient_vector = aggregator(united_jacobian_matrix) - gradient_vectors = _AggregateMatrices._disunite(united_gradient_vector, jacobian_matrices) - return gradient_vectors - - @staticmethod - def _unite(jacobian_matrices: OrderedDict[Tensor, Tensor]) -> Tensor: - return torch.cat(list(jacobian_matrices.values()), dim=1) - - @staticmethod - def _disunite( - united_gradient_vector: Tensor, jacobian_matrices: OrderedDict[Tensor, Tensor] - ) -> TensorDict: - gradient_vectors = {} - start = 0 - for key, jacobian_matrix in jacobian_matrices.items(): - end = start + jacobian_matrix.shape[1] - current_gradient_vector = united_gradient_vector[start:end] - gradient_vectors[key] = current_gradient_vector - start = end - return gradient_vectors - - -class _Matrixify(Transform): - """Transform reshaping Jacobians into JacobianMatrices.""" - - def __call__(self, jacobians: TensorDict) -> TensorDict: - jacobian_matrices = { - key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items() - } - return jacobian_matrices - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return input_keys - - -class _Reshape(Transform): - """Transform reshaping GradientVectors into Gradients.""" - - def __call__(self, gradient_vectors: TensorDict) -> TensorDict: - gradients = { - key: gradient_vector.view(key.shape) - for key, gradient_vector in gradient_vectors.items() - } - return gradients - - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: - return input_keys diff --git a/src/torchjd/utils/__init__.py b/src/torchjd/utils/__init__.py new file mode 100644 index 00000000..158fcece --- /dev/null +++ b/src/torchjd/utils/__init__.py @@ -0,0 +1,3 @@ +from ._jac_to_grad import jac_to_grad + +__all__ = ["jac_to_grad"] diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/utils/_accumulation.py new file mode 100644 index 00000000..60abf0d2 --- /dev/null +++ b/src/torchjd/utils/_accumulation.py @@ -0,0 +1,43 @@ +from collections.abc import Iterable + +from torch import Tensor + + +def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: + for param, jac in zip(params, jacobians, strict=True): + _check_expects_grad(param) + if hasattr(param, "jac") and param.jac is not None: + param.jac += jac + else: + # TODO: this could be a serious memory issue + # We clone the value because we do not want subsequent accumulations to also affect + # this value (in case it is still used outside). We do not detach from the + # computation graph because the value can have grad_fn that we want to keep track of + # (in case it was obtained via create_graph=True and a differentiable aggregator). + param.jac = jac.clone() + + +def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: + for param, grad in zip(params, gradients, strict=True): + _check_expects_grad(param) + if hasattr(param, "grad") and param.grad is not None: + param.grad += grad + else: + param.grad = grad.clone() + + +def _check_expects_grad(tensor: Tensor) -> None: + if not _expects_grad(tensor): + raise ValueError( + "Cannot populate the .grad field of a Tensor that does not satisfy:" + "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." + ) + + +def _expects_grad(tensor: Tensor) -> bool: + """ + Determines whether a Tensor expects its .grad attribute to be populated. + See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information. + """ + + return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad) diff --git a/src/torchjd/utils/_jac_to_grad.py b/src/torchjd/utils/_jac_to_grad.py new file mode 100644 index 00000000..76331157 --- /dev/null +++ b/src/torchjd/utils/_jac_to_grad.py @@ -0,0 +1,75 @@ +from collections.abc import Iterable + +import torch +from torch import Tensor + +from torchjd.aggregation import Aggregator +from torchjd.utils._accumulation import _accumulate_grads + + +def jac_to_grad( + params: Iterable[Tensor], aggregator: Aggregator, retain_jacs: bool = False +) -> None: + """ + Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result + into their ``.grad`` fields. + + :param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must + have the same first dimension (number of outputs). + :param aggregator: The aggregator used to reduce the Jacobians into gradients. + :param retain_jacs: Whether to preserve the ``.jac`` fields of the parameters. + """ + + params_ = list(params) + + if len(params_) == 0: + return + + if not all([hasattr(p, "jac") and p.jac is not None for p in params_]): + raise ValueError( + "Some `jac` fields were not populated. Did you use `autojac.backward` before calling " + "`jac_to_grad`?" + ) + + jacobians = [p.jac for p in params_] + + # TODO: check that the Jacobian shapes match + + jacobian_matrix = _unite_jacobians(jacobians) + gradient_vector = aggregator(jacobian_matrix) + gradients = _disunite_gradient(gradient_vector, jacobians, params_) + _accumulate_grads(params_, gradients) + + if not retain_jacs: + _free_jacs(params_) + + +def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: + jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] + jacobian_matrix = torch.concat(jacobian_matrices, dim=1) + return jacobian_matrix + + +def _disunite_gradient( + gradient_vector: Tensor, jacobians: list[Tensor], params: list[Tensor] +) -> list[Tensor]: + gradient_vectors = [] + start = 0 + for jacobian in jacobians: + end = start + jacobian[0].numel() + current_gradient_vector = gradient_vector[start:end] + gradient_vectors.append(current_gradient_vector) + start = end + gradients = [g.view(param.shape) for param, g in zip(params, gradient_vectors, strict=True)] + return gradients + + +def _free_jacs(params: Iterable[Tensor]) -> None: + """ + Clears the ``.jac`` fields of the provided parameters by setting them to ``None``. + + :param params: The parameters whose ``.jac`` fields should be cleared. + """ + + for p in params: + p.jac = None diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index ca2e1a25..a099378b 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -9,7 +9,6 @@ def test_backward(): import torch - from torchjd.aggregation import UPGrad from torchjd.autojac import backward param = torch.tensor([1.0, 2.0], requires_grad=True) @@ -17,6 +16,6 @@ def test_backward(): y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() - backward([y1, y2], UPGrad()) + backward([y1, y2]) - assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + assert_close(param.jac, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index b64b504c..0b0d19ef 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -15,6 +15,7 @@ def test_amp(): from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -42,10 +43,11 @@ def test_amp(): loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - optimizer.zero_grad() - mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) + mtl_backward(losses=scaled_losses, features=features) + jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() + optimizer.zero_grad() def test_basic_usage(): @@ -55,6 +57,7 @@ def test_basic_usage(): from torchjd import autojac from torchjd.aggregation import UPGrad + from torchjd.utils import jac_to_grad model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2)) optimizer = SGD(model.parameters(), lr=0.1) @@ -69,9 +72,10 @@ def test_basic_usage(): loss1 = loss_fn(output[:, 0], target1) loss2 = loss_fn(output[:, 1], target2) - optimizer.zero_grad() - autojac.backward([loss1, loss2], aggregator) + autojac.backward([loss1, loss2]) + jac_to_grad(model.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() def test_iwmtl(): @@ -114,10 +118,10 @@ def test_iwmtl(): # Obtain the weights that lead to no conflict between reweighted gradients weights = weighting(gramian) # shape: [16, 2] - optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + optimizer.zero_grad() def test_iwrm(): @@ -138,9 +142,9 @@ def test_autograd(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] loss = loss_fn(y_hat, y) # shape: [] (scalar) - optimizer.zero_grad() loss.backward() optimizer.step() + optimizer.zero_grad() def test_autojac(): import torch @@ -149,6 +153,7 @@ def test_autojac(): from torchjd.aggregation import UPGrad from torchjd.autojac import backward + from torchjd.utils import jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -163,9 +168,10 @@ def test_autojac(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() - backward(losses, aggregator) + backward(losses) + jac_to_grad(model.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() def test_autogram(): import torch @@ -189,11 +195,11 @@ def test_autogram(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() test_autograd() test_autojac() @@ -220,6 +226,7 @@ def test_lightning_integration(): from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad class Model(LightningModule): def __init__(self): @@ -240,9 +247,11 @@ def training_step(self, batch, batch_idx) -> None: loss2 = mse_loss(output2, target2) opt = self.optimizers() - opt.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) + + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() + opt.zero_grad() def configure_optimizers(self) -> OptimizerLRScheduler: optimizer = Adam(self.parameters(), lr=1e-3) @@ -275,6 +284,7 @@ def test_monitoring(): from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" @@ -314,9 +324,10 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() def test_mtl(): @@ -326,6 +337,7 @@ def test_mtl(): from torchjd.aggregation import UPGrad from torchjd.autojac import mtl_backward + from torchjd.utils import jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -351,9 +363,10 @@ def test_mtl(): loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) + mtl_backward(losses=[loss1, loss2], features=features) + jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() def test_partial_jd(): @@ -382,11 +395,11 @@ def test_partial_jd(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) weights = weighting(gramian) losses.backward(weights) optimizer.step() + optimizer.zero_grad() def test_rnn(): @@ -396,6 +409,7 @@ def test_rnn(): from torchjd.aggregation import UPGrad from torchjd.autojac import backward + from torchjd.utils import jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) @@ -408,6 +422,7 @@ def test_rnn(): output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - optimizer.zero_grad() - backward(losses, aggregator, parallel_chunk_size=1) + backward(losses, parallel_chunk_size=1) + jac_to_grad(rnn.parameters(), aggregator) optimizer.step() + optimizer.zero_grad() From 7dda66fca3018b9ea8ac09db94aa28b01754184c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 11:49:37 +0100 Subject: [PATCH 02/35] Delete jac field instead of setting to None --- src/torchjd/utils/_accumulation.py | 2 +- src/torchjd/utils/_jac_to_grad.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/utils/_accumulation.py index 60abf0d2..1881ecba 100644 --- a/src/torchjd/utils/_accumulation.py +++ b/src/torchjd/utils/_accumulation.py @@ -6,7 +6,7 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param) - if hasattr(param, "jac") and param.jac is not None: + if hasattr(param, "jac"): param.jac += jac else: # TODO: this could be a serious memory issue diff --git a/src/torchjd/utils/_jac_to_grad.py b/src/torchjd/utils/_jac_to_grad.py index 76331157..96dfe61e 100644 --- a/src/torchjd/utils/_jac_to_grad.py +++ b/src/torchjd/utils/_jac_to_grad.py @@ -66,10 +66,10 @@ def _disunite_gradient( def _free_jacs(params: Iterable[Tensor]) -> None: """ - Clears the ``.jac`` fields of the provided parameters by setting them to ``None``. + Deletes the ``.jac`` field of the provided parameters. :param params: The parameters whose ``.jac`` fields should be cleared. """ for p in params: - p.jac = None + del p.jac From e08e45f27c41a158cdc39d13c39adf27eb05a0ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 13:52:53 +0100 Subject: [PATCH 03/35] [WIP] Fix jac undefined errors --- src/torchjd/utils/_accumulation.py | 10 +++++++--- src/torchjd/utils/_jac_to_grad.py | 22 +++++++++++++--------- src/torchjd/utils/_tensor_with_jac.py | 11 +++++++++++ 3 files changed, 31 insertions(+), 12 deletions(-) create mode 100644 src/torchjd/utils/_tensor_with_jac.py diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/utils/_accumulation.py index 1881ecba..5a538e2b 100644 --- a/src/torchjd/utils/_accumulation.py +++ b/src/torchjd/utils/_accumulation.py @@ -1,20 +1,24 @@ from collections.abc import Iterable +from typing import cast from torch import Tensor +from torchjd.utils._tensor_with_jac import TensorWithJac + def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param) - if hasattr(param, "jac"): - param.jac += jac + if hasattr(param, "jac"): # No check for None because jac cannot be None + param_ = cast(TensorWithJac, param) + param_.jac += jac else: # TODO: this could be a serious memory issue # We clone the value because we do not want subsequent accumulations to also affect # this value (in case it is still used outside). We do not detach from the # computation graph because the value can have grad_fn that we want to keep track of # (in case it was obtained via create_graph=True and a differentiable aggregator). - param.jac = jac.clone() + param.__setattr__("jac", jac.clone()) def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: diff --git a/src/torchjd/utils/_jac_to_grad.py b/src/torchjd/utils/_jac_to_grad.py index 96dfe61e..d6c8108d 100644 --- a/src/torchjd/utils/_jac_to_grad.py +++ b/src/torchjd/utils/_jac_to_grad.py @@ -1,10 +1,12 @@ from collections.abc import Iterable +from typing import cast import torch from torch import Tensor from torchjd.aggregation import Aggregator from torchjd.utils._accumulation import _accumulate_grads +from torchjd.utils._tensor_with_jac import TensorWithJac def jac_to_grad( @@ -20,17 +22,19 @@ def jac_to_grad( :param retain_jacs: Whether to preserve the ``.jac`` fields of the parameters. """ - params_ = list(params) + params_ = list[TensorWithJac]() + for p in params: + if not hasattr(p, "jac"): + raise ValueError( + "Some `jac` fields were not populated. Did you use `autojac.backward` before" + "calling `jac_to_grad`?" + ) + p_ = cast(TensorWithJac, p) + params_.append(p_) if len(params_) == 0: return - if not all([hasattr(p, "jac") and p.jac is not None for p in params_]): - raise ValueError( - "Some `jac` fields were not populated. Did you use `autojac.backward` before calling " - "`jac_to_grad`?" - ) - jacobians = [p.jac for p in params_] # TODO: check that the Jacobian shapes match @@ -51,7 +55,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( - gradient_vector: Tensor, jacobians: list[Tensor], params: list[Tensor] + gradient_vector: Tensor, jacobians: list[Tensor], params: list[TensorWithJac] ) -> list[Tensor]: gradient_vectors = [] start = 0 @@ -64,7 +68,7 @@ def _disunite_gradient( return gradients -def _free_jacs(params: Iterable[Tensor]) -> None: +def _free_jacs(params: Iterable[TensorWithJac]) -> None: """ Deletes the ``.jac`` field of the provided parameters. diff --git a/src/torchjd/utils/_tensor_with_jac.py b/src/torchjd/utils/_tensor_with_jac.py new file mode 100644 index 00000000..86af8222 --- /dev/null +++ b/src/torchjd/utils/_tensor_with_jac.py @@ -0,0 +1,11 @@ +from torch import Tensor + + +class TensorWithJac(Tensor): + """ + Tensor known to have a populated jac field. + + Should not be directly instantiated, but can be used as a type hint and can be casted to. + """ + + jac: Tensor From 57e5c6dddf03a1e52911e88425fd0b0381148950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 16:04:56 +0100 Subject: [PATCH 04/35] Add check that the number of rows of the jacobians is consistant --- src/torchjd/utils/_jac_to_grad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/utils/_jac_to_grad.py b/src/torchjd/utils/_jac_to_grad.py index d6c8108d..3f1754d5 100644 --- a/src/torchjd/utils/_jac_to_grad.py +++ b/src/torchjd/utils/_jac_to_grad.py @@ -37,7 +37,8 @@ def jac_to_grad( jacobians = [p.jac for p in params_] - # TODO: check that the Jacobian shapes match + if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): + raise ValueError("All Jacobians should have the same number of rows.") jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) From 6ea59835505c938bd04deba68e65a945263dc309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 16:12:14 +0100 Subject: [PATCH 05/35] Add check of jac shape before assigning to .jac --- src/torchjd/utils/_accumulation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/utils/_accumulation.py index 78f72088..d2b85ba0 100644 --- a/src/torchjd/utils/_accumulation.py +++ b/src/torchjd/utils/_accumulation.py @@ -22,6 +22,17 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N # We do not detach from the computation graph because the value can have grad_fn # that we want to keep track of (in case it was obtained via create_graph=True and a # differentiable aggregator). + # + # We also check that the shape is correct to be consistent with torch, that checks that + # the grad shape is correct before assigning it. + + if jac.shape[1:] != param.shape: + raise RuntimeError( + f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " + f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" + " jacobian are the same size" + ) + param.__setattr__("jac", jac) From b7b3a75ffdd98bdedd706ded1433e3356c68363b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 16:22:06 +0100 Subject: [PATCH 06/35] Add changelog entry --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84d75334..20753457 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,31 @@ changelog does not include internal changes that do not affect the user. ### Changed +- **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the + Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new + function `torchjd.utils.jac_to_grad` should then be called to aggregate those `.jac` fields into + `.grad` fields. + This means that users now have more control on what they do with the Jacobians (they can easily + aggregate them group by group or even param by param if they want), but it now requires an extra + line of code to do the Jacobian descent step. To update, please change: + ```python + backward(losses, aggregator) + ``` + to + ```python + backward(losses) + jac_to_grad(model.parameters(), aggregator) + ``` + and + ```python + mtl_backward(losses, features, aggregator) + ``` + to + ```python + mtl_backward(losses, features) + jac_to_grad(shared_module.parameters(), aggregator) + ``` + - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory efficiency of `autojac`. From 28cf7019bb9691a437889a93e4271a06c99acf0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 16:49:42 +0100 Subject: [PATCH 07/35] Move check of jacobian shape outside of if/else --- src/torchjd/utils/_accumulation.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/utils/_accumulation.py index d2b85ba0..d77ca69f 100644 --- a/src/torchjd/utils/_accumulation.py +++ b/src/torchjd/utils/_accumulation.py @@ -9,6 +9,15 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param) + # We that the shape is correct to be consistent with torch, that checks that the grad + # shape is correct before assigning it. + if jac.shape[1:] != param.shape: + raise RuntimeError( + f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " + f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" + " jacobian are the same size" + ) + if hasattr(param, "jac"): # No check for None because jac cannot be None param_ = cast(TensorWithJac, param) param_.jac += jac @@ -22,17 +31,6 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N # We do not detach from the computation graph because the value can have grad_fn # that we want to keep track of (in case it was obtained via create_graph=True and a # differentiable aggregator). - # - # We also check that the shape is correct to be consistent with torch, that checks that - # the grad shape is correct before assigning it. - - if jac.shape[1:] != param.shape: - raise RuntimeError( - f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " - f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" - " jacobian are the same size" - ) - param.__setattr__("jac", jac) From 9a2a0ecdcc2cd682251024542b01239fe60ec9e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 8 Jan 2026 16:53:58 +0100 Subject: [PATCH 08/35] Improve docstring of AccumulateJac and AccumulateGrad --- src/torchjd/autojac/_transform/_accumulate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 5aaa6bd1..cfd28b4f 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -9,6 +9,9 @@ class AccumulateGrad(Transform): """ Transform from Gradients to {} that accumulates gradients with respect to keys into their ``grad`` field. + + The Gradients are not cloned and may be modified in-place by subsequent accumulations, so they + should not be used elsewhere. """ def __call__(self, gradients: TensorDict) -> TensorDict: @@ -23,6 +26,9 @@ class AccumulateJac(Transform): """ Transform from Jacobians to {} that accumulates jacobians with respect to keys into their ``jac`` field. + + The Jacobians are not cloned and may be modified in-place by subsequent accumulations, so they + should not be used elsewhere. """ def __call__(self, jacobians: TensorDict) -> TensorDict: From 382e9d3547888854af033cd0c65ec28691dfbede Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 9 Jan 2026 20:16:13 +0100 Subject: [PATCH 09/35] Fix tests - Remove test_aggregate.py - Update test_accumulate.py and test_interactions.py to test on AccumulateGrad instead of Accumulate - Fix tests in test_backward.py and test_mtl_backward.py to match the new interface: check the jac field instead of the .grad field. - Use _asserts.py for helper functions common to backward.py and mtl_backward.py --- tests/unit/autojac/_asserts.py | 35 ++++ .../autojac/_transform/test_accumulate.py | 20 +- .../unit/autojac/_transform/test_aggregate.py | 155 -------------- .../autojac/_transform/test_interactions.py | 12 +- tests/unit/autojac/test_backward.py | 88 ++++---- tests/unit/autojac/test_mtl_backward.py | 194 +++++++++--------- 6 files changed, 187 insertions(+), 317 deletions(-) create mode 100644 tests/unit/autojac/_asserts.py delete mode 100644 tests/unit/autojac/_transform/test_aggregate.py diff --git a/tests/unit/autojac/_asserts.py b/tests/unit/autojac/_asserts.py new file mode 100644 index 00000000..3221b44b --- /dev/null +++ b/tests/unit/autojac/_asserts.py @@ -0,0 +1,35 @@ +from typing import cast + +import torch +from torch.testing import assert_close + +from torchjd.utils._tensor_with_jac import TensorWithJac + + +def assert_has_jac(t: torch.Tensor) -> None: + assert hasattr(t, "jac") + t_ = cast(TensorWithJac, t) + assert t_.jac is not None and t_.jac.shape[1:] == t_.shape + + +def assert_has_no_jac(t: torch.Tensor) -> None: + assert not hasattr(t, "jac") + + +def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor) -> None: + assert hasattr(t, "jac") + t_ = cast(TensorWithJac, t) + assert_close(t_.jac, expected_jac) + + +def assert_has_grad(t: torch.Tensor) -> None: + assert (t.grad is not None) and (t.shape == t.grad.shape) + + +def assert_has_no_grad(t: torch.Tensor) -> None: + assert t.grad is None + + +def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor) -> None: + assert t.grad is not None + assert_close(t.grad, expected_grad) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 45db6d61..d1166f58 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -2,12 +2,12 @@ from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import ones_, tensor_, zeros_ -from torchjd.autojac._transform import Accumulate +from torchjd.autojac._transform import AccumulateGrad def test_single_accumulation(): """ - Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run + Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run once. """ @@ -19,7 +19,7 @@ def test_single_accumulation(): value3 = ones_([2, 3]) input = {key1: value1, key2: value2, key3: value3} - accumulate = Accumulate() + accumulate = AccumulateGrad() output = accumulate(input) expected_output = {} @@ -35,7 +35,7 @@ def test_single_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) def test_multiple_accumulation(iterations: int): """ - Tests that the Accumulate transform correctly accumulates gradients in .grad fields when run + Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run `iterations` times. """ @@ -46,7 +46,7 @@ def test_multiple_accumulation(iterations: int): value2 = ones_([1]) value3 = ones_([2, 3]) - accumulate = Accumulate() + accumulate = AccumulateGrad() for i in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards @@ -65,7 +65,7 @@ def test_multiple_accumulation(iterations: int): def test_no_requires_grad_fails(): """ - Tests that the Accumulate transform raises an error when it tries to populate a .grad of a + Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that does not require grad. """ @@ -73,7 +73,7 @@ def test_no_requires_grad_fails(): value = ones_([1]) input = {key: value} - accumulate = Accumulate() + accumulate = AccumulateGrad() with raises(ValueError): accumulate(input) @@ -81,7 +81,7 @@ def test_no_requires_grad_fails(): def test_no_leaf_and_no_retains_grad_fails(): """ - Tests that the Accumulate transform raises an error when it tries to populate a .grad of a + Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that is not a leaf and that does not retain grad. """ @@ -89,7 +89,7 @@ def test_no_leaf_and_no_retains_grad_fails(): value = ones_([1]) input = {key: value} - accumulate = Accumulate() + accumulate = AccumulateGrad() with raises(ValueError): accumulate(input) @@ -99,7 +99,7 @@ def test_check_keys(): """Tests that the `check_keys` method works correctly.""" key = tensor_([1.0], requires_grad=True) - accumulate = Accumulate() + accumulate = AccumulateGrad() output_keys = accumulate.check_keys({key}) assert output_keys == set() diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py deleted file mode 100644 index 5beaed20..00000000 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ /dev/null @@ -1,155 +0,0 @@ -import math - -import torch -from pytest import mark, raises -from settings import DEVICE -from utils.dict_assertions import assert_tensor_dicts_are_close -from utils.tensors import rand_, tensor_, zeros_ - -from torchjd.aggregation import Random -from torchjd.autojac._transform import OrderedSet, RequirementError -from torchjd.autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape -from torchjd.autojac._transform._base import TensorDict - - -def _make_jacobian_matrices(n_outputs: int, rng: torch.Generator) -> TensorDict: - jacobian_shapes = [[n_outputs, math.prod(shape)] for shape in _param_shapes] - jacobian_list = [rand_(shape, generator=rng) for shape in jacobian_shapes] - jacobian_matrices = {key: jac for key, jac in zip(_keys, jacobian_list)} - return jacobian_matrices - - -_param_shapes = [ - [], - [1], - [2], - [5], - [1, 1], - [2, 3], - [5, 5], - [1, 1, 1], - [2, 3, 4], - [5, 5, 5], - [1, 1, 1, 1], - [2, 3, 4, 5], - [5, 5, 5, 5], -] -_keys = [zeros_(shape) for shape in _param_shapes] - -_rng = torch.Generator(device=DEVICE) -_rng.manual_seed(0) -_jacobian_matrix_dicts = [_make_jacobian_matrices(n_outputs, _rng) for n_outputs in [1, 2, 5]] - - -@mark.parametrize("jacobian_matrices", _jacobian_matrix_dicts) -def test_aggregate_matrices_output_structure(jacobian_matrices: TensorDict): - """ - Tests that applying _AggregateMatrices to various dictionaries of jacobian matrices gives an - output of the desired structure. - """ - - aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet(_keys)) - gradient_vectors = aggregate_matrices(jacobian_matrices) - - assert set(jacobian_matrices.keys()) == set(gradient_vectors.keys()) - - for key in jacobian_matrices.keys(): - assert gradient_vectors[key].numel() == jacobian_matrices[key][0].numel() - - -def test_aggregate_matrices_empty_dict(): - """Tests that applying _AggregateMatrices to an empty input gives an empty output.""" - - aggregate_matrices = _AggregateMatrices(Random(), key_order=OrderedSet([])) - gradient_vectors = aggregate_matrices({}) - assert len(gradient_vectors) == 0 - - -def test_matrixify(): - """Tests that the Matrixify transform correctly creates matrices from the jacobians.""" - - n_outputs = 5 - key1 = zeros_([]) - key2 = zeros_([1]) - key3 = zeros_([2, 3]) - value1 = tensor_([1.0] * n_outputs) - value2 = tensor_([[2.0]] * n_outputs) - value3 = tensor_([[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]] * n_outputs) - input = {key1: value1, key2: value2, key3: value3} - - matrixify = _Matrixify() - - output = matrixify(input) - expected_output = { - key1: tensor_([[1.0]] * n_outputs), - key2: tensor_([[2.0]] * n_outputs), - key3: tensor_([[3.0, 4.0, 5.0, 6.0, 7.0, 8.0]] * n_outputs), - } - - assert_tensor_dicts_are_close(output, expected_output) - - -def test_reshape(): - """Tests that the Reshape transform correctly creates gradients from gradient vectors.""" - - key1 = zeros_([]) - key2 = zeros_([1]) - key3 = zeros_([2, 3]) - value1 = tensor_([1.0]) - value2 = tensor_([2.0]) - value3 = tensor_([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - input = {key1: value1, key2: value2, key3: value3} - - reshape = _Reshape() - - output = reshape(input) - expected_output = { - key1: tensor_(1.0), - key2: tensor_([2.0]), - key3: tensor_([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]), - } - - assert_tensor_dicts_are_close(output, expected_output) - - -def test_aggregate_matrices_check_keys(): - """ - Tests that the `check_keys` method works correctly: the input_keys must match the stored - key_order. - """ - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - key3 = tensor_([2.0]) - aggregate = _AggregateMatrices(Random(), OrderedSet([key2, key1])) - - output_keys = aggregate.check_keys({key1, key2}) - assert output_keys == {key1, key2} - - with raises(RequirementError): - aggregate.check_keys({key1}) - - with raises(RequirementError): - aggregate.check_keys({key1, key2, key3}) - - -def test_matrixify_check_keys(): - """Tests that the `check_keys` method works correctly.""" - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - matrixify = _Matrixify() - - output_keys = matrixify.check_keys({key1, key2}) - assert output_keys == {key1, key2} - - -def test_reshape_check_keys(): - """Tests that the `check_keys` method works correctly.""" - - key1 = tensor_([1.0]) - key2 = tensor_([2.0]) - reshape = _Reshape() - - output_keys = reshape.check_keys({key1, key2}) - assert output_keys == {key1, key2} diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 8a943e83..a712dcef 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -5,7 +5,7 @@ from utils.tensors import tensor_, zeros_ from torchjd.autojac._transform import ( - Accumulate, + AccumulateGrad, Conjunction, Diagonalize, Grad, @@ -186,10 +186,10 @@ def test_conjunction_is_associative(): def test_conjunction_accumulate_select(): """ - Tests that it is possible to conjunct an Accumulate and a Select in this order. - It is not trivial since the type of the TensorDict returned by the first transform (Accumulate) - is EmptyDict, which is not the type that the conjunction should return (Gradients), but a - subclass of it. + Tests that it is possible to conjunct an AccumulateGrad and a Select in this order. + It is not trivial since the type of the TensorDict returned by the first transform + (AccumulateGrad) is EmptyDict, which is not the type that the conjunction should return + (Gradients), but a subclass of it. """ key = tensor_([1.0, 2.0, 3.0], requires_grad=True) @@ -197,7 +197,7 @@ def test_conjunction_accumulate_select(): input = {key: value} select = Select(set()) - accumulate = Accumulate() + accumulate = AccumulateGrad() conjunction = accumulate | select output = conjunction(input) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 885a9c15..30b9ef6c 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,14 +1,13 @@ import torch from pytest import mark, raises -from torch.autograd import grad -from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import backward from torchjd.autojac._backward import _create_transform from torchjd.autojac._transform import OrderedSet +from ._asserts import assert_has_jac, assert_has_no_jac, assert_jac_close + def test_check_create_transform(): """Tests that _create_transform creates a valid Transform.""" @@ -21,7 +20,6 @@ def test_check_create_transform(): transform = _create_transform( tensors=OrderedSet([y1, y2]), - aggregator=Mean(), inputs=OrderedSet([a1, a2]), retain_graph=False, parallel_chunk_size=None, @@ -31,8 +29,7 @@ def test_check_create_transform(): assert output_keys == set() -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) -def test_various_aggregators(aggregator: Aggregator): +def test_shape_is_correct(): """Tests that backward works for various aggregators.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -41,24 +38,22 @@ def test_various_aggregators(aggregator: Aggregator): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], aggregator) + backward([y1, y2]) for a in [a1, a2]: - assert (a.grad is not None) and (a.shape == a.grad.shape) + assert_has_jac(a) -@mark.parametrize("aggregator", [Mean(), UPGrad()]) @mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_inputs", [True, False]) @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( - aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool, chunk_size: int | None, ): """ - Tests that the .grad value filled by backward is correct in a simple example of matrix-vector + Tests that the .jac value filled by backward is correct in a simple example of matrix-vector product. """ @@ -73,16 +68,15 @@ def test_value_is_correct( backward( [output], - aggregator, inputs=inputs, parallel_chunk_size=chunk_size, ) - assert_close(input.grad, aggregator(J)) + assert_jac_close(input, J) def test_empty_inputs(): - """Tests that backward does not fill the .grad values if no input is specified.""" + """Tests that backward does not fill the .jac values if no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True) @@ -90,15 +84,15 @@ def test_empty_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], Mean(), inputs=[]) + backward([y1, y2], inputs=[]) for a in [a1, a2]: - assert a.grad is None + assert_has_no_jac(a) def test_partial_inputs(): """ - Tests that backward fills the right .grad values when only a subset of the actual inputs are + Tests that backward fills the right .jac values when only a subset of the actual inputs are specified as inputs. """ @@ -108,10 +102,10 @@ def test_partial_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], Mean(), inputs=[a1]) + backward([y1, y2], inputs=[a1]) - assert (a1.grad is not None) and (a1.shape == a1.grad.shape) - assert a2.grad is None + assert_has_jac(a1) + assert_has_no_jac(a2) def test_empty_tensors_fails(): @@ -121,7 +115,7 @@ def test_empty_tensors_fails(): a2 = tensor_([3.0, 4.0], requires_grad=True) with raises(ValueError): - backward([], UPGrad(), inputs=[a1, a2]) + backward([], inputs=[a1, a2]) def test_multiple_tensors(): @@ -130,8 +124,6 @@ def test_multiple_tensors(): containing the all the values of the original tensors. """ - aggregator = UPGrad() - a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True) inputs = [a1, a2] @@ -139,16 +131,17 @@ def test_multiple_tensors(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], aggregator, retain_graph=True) + # TODO: improve that + backward([y1, y2], retain_graph=True) - input_to_grad = {a: a.grad for a in inputs} + input_to_jac = {a: a.jac for a in inputs} for a in inputs: - a.grad = None + del a.jac - backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), aggregator) + backward(torch.cat([y1.reshape(-1), y2.reshape(-1)])) for a in inputs: - assert (a.grad == input_to_grad[a]).all() + assert_jac_close(a, input_to_jac[a]) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -161,10 +154,10 @@ def test_various_valid_chunk_sizes(chunk_size): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - backward([y1, y2], UPGrad(), parallel_chunk_size=chunk_size) + backward([y1, y2], parallel_chunk_size=chunk_size) for a in [a1, a2]: - assert (a.grad is not None) and (a.shape == a.grad.shape) + assert_has_jac(a) @mark.parametrize("chunk_size", [0, -1]) @@ -178,7 +171,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): y2 = (a1**2).sum() + a2.norm() with raises(ValueError): - backward([y1, y2], UPGrad(), parallel_chunk_size=chunk_size) + backward([y1, y2], parallel_chunk_size=chunk_size) def test_input_retaining_grad_fails(): @@ -192,8 +185,13 @@ def test_input_retaining_grad_fails(): b.retain_grad() y = 3 * b + # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor + # (and it also fills b.jac with the correct Jacobian) + backward(tensors=y, inputs=[b]) + with raises(RuntimeError): - backward(tensors=y, aggregator=UPGrad(), inputs=[b]) + # Using such a BatchedTensor should result in an error + _ = -b.grad def test_non_input_retaining_grad_fails(): @@ -208,7 +206,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor - backward(tensors=y, aggregator=UPGrad(), inputs=[a]) + backward(tensors=y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -227,18 +225,12 @@ def test_tensor_used_multiple_times(chunk_size: int | None): c = a * b d = a * c e = a * d - aggregator = UPGrad() - backward([d, e], aggregator=aggregator, parallel_chunk_size=chunk_size) + backward([d, e], parallel_chunk_size=chunk_size) - expected_jacobian = tensor_( - [ - [2.0 * 3.0 * (a**2).item()], - [2.0 * 4.0 * (a**3).item()], - ], - ) + expected_jacobian = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()]) - assert_close(a.grad, aggregator(expected_jacobian).squeeze()) + assert_jac_close(a, expected_jacobian) def test_repeated_tensors(): @@ -257,7 +249,7 @@ def test_repeated_tensors(): y2 = (a1**2).sum() + (a2**2).sum() with raises(ValueError): - backward([y1, y1, y2], Sum()) + backward([y1, y1, y2]) def test_repeated_inputs(): @@ -273,10 +265,10 @@ def test_repeated_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + (a2**2).sum() - expected_grad_wrt_a1 = grad([y1, y2], a1, retain_graph=True)[0] - expected_grad_wrt_a2 = grad([y1, y2], a2, retain_graph=True)[0] + J1 = tensor_([[-1.0, 1.0], [2.0, 4.0]]) + J2 = tensor_([[1.0, 1.0], [6.0, 8.0]]) - backward([y1, y2], Sum(), inputs=[a1, a1, a2]) + backward([y1, y2], inputs=[a1, a1, a2]) - assert_close(a1.grad, expected_grad_wrt_a1) - assert_close(a2.grad, expected_grad_wrt_a2) + assert_jac_close(a1, J1) + assert_jac_close(a2, J2) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 86595f92..1bd247f6 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -2,14 +2,21 @@ from pytest import mark, raises from settings import DTYPE from torch.autograd import grad -from torch.testing import assert_close from utils.tensors import arange_, rand_, randn_, tensor_ -from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad from torchjd.autojac import mtl_backward from torchjd.autojac._mtl_backward import _create_transform from torchjd.autojac._transform import OrderedSet +from ._asserts import ( + assert_grad_close, + assert_has_grad, + assert_has_jac, + assert_has_no_grad, + assert_has_no_jac, + assert_jac_close, +) + def test_check_create_transform(): """Tests that _create_transform creates a valid Transform.""" @@ -26,7 +33,6 @@ def test_check_create_transform(): transform = _create_transform( losses=OrderedSet([y1, y2]), features=OrderedSet([f1, f2]), - aggregator=Mean(), tasks_params=[OrderedSet([p1]), OrderedSet([p2])], shared_params=OrderedSet([p0]), retain_graph=False, @@ -37,9 +43,8 @@ def test_check_create_transform(): assert output_keys == set() -@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) -def test_various_aggregators(aggregator: Aggregator): - """Tests that mtl_backward works for various aggregators.""" +def test_shape_is_correct(): + """Tests that mtl_backward works correctly.""" p0 = tensor_([1.0, 2.0], requires_grad=True) p1 = tensor_([1.0, 2.0], requires_grad=True) @@ -50,26 +55,25 @@ def test_various_aggregators(aggregator: Aggregator): y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f1, f2]) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) -@mark.parametrize("aggregator", [Mean(), UPGrad()]) @mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_shared_params", [True, False]) @mark.parametrize("manually_specify_tasks_params", [True, False]) @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( - aggregator: Aggregator, shape: tuple[int, int], manually_specify_shared_params: bool, manually_specify_tasks_params: bool, chunk_size: int | None, ): """ - Tests that the .grad value filled by mtl_backward is correct in a simple example of + Tests that the .jac value filled by mtl_backward is correct in a simple example of matrix-vector product for three tasks whose loss are given by a simple inner product of the shared features with the task parameter. @@ -100,20 +104,17 @@ def test_value_is_correct( mtl_backward( losses=[y1, y2, y3], features=f, - aggregator=aggregator, tasks_params=tasks_params, shared_params=shared_params, parallel_chunk_size=chunk_size, ) - assert_close(p1.grad, f) - assert_close(p2.grad, f) - assert_close(p3.grad, f) + assert_grad_close(p1, f) + assert_grad_close(p2, f) + assert_grad_close(p3, f) expected_jacobian = torch.stack((p1, p2, p3)) @ J - expected_aggregation = aggregator(expected_jacobian) - - assert_close(p0.grad, expected_aggregation) + assert_jac_close(p0, expected_jacobian) def test_empty_tasks_fails(): @@ -125,7 +126,7 @@ def test_empty_tasks_fails(): f2 = (p0**2).sum() + p0.norm() with raises(ValueError): - mtl_backward(losses=[], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[], features=[f1, f2]) def test_single_task(): @@ -138,10 +139,10 @@ def test_single_task(): f2 = (p0**2).sum() + p0.norm() y1 = f1 * p1[0] + f2 * p1[1] - mtl_backward(losses=[y1], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1], features=[f1, f2]) - for p in [p0, p1]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + assert_has_grad(p1) def test_incoherent_task_number_fails(): @@ -163,7 +164,6 @@ def test_incoherent_task_number_fails(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[p1]], # Wrong shared_params=[p0], ) @@ -171,14 +171,13 @@ def test_incoherent_task_number_fails(): mtl_backward( losses=[y1], # Wrong features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) def test_empty_params(): - """Tests that mtl_backward does not fill the .grad values if no parameter is specified.""" + """Tests that mtl_backward does not fill the .jac/.grad values if no parameter is specified.""" p0 = tensor_([1.0, 2.0], requires_grad=True) p1 = tensor_([1.0, 2.0], requires_grad=True) @@ -192,13 +191,13 @@ def test_empty_params(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), tasks_params=[[], []], shared_params=[], ) - for p in [p0, p1, p2]: - assert p.grad is None + assert_has_no_jac(p0) + for p in [p1, p2]: + assert_has_no_grad(p) def test_multiple_params_per_task(): @@ -216,10 +215,11 @@ def test_multiple_params_per_task(): y1 = f1 * p1_a + (f2 * p1_b).sum() + (f1 * p1_c).sum() y2 = f1 * p2_a * (f2 * p2_b).sum() - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[f1, f2]) - for p in [p0, p1_a, p1_b, p1_c, p2_a, p2_b]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1_a, p1_b, p1_c, p2_a, p2_b]: + assert_has_grad(p) @mark.parametrize( @@ -249,19 +249,20 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): mtl_backward( losses=[y1, y2], features=features, - aggregator=UPGrad(), tasks_params=[[p1], [p2]], # Enforce differentiation w.r.t. params that haven't been used shared_params=shared_params, ) - for p in [*shared_params, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + for p in shared_params: + assert_has_jac(p) + for p in [p1, p2]: + assert_has_grad(p) def test_partial_params(): """ - Tests that mtl_backward fills the right .grad values when only a subset of the parameters are - specified as inputs. + Tests that mtl_backward fills the right .jac/.grad values when only a subset of the parameters + are specified as inputs. """ p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -276,14 +277,13 @@ def test_partial_params(): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=Mean(), tasks_params=[[p1], []], shared_params=[p0], ) - assert (p0.grad is not None) and (p0.shape == p0.grad.shape) - assert (p1.grad is not None) and (p1.shape == p1.grad.shape) - assert p2.grad is None + assert_has_jac(p0) + assert_has_grad(p1) + assert_has_no_grad(p2) def test_empty_features_fails(): @@ -299,7 +299,7 @@ def test_empty_features_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[]) @mark.parametrize( @@ -323,10 +323,11 @@ def test_various_single_features(shape: tuple[int, ...]): y1 = (f * p1[0]).sum() + (f * p1[1]).sum() y2 = (f * p2[0]).sum() * (f * p2[1]).sum() - mtl_backward(losses=[y1, y2], features=f, aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=f) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) @mark.parametrize( @@ -354,10 +355,11 @@ def test_various_feature_lists(shapes: list[tuple[int]]): y1 = sum([(f * p).sum() for f, p in zip(features, p1)]) y2 = (features[0] * p2).sum() - mtl_backward(losses=[y1, y2], features=features, aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=features) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) def test_non_scalar_loss_fails(): @@ -373,7 +375,7 @@ def test_non_scalar_loss_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=UPGrad()) + mtl_backward(losses=[y1, y2], features=[f1, f2]) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -392,12 +394,12 @@ def test_various_valid_chunk_sizes(chunk_size): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), parallel_chunk_size=chunk_size, ) - for p in [p0, p1, p2]: - assert (p.grad is not None) and (p.shape == p.grad.shape) + assert_has_jac(p0) + for p in [p1, p2]: + assert_has_grad(p) @mark.parametrize("chunk_size", [0, -1]) @@ -417,15 +419,14 @@ def test_non_positive_chunk_size_fails(chunk_size: int): mtl_backward( losses=[y1, y2], features=[f1, f2], - aggregator=UPGrad(), parallel_chunk_size=chunk_size, ) def test_shared_param_retaining_grad_fails(): """ - Tests that mtl_backward raises an error when some shared param in the computation graph of the - ``features`` parameter retains grad and vmap has to be used. + Tests that mtl_backward fails to fill a valid `.grad` when some shared param in the computation + graph of the ``features`` parameter retains grad and vmap has to be used. """ p0 = tensor_(1.0, requires_grad=True) @@ -438,14 +439,17 @@ def test_shared_param_retaining_grad_fails(): y1 = p1 * f y2 = p2 * f + # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor + mtl_backward( + losses=[y1, y2], + features=[f], + tasks_params=[[p1], [p2]], + shared_params=[a, p0], + ) + with raises(RuntimeError): - mtl_backward( - losses=[y1, y2], - features=[f], - aggregator=UPGrad(), - tasks_params=[[p1], [p2]], - shared_params=[a, p0], - ) + # Using such a BatchedTensor should result in an error + _ = -a.grad def test_shared_activation_retaining_grad_fails(): @@ -468,7 +472,6 @@ def test_shared_activation_retaining_grad_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) @@ -490,15 +493,14 @@ def test_tasks_params_overlap(): y1 = f * p1 * p12 y2 = f * p2 * p12 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f]) - assert_close(p2.grad, f * p12) - assert_close(p1.grad, f * p12) - assert_close(p12.grad, f * p1 + f * p2) + assert_grad_close(p2, f * p12) + assert_grad_close(p1, f * p12) + assert_grad_close(p12, f * p1 + f * p2) J = tensor_([[-8.0, 8.0], [-12.0, 12.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_tasks_params_are_the_same(): @@ -511,13 +513,12 @@ def test_tasks_params_are_the_same(): y1 = f * p1 y2 = f + p1 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator) + mtl_backward(losses=[y1, y2], features=[f]) - assert_close(p1.grad, f + 1) + assert_grad_close(p1, f + 1) J = tensor_([[-2.0, 2.0], [-1.0, 1.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_task_params_is_subset_of_other_task_params(): @@ -534,14 +535,13 @@ def test_task_params_is_subset_of_other_task_params(): y1 = f * p1 y2 = y1 * p2 - aggregator = UPGrad() - mtl_backward(losses=[y1, y2], features=[f], aggregator=aggregator, retain_graph=True) + mtl_backward(losses=[y1, y2], features=[f], retain_graph=True) - assert_close(p2.grad, y1) - assert_close(p1.grad, p2 * f + f) + assert_grad_close(p2, y1) + assert_grad_close(p1, p2 * f + f) J = tensor_([[-2.0, 2.0], [-6.0, 6.0]]) - assert_close(p0.grad, aggregator(J)) + assert_jac_close(p0, J) def test_shared_params_overlapping_with_tasks_params_fails(): @@ -562,7 +562,6 @@ def test_shared_params_overlapping_with_tasks_params_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), tasks_params=[[p1], [p0, p2]], # Problem: p0 is also shared shared_params=[p0], ) @@ -586,7 +585,6 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): mtl_backward( losses=[y1, y2], features=[f], - aggregator=UPGrad(), ) @@ -610,7 +608,7 @@ def test_repeated_losses(): with raises(ValueError): losses = [y1, y1, y2] - mtl_backward(losses=losses, features=[f1, f2], aggregator=Sum(), retain_graph=True) + mtl_backward(losses=losses, features=[f1, f2], retain_graph=True) def test_repeated_features(): @@ -633,7 +631,7 @@ def test_repeated_features(): with raises(ValueError): features = [f1, f1, f2] - mtl_backward(losses=[y1, y2], features=features, aggregator=Sum()) + mtl_backward(losses=[y1, y2], features=features) def test_repeated_shared_params(): @@ -648,20 +646,20 @@ def test_repeated_shared_params(): p2 = tensor_([3.0, 4.0], requires_grad=True) f1 = tensor_([-1.0, 1.0]) @ p0 - f2 = (p0**2).sum() + p0.norm() + f2 = (p0**2).sum() y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] - expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] - expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + J0 = tensor_([[3.0, 9.0], [5.0, 19.0]]) + g1 = grad([y1], [p1], retain_graph=True)[0] + g2 = grad([y2], [p2], retain_graph=True)[0] shared_params = [p0, p0] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), shared_params=shared_params) + mtl_backward(losses=[y1, y2], features=[f1, f2], shared_params=shared_params) - assert_close(p0.grad, expected_grad_wrt_p0) - assert_close(p1.grad, expected_grad_wrt_p1) - assert_close(p2.grad, expected_grad_wrt_p2) + assert_jac_close(p0, J0) + assert_grad_close(p1, g1) + assert_grad_close(p2, g2) def test_repeated_task_params(): @@ -676,17 +674,17 @@ def test_repeated_task_params(): p2 = tensor_([3.0, 4.0], requires_grad=True) f1 = tensor_([-1.0, 1.0]) @ p0 - f2 = (p0**2).sum() + p0.norm() + f2 = (p0**2).sum() y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0] - expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0] - expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0] + J0 = tensor_([[3.0, 9.0], [5.0, 19.0]]) + g1 = grad([y1], [p1], retain_graph=True)[0] + g2 = grad([y2], [p2], retain_graph=True)[0] tasks_params = [[p1, p1], [p2]] - mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), tasks_params=tasks_params) + mtl_backward(losses=[y1, y2], features=[f1, f2], tasks_params=tasks_params) - assert_close(p0.grad, expected_grad_wrt_p0) - assert_close(p1.grad, expected_grad_wrt_p1) - assert_close(p2.grad, expected_grad_wrt_p2) + assert_jac_close(p0, J0) + assert_grad_close(p1, g1) + assert_grad_close(p2, g2) From 7e95c374f50ff32bd699cd7a939c2dc5e811aa71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 9 Jan 2026 20:24:07 +0100 Subject: [PATCH 10/35] Simplify a test --- tests/unit/autojac/test_backward.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 30b9ef6c..3bd50e81 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -121,27 +121,35 @@ def test_empty_tensors_fails(): def test_multiple_tensors(): """ Tests that giving multiple tensors to backward is equivalent to giving a single tensor - containing the all the values of the original tensors. + containing all the values of the original tensors. """ + J1 = tensor_([[-1.0, 1.0], [2.0, 4.0]]) + J2 = tensor_([[1.0, 1.0], [0.6, 0.8]]) + + # First computation graph: multiple tensors a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True) - inputs = [a1, a2] y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - # TODO: improve that - backward([y1, y2], retain_graph=True) + backward([y1, y2]) + + assert_jac_close(a1, J1) + assert_jac_close(a2, J2) + + # Second computation graph: single concatenated tensor + b1 = tensor_([1.0, 2.0], requires_grad=True) + b2 = tensor_([3.0, 4.0], requires_grad=True) - input_to_jac = {a: a.jac for a in inputs} - for a in inputs: - del a.jac + z1 = tensor_([-1.0, 1.0]) @ b1 + b2.sum() + z2 = (b1**2).sum() + b2.norm() - backward(torch.cat([y1.reshape(-1), y2.reshape(-1)])) + backward(torch.cat([z1.reshape(-1), z2.reshape(-1)])) - for a in inputs: - assert_jac_close(a, input_to_jac[a]) + assert_jac_close(b1, J1) + assert_jac_close(b2, J2) @mark.parametrize("chunk_size", [None, 1, 2, 4]) From 18763513cfa546e25e76db0bd15aafea23ec496a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 16:19:42 +0100 Subject: [PATCH 11/35] Add unit tests for AccumulateJac Rename existing AccumulateGrad tests to make naming explicit and add corresponding tests for AccumulateJac, including a shape mismatch test. Co-Authored-By: Claude Opus 4.5 --- .../autojac/_transform/test_accumulate.py | 130 +++++++++++++++++- 1 file changed, 123 insertions(+), 7 deletions(-) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index d1166f58..4bea1c84 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -2,10 +2,10 @@ from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import ones_, tensor_, zeros_ -from torchjd.autojac._transform import AccumulateGrad +from torchjd.autojac._transform import AccumulateGrad, AccumulateJac -def test_single_accumulation(): +def test_single_grad_accumulation(): """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run once. @@ -33,7 +33,7 @@ def test_single_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_accumulation(iterations: int): +def test_multiple_grad_accumulations(iterations: int): """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run `iterations` times. @@ -63,7 +63,7 @@ def test_multiple_accumulation(iterations: int): assert_tensor_dicts_are_close(grads, expected_grads) -def test_no_requires_grad_fails(): +def test_accumulate_grad_fails_when_no_requires_grad(): """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that does not require grad. @@ -79,7 +79,7 @@ def test_no_requires_grad_fails(): accumulate(input) -def test_no_leaf_and_no_retains_grad_fails(): +def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that is not a leaf and that does not retain grad. @@ -95,11 +95,127 @@ def test_no_leaf_and_no_retains_grad_fails(): accumulate(input) -def test_check_keys(): - """Tests that the `check_keys` method works correctly.""" +def test_accumulate_grad_check_keys(): + """Tests that the `check_keys` method works correctly for AccumulateGrad.""" key = tensor_([1.0], requires_grad=True) accumulate = AccumulateGrad() output_keys = accumulate.check_keys({key}) assert output_keys == set() + + +def test_single_jac_accumulation(): + """ + Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run + once. + """ + + key1 = zeros_([], requires_grad=True) + key2 = zeros_([1], requires_grad=True) + key3 = zeros_([2, 3], requires_grad=True) + value1 = ones_([4]) + value2 = ones_([4, 1]) + value3 = ones_([4, 2, 3]) + input = {key1: value1, key2: value2, key3: value3} + + accumulate = AccumulateJac() + + output = accumulate(input) + expected_output = {} + + assert_tensor_dicts_are_close(output, expected_output) + + jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac} + expected_jacs = {key1: value1, key2: value2, key3: value3} + + assert_tensor_dicts_are_close(jacs, expected_jacs) + + +@mark.parametrize("iterations", [1, 2, 4, 10, 13]) +def test_multiple_jac_accumulations(iterations: int): + """ + Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run + `iterations` times. + """ + + key1 = zeros_([], requires_grad=True) + key2 = zeros_([1], requires_grad=True) + key3 = zeros_([2, 3], requires_grad=True) + value1 = ones_([4]) + value2 = ones_([4, 1]) + value3 = ones_([4, 2, 3]) + + accumulate = AccumulateJac() + + for i in range(iterations): + # Clone values to ensure that we accumulate values that are not ever used afterwards + input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()} + accumulate(input) + + jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac} + expected_jacs = { + key1: iterations * value1, + key2: iterations * value2, + key3: iterations * value3, + } + + assert_tensor_dicts_are_close(jacs, expected_jacs) + + +def test_accumulate_jac_fails_when_no_requires_grad(): + """ + Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a + tensor that does not require grad. + """ + + key = zeros_([1], requires_grad=False) + value = ones_([4, 1]) + input = {key: value} + + accumulate = AccumulateJac() + + with raises(ValueError): + accumulate(input) + + +def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): + """ + Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a + tensor that is not a leaf and that does not retain grad. + """ + + key = tensor_([1.0], requires_grad=True) * 2 + value = ones_([4, 1]) + input = {key: value} + + accumulate = AccumulateJac() + + with raises(ValueError): + accumulate(input) + + +def test_accumulate_jac_fails_when_shape_mismatch(): + """ + Tests that the AccumulateJac transform raises an error when the jacobian shape does not match + the parameter shape (ignoring the first dimension). + """ + + key = zeros_([2, 3], requires_grad=True) + value = ones_([4, 3, 2]) # Wrong shape: should be [4, 2, 3], not [4, 3, 2] + input = {key: value} + + accumulate = AccumulateJac() + + with raises(RuntimeError): + accumulate(input) + + +def test_accumulate_jac_check_keys(): + """Tests that the `check_keys` method works correctly for AccumulateJac.""" + + key = tensor_([1.0], requires_grad=True) + accumulate = AccumulateJac() + + output_keys = accumulate.check_keys({key}) + assert output_keys == set() From 4f24e39ec1b7e25699c6e189afc3d89833dc5237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 16:32:33 +0100 Subject: [PATCH 12/35] Refactor accumulate tests to use loops and assert helpers --- .../autojac/_transform/test_accumulate.py | 86 ++++++------------- 1 file changed, 27 insertions(+), 59 deletions(-) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 4bea1c84..6dadf1ef 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -1,4 +1,5 @@ from pytest import mark, raises +from unit.autojac._asserts import assert_grad_close, assert_jac_close from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import ones_, tensor_, zeros_ @@ -11,25 +12,18 @@ def test_single_grad_accumulation(): once. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([]) - value2 = ones_([1]) - value3 = ones_([2, 3]) - input = {key1: value1, key2: value2, key3: value3} + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_(shape) for shape in shapes] + input = dict(zip(keys, values)) accumulate = AccumulateGrad() output = accumulate(input) - expected_output = {} + assert_tensor_dicts_are_close(output, {}) - assert_tensor_dicts_are_close(output, expected_output) - - grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} - expected_grads = {key1: value1, key2: value2, key3: value3} - - assert_tensor_dicts_are_close(grads, expected_grads) + for key, value in zip(keys, values): + assert_grad_close(key, value) @mark.parametrize("iterations", [1, 2, 4, 10, 13]) @@ -39,28 +33,18 @@ def test_multiple_grad_accumulations(iterations: int): `iterations` times. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([]) - value2 = ones_([1]) - value3 = ones_([2, 3]) - + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_(shape) for shape in shapes] accumulate = AccumulateGrad() for i in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards - input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()} + input = {key: value.clone() for key, value in zip(keys, values)} accumulate(input) - grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad} - expected_grads = { - key1: iterations * value1, - key2: iterations * value2, - key3: iterations * value3, - } - - assert_tensor_dicts_are_close(grads, expected_grads) + for key, value in zip(keys, values): + assert_grad_close(key, iterations * value) def test_accumulate_grad_fails_when_no_requires_grad(): @@ -111,25 +95,18 @@ def test_single_jac_accumulation(): once. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([4]) - value2 = ones_([4, 1]) - value3 = ones_([4, 2, 3]) - input = {key1: value1, key2: value2, key3: value3} + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_([4] + shape) for shape in shapes] + input = dict(zip(keys, values)) accumulate = AccumulateJac() output = accumulate(input) - expected_output = {} + assert_tensor_dicts_are_close(output, {}) - assert_tensor_dicts_are_close(output, expected_output) - - jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac} - expected_jacs = {key1: value1, key2: value2, key3: value3} - - assert_tensor_dicts_are_close(jacs, expected_jacs) + for key, value in zip(keys, values): + assert_jac_close(key, value) @mark.parametrize("iterations", [1, 2, 4, 10, 13]) @@ -139,28 +116,19 @@ def test_multiple_jac_accumulations(iterations: int): `iterations` times. """ - key1 = zeros_([], requires_grad=True) - key2 = zeros_([1], requires_grad=True) - key3 = zeros_([2, 3], requires_grad=True) - value1 = ones_([4]) - value2 = ones_([4, 1]) - value3 = ones_([4, 2, 3]) + shapes = [[], [1], [2, 3]] + keys = [zeros_(shape, requires_grad=True) for shape in shapes] + values = [ones_([4] + shape) for shape in shapes] accumulate = AccumulateJac() for i in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards - input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()} + input = {key: value.clone() for key, value in zip(keys, values)} accumulate(input) - jacs = {key1: key1.jac, key2: key2.jac, key3: key3.jac} - expected_jacs = { - key1: iterations * value1, - key2: iterations * value2, - key3: iterations * value3, - } - - assert_tensor_dicts_are_close(jacs, expected_jacs) + for key, value in zip(keys, values): + assert_jac_close(key, iterations * value) def test_accumulate_jac_fails_when_no_requires_grad(): From b1aaee9138cfeea1162eecfa715bbb00b41bd5da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 18:17:27 +0100 Subject: [PATCH 13/35] Move newly added functions --- CHANGELOG.md | 2 +- docs/source/examples/amp.rst | 5 ++--- docs/source/examples/basic_usage.rst | 3 +-- docs/source/examples/iwrm.rst | 9 +++------ .../source/examples/lightning_integration.rst | 5 ++--- docs/source/examples/monitoring.rst | 5 ++--- docs/source/examples/mtl.rst | 5 ++--- docs/source/examples/rnn.rst | 5 ++--- src/torchjd/autojac/__init__.py | 3 ++- .../{utils => autojac}/_accumulation.py | 14 ++++++++++--- .../{utils => autojac}/_jac_to_grad.py | 6 +++--- src/torchjd/autojac/_transform/_accumulate.py | 7 +++---- src/torchjd/utils/__init__.py | 3 --- src/torchjd/utils/_tensor_with_jac.py | 11 ---------- tests/doc/test_rst.py | 20 +++++++------------ tests/unit/autojac/_asserts.py | 2 +- 16 files changed, 42 insertions(+), 63 deletions(-) rename src/torchjd/{utils => autojac}/_accumulation.py (87%) rename src/torchjd/{utils => autojac}/_jac_to_grad.py (94%) delete mode 100644 src/torchjd/utils/__init__.py delete mode 100644 src/torchjd/utils/_tensor_with_jac.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 20753457..e933c3ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ changelog does not include internal changes that do not affect the user. - **BREAKING**: Removed from `backward` and `mtl_backward` the responsibility to aggregate the Jacobian. Now, these functions compute and populate the `.jac` fields of the parameters, and a new - function `torchjd.utils.jac_to_grad` should then be called to aggregate those `.jac` fields into + function `torchjd.autojac.jac_to_grad` should then be called to aggregate those `.jac` fields into `.grad` fields. This means that users now have more control on what they do with the Jacobians (they can easily aggregate them group by group or even param by param if they want), but it now requires an extra diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index 469b2cf3..c89aa7bc 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler following example shows the resulting code for a multi-task learning use-case. .. code-block:: python - :emphasize-lines: 2, 18, 28, 35-36, 38-39 + :emphasize-lines: 2, 17, 27, 34-35, 37-38 import torch from torch.amp import GradScaler @@ -20,8 +20,7 @@ following example shows the resulting code for a multi-task learning use-case. from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import mtl_backward, jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 0920ef1c..2662a822 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -19,8 +19,7 @@ Import several classes from ``torch`` and ``torchjd``: from torch.optim import SGD from torchjd import autojac - from torchjd.aggregation import UPGrad - from torchjd.utils import jac_to_grad + from torchjd.aggregation import UPGrad, jac_to_grad Define the model and the optimizer, as usual: diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index 18eec975..0a392726 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -50,7 +50,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac - X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -78,15 +77,14 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autojac .. code-block:: python - :emphasize-lines: 5-7, 13, 17, 22-24 + :emphasize-lines: 5-6, 12, 16, 21-23 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import backward, jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -115,7 +113,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autogram (recommended) .. code-block:: python - :emphasize-lines: 5-6, 13, 17-18, 22-25 + :emphasize-lines: 5-6, 12, 16-17, 21-24 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -124,7 +122,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac from torchjd.aggregation import UPGradWeighting from torchjd.autogram import Engine - X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index a010361c..61449b97 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using <../docs/autojac/mtl_backward>` at each training iteration. .. code-block:: python - :emphasize-lines: 9-11, 19, 32-33 + :emphasize-lines: 9-10, 18, 31-32 import torch from lightning import LightningModule, Trainer @@ -22,8 +22,7 @@ The following code example demonstrates a basic multi-task learning setup using from torch.utils.data import DataLoader, TensorDataset from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import mtl_backward, jac_to_grad class Model(LightningModule): def __init__(self): diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 6297e62c..69cc0e1b 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -15,7 +15,7 @@ Jacobian descent is doing something different than gradient descent. With they have a negative inner product). .. code-block:: python - :emphasize-lines: 10-12, 14-19, 34-35 + :emphasize-lines: 9-11, 13-18, 33-34 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -23,8 +23,7 @@ they have a negative inner product). from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import mtl_backward, jac_to_grad def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index e739d8a2..ce74647b 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -19,15 +19,14 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. .. code-block:: python - :emphasize-lines: 5-7, 20, 33-34 + :emphasize-lines: 5-6, 19, 32-33 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import mtl_backward, jac_to_grad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) diff --git a/docs/source/examples/rnn.rst b/docs/source/examples/rnn.rst index 847b66d9..42eb2f91 100644 --- a/docs/source/examples/rnn.rst +++ b/docs/source/examples/rnn.rst @@ -6,15 +6,14 @@ element of the output sequences. If the gradients of these losses are likely to descent can be leveraged to enhance optimization. .. code-block:: python - :emphasize-lines: 5-7, 11, 18, 20-21 + :emphasize-lines: 5-6, 10, 17, 19-20 import torch from torch.nn import RNN from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import backward, jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/autojac/__init__.py index 846c062c..ab99d98b 100644 --- a/src/torchjd/autojac/__init__.py +++ b/src/torchjd/autojac/__init__.py @@ -6,6 +6,7 @@ """ from ._backward import backward +from ._jac_to_grad import jac_to_grad from ._mtl_backward import mtl_backward -__all__ = ["backward", "mtl_backward"] +__all__ = ["backward", "jac_to_grad", "mtl_backward"] diff --git a/src/torchjd/utils/_accumulation.py b/src/torchjd/autojac/_accumulation.py similarity index 87% rename from src/torchjd/utils/_accumulation.py rename to src/torchjd/autojac/_accumulation.py index d77ca69f..d561aaea 100644 --- a/src/torchjd/utils/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -3,10 +3,18 @@ from torch import Tensor -from torchjd.utils._tensor_with_jac import TensorWithJac + +class TensorWithJac(Tensor): + """ + Tensor known to have a populated jac field. + + Should not be directly instantiated, but can be used as a type hint and can be casted to. + """ + + jac: Tensor -def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: +def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param) # We that the shape is correct to be consistent with torch, that checks that the grad @@ -34,7 +42,7 @@ def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> N param.__setattr__("jac", jac) -def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: +def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: for param, grad in zip(params, gradients, strict=True): _check_expects_grad(param) if hasattr(param, "grad") and param.grad is not None: diff --git a/src/torchjd/utils/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py similarity index 94% rename from src/torchjd/utils/_jac_to_grad.py rename to src/torchjd/autojac/_jac_to_grad.py index 3f1754d5..279818ee 100644 --- a/src/torchjd/utils/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -5,8 +5,8 @@ from torch import Tensor from torchjd.aggregation import Aggregator -from torchjd.utils._accumulation import _accumulate_grads -from torchjd.utils._tensor_with_jac import TensorWithJac + +from ._accumulation import TensorWithJac, accumulate_grads def jac_to_grad( @@ -43,7 +43,7 @@ def jac_to_grad( jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, jacobians, params_) - _accumulate_grads(params_, gradients) + accumulate_grads(params_, gradients) if not retain_jacs: _free_jacs(params_) diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index cfd28b4f..082ef1df 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -1,7 +1,6 @@ from torch import Tensor -from torchjd.utils._accumulation import _accumulate_grads, _accumulate_jacs - +from .._accumulation import accumulate_grads, accumulate_jacs from ._base import TensorDict, Transform @@ -15,7 +14,7 @@ class AccumulateGrad(Transform): """ def __call__(self, gradients: TensorDict) -> TensorDict: - _accumulate_grads(gradients.keys(), gradients.values()) + accumulate_grads(gradients.keys(), gradients.values()) return {} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: @@ -32,7 +31,7 @@ class AccumulateJac(Transform): """ def __call__(self, jacobians: TensorDict) -> TensorDict: - _accumulate_jacs(jacobians.keys(), jacobians.values()) + accumulate_jacs(jacobians.keys(), jacobians.values()) return {} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: diff --git a/src/torchjd/utils/__init__.py b/src/torchjd/utils/__init__.py deleted file mode 100644 index 158fcece..00000000 --- a/src/torchjd/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._jac_to_grad import jac_to_grad - -__all__ = ["jac_to_grad"] diff --git a/src/torchjd/utils/_tensor_with_jac.py b/src/torchjd/utils/_tensor_with_jac.py deleted file mode 100644 index 86af8222..00000000 --- a/src/torchjd/utils/_tensor_with_jac.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch import Tensor - - -class TensorWithJac(Tensor): - """ - Tensor known to have a populated jac field. - - Should not be directly instantiated, but can be used as a type hint and can be casted to. - """ - - jac: Tensor diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index bb4e60bb..1fac0308 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -14,8 +14,7 @@ def test_amp(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import jac_to_grad, mtl_backward shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -58,7 +57,7 @@ def test_basic_usage(): from torchjd import autojac from torchjd.aggregation import UPGrad - from torchjd.utils import jac_to_grad + from torchjd.autojac import jac_to_grad model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2)) optimizer = SGD(model.parameters(), lr=0.1) @@ -155,8 +154,7 @@ def test_autojac(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import backward, jac_to_grad X = torch.randn(8, 16, 10) Y = torch.randn(8, 16) @@ -229,8 +227,7 @@ def test_lightning_integration(): from torch.utils.data import DataLoader, TensorDataset from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import jac_to_grad, mtl_backward class Model(LightningModule): def __init__(self): @@ -287,8 +284,7 @@ def test_monitoring(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import jac_to_grad, mtl_backward def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" @@ -340,8 +336,7 @@ def test_mtl(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import mtl_backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import jac_to_grad, mtl_backward shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) task1_module = Linear(3, 1) @@ -413,8 +408,7 @@ def test_rnn(): from torch.optim import SGD from torchjd.aggregation import UPGrad - from torchjd.autojac import backward - from torchjd.utils import jac_to_grad + from torchjd.autojac import backward, jac_to_grad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) optimizer = SGD(rnn.parameters(), lr=0.1) diff --git a/tests/unit/autojac/_asserts.py b/tests/unit/autojac/_asserts.py index 3221b44b..742998d8 100644 --- a/tests/unit/autojac/_asserts.py +++ b/tests/unit/autojac/_asserts.py @@ -3,7 +3,7 @@ import torch from torch.testing import assert_close -from torchjd.utils._tensor_with_jac import TensorWithJac +from torchjd.autojac._accumulation import TensorWithJac def assert_has_jac(t: torch.Tensor) -> None: From ce9231b827bb7b93e74ea5e8c0d6b4e1c9466e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 18:27:59 +0100 Subject: [PATCH 14/35] Rename retain_jacs to retain_jac --- src/torchjd/autojac/_jac_to_grad.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 279818ee..94e6296c 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -9,9 +9,7 @@ from ._accumulation import TensorWithJac, accumulate_grads -def jac_to_grad( - params: Iterable[Tensor], aggregator: Aggregator, retain_jacs: bool = False -) -> None: +def jac_to_grad(params: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False) -> None: """ Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result into their ``.grad`` fields. @@ -19,7 +17,7 @@ def jac_to_grad( :param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (number of outputs). :param aggregator: The aggregator used to reduce the Jacobians into gradients. - :param retain_jacs: Whether to preserve the ``.jac`` fields of the parameters. + :param retain_jac: Whether to preserve the ``.jac`` fields of the parameters. """ params_ = list[TensorWithJac]() @@ -45,7 +43,7 @@ def jac_to_grad( gradients = _disunite_gradient(gradient_vector, jacobians, params_) accumulate_grads(params_, gradients) - if not retain_jacs: + if not retain_jac: _free_jacs(params_) From c3537139ba5c6606571235a1549d2c00a74a5982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 18:31:09 +0100 Subject: [PATCH 15/35] Rename params to tensors in jac_to_grad --- src/torchjd/autojac/_jac_to_grad.py | 44 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 94e6296c..2b2763ee 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -9,42 +9,44 @@ from ._accumulation import TensorWithJac, accumulate_grads -def jac_to_grad(params: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False) -> None: +def jac_to_grad( + tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False +) -> None: """ - Aggregates the Jacobians stored in the ``.jac`` fields of ``params`` and accumulates the result + Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result into their ``.grad`` fields. - :param params: The parameters whose ``.jac`` fields should be aggregated. All Jacobians must + :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (number of outputs). :param aggregator: The aggregator used to reduce the Jacobians into gradients. - :param retain_jac: Whether to preserve the ``.jac`` fields of the parameters. + :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors. """ - params_ = list[TensorWithJac]() - for p in params: - if not hasattr(p, "jac"): + tensors_ = list[TensorWithJac]() + for t in tensors: + if not hasattr(t, "jac"): raise ValueError( "Some `jac` fields were not populated. Did you use `autojac.backward` before" "calling `jac_to_grad`?" ) - p_ = cast(TensorWithJac, p) - params_.append(p_) + t_ = cast(TensorWithJac, t) + tensors_.append(t_) - if len(params_) == 0: + if len(tensors_) == 0: return - jacobians = [p.jac for p in params_] + jacobians = [t.jac for t in tensors_] if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): raise ValueError("All Jacobians should have the same number of rows.") jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, params_) - accumulate_grads(params_, gradients) + gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) + accumulate_grads(tensors_, gradients) if not retain_jac: - _free_jacs(params_) + _free_jacs(tensors_) def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: @@ -54,7 +56,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( - gradient_vector: Tensor, jacobians: list[Tensor], params: list[TensorWithJac] + gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: gradient_vectors = [] start = 0 @@ -63,16 +65,16 @@ def _disunite_gradient( current_gradient_vector = gradient_vector[start:end] gradient_vectors.append(current_gradient_vector) start = end - gradients = [g.view(param.shape) for param, g in zip(params, gradient_vectors, strict=True)] + gradients = [g.view(t.shape) for t, g in zip(tensors, gradient_vectors, strict=True)] return gradients -def _free_jacs(params: Iterable[TensorWithJac]) -> None: +def _free_jacs(tensors: Iterable[TensorWithJac]) -> None: """ - Deletes the ``.jac`` field of the provided parameters. + Deletes the ``.jac`` field of the provided tensors. - :param params: The parameters whose ``.jac`` fields should be cleared. + :param tensors: The tensors whose ``.jac`` fields should be cleared. """ - for p in params: - del p.jac + for t in tensors: + del t.jac From 5cf8c1c3cdd52e99010879a78ff3dd3bd045b6b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:06:37 +0100 Subject: [PATCH 16/35] Ad jac_to_grad tests --- tests/unit/autojac/test_jac_to_grad.py | 103 +++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/unit/autojac/test_jac_to_grad.py diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py new file mode 100644 index 00000000..7d16247c --- /dev/null +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -0,0 +1,103 @@ +from pytest import mark, raises +from unit.autojac._asserts import assert_grad_close, assert_has_jac, assert_has_no_jac +from utils.tensors import tensor_ + +from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad +from torchjd.autojac._jac_to_grad import jac_to_grad + + +@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) +def test_various_aggregators(aggregator: Aggregator): + """Tests that jac_to_grad works for various aggregators.""" + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + expected_grad = aggregator(jac) + g1 = expected_grad[0] + g2 = expected_grad[1:] + + jac_to_grad([t1, t2], aggregator) + + assert_grad_close(t1, g1) + assert_grad_close(t2, g2) + + +def test_single_tensor(): + """Tests that jac_to_grad works when a single tensor is provided.""" + + aggregator = UPGrad() + t = tensor_([2.0, 3.0, 4.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t.__setattr__("jac", jac) + g = aggregator(jac) + + jac_to_grad([t], aggregator) + + assert_grad_close(t, g) + + +def test_no_jac_field(): + """Tests that jac_to_grad fails when a tensor does not have a jac field.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_no_requires_grad(): + """Tests that jac_to_grad fails when a tensor does not require grad.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=False) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_row_mismatch(): + """Tests that jac_to_grad fails when the number of rows of the .jac is not constant.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + t1.__setattr__("jac", tensor_([5.0, 6.0, 7.0])) # 3 rows + t2.__setattr__("jac", tensor_([[1.0, 2.0], [3.0, 4.0]])) # 2 rows + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator) + + +def test_no_tensors(): + """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" + + jac_to_grad([], aggregator=UPGrad()) + + +@mark.parametrize("retain_jac", [True, False]) +def test_jacs_are_freed(retain_jac: bool): + """Tests that jac_to_grad frees the jac fields if an only if retain_jac is False.""" + + aggregator = UPGrad() + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + jac_to_grad([t1, t2], aggregator, retain_jac=retain_jac) + + check = assert_has_jac if retain_jac else assert_has_no_jac + check(t1) + check(t2) From e93f2d96e331903048e6ba5ebd446a96108724af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:13:45 +0100 Subject: [PATCH 17/35] Remove duplicated optimizer.zero_grad() lines --- docs/source/examples/amp.rst | 1 - docs/source/examples/iwrm.rst | 3 --- tests/doc/test_rst.py | 5 ----- 3 files changed, 9 deletions(-) diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index c89aa7bc..97431667 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -53,7 +53,6 @@ following example shows the resulting code for a multi-task learning use-case. scaler.step(optimizer) scaler.update() optimizer.zero_grad() - optimizer.zero_grad() .. hint:: Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index 0a392726..ebc2bde5 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -69,7 +69,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() In this baseline example, the update may negatively affect the loss of some elements of the batch. @@ -105,7 +104,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() Here, we compute the Jacobian of the per-sample losses with respect to the model parameters and use it to update the model such that no loss from the batch is (locally) increased. @@ -141,7 +139,6 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac losses.backward(weights) optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() Here, the per-sample gradients are never fully stored in memory, leading to large improvements in memory usage and speed compared to autojac, in most practical cases. The diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 1fac0308..b89ac77b 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -47,7 +47,6 @@ def test_amp(): scaler.step(optimizer) scaler.update() optimizer.zero_grad() - optimizer.zero_grad() def test_basic_usage(): @@ -122,7 +121,6 @@ def test_iwmtl(): losses.backward(weights) optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() def test_iwrm(): @@ -146,7 +144,6 @@ def test_autograd(): loss.backward() optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() def test_autojac(): import torch @@ -201,7 +198,6 @@ def test_autogram(): losses.backward(weights) optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() test_autograd() test_autojac() @@ -399,7 +395,6 @@ def test_partial_jd(): losses.backward(weights) optimizer.step() optimizer.zero_grad() - optimizer.zero_grad() def test_rnn(): From 2fb685636bdd478090a1ba9d348892fcaa4cbf06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:18:06 +0100 Subject: [PATCH 18/35] Fix formulation about freeing jacs --- docs/source/examples/basic_usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 2662a822..c3ee871c 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -68,7 +68,7 @@ Perform the Jacobian descent backward pass: The first function will populate the ``.jac`` field of each model parameter with the corresponding Jacobian, and the second one will aggregate these Jacobians and store the result in the ``.grad`` -field of the parameters. It also resets the ``.jac`` fields to ``None`` to save some memory. +field of the parameters. It also deletes the ``.jac`` fields save some memory. Update each parameter based on its ``.grad`` field, using the ``optimizer``: From 57fe5b405663b79407c7869db584d3be7220528e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:28:43 +0100 Subject: [PATCH 19/35] Add doc entry for jac_to_grad and usage example --- docs/source/docs/autojac/index.rst | 1 + docs/source/docs/autojac/jac_to_grad.rst | 6 ++++++ src/torchjd/autojac/_jac_to_grad.py | 25 +++++++++++++++++++++++- tests/doc/test_jac_to_grad.py | 22 +++++++++++++++++++++ 4 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 docs/source/docs/autojac/jac_to_grad.rst create mode 100644 tests/doc/test_jac_to_grad.py diff --git a/docs/source/docs/autojac/index.rst b/docs/source/docs/autojac/index.rst index 4ca478cf..5eeb22af 100644 --- a/docs/source/docs/autojac/index.rst +++ b/docs/source/docs/autojac/index.rst @@ -10,3 +10,4 @@ autojac backward.rst mtl_backward.rst + jac_to_grad.rst diff --git a/docs/source/docs/autojac/jac_to_grad.rst b/docs/source/docs/autojac/jac_to_grad.rst new file mode 100644 index 00000000..0b61f00e --- /dev/null +++ b/docs/source/docs/autojac/jac_to_grad.rst @@ -0,0 +1,6 @@ +:hide-toc: + +jac_to_grad +=========== + +.. autofunction:: torchjd.autojac.jac_to_grad diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2b2763ee..d3a72b9e 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -12,7 +12,7 @@ def jac_to_grad( tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False ) -> None: - """ + r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result into their ``.grad`` fields. @@ -20,6 +20,29 @@ def jac_to_grad( have the same first dimension (number of outputs). :param aggregator: The aggregator used to reduce the Jacobians into gradients. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors. + + .. admonition:: + Example + + This example shows how to use ``jac_to_grad`` after a call to ``backward`` + + >>> import torch + >>> + >>> from torchjd.autojac import backward, jac_to_grad + >>> from torchjd.aggregation import UPGrad + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> backward([y1, y2]) + >>> jac_to_grad([param], aggregator=UPGrad()) + >>> param.grad + tensor([-1., 1.]) + + The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. """ tensors_ = list[TensorWithJac]() diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py new file mode 100644 index 00000000..c41d0bf3 --- /dev/null +++ b/tests/doc/test_jac_to_grad.py @@ -0,0 +1,22 @@ +""" +This file contains the test of the jac_to_grad usage example, with a verification of the value of +the obtained `.grad` field. +""" + +from torch.testing import assert_close + + +def test_jac_to_grad(): + import torch + + from torchjd.aggregation import UPGrad + from torchjd.autojac import backward, jac_to_grad + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + backward([y1, y2]) + jac_to_grad([param], aggregator=UPGrad()) + + assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) From 0a1fc210eec5b677f19ad88eac9c1aef15775b78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:31:27 +0100 Subject: [PATCH 20/35] Add comments in jac_to_grad example --- src/torchjd/autojac/_jac_to_grad.py | 4 ++-- tests/doc/test_jac_to_grad.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index d3a72b9e..2924043d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -36,8 +36,8 @@ def jac_to_grad( >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> - >>> backward([y1, y2]) - >>> jac_to_grad([param], aggregator=UPGrad()) + >>> backward([y1, y2]) # param now has a .jac field + >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field >>> param.grad tensor([-1., 1.]) diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index c41d0bf3..57bd42f0 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -16,7 +16,7 @@ def test_jac_to_grad(): # Compute arbitrary quantities that are function of param y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() - backward([y1, y2]) - jac_to_grad([param], aggregator=UPGrad()) + backward([y1, y2]) # param now has a .jac field + jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) From f1ee074fd5f66b33730544ff3c727c5237bd82b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 Jan 2026 19:31:51 +0100 Subject: [PATCH 21/35] Fix docstring of test_backward.py --- tests/doc/test_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index a099378b..f989a1a7 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -1,6 +1,6 @@ """ This file contains the test of the backward usage example, with a verification of the value of the -obtained `.grad` field. +obtained `.jac` field. """ from torch.testing import assert_close From 8b3d447801e95bc1beaa579c5d7f63c8e9847541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 01:55:16 +0100 Subject: [PATCH 22/35] Fix formatting in backward docstring --- src/torchjd/autojac/_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 46ac2d48..e0188976 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -14,7 +14,7 @@ def backward( ) -> None: r""" Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and - accumulates them in the `.jac` fields of the `inputs`. + accumulates them in the ``.jac`` fields of the ``inputs``. :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will have one row for each value of each of these tensors. From 87b66f85d4ce8d9faa1c24ae028ab16b18f41d12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 01:56:06 +0100 Subject: [PATCH 23/35] Fix comment in accumulate_jacs that applied to accumulate_grads --- src/torchjd/autojac/_accumulation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index d561aaea..01e48a21 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -31,14 +31,13 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No param_.jac += jac else: # We do not clone the value to save memory and time, so subsequent modifications of - # the value of key.grad (subsequent accumulations) will also affect the value of - # gradients[key] and outside changes to the value of gradients[key] will also affect - # the value of key.grad. So to be safe, the values of gradients should not be used + # the value of key.jac (subsequent accumulations) will also affect the value of + # jacobians[key] and outside changes to the value of jacobians[key] will also affect + # the value of key.jac. So to be safe, the values of jacobians should not be used # anymore after being passed to this function. # # We do not detach from the computation graph because the value can have grad_fn - # that we want to keep track of (in case it was obtained via create_graph=True and a - # differentiable aggregator). + # that we want to keep track of (in case it was obtained via create_graph=True). param.__setattr__("jac", jac) From 139439517137b237cbce412236180af946fe5c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 01:58:36 +0100 Subject: [PATCH 24/35] Fix error message in _check_expects_grad --- src/torchjd/autojac/_accumulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index 01e48a21..52c7ccb2 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -16,7 +16,7 @@ class TensorWithJac(Tensor): def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): - _check_expects_grad(param) + _check_expects_grad(param, field_name=".jac") # We that the shape is correct to be consistent with torch, that checks that the grad # shape is correct before assigning it. if jac.shape[1:] != param.shape: @@ -43,17 +43,17 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No def accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None: for param, grad in zip(params, gradients, strict=True): - _check_expects_grad(param) + _check_expects_grad(param, field_name=".grad") if hasattr(param, "grad") and param.grad is not None: param.grad += grad else: param.grad = grad -def _check_expects_grad(tensor: Tensor) -> None: +def _check_expects_grad(tensor: Tensor, field_name: str) -> None: if not _expects_grad(tensor): raise ValueError( - "Cannot populate the .grad field of a Tensor that does not satisfy:" + f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n" "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." ) From 16349a04406cd0738038b59a9b77b2958bc55f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 01:59:58 +0100 Subject: [PATCH 25/35] Fix wrong import in basic_usage.rst --- docs/source/examples/basic_usage.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index c3ee871c..64a1dbcd 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -19,7 +19,8 @@ Import several classes from ``torch`` and ``torchjd``: from torch.optim import SGD from torchjd import autojac - from torchjd.aggregation import UPGrad, jac_to_grad + from torchjd.aggregation import UPGrad + from torchjd.autojac import jac_to_grad Define the model and the optimizer, as usual: From 0a8cc620a891a7c3cbb2d25afd4a750af2700090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 02:10:11 +0100 Subject: [PATCH 26/35] Add explanation about how jac_to_grad works in jac_to_grad's docstring --- src/torchjd/autojac/_jac_to_grad.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2924043d..26b4167d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -21,6 +21,12 @@ def jac_to_grad( :param aggregator: The aggregator used to reduce the Jacobians into gradients. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors. + .. note:: + This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all + of their dimensions except the first one), then concatenates those matrices into a combined + Jacobian matrix. The aggregator is then used on this matrix, which returns a combined + gradient vector, that is split and reshaped to fit into the ``.grad`` fields of the tensors. + .. admonition:: Example From c13a75b4759ed5eb28b592496341bca666502fca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 02:13:04 +0100 Subject: [PATCH 27/35] Improve description of parameters in jac_to_grad --- src/torchjd/autojac/_jac_to_grad.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 26b4167d..4fb587f9 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -17,9 +17,10 @@ def jac_to_grad( into their ``.grad`` fields. :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must - have the same first dimension (number of outputs). + have the same first dimension (e.g. number of losses). :param aggregator: The aggregator used to reduce the Jacobians into gradients. - :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors. + :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been + used. Defaults to ``False``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all From 674f6ad79e44e7df486eddf5ee503696d1aa0058 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 02:18:22 +0100 Subject: [PATCH 28/35] Improve error message and usage example of jac_to_grad --- src/torchjd/autojac/_jac_to_grad.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 4fb587f9..b1432523 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -35,8 +35,8 @@ def jac_to_grad( >>> import torch >>> - >>> from torchjd.autojac import backward, jac_to_grad >>> from torchjd.aggregation import UPGrad + >>> from torchjd.autojac import backward, jac_to_grad >>> >>> param = torch.tensor([1., 2.], requires_grad=True) >>> # Compute arbitrary quantities that are function of param @@ -48,7 +48,7 @@ def jac_to_grad( >>> param.grad tensor([-1., 1.]) - The ``.grad`` field of ``param`` now contains the aggregation of the Jacobian of + The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. """ @@ -56,8 +56,8 @@ def jac_to_grad( for t in tensors: if not hasattr(t, "jac"): raise ValueError( - "Some `jac` fields were not populated. Did you use `autojac.backward` before" - "calling `jac_to_grad`?" + "Some `jac` fields were not populated. Did you use `autojac.backward` or " + "`autojac.mtl_backward` before calling `jac_to_grad`?" ) t_ = cast(TensorWithJac, t) tensors_.append(t_) From 8a0fb0ece62f32184cf18e8509e45fba901e0653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 02:20:44 +0100 Subject: [PATCH 29/35] Make _disunite_gradient use less memory --- src/torchjd/autojac/_jac_to_grad.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index b1432523..a224c405 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -88,14 +88,13 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: - gradient_vectors = [] + gradients = list[Tensor]() start = 0 - for jacobian in jacobians: + for jacobian, t in zip(jacobians, tensors, strict=True): end = start + jacobian[0].numel() current_gradient_vector = gradient_vector[start:end] - gradient_vectors.append(current_gradient_vector) + gradients.append(current_gradient_vector.view(t.shape)) start = end - gradients = [g.view(t.shape) for t, g in zip(tensors, gradient_vectors, strict=True)] return gradients From 0e8add2869035c4dcd041b93937c6911f4970bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 02:26:28 +0100 Subject: [PATCH 30/35] Free .jacs earlier to divide by two peak memory @PierreQuinton this was a big issue that we didn't spot earlier. I don't think the jacobian_matrix can be a view of the concatenated jacobians, so I think that having both the individual matrices + the combined matrix alive at the same time means using double memory. With this _free_jacs call much ealier, if the garbage collector is reactive, we shouldn't have this issue of doubling the peak memory usage for no reason. I think we should check that this PR doesn't introduce a huge memory efficiency regression. Can't merge without doing that. --- src/torchjd/autojac/_jac_to_grad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index a224c405..5029e72c 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -70,14 +70,14 @@ def jac_to_grad( if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): raise ValueError("All Jacobians should have the same number of rows.") + if not retain_jac: + _free_jacs(tensors_) + jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) accumulate_grads(tensors_, gradients) - if not retain_jac: - _free_jacs(tensors_) - def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] From 430a8a2e05529c8dc782f0d1bf51c0eccf603a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 14:55:52 +0100 Subject: [PATCH 31/35] Use Tensor.split in _disunit_gradient --- src/torchjd/autojac/_jac_to_grad.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 5029e72c..8c85025f 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -88,13 +88,8 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: - gradients = list[Tensor]() - start = 0 - for jacobian, t in zip(jacobians, tensors, strict=True): - end = start + jacobian[0].numel() - current_gradient_vector = gradient_vector[start:end] - gradients.append(current_gradient_vector.view(t.shape)) - start = end + gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) + gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)] return gradients From f0fe5297c26437cfcd34a999dbbc24773a4924d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 15:09:13 +0100 Subject: [PATCH 32/35] Add kwargs to assert_jac_close and assert_grad_close These functions really are wrappers around assert_close, so we'd like them to always also take the parameters of assert_close, even if those change in the future, and to have the same default values. So I think kwargs is justified here. Also it's not user facing so the lack of documentation of the expected types will not be visible. --- tests/unit/autojac/_asserts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/autojac/_asserts.py b/tests/unit/autojac/_asserts.py index 742998d8..09f2520f 100644 --- a/tests/unit/autojac/_asserts.py +++ b/tests/unit/autojac/_asserts.py @@ -16,10 +16,10 @@ def assert_has_no_jac(t: torch.Tensor) -> None: assert not hasattr(t, "jac") -def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor) -> None: +def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None: assert hasattr(t, "jac") t_ = cast(TensorWithJac, t) - assert_close(t_.jac, expected_jac) + assert_close(t_.jac, expected_jac, **kwargs) def assert_has_grad(t: torch.Tensor) -> None: @@ -30,6 +30,6 @@ def assert_has_no_grad(t: torch.Tensor) -> None: assert t.grad is None -def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor) -> None: +def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor, **kwargs) -> None: assert t.grad is not None - assert_close(t.grad, expected_grad) + assert_close(t.grad, expected_grad, **kwargs) From cff6d8e9fa18828c60b83a521940c1e79d36186f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 15:10:30 +0100 Subject: [PATCH 33/35] Rename expected_jacobian to J in some test --- tests/unit/autojac/test_backward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 3bd50e81..dcfeba14 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -236,9 +236,9 @@ def test_tensor_used_multiple_times(chunk_size: int | None): backward([d, e], parallel_chunk_size=chunk_size) - expected_jacobian = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()]) + J = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()]) - assert_jac_close(a, expected_jacobian) + assert_jac_close(a, J) def test_repeated_tensors(): From 84bd552585a506d8877ed9442aa7ff643b1427d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 15:28:43 +0100 Subject: [PATCH 34/35] Move asserts to tests/utils and use them in doc tests --- tests/doc/test_backward.py | 4 ++-- tests/doc/test_jac_to_grad.py | 4 ++-- tests/unit/autojac/_transform/test_accumulate.py | 2 +- tests/unit/autojac/test_backward.py | 3 +-- tests/unit/autojac/test_jac_to_grad.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 13 ++++++------- .../{unit/autojac/_asserts.py => utils/asserts.py} | 0 7 files changed, 13 insertions(+), 15 deletions(-) rename tests/{unit/autojac/_asserts.py => utils/asserts.py} (100%) diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index f989a1a7..032f902d 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -3,7 +3,7 @@ obtained `.jac` field. """ -from torch.testing import assert_close +from utils.asserts import assert_jac_close def test_backward(): @@ -18,4 +18,4 @@ def test_backward(): backward([y1, y2]) - assert_close(param.jac, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) + assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 57bd42f0..1f064a6c 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -3,7 +3,7 @@ the obtained `.grad` field. """ -from torch.testing import assert_close +from utils.asserts import assert_grad_close def test_jac_to_grad(): @@ -19,4 +19,4 @@ def test_jac_to_grad(): backward([y1, y2]) # param now has a .jac field jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field - assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 6dadf1ef..c2c1cf28 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -1,5 +1,5 @@ from pytest import mark, raises -from unit.autojac._asserts import assert_grad_close, assert_jac_close +from utils.asserts import assert_grad_close, assert_jac_close from utils.dict_assertions import assert_tensor_dicts_are_close from utils.tensors import ones_, tensor_, zeros_ diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index dcfeba14..23f93921 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,13 +1,12 @@ import torch from pytest import mark, raises +from utils.asserts import assert_has_jac, assert_has_no_jac, assert_jac_close from utils.tensors import randn_, tensor_ from torchjd.autojac import backward from torchjd.autojac._backward import _create_transform from torchjd.autojac._transform import OrderedSet -from ._asserts import assert_has_jac, assert_has_no_jac, assert_jac_close - def test_check_create_transform(): """Tests that _create_transform creates a valid Transform.""" diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 7d16247c..60ea6838 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -1,5 +1,5 @@ from pytest import mark, raises -from unit.autojac._asserts import assert_grad_close, assert_has_jac, assert_has_no_jac +from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac from utils.tensors import tensor_ from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 1bd247f6..3be3650a 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -2,13 +2,7 @@ from pytest import mark, raises from settings import DTYPE from torch.autograd import grad -from utils.tensors import arange_, rand_, randn_, tensor_ - -from torchjd.autojac import mtl_backward -from torchjd.autojac._mtl_backward import _create_transform -from torchjd.autojac._transform import OrderedSet - -from ._asserts import ( +from utils.asserts import ( assert_grad_close, assert_has_grad, assert_has_jac, @@ -16,6 +10,11 @@ assert_has_no_jac, assert_jac_close, ) +from utils.tensors import arange_, rand_, randn_, tensor_ + +from torchjd.autojac import mtl_backward +from torchjd.autojac._mtl_backward import _create_transform +from torchjd.autojac._transform import OrderedSet def test_check_create_transform(): diff --git a/tests/unit/autojac/_asserts.py b/tests/utils/asserts.py similarity index 100% rename from tests/unit/autojac/_asserts.py rename to tests/utils/asserts.py From 4bb561d23f6cedf4f7a551067de35336cd69f613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 14 Jan 2026 16:02:56 +0100 Subject: [PATCH 35/35] Rename test and update docstring to match its changes --- tests/unit/autojac/test_backward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 23f93921..54d60b50 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -28,8 +28,8 @@ def test_check_create_transform(): assert output_keys == set() -def test_shape_is_correct(): - """Tests that backward works for various aggregators.""" +def test_jac_is_populated(): + """Tests that backward correctly fills the .jac field.""" a1 = tensor_([1.0, 2.0], requires_grad=True) a2 = tensor_([3.0, 4.0], requires_grad=True)