diff --git a/CHANGELOG.md b/CHANGELOG.md index e933c3ff..d520e901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ changelog does not include internal changes that do not affect the user. jac_to_grad(shared_module.parameters(), aggregator) ``` +- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency + of `autojac`. - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory efficiency of `autojac`. diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 133bf72d..62430797 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -91,7 +91,12 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last)) n_inputs = len(self.inputs) - jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs)) + if len(jacs_chunks) == 1: + # Avoid using cat to avoid doubling memory usage, if it's not needed + jacs = jacs_chunks[0] + else: + jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs)) + return jacs