Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ changelog does not include internal changes that do not affect the user.
### Changed

- **BREAKING**: Changed the dependencies of `CAGrad` and `NashMTL` to be optional when installing
TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install
torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies.
This should make TorchJD more lightweight.
TorchJD. Users of these aggregators will have to use `pip install "torchjd[cagrad]"`, `pip install
"torchjd[nash_mtl]"` or `pip install "torchjd[full]"` to install TorchJD alongside those
dependencies. This should make TorchJD more lightweight.
- **BREAKING**: Made the aggregator modules and the `autojac` package protected. The aggregators
must now always be imported via their package (e.g.
`from torchjd.aggregation.upgrad import UPGrad` must be changed to
Expand Down
6 changes: 3 additions & 3 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ Note that `torchjd` requires Python 3.10, 3.11, 3.12, 3.13 or 3.14 and `torch>=2
Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
when installing `torchjd`. To install them, you can use:
```
pip install torchjd[cagrad]
pip install "torchjd[cagrad]"
```
```
pip install torchjd[nash_mtl]
pip install "torchjd[nash_mtl]"
```

To install `torchjd` with all of its optional dependencies, you can also use:
```
pip install torchjd[full]
pip install "torchjd[full]"
```
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CAGrad(GramianWeightedAggregator):
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install torchjd[cagrad]``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001):
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class NashMTL(WeightedAggregator):
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
To install it, use ``pip install torchjd[nash_mtl]``.
To install it, use ``pip install "torchjd[nash_mtl]"``.

.. warning::
This implementation was adapted from the `official implementation
Expand Down