Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/torchjd/autojac/_accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import cast
from typing import TypeGuard

from torch import Tensor

Expand All @@ -14,6 +14,10 @@ class TensorWithJac(Tensor):
jac: Tensor


def is_tensor_with_jac(t: Tensor) -> TypeGuard[TensorWithJac]:
return hasattr(t, "jac")


def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
for param, jac in zip(params, jacobians, strict=True):
_check_expects_grad(param, field_name=".jac")
Expand All @@ -26,9 +30,8 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
" jacobian are the same size"
)

if hasattr(param, "jac"): # No check for None because jac cannot be None
param_ = cast(TensorWithJac, param)
param_.jac += jac
if is_tensor_with_jac(param):
param.jac += jac
else:
# We do not clone the value to save memory and time, so subsequent modifications of
# the value of key.jac (subsequent accumulations) will also affect the value of
Expand Down
8 changes: 3 additions & 5 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections.abc import Iterable
from typing import cast

import torch
from torch import Tensor

from torchjd.aggregation import Aggregator

from ._accumulation import TensorWithJac, accumulate_grads
from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac


def jac_to_grad(
Expand Down Expand Up @@ -54,13 +53,12 @@ def jac_to_grad(

tensors_ = list[TensorWithJac]()
for t in tensors:
if not hasattr(t, "jac"):
if not is_tensor_with_jac(t):
raise ValueError(
"Some `jac` fields were not populated. Did you use `autojac.backward` or "
"`autojac.mtl_backward` before calling `jac_to_grad`?"
)
t_ = cast(TensorWithJac, t)
tensors_.append(t_)
tensors_.append(t)

if len(tensors_) == 0:
return
Expand Down
16 changes: 6 additions & 10 deletions tests/utils/asserts.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from typing import cast

import torch
from torch.testing import assert_close

from torchjd._linalg.matrix import PSDMatrix
from torchjd.autojac._accumulation import TensorWithJac
from torchjd.autojac._accumulation import is_tensor_with_jac


def assert_has_jac(t: torch.Tensor) -> None:
assert hasattr(t, "jac")
t_ = cast(TensorWithJac, t)
assert t_.jac is not None and t_.jac.shape[1:] == t_.shape
assert is_tensor_with_jac(t)
assert t.jac is not None and t.jac.shape[1:] == t.shape


def assert_has_no_jac(t: torch.Tensor) -> None:
assert not hasattr(t, "jac")
assert not is_tensor_with_jac(t)


def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None:
assert hasattr(t, "jac")
t_ = cast(TensorWithJac, t)
assert_close(t_.jac, expected_jac, **kwargs)
assert is_tensor_with_jac(t)
assert_close(t.jac, expected_jac, **kwargs)


def assert_has_grad(t: torch.Tensor) -> None:
Expand Down
Loading