From 24002402abaddb07b2dfc4a98d77c70108bef3ca Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 19 Jan 2026 10:35:19 +0100 Subject: [PATCH 1/2] refactor(autojac): Use `TypeGuard` for `TensorWithJac` --- src/torchjd/autojac/_accumulation.py | 11 +++++++---- src/torchjd/autojac/_jac_to_grad.py | 8 +++----- tests/utils/asserts.py | 16 ++++++---------- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index 52c7ccb2..27085261 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import cast +from typing import TypeGuard from torch import Tensor @@ -14,6 +14,10 @@ class TensorWithJac(Tensor): jac: Tensor +def is_tensor_with_jac(t: Tensor) -> TypeGuard[TensorWithJac]: + return hasattr(t, "jac") + + def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None: for param, jac in zip(params, jacobians, strict=True): _check_expects_grad(param, field_name=".jac") @@ -26,9 +30,8 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No " jacobian are the same size" ) - if hasattr(param, "jac"): # No check for None because jac cannot be None - param_ = cast(TensorWithJac, param) - param_.jac += jac + if is_tensor_with_jac(param): # No check for None because jac cannot be None + param.jac += jac else: # We do not clone the value to save memory and time, so subsequent modifications of # the value of key.jac (subsequent accumulations) will also affect the value of diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8c85025f..61427467 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,12 +1,11 @@ from collections.abc import Iterable -from typing import cast import torch from torch import Tensor from torchjd.aggregation import Aggregator -from ._accumulation import TensorWithJac, accumulate_grads +from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac def jac_to_grad( @@ -54,13 +53,12 @@ def jac_to_grad( tensors_ = list[TensorWithJac]() for t in tensors: - if not hasattr(t, "jac"): + if not is_tensor_with_jac(t): raise ValueError( "Some `jac` fields were not populated. Did you use `autojac.backward` or " "`autojac.mtl_backward` before calling `jac_to_grad`?" ) - t_ = cast(TensorWithJac, t) - tensors_.append(t_) + tensors_.append(t) if len(tensors_) == 0: return diff --git a/tests/utils/asserts.py b/tests/utils/asserts.py index 3828aa5e..0d2a0a37 100644 --- a/tests/utils/asserts.py +++ b/tests/utils/asserts.py @@ -1,26 +1,22 @@ -from typing import cast - import torch from torch.testing import assert_close from torchjd._linalg.matrix import PSDMatrix -from torchjd.autojac._accumulation import TensorWithJac +from torchjd.autojac._accumulation import is_tensor_with_jac def assert_has_jac(t: torch.Tensor) -> None: - assert hasattr(t, "jac") - t_ = cast(TensorWithJac, t) - assert t_.jac is not None and t_.jac.shape[1:] == t_.shape + assert is_tensor_with_jac(t) + assert t.jac is not None and t.jac.shape[1:] == t.shape def assert_has_no_jac(t: torch.Tensor) -> None: - assert not hasattr(t, "jac") + assert not is_tensor_with_jac(t) def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None: - assert hasattr(t, "jac") - t_ = cast(TensorWithJac, t) - assert_close(t_.jac, expected_jac, **kwargs) + assert is_tensor_with_jac(t) + assert_close(t.jac, expected_jac, **kwargs) def assert_has_grad(t: torch.Tensor) -> None: From eec082b3f4bd36c8d124f8e35b0fdfc358fe8f98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Mon, 19 Jan 2026 16:58:30 +0100 Subject: [PATCH 2/2] Remove outdated comment Co-authored-by: Pierre Quinton --- src/torchjd/autojac/_accumulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index 27085261..bafd3d62 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -30,7 +30,7 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No " jacobian are the same size" ) - if is_tensor_with_jac(param): # No check for None because jac cannot be None + if is_tensor_with_jac(param): param.jac += jac else: # We do not clone the value to save memory and time, so subsequent modifications of