diff --git a/README.md b/README.md index 02e066e0..7342a497 100644 --- a/README.md +++ b/README.md @@ -111,11 +111,11 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() - loss = loss1 + loss2 - loss.backward() + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() + optimizer.zero_grad() ``` > [!NOTE] @@ -150,12 +150,12 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr - loss = loss_fn(output, target) # shape [1] + losses = loss_fn(output, target) # shape [16] - optimizer.zero_grad() - loss.backward() + gramian = engine.compute_gramian(losses) # shape: [16, 16] + weights = weighting(gramian) # shape: [16] + losses.backward(weights) optimizer.step() + optimizer.zero_grad() ``` Lastly, you can even combine the two approaches by considering multiple tasks and each element of @@ -201,10 +201,10 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets): # Obtain the weights that lead to no conflict between reweighted gradients weights = weighting(gramian) # shape: [16, 2] - optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + optimizer.zero_grad() ``` > [!NOTE] diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index 0e719bfd..0aad8da0 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler following example shows the resulting code for a multi-task learning use-case. .. code-block:: python - :emphasize-lines: 2, 17, 27, 34, 36-38 + :emphasize-lines: 2, 17, 27, 34-37 import torch from torch.amp import GradScaler @@ -48,10 +48,10 @@ following example shows the resulting code for a multi-task learning use-case. loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - optimizer.zero_grad() mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) scaler.step(optimizer) scaler.update() + optimizer.zero_grad() .. hint:: Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 8fa4320b..1cca64b7 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -59,12 +59,6 @@ We can now compute the losses associated to each element of the batch. The last steps are similar to gradient descent-based optimization, but using the two losses. -Reset the ``.grad`` field of each model parameter: - -.. code-block:: python - - optimizer.zero_grad() - Perform the Jacobian descent backward pass: .. code-block:: python @@ -81,3 +75,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``: optimizer.step() The model's parameters have been updated! + +As usual, you should now reset the ``.grad`` field of each model parameter: + +.. code-block:: python + + optimizer.zero_grad() diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 4c1c7a4c..8b2410f7 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely. The following example shows how to do that. .. code-block:: python - :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42 + :emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -51,10 +51,10 @@ The following example shows how to do that. # Obtain the weights that lead to no conflict between reweighted gradients weights = weighting(gramian) # shape: [16, 2] - optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + optimizer.zero_grad() .. note:: In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index a326f582..d1b52426 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -64,11 +64,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] loss = loss_fn(y_hat, y) # shape: [] (scalar) - optimizer.zero_grad() loss.backward() optimizer.step() + optimizer.zero_grad() In this baseline example, the update may negatively affect the loss of some elements of the batch. @@ -76,7 +76,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autojac .. code-block:: python - :emphasize-lines: 5-6, 12, 16, 21, 23 + :emphasize-lines: 5-6, 12, 16, 21-22 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -99,11 +99,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() backward(losses, aggregator) optimizer.step() + optimizer.zero_grad() Here, we compute the Jacobian of the per-sample losses with respect to the model parameters and use it to update the model such that no loss from the batch is (locally) increased. @@ -111,7 +111,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac .. tab-item:: autogram (recommended) .. code-block:: python - :emphasize-lines: 5-6, 12, 16-17, 21, 23-25 + :emphasize-lines: 5-6, 12, 16-17, 21-24 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() Here, the per-sample gradients are never fully stored in memory, leading to large improvements in memory usage and speed compared to autojac, in most practical cases. The diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 203f63b5..c1fbba3b 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using <../docs/autojac/mtl_backward>` at each training iteration. .. code-block:: python - :emphasize-lines: 9-10, 18, 32 + :emphasize-lines: 9-10, 18, 31 import torch from lightning import LightningModule, Trainer @@ -43,9 +43,9 @@ The following code example demonstrates a basic multi-task learning setup using loss2 = mse_loss(output2, target2) opt = self.optimizers() - opt.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) opt.step() + opt.zero_grad() def configure_optimizers(self) -> OptimizerLRScheduler: optimizer = Adam(self.parameters(), lr=1e-3) diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 8ec675aa..f12fd1da 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -63,6 +63,6 @@ they have a negative inner product). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index d726ae3a..dd770340 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -19,7 +19,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. .. code-block:: python - :emphasize-lines: 5-6, 19, 33 + :emphasize-lines: 5-6, 19, 32 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -52,9 +52,9 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() + optimizer.zero_grad() .. note:: In this example, the Jacobian is only with respect to the shared parameters. The task-specific diff --git a/docs/source/examples/partial_jd.rst b/docs/source/examples/partial_jd.rst index c86a653a..ad82205a 100644 --- a/docs/source/examples/partial_jd.rst +++ b/docs/source/examples/partial_jd.rst @@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time. for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) weights = weighting(gramian) losses.backward(weights) optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/rnn.rst b/docs/source/examples/rnn.rst index d9cb8b98..a257c938 100644 --- a/docs/source/examples/rnn.rst +++ b/docs/source/examples/rnn.rst @@ -6,7 +6,7 @@ element of the output sequences. If the gradients of these losses are likely to descent can be leveraged to enhance optimization. .. code-block:: python - :emphasize-lines: 5-6, 10, 17, 20 + :emphasize-lines: 5-6, 10, 17, 19 import torch from torch.nn import RNN @@ -26,9 +26,9 @@ descent can be leveraged to enhance optimization. output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - optimizer.zero_grad() backward(losses, aggregator, parallel_chunk_size=1) optimizer.step() + optimizer.zero_grad() .. note:: At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 22787896..643845fc 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -77,7 +77,7 @@ class Engine: Train a model using Gramian-based Jacobian descent. .. code-block:: python - :emphasize-lines: 5-6, 15-16, 18-19, 26-28 + :emphasize-lines: 5-6, 15-16, 18-19, 26-29 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -103,11 +103,11 @@ class Engine: output = model(input).squeeze(dim=1) # shape: [16] losses = criterion(output, target) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() This is equivalent to just calling ``torchjd.autojac.backward(losses, UPGrad())``. However, since the Jacobian never has to be entirely in memory, it is often much more diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 64ce48f7..4445bc67 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -26,8 +26,8 @@ def test_engine(): output = model(input).squeeze(dim=1) # shape: [16] losses = criterion(output, target) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index b64b504c..867aad6b 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -42,10 +42,10 @@ def test_amp(): loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - optimizer.zero_grad() mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator) scaler.step(optimizer) scaler.update() + optimizer.zero_grad() def test_basic_usage(): @@ -69,9 +69,9 @@ def test_basic_usage(): loss1 = loss_fn(output[:, 0], target1) loss2 = loss_fn(output[:, 1], target2) - optimizer.zero_grad() autojac.backward([loss1, loss2], aggregator) optimizer.step() + optimizer.zero_grad() def test_iwmtl(): @@ -114,10 +114,10 @@ def test_iwmtl(): # Obtain the weights that lead to no conflict between reweighted gradients weights = weighting(gramian) # shape: [16, 2] - optimizer.zero_grad() # Do the standard backward pass, but weighted using the obtained weights losses.backward(weights) optimizer.step() + optimizer.zero_grad() def test_iwrm(): @@ -138,9 +138,9 @@ def test_autograd(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] loss = loss_fn(y_hat, y) # shape: [] (scalar) - optimizer.zero_grad() loss.backward() optimizer.step() + optimizer.zero_grad() def test_autojac(): import torch @@ -163,9 +163,9 @@ def test_autojac(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() backward(losses, aggregator) optimizer.step() + optimizer.zero_grad() def test_autogram(): import torch @@ -189,11 +189,11 @@ def test_autogram(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step() + optimizer.zero_grad() test_autograd() test_autojac() @@ -240,9 +240,9 @@ def training_step(self, batch, batch_idx) -> None: loss2 = mse_loss(output2, target2) opt = self.optimizers() - opt.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) opt.step() + opt.zero_grad() def configure_optimizers(self) -> OptimizerLRScheduler: optimizer = Adam(self.parameters(), lr=1e-3) @@ -314,9 +314,9 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() + optimizer.zero_grad() def test_mtl(): @@ -351,9 +351,9 @@ def test_mtl(): loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - optimizer.zero_grad() mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() + optimizer.zero_grad() def test_partial_jd(): @@ -382,11 +382,11 @@ def test_partial_jd(): for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] - optimizer.zero_grad() gramian = engine.compute_gramian(losses) weights = weighting(gramian) losses.backward(weights) optimizer.step() + optimizer.zero_grad() def test_rnn(): @@ -408,6 +408,6 @@ def test_rnn(): output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. - optimizer.zero_grad() backward(losses, aggregator, parallel_chunk_size=1) optimizer.step() + optimizer.zero_grad() diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 50135796..62347b33 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -345,7 +345,7 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch loss_fn = make_mse_loss_fn(targets) autogram_forward_backward(model, inputs, loss_fn, engine, weighting) optimizer.step() - model.zero_grad() + optimizer.zero_grad() @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS)