From a51b5f39abfbda308685312916ee185d35773dcd Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 15 Oct 2025 13:36:13 +0200 Subject: [PATCH] refactor(autogram): Create block-diagonal matrix using einsum. --- src/torchjd/autogram/_engine.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index ecde7483..22787896 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -319,11 +319,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: if has_non_batch_dim: # There is one non-batched dimension, it is the first one non_batch_dim_len = output.shape[0] - jac_output_shape = [output.shape[0]] + list(output.shape) - - jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) - for i in range(non_batch_dim_len): - jac_output[i, i, ...] = 1.0 + identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype) + ones = torch.ones_like(output[0]) + jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones) _ = vmap(differentiation)(jac_output) else: