diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index 52c7ccb2..bafd3d62 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import cast +from typing import TypeGuard from torch import Tensor @@ -14,6 +14,10 @@ class TensorWithJac(Tensor): jac: Tensor +def is_tensor_with_jac(t: Tensor) -> TypeGuard[TensorWithJac]: + return hasattr(t, "jac") + + def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param, field_name=".jac") @@ -26,9 +30,8 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No " jacobian are the same size" ) - if hasattr(param, "jac"): # No check for None because jac cannot be None - param_ = cast(TensorWithJac, param) - param_.jac += jac + if is_tensor_with_jac(param): + param.jac += jac else: # We do not clone the value to save memory and time, so subsequent modifications of # the value of key.jac (subsequent accumulations) will also affect the value of diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8c85025f..61427467 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,12 +1,11 @@ from collections.abc import Iterable -from typing import cast import torch from torch import Tensor from torchjd.aggregation import Aggregator -from ._accumulation import TensorWithJac, accumulate_grads +from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac def jac_to_grad( @@ -54,13 +53,12 @@ def jac_to_grad( tensors_ = list[TensorWithJac]() for t in tensors: - if not hasattr(t, "jac"): + if not is_tensor_with_jac(t): raise ValueError( "Some `jac` fields were not populated. Did you use `autojac.backward` or " "`autojac.mtl_backward` before calling `jac_to_grad`?" ) - t_ = cast(TensorWithJac, t) - tensors_.append(t_) + tensors_.append(t) if len(tensors_) == 0: return diff --git a/tests/utils/asserts.py b/tests/utils/asserts.py index 3828aa5e..0d2a0a37 100644 --- a/tests/utils/asserts.py +++ b/tests/utils/asserts.py @@ -1,26 +1,22 @@ -from typing import cast - import torch from torch.testing import assert_close from torchjd._linalg.matrix import PSDMatrix -from torchjd.autojac._accumulation import TensorWithJac +from torchjd.autojac._accumulation import is_tensor_with_jac def assert_has_jac(t: torch.Tensor) -> None: - assert hasattr(t, "jac") - t_ = cast(TensorWithJac, t) - assert t_.jac is not None and t_.jac.shape[1:] == t_.shape + assert is_tensor_with_jac(t) + assert t.jac is not None and t.jac.shape[1:] == t.shape def assert_has_no_jac(t: torch.Tensor) -> None: - assert not hasattr(t, "jac") + assert not is_tensor_with_jac(t) def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None: - assert hasattr(t, "jac") - t_ = cast(TensorWithJac, t) - assert_close(t_.jac, expected_jac, **kwargs) + assert is_tensor_with_jac(t) + assert_close(t.jac, expected_jac, **kwargs) def assert_has_grad(t: torch.Tensor) -> None: