From 958b7ab95cb5ee044f1c3cd2aac93fe4380fc195 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 16 Jan 2026 18:20:25 +0100 Subject: [PATCH] refactor: Avoid cat in Jac when not needed --- CHANGELOG.md | 2 ++ src/torchjd/autojac/_transform/_jac.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e933c3ff..d520e901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ changelog does not include internal changes that do not affect the user. jac_to_grad(shared_module.parameters(), aggregator) ``` +- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency + of `autojac`. - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory efficiency of `autojac`. diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 133bf72d..62430797 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -91,7 +91,12 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: jacs_chunks.append(_get_jacs_chunk(jac_outputs_chunk, get_vjp_last)) n_inputs = len(self.inputs) - jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs)) + if len(jacs_chunks) == 1: + # Avoid using cat to avoid doubling memory usage, if it's not needed + jacs = jacs_chunks[0] + else: + jacs = tuple(torch.cat([chunks[i] for chunks in jacs_chunks]) for i in range(n_inputs)) + return jacs