Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f9529e0
Make `Weighting.forward` use positional-only arguments, this makes ty…
PierreQuinton Feb 2, 2026
99f971d
Remove useless (for ty) typing exception
PierreQuinton Feb 2, 2026
fa9be50
Add cast in NashMTL.
PierreQuinton Feb 2, 2026
6fbe795
Add a cast to bool (to avoid numpy.bool) in NashMTL.
PierreQuinton Feb 2, 2026
aa64c48
make ty ignore `_nash_mtl.py` (this is now done in pyproject.toml
PierreQuinton Feb 2, 2026
03720c8
make `JacobianComputer._compute_jacobian` use positional-only argumen…
PierreQuinton Feb 2, 2026
c26d8f2
Remove type checking of subclasses of `autograd.Functions` methods.
PierreQuinton Feb 2, 2026
9567281
Use positional-only arguments for Transform.__call__
PierreQuinton Feb 2, 2026
428979f
use positional-only arguments for `Differentiate._differentiate`
PierreQuinton Feb 2, 2026
4987483
fix name of parameters of methods in `OrderedSet`
PierreQuinton Feb 2, 2026
0b65d75
Change CI to use ty
PierreQuinton Feb 2, 2026
096ba76
fixup
PierreQuinton Feb 2, 2026
29ef8ab
Merge branch 'main' into replace-mypy-with-ty
ValerianRey Feb 2, 2026
3f9f95a
Make JacobianComputer subclasses also have pos-only arguments
ValerianRey Feb 2, 2026
20e9cf6
Make Transform subclasses also have pos-only arguments
ValerianRey Feb 2, 2026
1650a60
Same but for Differentiate subclasses
ValerianRey Feb 2, 2026
9305fd0
Explain that we use ty in contributing.md
ValerianRey Feb 2, 2026
2bce2c3
Remove mypy badge
ValerianRey Feb 2, 2026
12c511e
Make Weighting subclasses also use positional-only params
ValerianRey Feb 2, 2026
0c5fc62
Fix typing error in interactive_plotter.py
ValerianRey Feb 2, 2026
8422da0
Fix typing error in test_compute_gramian_various_output_shapes
ValerianRey Feb 2, 2026
479ef5e
Make ModuleFactory generic
ValerianRey Feb 2, 2026
adc3d09
Add casts to PSDMatrix
ValerianRey Feb 2, 2026
e6d426c
Ignore unsupported-operator when calling .grad of BatchedTensor
ValerianRey Feb 2, 2026
0a3850c
Ignore type errors in the lightning example's test
ValerianRey Feb 2, 2026
9ab23e7
Run ty against the indicated value of pyproject.toml (src and tests) …
ValerianRey Feb 2, 2026
ca546c5
Add test optional dependencies to typing correctness check
ValerianRey Feb 3, 2026
ce05f17
Fixup
ValerianRey Feb 3, 2026
2087f7a
Add plot to typing correctness dependencies
ValerianRey Feb 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ jobs:

- uses: ./.github/actions/install-deps
with:
groups: check
groups: check test plot

- name: Run mypy
run: uv run mypy src/torchjd
- name: Run ty
run: uv run ty check

check-todos:
name: Absence of TODOs
Expand Down
7 changes: 4 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ uv run pre-commit install
uv run make clean
```

## Running `mypy`
## Type checking

From the root of the repo, run:
We use [ty](https://docs.astral.sh/ty/) for type-checking. If you're on VSCode, we recommend using
the `ty` extension. You can also run it from the root of the repo with:
```bash
uv run mypy src/torchjd
uv run ty check
```

## Development guidelines
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
[![Static Badge](https://img.shields.io/badge/%F0%9F%92%AC_ChatBot-chat.torchjd.org-blue?logo=%F0%9F%92%AC)](https://chat.torchjd.org)
[![Tests](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml/badge.svg)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
[![codecov](https://codecov.io/gh/TorchJD/torchjd/graph/badge.svg?token=8AUCZE76QH)](https://codecov.io/gh/TorchJD/torchjd)
[![mypy](https://img.shields.io/github/actions/workflow/status/TorchJD/torchjd/checks.yml?label=mypy)](https://github.com/TorchJD/torchjd/actions/workflows/checks.yml)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/)
[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk)
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"

[dependency-groups]
check = [
"mypy>=1.16.0",
"ty>=0.0.14",
"pre-commit>=2.9.2", # isort doesn't work before 2.9.2
]

Expand Down Expand Up @@ -114,3 +114,7 @@ exclude_lines = [
"pragma: not covered",
"@overload",
]

[tool.ty.src]
include = ["src", "tests"]
exclude = ["src/torchjd/aggregation/_nash_mtl.py"]
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
alpha = B @ w
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, c: float, norm_eps: float = 0.0001):
self.c = c
self.norm_eps = norm_eps

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))

reduced_matrix = U @ S.sqrt().diag()
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, weights: Tensor):
super().__init__()
self.weights = weights

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
self._check_matrix_shape(matrix)
return self.weights

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class IMTLGWeighting(Weighting[PSDMatrix]):
:class:`~torchjd.aggregation.IMTLG`.
"""

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
d = torch.sqrt(torch.diagonal(gramian))
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
self.n_byzantine = n_byzantine
self.n_selected = n_selected

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
self._check_matrix_shape(gramian)
gradient_norms_squared = torch.diagonal(gramian)
distances_squared = (
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MeanWeighting(Weighting[Matrix]):
\mathbb{R}^m`.
"""

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
m = matrix.shape[0]
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
self.epsilon = epsilon
self.max_iters = max_iters

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
"""
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
Optimization
Expand All @@ -65,7 +65,7 @@ def forward(self, gramian: PSDMatrix) -> Tensor:
elif b <= a:
gamma = 0.0
else:
gamma = (b - a) / (b + c - 2 * a) # type: ignore[assignment]
gamma = (b - a) / (b + c - 2 * a)
alpha = (1 - gamma) * alpha + gamma * e_t
if gamma < self.epsilon:
break
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# mypy: ignore-errors
from typing import cast

from torchjd._linalg import Matrix

Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(

def reset(self) -> None:
"""Resets the internal state of the algorithm."""
self.weighting.reset()
cast(_NashMTLWeighting, self.weighting).reset()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)

def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool:
return (
return bool(
(self.alpha_param.value is None)
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6)
Expand Down Expand Up @@ -198,7 +198,7 @@ def _init_optim_problem(self) -> None:
obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)
self.prob = cp.Problem(obj, constraint)

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
if self.step == 0:
self._init_optim_problem()

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PCGradWeighting(Weighting[PSDMatrix]):
:class:`~torchjd.aggregation.PCGrad`.
"""

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
dtype = gramian.dtype
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class RandomWeighting(Weighting[Matrix]):
at each call.
"""

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
weights = F.softmax(random_vector, dim=-1)
return weights
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SumWeighting(Weighting[Matrix]):
:math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`.
"""

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
weights = torch.ones(matrix.shape[0], device=device, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix) -> Tensor:
def forward(self, gramian: PSDMatrix, /) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
super().__init__()

@abstractmethod
def forward(self, stat: _T) -> Tensor:
def forward(self, stat: _T, /) -> Tensor:
"""Computes the vector of weights from the input stat."""

def __call__(self, stat: Tensor) -> Tensor:
Expand All @@ -51,7 +51,7 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu
self.fn = fn
self.weighting = weighting

def forward(self, stat: _T) -> Tensor:
def forward(self, stat: _T, /) -> Tensor:
return self.weighting(self.fn(stat))


Expand Down
5 changes: 4 additions & 1 deletion src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _compute_jacobian(
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
/,
) -> Matrix:
"""
Computes and returns the Jacobian. The output must be a matrix (2D Tensor).
Expand All @@ -75,6 +76,7 @@ def _compute_jacobian(
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
/,
) -> Matrix:
grad_outputs_in_dims = (0,) * len(grad_outputs)
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
Expand Down Expand Up @@ -133,6 +135,7 @@ def _compute_jacobian(
grad_outputs: tuple[Tensor, ...],
_: tuple[PyTree, ...],
__: dict[str, PyTree],
/,
) -> Matrix:
flat_rg_params, ___ = tree_flatten(self.rg_params)
grads = torch.autograd.grad(
Expand Down Expand Up @@ -172,7 +175,7 @@ def vmap(
jac_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> tuple[Tensor, None]:
) -> tuple[Tensor, None]: # type: ignore[reportIncompatibleMethodOverride]
# There is a non-batched dimension
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def setup_context(
ctx,
inputs: tuple,
_,
) -> None:
) -> None: # type: ignore[reportIncompatibleMethodOverride]
ctx.gramian_accumulation_phase = inputs[0]
ctx.gramian_computer = inputs[1]
ctx.args = inputs[2]
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_transform/_accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AccumulateGrad(Transform):
should not be used elsewhere.
"""

def __call__(self, gradients: TensorDict) -> TensorDict:
def __call__(self, gradients: TensorDict, /) -> TensorDict:
accumulate_grads(gradients.keys(), gradients.values())
return {}

Expand All @@ -30,7 +30,7 @@ class AccumulateJac(Transform):
should not be used elsewhere.
"""

def __call__(self, jacobians: TensorDict) -> TensorDict:
def __call__(self, jacobians: TensorDict, /) -> TensorDict:
accumulate_jacs(jacobians.keys(), jacobians.values())
return {}

Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/autojac/_transform/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __str__(self) -> str:
return type(self).__name__

@abstractmethod
def __call__(self, input: TensorDict) -> TensorDict:
def __call__(self, input: TensorDict, /) -> TensorDict:
"""Applies the transform to the input."""

@abstractmethod
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, outer: Transform, inner: Transform):
def __str__(self) -> str:
return str(self.outer) + " ∘ " + str(self.inner)

def __call__(self, input: TensorDict) -> TensorDict:
def __call__(self, input: TensorDict, /) -> TensorDict:
intermediate = self.inner(input)
return self.outer(intermediate)

Expand Down Expand Up @@ -107,7 +107,7 @@ def __str__(self) -> str:
strings.append(s)
return "(" + " | ".join(strings) + ")"

def __call__(self, tensor_dict: TensorDict) -> TensorDict:
def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
union: TensorDict = {}
for transform in self.transforms:
union |= transform(tensor_dict)
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_diagonalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, key_order: OrderedSet[Tensor]):
self.indices.append((begin, end))
begin = end

def __call__(self, tensors: TensorDict) -> TensorDict:
def __call__(self, tensors: TensorDict, /) -> TensorDict:
flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order]
diagonal_matrix = torch.cat(flattened_considered_values).diag()
diagonalized_tensors = {
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(
self.retain_graph = retain_graph
self.create_graph = create_graph

def __call__(self, tensors: TensorDict) -> TensorDict:
def __call__(self, tensors: TensorDict, /) -> TensorDict:
tensor_outputs = [tensors[output] for output in self.outputs]

differentiated_tuple = self._differentiate(tensor_outputs)
new_differentiations = dict(zip(self.inputs, differentiated_tuple))
return type(tensors)(new_differentiations)

@abstractmethod
def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]:
"""
Abstract method for differentiating the outputs with respect to the inputs, and applying the
linear transformations represented by the tensor_outputs to the results.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
):
super().__init__(outputs, inputs, retain_graph, create_graph)

def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]:
"""
Computes the gradient of each output element with respect to each input tensor, and applies
the linear transformations represented by the grad_outputs to the results.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Init(Transform):
def __init__(self, values: Set[Tensor]):
self.values = values

def __call__(self, input: TensorDict) -> TensorDict:
def __call__(self, input: TensorDict, /) -> TensorDict:
return {value: torch.ones_like(value) for value in self.values}

def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
super().__init__(outputs, inputs, retain_graph, create_graph)
self.chunk_size = chunk_size

def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]:
"""
Computes the jacobian of each output with respect to each input, and applies the linear
transformations represented by the jac_outputs to the results.
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/autojac/_transform/_ordered_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def difference_update(self, elements: set[_T]) -> None:
for element in elements:
self.discard(element)

def add(self, element: _T) -> None:
def add(self, value: _T) -> None:
"""Adds the specified element to the OrderedSet."""

self.ordered_dict[element] = None
self.ordered_dict[value] = None

def __add__(self, other: OrderedSet[_T]) -> OrderedSet[_T]:
"""Creates a new OrderedSet with the elements of self followed by the elements of other."""
Expand All @@ -40,5 +40,5 @@ def __iter__(self) -> Iterator[_T]:
def __len__(self) -> int:
return len(self.ordered_dict)

def __contains__(self, element: object) -> bool:
return element in self.ordered_dict
def __contains__(self, x: object) -> bool:
return x in self.ordered_dict
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Select(Transform):
def __init__(self, keys: Set[Tensor]):
self.keys = keys

def __call__(self, tensor_dict: TensorDict) -> TensorDict:
def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
output = {key: tensor_dict[key] for key in self.keys}
return type(tensor_dict)(output)

Expand Down
Loading