Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Changed

- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.

## [0.8.1] - 2026-01-07

### Added
Expand Down
15 changes: 10 additions & 5 deletions src/torchjd/autojac/_transform/_accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@ def __call__(self, gradients: TensorDict) -> TensorDict:
if hasattr(key, "grad") and key.grad is not None:
key.grad += gradients[key]
else:
# We clone the value because we do not want subsequent accumulations to also affect
# this value (in case it is still used outside). We do not detach from the
# computation graph because the value can have grad_fn that we want to keep track of
# (in case it was obtained via create_graph=True and a differentiable aggregator).
key.grad = gradients[key].clone()
# We do not clone the value to save memory and time, so subsequent modifications of
# the value of key.grad (subsequent accumulations) will also affect the value of
# gradients[key] and outside changes to the value of gradients[key] will also affect
# the value of key.grad. So to be safe, the values of gradients should not be used
# anymore after being passed to this function.
#
# We do not detach from the computation graph because the value can have grad_fn
# that we want to keep track of (in case it was obtained via create_graph=True and a
# differentiable aggregator).
key.grad = gradients[key]

return {}

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/autojac/_transform/test_accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def test_multiple_accumulation(iterations: int):
value1 = ones_([])
value2 = ones_([1])
value3 = ones_([2, 3])
input = {key1: value1, key2: value2, key3: value3}

accumulate = Accumulate()

for i in range(iterations):
# Clone values to ensure that we accumulate values that are not ever used afterwards
input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()}
accumulate(input)

grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}
Expand Down
Loading