diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 5a1553f1..edc819dd 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -51,11 +51,12 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix: sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is therefore `G` divided by the sum of its diagonal elements. """ + squared_frobenius_norm = gramian.diagonal().sum() - if squared_frobenius_norm < eps: - output = torch.zeros_like(gramian) - else: - output = gramian / squared_frobenius_norm + condition = squared_frobenius_norm < eps + + # Use torch.where rather than a if-else to avoid cuda synchronization. + output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm) return cast(PSDMatrix, output)