diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 4a73178d..e610ce58 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -8,9 +8,9 @@ from torchjd._linalg import PSDTensor, is_psd_tensor -_T = TypeVar("_T", contravariant=True) -_FnInputT = TypeVar("_FnInputT") -_FnOutputT = TypeVar("_FnOutputT") +_T = TypeVar("_T", contravariant=True, bound=Tensor) +_FnInputT = TypeVar("_FnInputT", bound=Tensor) +_FnOutputT = TypeVar("_FnOutputT", bound=Tensor) class Weighting(Generic[_T], nn.Module, ABC): @@ -27,9 +27,11 @@ def __init__(self): def forward(self, stat: _T) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: _T) -> Tensor: + def __call__(self, stat: Tensor) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" + # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of + # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: