diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 1d7bebff..5a1553f1 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -21,6 +21,11 @@ def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: pass +@overload +def compute_gramian(t: Tensor, contracted_dims: int) -> PSDTensor: + pass + + def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: """ Computes the `Gramian matrix `_ of the input.