From 150e2208b0a23375dcc3ab18518a4f170e10c7e4 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 11 Feb 2026 09:33:14 +0100 Subject: [PATCH 01/11] Add tidy imports --- pyproject.toml | 13 +++++++------ src/torchjd/autojac/_transform/_accumulate.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06071519..30dfcc24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,13 +129,14 @@ target-version = "py310" [tool.ruff.lint] select = [ - "E", # pycodestyle Error - "F", # Pyflakes - "W", # pycodestyle Warning - "I", # isort - "UP", # pyupgrade - "B", # flake8-bugbear + "E", # pycodestyle Error + "F", # Pyflakes + "W", # pycodestyle Warning + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear "FIX", # flake8-fixme + "TID", # flake8-tidy-imports ] ignore = [ diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 38dfb10e..a61c52da 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -1,6 +1,7 @@ from torch import Tensor -from .._accumulation import accumulate_grads, accumulate_jacs +from torchjd.autojac._accumulation import accumulate_grads, accumulate_jacs + from ._base import TensorDict, Transform From 693438482a7eaa57af73be2037699aa54c1f7459 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 11:22:54 +0100 Subject: [PATCH 02/11] Add c4 (flatke8-comprehensions) --- pyproject.toml | 1 + src/torchjd/autojac/_jac_to_grad.py | 2 +- src/torchjd/autojac/_transform/_grad.py | 2 +- src/torchjd/autojac/_transform/_jac.py | 2 +- src/torchjd/autojac/_utils.py | 4 +-- tests/plots/_utils.py | 39 ++++++++++++------------- tests/unit/autojac/test_mtl_backward.py | 6 ++-- tests/utils/architectures.py | 8 ++--- tests/utils/tensors.py | 2 +- 9 files changed, 32 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 30dfcc24..834a8622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ select = [ "I", # isort "UP", # pyupgrade "B", # flake8-bugbear + "C4", # flake8-comprehensions "FIX", # flake8-fixme "TID", # flake8-tidy-imports ] diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 352e2655..2b2e2561 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -65,7 +65,7 @@ def jac_to_grad( jacobians = [t.jac for t in tensors_] - if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): + if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]): raise ValueError("All Jacobians should have the same number of rows.") if not retain_jac: diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index a52b7b15..e61d4748 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -48,7 +48,7 @@ def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ... """ if len(self.inputs) == 0: - return tuple() + return () if len(self.outputs) == 0: return tuple(torch.zeros_like(input) for input in self.inputs) diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 0783e22a..38326d26 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -57,7 +57,7 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] """ if len(self.inputs) == 0: - return tuple() + return () if len(self.outputs) == 0: return tuple( diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 87ae5068..1a460ce5 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -42,10 +42,10 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O """ - if any([tensor.grad_fn is None for tensor in tensors]): + if any(tensor.grad_fn is None for tensor in tensors): raise ValueError("All `tensors` should have a `grad_fn`.") - if any([tensor.grad_fn is None for tensor in excluded]): + if any(tensor.grad_fn is None for tensor in excluded): raise ValueError("All `excluded` tensors should have a `grad_fn`.") accumulate_grads = _get_descendant_accumulate_grads( diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index fe844d66..39cd08f8 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -62,7 +62,7 @@ def make_vector_scatter( text_size: float = 12, marker_size: float = 12, ) -> Scatter: - line = dict(color=color, width=line_width) + line = {"color": color, "width": line_width} if dash: line["dash"] = "dash" @@ -71,16 +71,16 @@ def make_vector_scatter( y=[0, gradient[1]], mode="lines+markers+text", line=line, - marker=dict( - symbol="arrow", - color=color, - size=marker_size, - angleref="previous", - ), + marker={ + "symbol": "arrow", + "color": color, + "size": marker_size, + "angleref": "previous", + }, name=label, text=["", label], textposition=textposition, - textfont=dict(color=color, size=text_size), + textfont={"color": color, "size": text_size}, showlegend=showlegend, ) return scatter @@ -121,9 +121,13 @@ def make_cone_scatter( ) if printable: - fillpattern = dict( - bgcolor="white", shape="\\", fgcolor="rgba(0, 220, 0, 0.5)", size=30, solidity=0.15 - ) + fillpattern = { + "bgcolor": "white", + "shape": "\\", + "fgcolor": "rgba(0, 220, 0, 0.5)", + "size": 30, + "solidity": 0.15, + } else: fillpattern = None @@ -133,7 +137,7 @@ def make_cone_scatter( fill="toself", # Fill the area inside the polygon mode="lines", fillcolor="rgba(0, 255, 0, 0.07)", - line=dict(color="rgb(0, 220, 0)", width=2), + line={"color": "rgb(0, 220, 0)", "width": 2}, name=label, fillpattern=fillpattern, ) @@ -146,11 +150,7 @@ def make_segment_scatter(start: torch.Tensor, end: torch.Tensor) -> Scatter: x=[start[0], end[0]], y=[start[1], end[1]], mode="lines", - line=dict( - color="rgb(150, 150, 150)", - width=2.5, - dash="longdash", - ), + line={"color": "rgb(150, 150, 150)", "width": 2.5, "dash": "longdash"}, ) return segment @@ -161,10 +161,7 @@ def make_polygon_scatter(points: list[torch.Tensor]) -> Scatter: x=[point[0] for point in points], y=[point[1] for point in points], mode="lines", - line=dict( - color="rgb(100, 100, 100)", - width=1.5, - ), + line={"color": "rgb(100, 100, 100)", "width": 1.5}, ) return polygon diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 00bda738..f3f7c549 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -224,13 +224,13 @@ def test_multiple_params_per_task(): @mark.parametrize( "shared_params_shapes", [ - [tuple()], + [()], [(2,)], [(3, 2)], [(4, 3, 2)], - [tuple(), (2,)], + [(), (2,)], [(3, 2), (2,)], - [(4, 3, 2), (3, 2), tuple()], + [(4, 3, 2), (3, 2), ()], [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index f1b98b6d..d0b8adb2 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -393,7 +393,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: - return {"one": [None, tuple()], "two": None} + return {"one": [None, ()], "two": None} class _EmptyTupleOutput(nn.Module): def __init__(self, shape: tuple[int, ...]): @@ -401,7 +401,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> tuple: - return tuple() + return () class _EmptyPytreeOutput(nn.Module): def __init__(self, shape: tuple[int, ...]): @@ -409,7 +409,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: - return {"one": [tuple(), tuple()], "two": [[], []]} + return {"one": [(), ()], "two": [[], []]} def __init__(self): super().__init__() @@ -670,7 +670,7 @@ class Ndim0Output(ShapedModule): """Simple model whose output is a scalar.""" INPUT_SHAPES = (5,) - OUTPUT_SHAPES = tuple() + OUTPUT_SHAPES = () def __init__(self): super().__init__() diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 6c91a08c..b5f7ee94 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -38,6 +38,6 @@ def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: def is_leaf(s): - return isinstance(s, tuple) and all([isinstance(e, int) for e in s]) + return isinstance(s, tuple) and all(isinstance(e, int) for e in s) return tree_map(lambda s: randn_((batch_size,) + s), tensor_shapes, is_leaf=is_leaf) From 65b9050383ad73b03ad594e54575b6caedc2fe54 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 11:25:33 +0100 Subject: [PATCH 03/11] Add SIM (simplification of if/else and conditions) --- pyproject.toml | 1 + src/torchjd/_linalg/_gramian.py | 2 +- src/torchjd/aggregation/_imtl_g.py | 5 +---- src/torchjd/autogram/_gramian_utils.py | 4 ++-- tests/unit/aggregation/_matrix_samplers.py | 6 +++--- tests/unit/autojac/test_backward.py | 5 +---- tests/unit/autojac/test_jac.py | 5 +---- tests/unit/autojac/test_mtl_backward.py | 12 +++--------- 8 files changed, 13 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 834a8622..3d638152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,7 @@ select = [ "C4", # flake8-comprehensions "FIX", # flake8-fixme "TID", # flake8-tidy-imports + "SIM", # flake8-simplify ] ignore = [ diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index edc819dd..f048c238 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -35,7 +35,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: first dimension). """ - contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim + contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim indices_source = list(range(t.ndim - contracted_dims)) indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) transposed = t.movedim(indices_source, indices_dest) diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 6355158e..f45e8c2e 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -34,9 +34,6 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() - if v_sum.abs() < 1e-12: - weights = torch.zeros_like(v) - else: - weights = v / v_sum + weights = torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum return weights diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index bfc0ed9e..12c2bfe2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -75,8 +75,8 @@ def movedim(gramian: PSDTensor, half_source: list[int], half_destination: list[i # Map everything to the range [0, gramian.ndim//2[ half_ndim = gramian.ndim // 2 - half_source_ = [i if 0 <= i else i + half_ndim for i in half_source] - half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination] + half_source_ = [i if i >= 0 else i + half_ndim for i in half_source] + half_destination_ = [i if i >= 0 else i + half_ndim for i in half_destination] # Mirror the half source and the half destination and use the result to move the dimensions of # the gramian diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 8a3acd8d..1b5cc8ab 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -59,7 +59,7 @@ class StrongSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 1 < m + assert m > 1 assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: @@ -94,7 +94,7 @@ class StrictlyWeakSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 1 < m + assert m > 1 assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: @@ -126,7 +126,7 @@ class NonWeakSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 0 < rank + assert rank > 0 def __call__(self, rng: torch.Generator | None = None) -> Tensor: u = torch.abs(randn_([self.m], generator=rng)) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 0a1ce91d..80bd06e3 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -60,10 +60,7 @@ def test_value_is_correct( input = randn_([shape[1]], requires_grad=True) output = J @ input # Note that the Jacobian of output w.r.t. input is J. - if manually_specify_inputs: - inputs = [input] - else: - inputs = None + inputs = [input] if manually_specify_inputs else None backward( [output], diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3a5fb9a4..3d776a8d 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -64,10 +64,7 @@ def test_value_is_correct( input = randn_([shape[1]], requires_grad=True) output = J @ input # Note that the Jacobian of output w.r.t. input is J. - if manually_specify_inputs: - inputs = [input] - else: - inputs = None + inputs = [input] if manually_specify_inputs else None jacobians = jac( [output], diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index f3f7c549..21611561 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -90,15 +90,9 @@ def test_value_is_correct( y2 = p2 @ f y3 = p3 @ f - if manually_specify_shared_params: - shared_params = [p0] - else: - shared_params = None - - if manually_specify_tasks_params: - tasks_params = [[p1], [p2], [p3]] - else: - tasks_params = None + shared_params = [p0] if manually_specify_shared_params else None + + tasks_params = [[p1], [p2], [p3]] if manually_specify_tasks_params else None mtl_backward( losses=[y1, y2, y3], From 9b15093a4e082c74ef7a5e226cfd7f08e765c6ab Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 11:38:30 +0100 Subject: [PATCH 04/11] Add PERF (detect python pattern that are slow like list lookups). --- pyproject.toml | 1 + src/torchjd/aggregation/_nash_mtl.py | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3d638152..a021d29d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ select = [ "FIX", # flake8-fixme "TID", # flake8-tidy-imports "SIM", # flake8-simplify + "PERF", # Perflint ] ignore = [ diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 7485a358..e875d46b 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -189,12 +189,10 @@ def _init_optim_problem(self) -> None: self.phi_alpha = self._calc_phi_alpha_linearization() G_alpha = self.G_param @ self.alpha_param - constraint = [] - for i in range(self.n_tasks): - constraint.append( - -cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i]) - <= 0 - ) + constraint = [ + -cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0 + for a, G_a in zip(self.alpha_param, G_alpha, strict=True) + ] obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param) self.prob = cp.Problem(obj, constraint) From 258d965629b5f1bd588277c19f5c0a0b145ed0de Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 11:53:06 +0100 Subject: [PATCH 05/11] Add FURB (Modernizes old-school Python idioms into cleaner, faster versions) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a021d29d..a3a00529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,7 @@ select = [ "TID", # flake8-tidy-imports "SIM", # flake8-simplify "PERF", # Perflint + "FURB", # refurb ] ignore = [ From 0bae8d8ce13394518f9b61f97efd4152158308be Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 11 Feb 2026 09:51:00 +0100 Subject: [PATCH 06/11] Add RUFF with few exceptions. --- pyproject.toml | 7 +++++-- src/torchjd/autogram/_engine.py | 2 +- src/torchjd/autojac/_transform/_diagonalize.py | 2 +- src/torchjd/autojac/_transform/_jac.py | 2 +- tests/unit/autogram/test_engine.py | 4 ++-- tests/unit/autogram/test_gramian_utils.py | 14 +++++++------- tests/unit/autojac/_transform/test_accumulate.py | 4 ++-- tests/unit/autojac/_transform/test_interactions.py | 4 +--- tests/utils/architectures.py | 10 +++++----- tests/utils/forward_backwards.py | 2 +- tests/utils/tensors.py | 2 +- 11 files changed, 27 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3a00529..a52bdb05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,11 +141,14 @@ select = [ "SIM", # flake8-simplify "PERF", # Perflint "FURB", # refurb + "RUF", # Ruff-specific rules ] ignore = [ - "E501", # line-too-long (handled by the formatter) - "E402", # module-import-not-at-top-of-file + "E501", # line-too-long (handled by the formatter) + "E402", # module-import-not-at-top-of-file + "RUF022", # __all__ not sorted + "RUF010", # Use explicit conversion flag ] [tool.ruff.lint.isort] diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 3e9ddfee..4a09a1cc 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -278,7 +278,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: target_shape = [] if has_non_batch_dim: - target_shape = [-1] + target_shape + target_shape = [-1, *target_shape] reshaped_output = ordered_output.reshape(target_shape) # There are four different cases for the shape of reshaped_output: diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 88e5525e..a832da63 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -64,7 +64,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict: flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order] diagonal_matrix = torch.cat(flattened_considered_values).diag() diagonalized_tensors = { - key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape) + key: diagonal_matrix[:, begin:end].reshape((-1, *key.shape)) for (begin, end), key in zip(self.indices, self.key_order, strict=True) } return diagonalized_tensors diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 38326d26..d78cf3df 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -62,7 +62,7 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] if len(self.outputs) == 0: return tuple( [ - torch.empty((0,) + input.shape, device=input.device, dtype=input.dtype) + torch.empty((0, *input.shape), device=input.device, dtype=input.dtype) for input in self.inputs ] ) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 2461e383..64f6b926 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -536,11 +536,11 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): input = randn_([batch_size, input_size]) engine1 = Engine(model1, batch_dim=batch_dim) - output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) + output1 = model1(input).reshape([batch_size, *non_batched_shape]).movedim(0, batch_dim) gramian1 = engine1.compute_gramian(output1) engine2 = Engine(model2, batch_dim=None) - output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) + output2 = model2(input).reshape([batch_size, *non_batched_shape]).movedim(0, batch_dim) gramian2 = engine2.compute_gramian(output2) assert_close(gramian1, gramian2) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index ac912918..5f74df8d 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -29,8 +29,8 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]): """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" - original_matrix = randn_(original_shape + [2]) - target_matrix = original_matrix.reshape(target_shape + [2]) + original_matrix = randn_([*original_shape, 2]) + target_matrix = original_matrix.reshape([*target_shape, 2]) original_gramian = compute_gramian(original_matrix, 1) target_gramian = compute_gramian(target_matrix, 1) @@ -56,7 +56,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int] ], ) def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): - matrix = randn_(original_shape + [2]) + matrix = randn_([*original_shape, 2]) gramian = compute_gramian(matrix, 1) reshaped_gramian = reshape(gramian, target_shape) assert_is_psd_tensor(reshaped_gramian, atol=1e-04, rtol=0.0) @@ -73,7 +73,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): ], ) def test_flatten_yields_matrix(shape: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) assert is_psd_matrix(flattened_gramian) @@ -90,7 +90,7 @@ def test_flatten_yields_matrix(shape: list[int]): ], ) def test_flatten_yields_psd(shape: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) assert_is_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0) @@ -117,7 +117,7 @@ def test_flatten_yields_psd(shape: list[int]): def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" - original_matrix = randn_(shape + [2]) + original_matrix = randn_([*shape, 2]) target_matrix = original_matrix.movedim(source, destination) original_gramian = compute_gramian(original_matrix, 1) @@ -147,7 +147,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ], ) def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) moveddim_gramian = movedim(gramian, source, destination) assert_is_psd_tensor(moveddim_gramian) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index eaa09549..8c179a89 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -97,7 +97,7 @@ def test_single_jac_accumulation(): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] - values = [ones_([4] + shape) for shape in shapes] + values = [ones_([4, *shape]) for shape in shapes] input = dict(zip(keys, values, strict=True)) accumulate = AccumulateJac() @@ -118,7 +118,7 @@ def test_multiple_jac_accumulations(iterations: int): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] - values = [ones_([4] + shape) for shape in shapes] + values = [ones_([4, *shape]) for shape in shapes] accumulate = AccumulateJac() diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index a712dcef..470f5d6d 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -235,9 +235,7 @@ def test_equivalence_jac_grads(): grad_2_A, grad_2_b, grad_2_c = grad_dict_2[A], grad_dict_2[b], grad_dict_2[c] n_outputs = len(outputs) - batched_grad_outputs = [ - zeros_((n_outputs,) + grad_output.shape) for grad_output in grad_outputs - ] + batched_grad_outputs = [zeros_((n_outputs, *grad_output.shape)) for grad_output in grad_outputs] for i, grad_output in enumerate(grad_outputs): batched_grad_outputs[i][i] = grad_output diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index d0b8adb2..82b563e1 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Generic, TypeVar +from typing import ClassVar, Generic, TypeVar import torch import torchvision @@ -131,7 +131,7 @@ class SingleInputPyTreeOutput(ShapedModule): """Module taking a single input and returning a complex PyTree of tensors as output.""" INPUT_SHAPES = (50,) - OUTPUT_SHAPES = { + OUTPUT_SHAPES: ClassVar = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],), @@ -156,7 +156,7 @@ def forward(self, input: Tensor) -> PyTree: class PyTreeInputSingleOutput(ShapedModule): """Module taking a complex PyTree of tensors as input and returning a single output.""" - INPUT_SHAPES = { + INPUT_SHAPES: ClassVar = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } @@ -193,12 +193,12 @@ class PyTreeInputPyTreeOutput(ShapedModule): output. """ - INPUT_SHAPES = { + INPUT_SHAPES: ClassVar = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } - OUTPUT_SHAPES = { + OUTPUT_SHAPES: ClassVar = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],), diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index f8b9dfe2..57483dd7 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -198,7 +198,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): for module in self.model.modules(): self._restore_original_params(module) - return False # don’t suppress exceptions + return False # don't suppress exceptions def _restore_original_params(self, module: nn.Module): original_params = self._module_to_original_params.pop(module, {}) diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index b5f7ee94..7988157d 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -40,4 +40,4 @@ def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: def is_leaf(s): return isinstance(s, tuple) and all(isinstance(e, int) for e in s) - return tree_map(lambda s: randn_((batch_size,) + s), tensor_shapes, is_leaf=is_leaf) + return tree_map(lambda s: randn_((batch_size, *s)), tensor_shapes, is_leaf=is_leaf) From 24e4ec3e97f1d135493913796129d4382e521e9f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 11 Feb 2026 09:56:09 +0100 Subject: [PATCH 07/11] Add RET (flake8 returns) --- pyproject.toml | 2 ++ src/torchjd/aggregation/_utils/pref_vector.py | 16 +++++++--------- src/torchjd/autogram/_gramian_computer.py | 3 +-- src/torchjd/autojac/_transform/_jac.py | 3 +-- tests/plots/_utils.py | 5 +---- tests/utils/architectures.py | 8 +++----- tests/utils/forward_backwards.py | 3 +-- 7 files changed, 16 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a52bdb05..5a0fe687 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ select = [ "FIX", # flake8-fixme "TID", # flake8-tidy-imports "SIM", # flake8-simplify + "RET", # flake8-return "PERF", # Perflint "FURB", # refurb "RUF", # Ruff-specific rules @@ -149,6 +150,7 @@ ignore = [ "E402", # module-import-not-at-top-of-file "RUF022", # __all__ not sorted "RUF010", # Use explicit conversion flag + "RET504", # Unnecessary assignment return statement ] [tool.ruff.lint.isort] diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index fea943d7..d2d4a3b6 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -17,13 +17,12 @@ def pref_vector_to_weighting( if pref_vector is None: return default - else: - if pref_vector.ndim != 1: - raise ValueError( - "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " - f"{pref_vector.ndim}`." - ) - return ConstantWeighting(pref_vector) + if pref_vector.ndim != 1: + raise ValueError( + "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " + f"{pref_vector.ndim}`." + ) + return ConstantWeighting(pref_vector) def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: @@ -31,5 +30,4 @@ def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: if pref_vector is None: return "" - else: - return f"([{vector_to_str(pref_vector)}])" + return f"([{vector_to_str(pref_vector)}])" diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 8c1546e0..829e5da3 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -73,5 +73,4 @@ def __call__( gramian = compute_gramian(self.summed_jacobian) del self.summed_jacobian return gramian - else: - return None + return None diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index d78cf3df..b5090f49 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -115,5 +115,4 @@ def _get_jacs_chunk( grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk] gradients = get_vjp(grad_outputs) return tuple(gradient.unsqueeze(0) for gradient in gradients) - else: - return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk) + return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk) diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index 39cd08f8..979ec360 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -239,10 +239,7 @@ def coord_to_angle(x: float, y: float) -> tuple[float, float]: if r == 0: raise ValueError("No angle") - elif y >= 0: - angle = np.arccos(x / r) - else: - angle = 2 * np.pi - np.arccos(x / r) + angle = np.arccos(x / r) if y >= 0 else 2 * np.pi - np.arccos(x / r) return angle, r diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 82b563e1..fb353c75 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -48,14 +48,13 @@ def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: if isinstance(module, ShapedModule): return module.INPUT_SHAPES, module.OUTPUT_SHAPES - elif isinstance(module, nn.BatchNorm2d | nn.InstanceNorm2d): + if isinstance(module, nn.BatchNorm2d | nn.InstanceNorm2d): HEIGHT = 6 # Arbitrary choice WIDTH = 6 # Arbitrary choice shape = (module.num_features, HEIGHT, WIDTH) return shape, shape - else: - raise ValueError("Unknown input / output shapes of module", module) + raise ValueError("Unknown input / output shapes of module", module) class OverlyNested(ShapedModule): @@ -808,8 +807,7 @@ def __init__(self): def forward(self, s: str, input: Tensor) -> Tensor: if s == "two": return input @ self.matrix * 2.0 - else: - return input @ self.matrix + return input @ self.matrix class WithModuleWithStringArg(ShapedModule): diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 57483dd7..b350a894 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -111,8 +111,7 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor: if raw_losses.ndim == 1: return raw_losses.unsqueeze(1) - else: - return raw_losses.flatten(start_dim=1) + return raw_losses.flatten(start_dim=1) def compute_gramian_with_autograd( From 91d6a67bac192b8677c90bd01b61bd0f0f26570e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 12:40:39 +0100 Subject: [PATCH 08/11] Add PYI --- pyproject.toml | 1 + src/torchjd/aggregation/_weighting_bases.py | 2 +- src/torchjd/autojac/_transform/_init.py | 4 ++-- src/torchjd/autojac/_transform/_select.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a0fe687..ded8fee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,7 @@ select = [ "TID", # flake8-tidy-imports "SIM", # flake8-simplify "RET", # flake8-return + "PYI", # flake8-pyi "PERF", # Perflint "FURB", # refurb "RUF", # Ruff-specific rules diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 74129a36..3b037891 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -13,7 +13,7 @@ _FnOutputT = TypeVar("_FnOutputT", bound=Tensor) -class Weighting(Generic[_T], nn.Module, ABC): +class Weighting(nn.Module, ABC, Generic[_T]): r""" Abstract base class for all weighting methods. It has the role of extracting a vector of weights of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`, diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 551f8197..11bdcb68 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -1,4 +1,4 @@ -from collections.abc import Set +from collections.abc import Set as AbstractSet import torch from torch import Tensor @@ -13,7 +13,7 @@ class Init(Transform): :param values: Tensors for which Gradients must be returned. """ - def __init__(self, values: Set[Tensor]): + def __init__(self, values: AbstractSet[Tensor]): self.values = values def __call__(self, input: TensorDict, /) -> TensorDict: diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index 29a6bcd2..76712ab9 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -1,4 +1,4 @@ -from collections.abc import Set +from collections.abc import Set as AbstractSet from torch import Tensor @@ -12,7 +12,7 @@ class Select(Transform): :param keys: The keys that should be included in the returned subset. """ - def __init__(self, keys: Set[Tensor]): + def __init__(self, keys: AbstractSet[Tensor]): self.keys = keys def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: From b1296c0d810542e35c5c3df0930220efb23d2c00 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 12:41:27 +0100 Subject: [PATCH 09/11] Add PIE --- pyproject.toml | 1 + src/torchjd/autojac/_transform/_base.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ded8fee0..44c59d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ select = [ "SIM", # flake8-simplify "RET", # flake8-return "PYI", # flake8-pyi + "PIE", # flake8-pie "PERF", # Perflint "FURB", # refurb "RUF", # Ruff-specific rules diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index db8ff2cb..fbbbdcd6 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -22,8 +22,6 @@ class RequirementError(ValueError): """Inappropriate set of inputs keys.""" - pass - class Transform(ABC): """ From d282f59df3729430951ea6dae9890153330a7f4a Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 5 Feb 2026 12:46:03 +0100 Subject: [PATCH 10/11] Add COM (commas) --- pyproject.toml | 1 + src/torchjd/_linalg/_gramian.py | 4 ++- src/torchjd/aggregation/_aggregator_bases.py | 2 +- src/torchjd/aggregation/_aligned_mtl.py | 5 ++-- src/torchjd/aggregation/_constant.py | 4 +-- src/torchjd/aggregation/_dualproj.py | 2 +- src/torchjd/aggregation/_graddrop.py | 4 +-- src/torchjd/aggregation/_krum.py | 8 ++--- src/torchjd/aggregation/_nash_mtl.py | 4 +-- src/torchjd/aggregation/_trimmed_mean.py | 4 +-- src/torchjd/aggregation/_upgrad.py | 2 +- src/torchjd/aggregation/_utils/pref_vector.py | 5 ++-- src/torchjd/autogram/_engine.py | 4 +-- src/torchjd/autogram/_jacobian_computer.py | 9 ++++-- src/torchjd/autojac/_accumulation.py | 4 +-- src/torchjd/autojac/_jac_to_grad.py | 10 +++++-- src/torchjd/autojac/_mtl_backward.py | 3 +- .../autojac/_transform/_diagonalize.py | 2 +- .../autojac/_transform/_differentiate.py | 2 +- src/torchjd/autojac/_transform/_init.py | 2 +- src/torchjd/autojac/_transform/_jac.py | 5 ++-- .../autojac/_transform/_materialize.py | 3 +- src/torchjd/autojac/_transform/_select.py | 2 +- src/torchjd/autojac/_utils.py | 8 +++-- tests/plots/_utils.py | 21 +++++++++---- tests/plots/interactive_plotter.py | 5 ++-- tests/profiling/run_profiler.py | 5 +++- tests/profiling/speed_grad_vs_jac_vs_gram.py | 6 +++- tests/settings.py | 4 +-- tests/unit/aggregation/_asserts.py | 9 ++++-- tests/unit/aggregation/test_krum.py | 5 +++- tests/unit/aggregation/test_pcgrad.py | 5 +++- tests/unit/aggregation/test_values.py | 6 ++-- tests/unit/autogram/test_engine.py | 17 ++++++++--- tests/unit/autojac/_transform/test_jac.py | 30 +++++++++++++++---- tests/unit/linalg/test_gramian.py | 2 +- tests/utils/architectures.py | 6 ++-- tests/utils/forward_backwards.py | 4 ++- 38 files changed, 151 insertions(+), 73 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 44c59d70..8d19daf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,7 @@ select = [ "RET", # flake8-return "PYI", # flake8-pyi "PIE", # flake8-pie + "COM", # flake8-commas "PERF", # Perflint "FURB", # refurb "RUF", # Ruff-specific rules diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index f048c238..58a2af82 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -70,7 +70,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix: """ regularization_matrix = eps * torch.eye( - gramian.shape[0], dtype=gramian.dtype, device=gramian.device + gramian.shape[0], + dtype=gramian.dtype, + device=gramian.device, ) output = gramian + regularization_matrix return cast(PSDMatrix, output) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 6935199b..5a656f67 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None: if not is_matrix(matrix): raise ValueError( "Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = " - f"{matrix.shape}`." + f"{matrix.shape}`.", ) @abstractmethod diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index bf4f8dc0..fe807e0a 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -107,7 +107,8 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: @staticmethod def _compute_balance_transformation( - M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min" + M: Tensor, + scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps @@ -130,7 +131,7 @@ def _compute_balance_transformation( scale = lambda_.mean() else: raise ValueError( - f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'." + f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.", ) B = scale.sqrt() * V @ sigma_inv @ V.T diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index f4f062bf..03b629ea 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -39,7 +39,7 @@ def __init__(self, weights: Tensor): if weights.dim() != 1: raise ValueError( "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = " - f"{weights.shape}`." + f"{weights.shape}`.", ) super().__init__() @@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None: if matrix.shape[0] != len(self.weights): raise ValueError( f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " - f"weights). Found `matrix` with {matrix.shape[0]} rows." + f"weights). Found `matrix` with {matrix.shape[0]} rows.", ) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 4fc8cefd..d7d88648 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -40,7 +40,7 @@ def __init__( self._solver: SUPPORTED_SOLVER = solver super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) + DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) # This prevents considering the computed weights as constant w.r.t. the matrix. diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index b6ea1327..f8a39426 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None): if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " - f"{leak.shape}`." + f"{leak.shape}`.", ) super().__init__() @@ -64,7 +64,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: if self.leak is not None and n_rows != len(self.leak): raise ValueError( f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the " - f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`." + f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.", ) def __repr__(self) -> str: diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index bce211c6..93565174 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1): if n_byzantine < 0: raise ValueError( "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = " - f"{n_byzantine}`." + f"{n_byzantine}`.", ) if n_selected < 1: raise ValueError( "Parameter `n_selected` should be a positive integer. Found `n_selected = " - f"{n_selected}`." + f"{n_selected}`.", ) self.n_byzantine = n_byzantine @@ -85,11 +85,11 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None: if gramian.shape[0] < min_rows: raise ValueError( f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found" - f" `gramian` with {gramian.shape[0]} rows." + f" `gramian` with {gramian.shape[0]} rows.", ) if gramian.shape[0] < self.n_selected: raise ValueError( f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). " - f"Found `gramian` with {gramian.shape[0]} rows." + f"Found `gramian` with {gramian.shape[0]} rows.", ) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index e875d46b..83455245 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -84,7 +84,7 @@ def __init__( max_norm=max_norm, update_weights_every=update_weights_every, optim_niter=optim_niter, - ) + ), ) self._n_tasks = n_tasks self._max_norm = max_norm @@ -144,7 +144,7 @@ def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool: return bool( (self.alpha_param.value is None) or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) - or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6) + or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6), ) def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index f4e3dfc4..07ed055a 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -20,7 +20,7 @@ def __init__(self, trim_number: int): if trim_number < 0: raise ValueError( "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = " - f"{trim_number}`." + f"{trim_number}`.", ) self.trim_number = trim_number @@ -41,7 +41,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: if n_rows < min_rows: raise ValueError( f"Parameter `matrix` should be a matrix of at least {min_rows} rows " - f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`." + f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`.", ) def __repr__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 6b8ec0f6..8234b3a8 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -41,7 +41,7 @@ def __init__( self._solver: SUPPORTED_SOLVER = solver super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) + UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) # This prevents considering the computed weights as constant w.r.t. the matrix. diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index d2d4a3b6..caffabd9 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -8,7 +8,8 @@ def pref_vector_to_weighting( - pref_vector: Tensor | None, default: Weighting[Matrix] + pref_vector: Tensor | None, + default: Weighting[Matrix], ) -> Weighting[Matrix]: """ Returns the weighting associated to a given preference vector, with a fallback to a default @@ -20,7 +21,7 @@ def pref_vector_to_weighting( if pref_vector.ndim != 1: raise ValueError( "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " - f"{pref_vector.ndim}`." + f"{pref_vector.ndim}`.", ) return ConstantWeighting(pref_vector) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 4a09a1cc..f8248c0c 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -222,7 +222,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The " f"recommended fix is to replace incompatible layers by something else (e.g. " f"BatchNorm by InstanceNorm). If you really can't and performance is not a " - f"priority, you may also just set `batch_dim=None` when creating the engine." + f"priority, you may also just set `batch_dim=None` when creating the engine.", ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: raise ValueError( @@ -231,7 +231,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f" to performing in-place operations on tensors and having side-effects during " f"the forward pass. Try setting `track_running_stats` to `False`. If you really" f" can't and performance is not a priority, you may also just set " - f"`batch_dim=None` when creating the engine." + f"`batch_dim=None` when creating the engine.", ) # Currently, the type PSDMatrix is hidden from users, so Tensor is correct. diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 6929f88e..45cc71ba 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -46,7 +46,11 @@ def __call__( ) -> Matrix: # This makes __call__ vmappable. return ComputeModuleJacobians.apply( - self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs + self._compute_jacobian, + rg_outputs, + grad_outputs, + args, + kwargs, ) @abstractmethod @@ -155,7 +159,8 @@ class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( compute_jacobian_fn: Callable[ - [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Matrix + [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], + Matrix, ], rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index bafd3d62..4272da6b 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -27,7 +27,7 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No raise RuntimeError( f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" - " jacobian are the same size" + " jacobian are the same size", ) if is_tensor_with_jac(param): @@ -57,7 +57,7 @@ def _check_expects_grad(tensor: Tensor, field_name: str) -> None: if not _expects_grad(tensor): raise ValueError( f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n" - "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." + "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`.", ) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2b2e2561..06c9c47d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -9,7 +9,9 @@ def jac_to_grad( - tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False + tensors: Iterable[Tensor], + aggregator: Aggregator, + retain_jac: bool = False, ) -> None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result @@ -56,7 +58,7 @@ def jac_to_grad( if not is_tensor_with_jac(t): raise ValueError( "Some `jac` fields were not populated. Did you use `autojac.backward` or " - "`autojac.mtl_backward` before calling `jac_to_grad`?" + "`autojac.mtl_backward` before calling `jac_to_grad`?", ) tensors_.append(t) @@ -84,7 +86,9 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( - gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] + gradient_vector: Tensor, + jacobians: list[Tensor], + tensors: list[TensorWithJac], ) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 831099ed..20fd1b01 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -184,7 +184,8 @@ def _check_losses_are_scalar(losses: Iterable[Tensor]) -> None: def _check_no_overlap( - shared_params: Iterable[Tensor], tasks_params: Sequence[Iterable[Tensor]] + shared_params: Iterable[Tensor], + tasks_params: Sequence[Iterable[Tensor]], ) -> None: task_param_set = {param for task_params in tasks_params for param in task_params} shared_param_set = set(shared_params) diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index a832da63..182306de 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -73,6 +73,6 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: if not set(self.key_order) == input_keys: raise RequirementError( f"The input_keys must match the key_order. Found input_keys {input_keys} and" - f"key_order {self.key_order}." + f"key_order {self.key_order}.", ) return input_keys diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 3cec097d..18117f3c 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -60,7 +60,7 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: if not outputs == input_keys: raise RequirementError( f"The input_keys must match the expected outputs. Found input_keys {input_keys} and" - f"outputs {outputs}." + f"outputs {outputs}.", ) return set(self.inputs) diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 11bdcb68..26042979 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -22,6 +22,6 @@ def __call__(self, input: TensorDict, /) -> TensorDict: def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: if not input_keys == set(): raise RequirementError( - f"The input_keys should be the empty set. Found input_keys {input_keys}." + f"The input_keys should be the empty set. Found input_keys {input_keys}.", ) return set(self.values) diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index b5090f49..1f33d1d9 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -64,7 +64,7 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] [ torch.empty((0, *input.shape), device=input.device, dtype=input.dtype) for input in self.inputs - ] + ], ) # If the jac_outputs are correct, this value should be the same for all jac_outputs. @@ -101,7 +101,8 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] def _get_jacs_chunk( - jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]] + jac_outputs_chunk: list[Tensor], + get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]], ) -> tuple[Tensor, ...]: """ Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by diff --git a/src/torchjd/autojac/_transform/_materialize.py b/src/torchjd/autojac/_transform/_materialize.py index 89100168..038565b2 100644 --- a/src/torchjd/autojac/_transform/_materialize.py +++ b/src/torchjd/autojac/_transform/_materialize.py @@ -5,7 +5,8 @@ def materialize( - optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor] + optional_tensors: Sequence[Tensor | None], + inputs: Sequence[Tensor], ) -> tuple[Tensor, ...]: """ Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index 76712ab9..1575ecf6 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -24,6 +24,6 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: if not keys.issubset(input_keys): raise RequirementError( f"The input_keys should be a super set of the keys to select. Found input_keys " - f"{input_keys} and keys to select {keys}." + f"{input_keys} and keys to select {keys}.", ) return keys diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 1a460ce5..e286cf20 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -12,12 +12,13 @@ def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: if not (parallel_chunk_size is None or parallel_chunk_size > 0): raise ValueError( "`parallel_chunk_size` should be `None` or greater than `0`. (got " - f"{parallel_chunk_size})" + f"{parallel_chunk_size})", ) def as_checked_ordered_set( - tensors: Sequence[Tensor] | Tensor, variable_name: str + tensors: Sequence[Tensor] | Tensor, + variable_name: str, ) -> OrderedSet[Tensor]: if isinstance(tensors, Tensor): tensors = [tensors] @@ -61,7 +62,8 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O def _get_descendant_accumulate_grads( - roots: OrderedSet[Node], excluded_nodes: set[Node] + roots: OrderedSet[Node], + excluded_nodes: set[Node], ) -> OrderedSet[Node]: """ Gets the AccumulateGrad descendants of the specified nodes. diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index 979ec360..0705893b 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -28,7 +28,11 @@ def make_fig(self) -> Figure: for i in range(len(results)): scatter = make_vector_scatter( - results[i], "black", str(self.aggregators[i]), showlegend=True, dash=True + results[i], + "black", + str(self.aggregators[i]), + showlegend=True, + dash=True, ) fig.add_trace(scatter) @@ -87,7 +91,11 @@ def make_vector_scatter( def make_cone_scatter( - start_angle: float, opening: float, label: str, scale: float = 100.0, printable: bool = False + start_angle: float, + opening: float, + label: str, + scale: float = 100.0, + printable: bool = False, ) -> Scatter: if opening < -1e-8: cone_outline = np.zeros([0, 2]) @@ -105,7 +113,7 @@ def make_cone_scatter( start_vec, # Tip of the first vector end_vec, # Tip of the second vector [0, 0], # Back to the origin to close the cone - ] + ], ) else: middle_point = angle_to_coord(middle_angle, scale) @@ -117,7 +125,7 @@ def make_cone_scatter( middle_point, # Tip of the vector in-between end_vec, # Tip of the second vector [0, 0], # Back to the origin to close the cone - ] + ], ) if printable: @@ -167,7 +175,10 @@ def make_polygon_scatter(points: list[torch.Tensor]) -> Scatter: def make_right_angle( - vector: torch.Tensor, size: float, positive_para: bool = True, positive_orth: bool = True + vector: torch.Tensor, + size: float, + positive_para: bool = True, + positive_orth: bool = True, ) -> list[torch.Tensor]: vec_para = vector / torch.linalg.norm(vector) * size vec_orth = torch.tensor([-vec_para[1], vec_para[0]]) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index d78b7fde..0d26a1e6 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -39,7 +39,7 @@ def main() -> None: [0.0, 1.0], [1.0, -1.0], [1.0, 0.0], - ] + ], ) aggregators = [ @@ -148,7 +148,8 @@ def update_aggregators(value: list[str]) -> Figure: def make_gradient_div( - i: int, initial_gradient: torch.Tensor + i: int, + initial_gradient: torch.Tensor, ) -> tuple[html.Div, dcc.Input, dcc.Input]: x = initial_gradient[0].item() y = initial_gradient[1].item() diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index b143a55b..7707cb01 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -90,7 +90,10 @@ def _get_profiler_activities() -> list[ProfilerActivity]: def _save_and_print_trace( - prof: profile, method_name: str, factory: ModuleFactory, batch_size: int + prof: profile, + method_name: str, + factory: ModuleFactory, + batch_size: int, ) -> None: filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" output_dir = TRACES_DIR / method_name diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index 13b57b62..d68d67aa 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -105,7 +105,11 @@ def post_fn(): print_times("autograd", autograd_times) autograd_gramian_times = time_call( - fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs + fn_autograd_gramian, + init_fn_autograd_gramian, + pre_fn, + post_fn, + n_runs, ) print_times("autograd gramian", autograd_gramian_times) diff --git a/tests/settings.py b/tests/settings.py index b7fe2345..008080b4 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -13,7 +13,7 @@ if _device_str not in _POSSIBLE_TEST_DEVICES: raise ValueError( f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n" - f"Possible values: {_POSSIBLE_TEST_DEVICES}." + f"Possible values: {_POSSIBLE_TEST_DEVICES}.", ) if _device_str == "cuda:0" and not torch.cuda.is_available(): @@ -29,7 +29,7 @@ if _dtype_str not in _POSSIBLE_TEST_DTYPES: raise ValueError( f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n" - f"Possible values: {_POSSIBLE_TEST_DTYPES}." + f"Possible values: {_POSSIBLE_TEST_DTYPES}.", ) DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32 diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 15b69874..8c119674 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -22,7 +22,10 @@ def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None: def assert_non_conflicting( - aggregator: Aggregator, matrix: Tensor, atol: float = 4e-04, rtol: float = 4e-04 + aggregator: Aggregator, + matrix: Tensor, + atol: float = 4e-04, + rtol: float = 4e-04, ) -> None: """Tests empirically that a given `Aggregator` satisfies the non-conflicting property.""" @@ -81,7 +84,9 @@ def assert_linear_under_scaling( def assert_strongly_stationary( - aggregator: Aggregator, matrix: Tensor, threshold: float = 5e-03 + aggregator: Aggregator, + matrix: Tensor, + threshold: float = 5e-03, ) -> None: """ Tests empirically that a given `Aggregator` is strongly stationary. diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index ff75e5cf..48fa4019 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -62,7 +62,10 @@ def test_n_selected_check(n_selected: int, expectation: ExceptionContext): ], ) def test_matrix_shape_check( - n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext + n_byzantine: int, + n_selected: int, + n_rows: int, + expectation: ExceptionContext, ): aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) matrix = ones_([n_rows, 5]) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index d55d87c0..57a9120c 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -52,7 +52,10 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): pc_grad_weighting = PCGradWeighting() upgrad_sum_weighting = UPGradWeighting( - ones_((2,)), norm_eps=0.0, reg_eps=0.0, solver="quadprog" + ones_((2,)), + norm_eps=0.0, + reg_eps=0.0, + solver="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 5fed3869..67f24862 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -39,7 +39,7 @@ [75.0, -666.0, 23], # adversarial row [1.0, 2.0, 3.0], [2.0, 0.0, 1.0], - ] + ], ) J_TrimmedMean = tensor( [ @@ -47,7 +47,7 @@ [1.0, -1e11], [-1e10, 1e10], [2.0, 2.0], - ] + ], ) AGGREGATOR_PARAMETRIZATIONS = [ @@ -101,7 +101,7 @@ J_base, tensor([0.0542, 0.7061, 0.7061]), marks=mark.filterwarnings("ignore::UserWarning"), - ) + ), ) except ImportError: diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 64f6b926..00ae67b1 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -144,7 +144,9 @@ def _assert_gramian_is_equivalent_to_autograd( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): model_autograd, model_autogram = factory(), factory() engine = Engine(model_autogram, batch_dim=batch_dim) @@ -208,7 +210,9 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) def test_compute_gramian_with_weird_modules( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): """ Tests that compute_gramian works even with some problematic modules when batch_dim is None. It @@ -230,7 +234,9 @@ def test_compute_gramian_with_weird_modules( @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) def test_compute_gramian_unsupported_architectures( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to @@ -353,7 +359,10 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch @mark.parametrize("use_engine", [False, True]) @mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - factory: ModuleFactory, batch_size: int, use_engine: bool, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + use_engine: bool, + batch_dim: int | None, ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index c00e43d2..e1efecf5 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -106,10 +106,16 @@ def test_retain_graph(): input = {y: eye_(2)} jac_retain_graph = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=True, ) jac_discard_graph = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=False + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=False, ) jac_retain_graph(input) @@ -140,10 +146,16 @@ def test_two_levels(): input = {z: eye_(2)} outer_jac = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=True, ) inner_jac = Jac( - outputs=OrderedSet([z]), inputs=OrderedSet([y]), chunk_size=None, retain_graph=True + outputs=OrderedSet([z]), + inputs=OrderedSet([y]), + chunk_size=None, + retain_graph=True, ) composed_jac = outer_jac << inner_jac jac = Jac(outputs=OrderedSet([z]), inputs=OrderedSet([a1, a2]), chunk_size=None) @@ -236,7 +248,10 @@ def test_composition_of_jacs_is_jac(): input = {z1: tensor_([1.0, 0.0]), z2: tensor_([0.0, 1.0])} outer_jac = Jac( - outputs=OrderedSet([y1, y2]), inputs=OrderedSet([a]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y1, y2]), + inputs=OrderedSet([a]), + chunk_size=None, + retain_graph=True, ) inner_jac = Jac( outputs=OrderedSet([z1, z2]), @@ -291,7 +306,10 @@ def test_create_graph(): input = {y: eye_(2)} jac = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, create_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + create_graph=True, ) jacobians = jac(input) diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py index 6fc22512..53373822 100644 --- a/tests/unit/linalg/test_gramian.py +++ b/tests/unit/linalg/test_gramian.py @@ -57,7 +57,7 @@ def test_compute_gramian_matrix_input_0(): [ [[[1.0, 3.0], [2.0, 4.0]], [[2.0, 6.0], [4.0, 8.0]]], [[[3.0, 9.0], [6.0, 12.0]], [[4.0, 12.0], [8.0, 16.0]]], - ] + ], ) assert_close(gramian, expected) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index fb353c75..5e8586b2 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1148,7 +1148,7 @@ class InstanceNormResNet18(ShapedModule): def __init__(self): super().__init__() self.resnet18 = torchvision.models.resnet18( - norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) + norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), ) def forward(self, input: Tensor) -> Tensor: @@ -1164,7 +1164,7 @@ class GroupNormMobileNetV3Small(ShapedModule): def __init__(self): super().__init__() self.mobile_net = torchvision.models.mobilenet_v3_small( - norm_layer=partial(nn.GroupNorm, 2, affine=True) + norm_layer=partial(nn.GroupNorm, 2, affine=True), ) def forward(self, input: Tensor) -> Tensor: @@ -1194,7 +1194,7 @@ class InstanceNormMobileNetV2(ShapedModule): def __init__(self): super().__init__() self.mobilenet = torchvision.models.mobilenet_v2( - norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) + norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), ) def forward(self, input: Tensor) -> Tensor: diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index b350a894..57f9b90a 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -115,7 +115,9 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor: def compute_gramian_with_autograd( - output: Tensor, params: list[nn.Parameter], retain_graph: bool = False + output: Tensor, + params: list[nn.Parameter], + retain_graph: bool = False, ) -> PSDTensor: """ Computes the Gramian of the Jacobian of the outputs with respect to the params using vmapped From f06327740b78cadde8cb1318369daeb772f8f5ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 13 Feb 2026 11:38:09 +0100 Subject: [PATCH 11/11] Remove RUF012 --- pyproject.toml | 1 + tests/utils/architectures.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d19daf5..251f0623 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,6 +153,7 @@ ignore = [ "E402", # module-import-not-at-top-of-file "RUF022", # __all__ not sorted "RUF010", # Use explicit conversion flag + "RUF012", # Mutable default value for class attribute (a bit tedious to fix) "RET504", # Unnecessary assignment return statement ] diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 5e8586b2..cf3261f8 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,5 +1,5 @@ from functools import partial -from typing import ClassVar, Generic, TypeVar +from typing import Generic, TypeVar import torch import torchvision @@ -130,7 +130,7 @@ class SingleInputPyTreeOutput(ShapedModule): """Module taking a single input and returning a complex PyTree of tensors as output.""" INPUT_SHAPES = (50,) - OUTPUT_SHAPES: ClassVar = { + OUTPUT_SHAPES = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],), @@ -155,7 +155,7 @@ def forward(self, input: Tensor) -> PyTree: class PyTreeInputSingleOutput(ShapedModule): """Module taking a complex PyTree of tensors as input and returning a single output.""" - INPUT_SHAPES: ClassVar = { + INPUT_SHAPES = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } @@ -192,12 +192,12 @@ class PyTreeInputPyTreeOutput(ShapedModule): output. """ - INPUT_SHAPES: ClassVar = { + INPUT_SHAPES = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } - OUTPUT_SHAPES: ClassVar = { + OUTPUT_SHAPES = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],),