diff --git a/CHANGELOG.md b/CHANGELOG.md index e627f169..84d75334 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Changed + +- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory + efficiency of `autojac`. + ## [0.8.1] - 2026-01-07 ### Added diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 5a1ac89c..7bfce193 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -15,11 +15,16 @@ def __call__(self, gradients: TensorDict) -> TensorDict: if hasattr(key, "grad") and key.grad is not None: key.grad += gradients[key] else: - # We clone the value because we do not want subsequent accumulations to also affect - # this value (in case it is still used outside). We do not detach from the - # computation graph because the value can have grad_fn that we want to keep track of - # (in case it was obtained via create_graph=True and a differentiable aggregator). - key.grad = gradients[key].clone() + # We do not clone the value to save memory and time, so subsequent modifications of + # the value of key.grad (subsequent accumulations) will also affect the value of + # gradients[key] and outside changes to the value of gradients[key] will also affect + # the value of key.grad. So to be safe, the values of gradients should not be used + # anymore after being passed to this function. + # + # We do not detach from the computation graph because the value can have grad_fn + # that we want to keep track of (in case it was obtained via create_graph=True and a + # differentiable aggregator). + key.grad = gradients[key] return {} diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 6c8bbbfe..45db6d61 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -45,11 +45,12 @@ def test_multiple_accumulation(iterations: int): value1 = ones_([]) value2 = ones_([1]) value3 = ones_([2, 3]) - input = {key1: value1, key2: value2, key3: value3} accumulate = Accumulate() for i in range(iterations): + # Clone values to ensure that we accumulate values that are not ever used afterwards + input = {key1: value1.clone(), key2: value2.clone(), key3: value3.clone()} accumulate(input) grads = {key1: key1.grad, key2: key2.grad, key3: key3.grad}