From f46c7857705a67aa8484ffe32330c7c16e362800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 14:20:52 +0100 Subject: [PATCH] refactor(aggregation): Fix user-facing typing error --- src/torchjd/aggregation/_weighting_bases.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 4a73178d..e610ce58 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -8,9 +8,9 @@ from torchjd._linalg import PSDTensor, is_psd_tensor -_T = TypeVar("_T", contravariant=True) -_FnInputT = TypeVar("_FnInputT") -_FnOutputT = TypeVar("_FnOutputT") +_T = TypeVar("_T", contravariant=True, bound=Tensor) +_FnInputT = TypeVar("_FnInputT", bound=Tensor) +_FnOutputT = TypeVar("_FnOutputT", bound=Tensor) class Weighting(Generic[_T], nn.Module, ABC): @@ -27,9 +27,11 @@ def __init__(self): def forward(self, stat: _T) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: _T) -> Tensor: + def __call__(self, stat: Tensor) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" + # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of + # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: