From 746f3e1dee8c1a29cf9cc096bf56ad836c019752 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 18:14:01 +0100 Subject: [PATCH] refactor(linalg): Add missing overload to --- src/torchjd/_linalg/_gramian.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 1d7bebff..5a1553f1 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -21,6 +21,11 @@ def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: pass +@overload +def compute_gramian(t: Tensor, contracted_dims: int) -> PSDTensor: + pass + + def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: """ Computes the `Gramian matrix `_ of the input.