Skip to content

Commit eeb3d9f

Browse files
authored
typing: Fix ty errors (#634)
* Use ty: ignore directives rather than type: ignore * Fix type error in compute_gramian_with_autograd
1 parent 927c1a4 commit eeb3d9f

File tree

8 files changed

+13
-12
lines changed

8 files changed

+13
-12
lines changed

src/torchjd/autogram/_jacobian_computer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def vmap(
181181
jac_outputs: tuple[Tensor, ...],
182182
args: tuple[PyTree, ...],
183183
kwargs: dict[str, PyTree],
184-
) -> tuple[Tensor, None]: # type: ignore[reportIncompatibleMethodOverride]
184+
) -> tuple[Tensor, None]: # ty: ignore[invalid-method-override]
185185
# There is a non-batched dimension
186186
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
187187
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def setup_context(
174174
ctx: Any,
175175
inputs: tuple,
176176
_,
177-
) -> None: # type: ignore[reportIncompatibleMethodOverride]
177+
) -> None: # ty: ignore[invalid-method-override]
178178
ctx.gramian_accumulation_phase = inputs[0]
179179
ctx.gramian_computer = inputs[1]
180180
ctx.args = inputs[2]

src/torchjd/autojac/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O
179179

180180
# accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field.
181181
# They cannot be typed as such because AccumulateGrad is not public.
182-
leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined]
182+
leaves = OrderedSet([g.variable for g in accumulate_grads]) # ty: ignore[unresolved-attribute]
183183

184184
return leaves
185185

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_representations() -> None:
3939

4040

4141
def test_invalid_scale_mode() -> None:
42-
aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type]
42+
aggregator = AlignedMTL(scale_mode="test") # ty: ignore[invalid-argument-type]
4343
matrix = ones_(3, 4)
4444
with raises(ValueError, match=r"Invalid scale_mode=.*Expected"):
4545
aggregator(matrix)

tests/unit/autojac/test_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_input_retaining_grad_fails() -> None:
317317

318318
with raises(RuntimeError):
319319
# Using such a BatchedTensor should result in an error
320-
_ = -b.grad # type: ignore[unsupported-operator]
320+
_ = -b.grad # ty: ignore[unsupported-operator]
321321

322322

323323
def test_non_input_retaining_grad_fails() -> None:
@@ -336,7 +336,7 @@ def test_non_input_retaining_grad_fails() -> None:
336336

337337
with raises(RuntimeError):
338338
# Using such a BatchedTensor should result in an error
339-
_ = -b.grad # type: ignore[unsupported-operator]
339+
_ = -b.grad # ty: ignore[unsupported-operator]
340340

341341

342342
@mark.parametrize("chunk_size", [1, 3, None])

tests/unit/autojac/test_jac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_input_retaining_grad_fails() -> None:
315315

316316
with raises(RuntimeError):
317317
# Using such a BatchedTensor should result in an error
318-
_ = -b.grad # type: ignore[unsupported-operator]
318+
_ = -b.grad # ty: ignore[unsupported-operator]
319319

320320

321321
def test_non_input_retaining_grad_fails() -> None:
@@ -334,7 +334,7 @@ def test_non_input_retaining_grad_fails() -> None:
334334

335335
with raises(RuntimeError):
336336
# Using such a BatchedTensor should result in an error
337-
_ = -b.grad # type: ignore[unsupported-operator]
337+
_ = -b.grad # ty: ignore[unsupported-operator]
338338

339339

340340
@mark.parametrize("chunk_size", [1, 3, None])

tests/unit/autojac/test_mtl_backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_shared_param_retaining_grad_fails() -> None:
448448

449449
with raises(RuntimeError):
450450
# Using such a BatchedTensor should result in an error
451-
_ = -a.grad # type: ignore[unsupported-operator]
451+
_ = -a.grad # ty: ignore[unsupported-operator]
452452

453453

454454
def test_shared_activation_retaining_grad_fails() -> None:
@@ -477,7 +477,7 @@ def test_shared_activation_retaining_grad_fails() -> None:
477477

478478
with raises(RuntimeError):
479479
# Using such a BatchedTensor should result in an error
480-
_ = -a.grad # type: ignore[unsupported-operator]
480+
_ = -a.grad # ty: ignore[unsupported-operator]
481481

482482

483483
def test_tasks_params_overlap() -> None:

tests/utils/forward_backwards.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:
139139

140140
jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output)))
141141
jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians]
142-
gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices])
142+
products = [jacobian @ jacobian.T for jacobian in jacobian_matrices]
143+
gramian = torch.stack(products).sum(dim=0)
143144

144-
return gramian
145+
return PSDTensor(gramian)
145146

146147

147148
class CloneParams:

0 commit comments

Comments
 (0)