From fcd865af5612eac1927888fd2135cb4e13300b29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 4 Feb 2026 15:30:14 +0100 Subject: [PATCH] perf(aggregation): Prevent cuda sync in normalize --- src/torchjd/_linalg/_gramian.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 5a1553f1..edc819dd 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -51,11 +51,12 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix: sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is therefore `G` divided by the sum of its diagonal elements. """ + squared_frobenius_norm = gramian.diagonal().sum() - if squared_frobenius_norm < eps: - output = torch.zeros_like(gramian) - else: - output = gramian / squared_frobenius_norm + condition = squared_frobenius_norm < eps + + # Use torch.where rather than a if-else to avoid cuda synchronization. + output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm) return cast(PSDMatrix, output)