From f9529e07ed7b311279a3989b1e5a1a042677da0c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:14:30 +0100 Subject: [PATCH 01/28] Make `Weighting.forward` use positional-only arguments, this makes ty more happy about typing. The problem was that the name of the parameter in the class `Weighting` was `stat`, in subclasses, it was `tensor`, `matrix` or `gramian`, which don't match (it would make giving named parameters incorrect in terms of liskov). --- src/torchjd/aggregation/_weighting_bases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index e610ce58..74129a36 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -24,7 +24,7 @@ def __init__(self): super().__init__() @abstractmethod - def forward(self, stat: _T) -> Tensor: + def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" def __call__(self, stat: Tensor) -> Tensor: @@ -51,7 +51,7 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu self.fn = fn self.weighting = weighting - def forward(self, stat: _T) -> Tensor: + def forward(self, stat: _T, /) -> Tensor: return self.weighting(self.fn(stat)) From 99f971d9cda80160fb4d1ff3c59f9e2f72706c1f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:16:16 +0100 Subject: [PATCH 02/28] Remove useless (for ty) typing exception --- src/torchjd/aggregation/_mgda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index f6edbf6b..dc4bfbee 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -65,7 +65,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor: elif b <= a: gamma = 0.0 else: - gamma = (b - a) / (b + c - 2 * a) # type: ignore[assignment] + gamma = (b - a) / (b + c - 2 * a) alpha = (1 - gamma) * alpha + gamma * e_t if gamma < self.epsilon: break From fa9be50a34cd2fd4cc5bfe54906a90238504f2cd Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:22:10 +0100 Subject: [PATCH 03/28] Add cast in NashMTL. --- src/torchjd/aggregation/_nash_mtl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 43939e3b..391149ec 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -24,6 +24,7 @@ # SOFTWARE. # mypy: ignore-errors +from typing import cast from torchjd._linalg import Matrix @@ -96,7 +97,7 @@ def __init__( def reset(self) -> None: """Resets the internal state of the algorithm.""" - self.weighting.reset() + cast(_NashMTLWeighting, self.weighting).reset() def __repr__(self) -> str: return ( From 6fbe79527d7aee2bdeacc4743343c0da8c3e6a1f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:31:36 +0100 Subject: [PATCH 04/28] Add a cast to bool (to avoid numpy.bool) in NashMTL. --- src/torchjd/aggregation/_nash_mtl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 391149ec..d0377652 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -142,7 +142,7 @@ def __init__( self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool: - return ( + return bool( (self.alpha_param.value is None) or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6) From aa64c4819c9c188ef5e9e4145799a07bdee23d50 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:41:08 +0100 Subject: [PATCH 05/28] make ty ignore `_nash_mtl.py` (this is now done in pyproject.toml --- pyproject.toml | 4 ++++ src/torchjd/aggregation/_nash_mtl.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c91e8ed5..5e772f26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,3 +114,7 @@ exclude_lines = [ "pragma: not covered", "@overload", ] + +[tool.ty.src] +include = ["src", "tests"] +exclude = ["src/torchjd/aggregation/_nash_mtl.py"] diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index d0377652..ef88959e 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -23,7 +23,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# mypy: ignore-errors from typing import cast from torchjd._linalg import Matrix From 03720c8a1211c5e8bd4f577cf7ccf17dc0acf35b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 11:43:17 +0100 Subject: [PATCH 06/28] make `JacobianComputer._compute_jacobian` use positional-only argument to allow renaming them. --- src/torchjd/autogram/_jacobian_computer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index fb0280f0..d419c685 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -56,6 +56,7 @@ def _compute_jacobian( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], + /, ) -> Matrix: """ Computes and returns the Jacobian. The output must be a matrix (2D Tensor). From c26d8f294bcfacc625f51c17e88333b088a72a58 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 13:05:52 +0100 Subject: [PATCH 07/28] Remove type checking of subclasses of `autograd.Functions` methods. --- src/torchjd/autogram/_jacobian_computer.py | 2 +- src/torchjd/autogram/_module_hook_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index d419c685..5950e171 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -173,7 +173,7 @@ def vmap( jac_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> tuple[Tensor, None]: + ) -> tuple[Tensor, None]: # type: ignore[reportIncompatibleMethodOverride] # There is a non-batched dimension # We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])( diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index f72b2c75..fe4c22a5 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -173,7 +173,7 @@ def setup_context( ctx, inputs: tuple, _, - ) -> None: + ) -> None: # type: ignore[reportIncompatibleMethodOverride] ctx.gramian_accumulation_phase = inputs[0] ctx.gramian_computer = inputs[1] ctx.args = inputs[2] From 95672813b1603b7c2f2bde658aef73db1fa23393 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 13:06:57 +0100 Subject: [PATCH 08/28] Use positional-only arguments for Transform.__call__ --- src/torchjd/autojac/_transform/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index 9751d39e..79fc82fb 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -41,7 +41,7 @@ def __str__(self) -> str: return type(self).__name__ @abstractmethod - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: """Applies the transform to the input.""" @abstractmethod From 428979f880690181b0998dfce8d5f39b380c1610 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 13:07:41 +0100 Subject: [PATCH 09/28] use positional-only arguments for `Differentiate._differentiate` --- src/torchjd/autojac/_transform/_differentiate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index ddd1f064..e08e8776 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -45,7 +45,7 @@ def __call__(self, tensors: TensorDict) -> TensorDict: return type(tensors)(new_differentiations) @abstractmethod - def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: + def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: """ Abstract method for differentiating the outputs with respect to the inputs, and applying the linear transformations represented by the tensor_outputs to the results. From 49874830785aa855a7ef7bf092b78c57961de84d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 13:09:33 +0100 Subject: [PATCH 10/28] fix name of parameters of methods in `OrderedSet` --- src/torchjd/autojac/_transform/_ordered_set.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_transform/_ordered_set.py b/src/torchjd/autojac/_transform/_ordered_set.py index ae5784af..c929cb45 100644 --- a/src/torchjd/autojac/_transform/_ordered_set.py +++ b/src/torchjd/autojac/_transform/_ordered_set.py @@ -20,10 +20,10 @@ def difference_update(self, elements: set[_T]) -> None: for element in elements: self.discard(element) - def add(self, element: _T) -> None: + def add(self, value: _T) -> None: """Adds the specified element to the OrderedSet.""" - self.ordered_dict[element] = None + self.ordered_dict[value] = None def __add__(self, other: OrderedSet[_T]) -> OrderedSet[_T]: """Creates a new OrderedSet with the elements of self followed by the elements of other.""" @@ -40,5 +40,5 @@ def __iter__(self) -> Iterator[_T]: def __len__(self) -> int: return len(self.ordered_dict) - def __contains__(self, element: object) -> bool: - return element in self.ordered_dict + def __contains__(self, x: object) -> bool: + return x in self.ordered_dict From 0b65d7527edd6eb2bba51b6625727f9e578595e7 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 16:09:31 +0100 Subject: [PATCH 11/28] Change CI to use ty --- .github/workflows/checks.yml | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 9fdd6fa4..5a00b0f2 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -122,8 +122,8 @@ jobs: with: groups: check - - name: Run mypy - run: uv run mypy src/torchjd + - name: Run ty + run: uv run ty src/torchjd check-todos: name: Absence of TODOs diff --git a/pyproject.toml b/pyproject.toml index 5e772f26..aac71590 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md" [dependency-groups] check = [ - "mypy>=1.16.0", + "ty>=0.0.14", "pre-commit>=2.9.2", # isort doesn't work before 2.9.2 ] From 096ba76881309bf448185a95c954ed8c862972ac Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 2 Feb 2026 16:13:17 +0100 Subject: [PATCH 12/28] fixup --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 5a00b0f2..c878b83e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -123,7 +123,7 @@ jobs: groups: check - name: Run ty - run: uv run ty src/torchjd + run: uv run ty check src/torchjd check-todos: name: Absence of TODOs From 3f9f95a53285e0542015bbfdac9f8875a2701b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:31:19 +0100 Subject: [PATCH 13/28] Make JacobianComputer subclasses also have pos-only arguments --- src/torchjd/autogram/_jacobian_computer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 5950e171..6929f88e 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -76,6 +76,7 @@ def _compute_jacobian( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], + /, ) -> Matrix: grad_outputs_in_dims = (0,) * len(grad_outputs) args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) @@ -134,6 +135,7 @@ def _compute_jacobian( grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree], + /, ) -> Matrix: flat_rg_params, ___ = tree_flatten(self.rg_params) grads = torch.autograd.grad( From 20e9cf609342f35d934b360d36285ae2399592bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:36:10 +0100 Subject: [PATCH 14/28] Make Transform subclasses also have pos-only arguments Similarly to the previous commit, this is not strictly necessary according to LSP, but I think it's weird that subclasses don't enforce pos-only arguments if the parent class enforces that. --- src/torchjd/autojac/_transform/_accumulate.py | 4 ++-- src/torchjd/autojac/_transform/_base.py | 4 ++-- src/torchjd/autojac/_transform/_diagonalize.py | 2 +- src/torchjd/autojac/_transform/_differentiate.py | 2 +- src/torchjd/autojac/_transform/_init.py | 2 +- src/torchjd/autojac/_transform/_select.py | 2 +- src/torchjd/autojac/_transform/_stack.py | 2 +- tests/unit/autojac/_transform/test_base.py | 2 +- tests/unit/autojac/_transform/test_stack.py | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 082ef1df..38dfb10e 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -13,7 +13,7 @@ class AccumulateGrad(Transform): should not be used elsewhere. """ - def __call__(self, gradients: TensorDict) -> TensorDict: + def __call__(self, gradients: TensorDict, /) -> TensorDict: accumulate_grads(gradients.keys(), gradients.values()) return {} @@ -30,7 +30,7 @@ class AccumulateJac(Transform): should not be used elsewhere. """ - def __call__(self, jacobians: TensorDict) -> TensorDict: + def __call__(self, jacobians: TensorDict, /) -> TensorDict: accumulate_jacs(jacobians.keys(), jacobians.values()) return {} diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index 79fc82fb..db8ff2cb 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -76,7 +76,7 @@ def __init__(self, outer: Transform, inner: Transform): def __str__(self) -> str: return str(self.outer) + " ∘ " + str(self.inner) - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: intermediate = self.inner(input) return self.outer(intermediate) @@ -107,7 +107,7 @@ def __str__(self) -> str: strings.append(s) return "(" + " | ".join(strings) + ")" - def __call__(self, tensor_dict: TensorDict) -> TensorDict: + def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: union: TensorDict = {} for transform in self.transforms: union |= transform(tensor_dict) diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 339f0bcc..cc7791ea 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -60,7 +60,7 @@ def __init__(self, key_order: OrderedSet[Tensor]): self.indices.append((begin, end)) begin = end - def __call__(self, tensors: TensorDict) -> TensorDict: + def __call__(self, tensors: TensorDict, /) -> TensorDict: flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order] diagonal_matrix = torch.cat(flattened_considered_values).diag() diagonalized_tensors = { diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index e08e8776..260d1dab 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -37,7 +37,7 @@ def __init__( self.retain_graph = retain_graph self.create_graph = create_graph - def __call__(self, tensors: TensorDict) -> TensorDict: + def __call__(self, tensors: TensorDict, /) -> TensorDict: tensor_outputs = [tensors[output] for output in self.outputs] differentiated_tuple = self._differentiate(tensor_outputs) diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 361eb0a6..551f8197 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -16,7 +16,7 @@ class Init(Transform): def __init__(self, values: Set[Tensor]): self.values = values - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: return {value: torch.ones_like(value) for value in self.values} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index 9163a623..29a6bcd2 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -15,7 +15,7 @@ class Select(Transform): def __init__(self, keys: Set[Tensor]): self.keys = keys - def __call__(self, tensor_dict: TensorDict) -> TensorDict: + def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: output = {key: tensor_dict[key] for key in self.keys} return type(tensor_dict)(output) diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/autojac/_transform/_stack.py index eae043e1..1a3fc2ad 100644 --- a/src/torchjd/autojac/_transform/_stack.py +++ b/src/torchjd/autojac/_transform/_stack.py @@ -23,7 +23,7 @@ class Stack(Transform): def __init__(self, transforms: Sequence[Transform]): self.transforms = transforms - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: results = [transform(input) for transform in self.transforms] result = _stack(results) return result diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 435b97a7..5da475e6 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -17,7 +17,7 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): def __str__(self): return "T" - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: # Ignore the input, create a dictionary with the right keys as an output. output_dict = {key: empty_(0) for key in self._output_keys} return output_dict diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index 0151fade..fc2cdf7a 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -15,7 +15,7 @@ class FakeGradientsTransform(Transform): def __init__(self, keys: Iterable[Tensor]): self.keys = set(keys) - def __call__(self, input: TensorDict) -> TensorDict: + def __call__(self, input: TensorDict, /) -> TensorDict: return {key: torch.ones_like(key) for key in self.keys} def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: From 1650a60fc1b5806dff53915b7dc9f27f52339978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:36:47 +0100 Subject: [PATCH 15/28] Same but for Differentiate subclasses --- src/torchjd/autojac/_transform/_grad.py | 2 +- src/torchjd/autojac/_transform/_jac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index de8d3e9b..a52b7b15 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -34,7 +34,7 @@ def __init__( ): super().__init__(outputs, inputs, retain_graph, create_graph) - def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: + def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: """ Computes the gradient of each output element with respect to each input tensor, and applies the linear transformations represented by the grad_outputs to the results. diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 62430797..0783e22a 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -42,7 +42,7 @@ def __init__( super().__init__(outputs, inputs, retain_graph, create_graph) self.chunk_size = chunk_size - def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: + def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: """ Computes the jacobian of each output with respect to each input, and applies the linear transformations represented by the jac_outputs to the results. From 9305fd0ed82a8e2c5faa80e987b31729dce7d981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:44:31 +0100 Subject: [PATCH 16/28] Explain that we use ty in contributing.md --- CONTRIBUTING.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 267b69b3..9bd1eb0a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -111,11 +111,12 @@ uv run pre-commit install uv run make clean ``` -## Running `mypy` +## Type checking -From the root of the repo, run: +We use [ty](https://docs.astral.sh/ty/) for type-checking. If you're on VSCode, we recommend using +the `ty` extension. You can also run it from the root of the repo with: ```bash -uv run mypy src/torchjd +uv run ty check ``` ## Development guidelines From 2bce2c39fe82663f2cae786416f8b41b90a463c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:45:25 +0100 Subject: [PATCH 17/28] Remove mypy badge - We could maybe re-add a ty badge but it's supposed to be covered by the Checks badge, which includes tests, type checking, and any other check --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 37a530dd..f05438cf 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ [![Static Badge](https://img.shields.io/badge/%F0%9F%92%AC_ChatBot-chat.torchjd.org-blue?logo=%F0%9F%92%AC)](https://chat.torchjd.org) [![Tests](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml/badge.svg)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml) [![codecov](https://codecov.io/gh/TorchJD/torchjd/graph/badge.svg?token=8AUCZE76QH)](https://codecov.io/gh/TorchJD/torchjd) -[![mypy](https://img.shields.io/github/actions/workflow/status/TorchJD/torchjd/checks.yml?label=mypy)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/) [![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk) From 12c511e5223e811b668af2f5bb1e36fe69f3d98e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 20:58:03 +0100 Subject: [PATCH 18/28] Make Weighting subclasses also use positional-only params --- src/torchjd/aggregation/_aligned_mtl.py | 2 +- src/torchjd/aggregation/_cagrad.py | 2 +- src/torchjd/aggregation/_constant.py | 2 +- src/torchjd/aggregation/_dualproj.py | 2 +- src/torchjd/aggregation/_imtl_g.py | 2 +- src/torchjd/aggregation/_krum.py | 2 +- src/torchjd/aggregation/_mean.py | 2 +- src/torchjd/aggregation/_mgda.py | 2 +- src/torchjd/aggregation/_nash_mtl.py | 2 +- src/torchjd/aggregation/_pcgrad.py | 2 +- src/torchjd/aggregation/_random.py | 2 +- src/torchjd/aggregation/_sum.py | 2 +- src/torchjd/aggregation/_upgrad.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index eadef9ab..bf4f8dc0 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -98,7 +98,7 @@ def __init__( self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: w = self.weighting(gramian) B = self._compute_balance_transformation(gramian, self._scale_mode) alpha = B @ w diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 2b768c70..67f94a94 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -76,7 +76,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001): self.c = c self.norm_eps = norm_eps - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: U, S, _ = torch.svd(normalize(gramian, self.norm_eps)) reduced_matrix = U @ S.sqrt().diag() diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 81404512..f4f062bf 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -45,7 +45,7 @@ def __init__(self, weights: Tensor): super().__init__() self.weights = weights - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: self._check_matrix_shape(matrix) return self.weights diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index d91e32aa..4fc8cefd 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -85,7 +85,7 @@ def __init__( self.reg_eps = reg_eps self.solver: SUPPORTED_SOLVER = solver - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) w = project_weights(u, G, self.solver) diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 7c8369ee..6355158e 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -29,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]): :class:`~torchjd.aggregation.IMTLG`. """ - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: d = torch.sqrt(torch.diagonal(gramian)) v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 7b523360..bce211c6 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -61,7 +61,7 @@ def __init__(self, n_byzantine: int, n_selected: int = 1): self.n_byzantine = n_byzantine self.n_selected = n_selected - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: self._check_matrix_shape(gramian) gradient_norms_squared = torch.diagonal(gramian) distances_squared = ( diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index f739e966..d7085e10 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -24,7 +24,7 @@ class MeanWeighting(Weighting[Matrix]): \mathbb{R}^m`. """ - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: device = matrix.device dtype = matrix.dtype m = matrix.shape[0] diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index dc4bfbee..a2608404 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -43,7 +43,7 @@ def __init__(self, epsilon: float = 0.001, max_iters: int = 100): self.epsilon = epsilon self.max_iters = max_iters - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: """ This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index ef88959e..76862975 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -198,7 +198,7 @@ def _init_optim_problem(self) -> None: obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param) self.prob = cp.Problem(obj, constraint) - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: if self.step == 0: self._init_optim_problem() diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 91af872d..d6cc3f10 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -29,7 +29,7 @@ class PCGradWeighting(Weighting[PSDMatrix]): :class:`~torchjd.aggregation.PCGrad`. """ - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration device = gramian.device dtype = gramian.dtype diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 2f2e330c..53ef188c 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -26,7 +26,7 @@ class RandomWeighting(Weighting[Matrix]): at each call. """ - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) weights = F.softmax(random_vector, dim=-1) return weights diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index da33512a..0d8bd5d6 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -23,7 +23,7 @@ class SumWeighting(Weighting[Matrix]): :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: device = matrix.device dtype = matrix.dtype weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 132b72e6..6b8ec0f6 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -86,7 +86,7 @@ def __init__( self.reg_eps = reg_eps self.solver: SUPPORTED_SOLVER = solver - def forward(self, gramian: PSDMatrix) -> Tensor: + def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) W = project_weights(U, G, self.solver) From 0c5fc62ec0163637bc2d4d97f8e7d8bd1a64e5ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:13:36 +0100 Subject: [PATCH 19/28] Fix typing error in interactive_plotter.py --- tests/plots/interactive_plotter.py | 49 +++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 5a40e0b6..25f034e5 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -90,11 +90,11 @@ def main() -> None: gradient_slider_inputs = [] for i in range(len(matrix)): initial_gradient = matrix[i] - div = make_gradient_div(i, initial_gradient) + div, angle_input, r_input = make_gradient_div(i, initial_gradient) gradient_divs.append(div) - gradient_slider_inputs.append(Input(div.children[1], "value")) - gradient_slider_inputs.append(Input(div.children[2], "value")) + gradient_slider_inputs.append(Input(angle_input, "value")) + gradient_slider_inputs.append(Input(r_input, "value")) aggregator_strings = [str(aggregator) for aggregator in aggregators] checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist") @@ -147,32 +147,39 @@ def update_aggregators(value: list[str]) -> Figure: app.run(debug=False, port=1222) -def make_gradient_div(i: int, initial_gradient: torch.Tensor) -> html.Div: +def make_gradient_div( + i: int, initial_gradient: torch.Tensor +) -> tuple[html.Div, dcc.Input, dcc.Input]: x = initial_gradient[0].item() y = initial_gradient[1].item() angle, r = coord_to_angle(x, y) + + angle_input = dcc.Input( + id=f"g{i + 1}-angle-range", + type="range", + value=angle, + min=0, + max=2 * np.pi, + style={"width": "250px"}, + ) + + r_input = dcc.Input( + id=f"g{i + 1}-r-range", + type="range", + value=r, + min=MIN_LENGTH, + max=MAX_LENGTH, + style={"width": "250px"}, + ) + div = html.Div( [ html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}), - dcc.Input( - id=f"g{i + 1}-angle-range", - type="range", - value=angle, - min=0, - max=2 * np.pi, - style={"width": "250px"}, - ), - dcc.Input( - id=f"g{i + 1}-r-range", - type="range", - value=r, - min=MIN_LENGTH, - max=MAX_LENGTH, - style={"width": "250px"}, - ), + angle_input, + r_input, ], ) - return div + return div, angle_input, r_input def open_browser() -> None: From 8422da00ec36d2c475b8296814097f324ce4b895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:14:17 +0100 Subject: [PATCH 20/28] Fix typing error in test_compute_gramian_various_output_shapes --- tests/unit/autogram/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 531b6f08..56ad23d9 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -263,7 +263,7 @@ def test_compute_gramian_unsupported_architectures( ], ) def test_compute_gramian_various_output_shapes( - batch_size: int | None, + batch_size: int, reduction: Callable[[list[Tensor]], Tensor], batch_dim: int | None, movedim_source: list[int], From 479ef5eeb1ea8f755f060ff6ca4da50e8aa5e41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:28:51 +0100 Subject: [PATCH 21/28] Make ModuleFactory generic --- tests/utils/architectures.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index a74c7de8..f7e371ed 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Generic, TypeVar import torch import torchvision @@ -8,14 +9,16 @@ from torch.utils._pytree import PyTree from utils.contexts import fork_rng +_T = TypeVar("_T", bound=nn.Module) -class ModuleFactory: - def __init__(self, architecture: type[nn.Module], *args, **kwargs): - self.architecture = architecture + +class ModuleFactory(Generic[_T]): + def __init__(self, architecture: type[_T], *args, **kwargs): + self.architecture: type[_T] = architecture self.args = args self.kwargs = kwargs - def __call__(self) -> nn.Module: + def __call__(self) -> _T: with fork_rng(seed=0): return self.architecture(*self.args, **self.kwargs).to(device=DEVICE, dtype=DTYPE) From adc3d09b2cba5b3c9c253a68dbb978495c96ce39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:29:51 +0100 Subject: [PATCH 22/28] Add casts to PSDMatrix --- tests/unit/autogram/test_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 56ad23d9..e1cfa16c 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,6 +1,7 @@ from collections.abc import Callable from itertools import combinations from math import prod +from typing import cast import pytest import torch @@ -78,7 +79,7 @@ ) from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_ -from torchjd._linalg import compute_gramian +from torchjd._linalg import PSDMatrix, compute_gramian from torchjd.aggregation import UPGradWeighting from torchjd.autogram._engine import Engine from torchjd.autogram._gramian_utils import movedim, reshape @@ -457,7 +458,7 @@ def test_reshape_equivariance(shape: list[int]): engine1 = Engine(model1, batch_dim=None) output = model1(input) - gramian = engine1.compute_gramian(output) + gramian = cast(PSDMatrix, engine1.compute_gramian(output)) expected_reshaped_gramian = reshape(gramian, shape[1:]) engine2 = Engine(model2, batch_dim=None) @@ -495,7 +496,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: engine1 = Engine(model1, batch_dim=None) output = model1(input).reshape(shape[1:]) - gramian = engine1.compute_gramian(output) + gramian = cast(PSDMatrix, engine1.compute_gramian(output)) expected_moved_gramian = movedim(gramian, source, destination) engine2 = Engine(model2, batch_dim=None) From e6d426c430e2de90fafdbf9db276fc0427c7da33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:31:36 +0100 Subject: [PATCH 23/28] Ignore unsupported-operator when calling .grad of BatchedTensor --- tests/unit/autojac/test_backward.py | 4 ++-- tests/unit/autojac/test_jac.py | 4 ++-- tests/unit/autojac/test_mtl_backward.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 54d60b50..0a1ce91d 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -198,7 +198,7 @@ def test_input_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -b.grad + _ = -b.grad # type: ignore[unsupported-operator] def test_non_input_retaining_grad_fails(): @@ -217,7 +217,7 @@ def test_non_input_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -b.grad + _ = -b.grad # type: ignore[unsupported-operator] @mark.parametrize("chunk_size", [1, 3, None]) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 774d9ac9..8feac59e 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -197,7 +197,7 @@ def test_input_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -b.grad + _ = -b.grad # type: ignore[unsupported-operator] def test_non_input_retaining_grad_fails(): @@ -216,7 +216,7 @@ def test_non_input_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -b.grad + _ = -b.grad # type: ignore[unsupported-operator] @mark.parametrize("chunk_size", [1, 3, None]) diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 3be3650a..c5515f40 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -448,7 +448,7 @@ def test_shared_param_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -a.grad + _ = -a.grad # type: ignore[unsupported-operator] def test_shared_activation_retaining_grad_fails(): @@ -477,7 +477,7 @@ def test_shared_activation_retaining_grad_fails(): with raises(RuntimeError): # Using such a BatchedTensor should result in an error - _ = -a.grad + _ = -a.grad # type: ignore[unsupported-operator] def test_tasks_params_overlap(): From 0a3850ccf2371cda96ed50a9895dc46ce6e6dbae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:41:44 +0100 Subject: [PATCH 24/28] Ignore type errors in the lightning example's test --- tests/doc/test_rst.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 6513d41c..327cb333 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -4,6 +4,8 @@ functions here to test them. """ +from typing import no_type_check + from pytest import mark @@ -209,6 +211,7 @@ def test_autogram(): "ignore::FutureWarning", "ignore::lightning.fabric.utilities.warnings.PossibleUserWarning", ) +@no_type_check # Typing is annoying with Lightning, which would make the example too hard to read. def test_lightning_integration(): # Extra ---------------------------------------------------------------------------------------- import logging From 9ab23e7e18ec05c0f1cd480b91b1702431ae53c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 2 Feb 2026 21:42:36 +0100 Subject: [PATCH 25/28] Run ty against the indicated value of pyproject.toml (src and tests) in CI --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index c878b83e..912c056e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -123,7 +123,7 @@ jobs: groups: check - name: Run ty - run: uv run ty check src/torchjd + run: uv run ty check check-todos: name: Absence of TODOs From ca546c56c5ae2aed41d0f3197cc507ede5c6a76e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 3 Feb 2026 01:06:24 +0100 Subject: [PATCH 26/28] Add test optional dependencies to typing correctness check --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 912c056e..1c806a5b 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -120,7 +120,7 @@ jobs: - uses: ./.github/actions/install-deps with: - groups: check + groups: check, test - name: Run ty run: uv run ty check From ce05f17782b5a6f8127cfb1c106c9e3328bd6c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 3 Feb 2026 01:10:10 +0100 Subject: [PATCH 27/28] Fixup --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1c806a5b..0dff450d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -120,7 +120,7 @@ jobs: - uses: ./.github/actions/install-deps with: - groups: check, test + groups: check test - name: Run ty run: uv run ty check From 2087f7aed005ed131ed8eca5bc823cc18f8ffc51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 3 Feb 2026 01:16:52 +0100 Subject: [PATCH 28/28] Add plot to typing correctness dependencies --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 0dff450d..3ef322f3 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -120,7 +120,7 @@ jobs: - uses: ./.github/actions/install-deps with: - groups: check test + groups: check test plot - name: Run ty run: uv run ty check