diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 32b7e127..58720967 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -109,8 +109,8 @@ jobs: # This reduces false positives due to rate limits GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - typing: - name: Typing correctness + code-quality: + name: Code quality (ty and ruff) runs-on: ubuntu-latest steps: - name: Checkout repository @@ -128,32 +128,5 @@ jobs: - name: Run ty run: uv run ty check --output-format=github - check-todos: - name: Absence of TODOs - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Scan for TODO strings - run: | - echo "Scanning codebase for TODOs..." - - git grep -nE "TODO" -- . ':(exclude).github/workflows/*' > todos_found.txt || true - - if [ -s todos_found.txt ]; then - echo "❌ ERROR: Found TODOs in the following files:" - echo "-------------------------------------------" - - while IFS=: read -r file line content; do - echo "::error file=$file,line=$line::TODO found at $file:$line - must be resolved before merge:%0A$content" - done < todos_found.txt - - echo "-------------------------------------------" - echo "Please resolve these TODOs or track them in an issue before merging." - - exit 1 - else - echo "✅ No TODOs found. Codebase is clean!" - exit 0 - fi + - name: Run ruff + run: uv run ruff check --output-format=github diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a01494aa..6b8b4224 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,32 +9,12 @@ repos: - id: check-docstring-first # Check a common error of defining a docstring after code. - id: check-merge-conflict # Check for files that contain merge conflict strings. -- repo: https://github.com/PyCQA/flake8 - rev: 7.3.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.0 hooks: - - id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually. - args: [ - '--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems. - ] - -- repo: https://github.com/pycqa/isort - rev: 7.0.0 - hooks: - - id: isort # Sort imports. - args: [ - --multi-line=3, - --line-length=100, - --trailing-comma, - --force-grid-wrap=0, - --use-parentheses, - --ensure-newline-before-comments, - ] - -- repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.12.0 - hooks: - - id: black # Format code. - args: [--line-length=100] + - id: ruff + args: [ --fix, --ignore, FIX ] + - id: ruff-format ci: autoupdate_commit_msg: 'chore: Update pre-commit hooks' diff --git a/pyproject.toml b/pyproject.toml index 7f879851..06071519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md" [dependency-groups] check = [ + "ruff>=0.14.14", "ty>=0.0.14", "pre-commit>=2.9.2", # isort doesn't work before 2.9.2 ] @@ -122,6 +123,32 @@ exclude_lines = [ "@overload", ] +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle Error + "F", # Pyflakes + "W", # pycodestyle Warning + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "FIX", # flake8-fixme +] + +ignore = [ + "E501", # line-too-long (handled by the formatter) + "E402", # module-import-not-at-top-of-file +] + +[tool.ruff.lint.isort] +combine-as-imports = true + +[tool.ruff.format] +quote-style = "double" + [tool.ty.src] include = ["src", "tests"] exclude = ["src/torchjd/aggregation/_nash_mtl.py"] diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py index a74b6c78..4253561a 100644 --- a/src/torchjd/__init__.py +++ b/src/torchjd/__init__.py @@ -1,8 +1,7 @@ from collections.abc import Callable from warnings import warn as _warn -from .autojac import backward as _backward -from .autojac import mtl_backward as _mtl_backward +from .autojac import backward as _backward, mtl_backward as _mtl_backward _deprecated_items: dict[str, tuple[str, Callable]] = { "backward": ("autojac", _backward), diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 6e0620ca..b6ea1327 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -74,5 +74,5 @@ def __str__(self) -> str: if self.leak is None: leak_str = "" else: - leak_str = f"([{', '.join(['{:.2f}'.format(l_).rstrip('0') for l_ in self.leak])}])" + leak_str = f"([{', '.join([f'{l_:.2f}'.rstrip('0') for l_ in self.leak])}])" return f"GradDrop{leak_str}" diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index a2608404..8f753c2a 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -53,7 +53,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: dtype = gramian.dtype alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0] - for i in range(self.max_iters): + for _ in range(self.max_iters): t = torch.argmin(gramian @ alpha) e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype) e_t[t] = 1.0 diff --git a/src/torchjd/aggregation/_utils/str.py b/src/torchjd/aggregation/_utils/str.py index 8fda8b26..82a04540 100644 --- a/src/torchjd/aggregation/_utils/str.py +++ b/src/torchjd/aggregation/_utils/str.py @@ -7,5 +7,5 @@ def vector_to_str(vector: Tensor) -> str: `1.23, 1., ...`. """ - weights_str = ", ".join(["{:.2f}".format(value).rstrip("0") for value in vector]) + weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector]) return weights_str diff --git a/src/torchjd/autogram/_gramian_accumulator.py b/src/torchjd/autogram/_gramian_accumulator.py index 91ace4e6..e9fe81f8 100644 --- a/src/torchjd/autogram/_gramian_accumulator.py +++ b/src/torchjd/autogram/_gramian_accumulator.py @@ -1,5 +1,3 @@ -from typing import Optional - from torchjd._linalg import PSDMatrix @@ -13,7 +11,7 @@ class GramianAccumulator: """ def __init__(self) -> None: - self._gramian: Optional[PSDMatrix] = None + self._gramian: PSDMatrix | None = None def reset(self) -> None: self._gramian = None @@ -25,7 +23,7 @@ def accumulate_gramian(self, gramian: PSDMatrix) -> None: self._gramian = gramian @property - def gramian(self) -> Optional[PSDMatrix]: + def gramian(self) -> PSDMatrix | None: """ Get the Gramian matrix accumulated so far. diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index f5be882c..8c1546e0 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, cast +from typing import cast from torch import Tensor from torch.utils._pytree import PyTree @@ -16,12 +16,14 @@ def __call__( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[PSDMatrix]: + ) -> PSDMatrix | None: """Compute what we can for a module and optionally return the gramian if it's ready.""" + @abstractmethod def track_forward_call(self) -> None: """Track that the module's forward was called. Necessary in some implementations.""" + @abstractmethod def reset(self) -> None: """Reset state if any. Necessary in some implementations.""" @@ -40,7 +42,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): def __init__(self, jacobian_computer: JacobianComputer): super().__init__(jacobian_computer) self.remaining_counter = 0 - self.summed_jacobian: Optional[Matrix] = None + self.summed_jacobian: Matrix | None = None def reset(self) -> None: self.remaining_counter = 0 @@ -55,7 +57,7 @@ def __call__( grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree], - ) -> Optional[PSDMatrix]: + ) -> PSDMatrix | None: """Compute what we can for a module and optionally return the gramian if it's ready.""" jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index fe4c22a5..ef48b784 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -141,7 +141,7 @@ def __call__( *rg_outputs, ) - for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs): + for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs, strict=True): flat_outputs[idx] = output return tree_unflatten(flat_outputs, output_spec) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 41b5f108..1c809d2d 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -1,5 +1,4 @@ -from collections.abc import Sequence -from typing import Iterable +from collections.abc import Iterable, Sequence from torch import Tensor diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 61427467..352e2655 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -87,7 +87,7 @@ def _disunite_gradient( gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) - gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)] + gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] return gradients diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 5755c9ee..831099ed 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -132,7 +132,7 @@ def _create_transform( OrderedSet([loss]), retain_graph, ) - for task_params, loss in zip(tasks_params, losses) + for task_params, loss in zip(tasks_params, losses, strict=True) ] # Transform that stacks the gradients of the losses w.r.t. the shared representations into a diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index cc7791ea..88e5525e 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -65,7 +65,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict: diagonal_matrix = torch.cat(flattened_considered_values).diag() diagonalized_tensors = { key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape) - for (begin, end), key in zip(self.indices, self.key_order) + for (begin, end), key in zip(self.indices, self.key_order, strict=True) } return diagonalized_tensors diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 260d1dab..3cec097d 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -41,7 +41,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict: tensor_outputs = [tensors[output] for output in self.outputs] differentiated_tuple = self._differentiate(tensor_outputs) - new_differentiations = dict(zip(self.inputs, differentiated_tuple)) + new_differentiations = dict(zip(self.inputs, differentiated_tuple, strict=True)) return type(tensors)(new_differentiations) @abstractmethod diff --git a/src/torchjd/autojac/_transform/_materialize.py b/src/torchjd/autojac/_transform/_materialize.py index 98f60e99..89100168 100644 --- a/src/torchjd/autojac/_transform/_materialize.py +++ b/src/torchjd/autojac/_transform/_materialize.py @@ -16,7 +16,7 @@ def materialize( """ tensors = [] - for optional_tensor, input in zip(optional_tensors, inputs): + for optional_tensor, input in zip(optional_tensors, inputs, strict=True): if optional_tensor is None: tensors.append(torch.zeros_like(input)) else: diff --git a/tests/conftest.py b/tests/conftest.py index 06c3d98b..5288aa1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,11 +53,11 @@ def pytest_make_parametrize_id(config, val, argname): MAX_SIZE = 40 optional_string = None # Returning None means using pytest's way of making the string - if isinstance(val, (Aggregator, ModuleFactory, Weighting)): + if isinstance(val, Aggregator | ModuleFactory | Weighting): optional_string = str(val) elif isinstance(val, Tensor): optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor - elif isinstance(val, (tuple, list, set)) and len(val) < 20: + elif isinstance(val, tuple | list | set) and len(val) < 20: optional_string = str(val) elif isinstance(val, RaisesExc): optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions]) diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 4445bc67..43651824 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -22,7 +22,7 @@ def test_engine(): # Create the engine before the backward pass, and only once. engine = Engine(model, batch_dim=0) - for input, target in zip(inputs, targets): + for input, target in zip(inputs, targets, strict=True): output = model(input).squeeze(dim=1) # shape: [16] losses = criterion(output, target) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 327cb333..cdf26812 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -35,7 +35,7 @@ def test_amp(): task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False): with torch.autocast(device_type="cpu", dtype=torch.float16): features = shared_module(input) output1 = task1_module(features) @@ -105,7 +105,7 @@ def test_iwmtl(): task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False): features = shared_module(input) # shape: [16, 3] out1 = task1_module(features).squeeze(1) # shape: [16] out2 = task2_module(features).squeeze(1) # shape: [16] @@ -140,7 +140,7 @@ def test_autograd(): params = model.parameters() optimizer = SGD(params, lr=0.1) - for x, y in zip(X, Y): + for x, y in zip(X, Y, strict=False): y_hat = model(x).squeeze(dim=1) # shape: [16] loss = loss_fn(y_hat, y) # shape: [] (scalar) loss.backward() @@ -165,7 +165,7 @@ def test_autojac(): optimizer = SGD(params, lr=0.1) aggregator = UPGrad() - for x, y in zip(X, Y): + for x, y in zip(X, Y, strict=False): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] backward(losses) @@ -192,7 +192,7 @@ def test_autogram(): weighting = UPGradWeighting() engine = Engine(model, batch_dim=0) - for x, y in zip(X, Y): + for x, y in zip(X, Y, strict=False): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] gramian = engine.compute_gramian(losses) # shape: [16, 16] @@ -318,7 +318,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False): features = shared_module(input) output1 = task1_module(features) output2 = task2_module(features) @@ -356,7 +356,7 @@ def test_mtl(): task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task - for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False): features = shared_module(input) output1 = task1_module(features) output2 = task2_module(features) @@ -392,7 +392,7 @@ def test_partial_jd(): params = model.parameters() optimizer = SGD(params, lr=0.1) - for x, y in zip(X, Y): + for x, y in zip(X, Y, strict=False): y_hat = model(x).squeeze(dim=1) # shape: [16] losses = loss_fn(y_hat, y) # shape: [16] gramian = engine.compute_gramian(losses) @@ -417,7 +417,7 @@ def test_rnn(): inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. - for input, target in zip(inputs, targets): + for input, target in zip(inputs, targets, strict=False): output, _ = rnn(input) # output is of shape [5, 3, 20]. losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 25f034e5..d78b7fde 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -7,8 +7,8 @@ import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure -from plots._utils import Plotter, angle_to_coord, coord_to_angle +from plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( IMTLG, MGDA, diff --git a/tests/profiling/plot_memory_timeline.py b/tests/profiling/plot_memory_timeline.py index f9197101..f7cbeec4 100644 --- a/tests/profiling/plot_memory_timeline.py +++ b/tests/profiling/plot_memory_timeline.py @@ -28,7 +28,7 @@ def from_event(event: dict): def extract_memory_timeline(path: Path) -> np.ndarray: - with open(path, "r") as f: + with open(path) as f: data = json.load(f) events = data["traceEvents"] @@ -53,7 +53,7 @@ def plot_memory_timelines(experiment: str, folders: list[str]) -> None: timelines.append(extract_memory_timeline(path)) fig, ax = plt.subplots(figsize=(12, 6)) - for folder, timeline in zip(folders, timelines): + for folder, timeline in zip(folders, timelines, strict=True): time = (timeline[:, 0] - timeline[0, 0]) // 1000 # Make time start at 0 and convert to ms. memory = timeline[:, 1] ax.plot(time, memory, label=folder, linewidth=1.5) diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index ebab7849..b143a55b 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -1,5 +1,5 @@ import gc -from typing import Callable +from collections.abc import Callable import torch from settings import DEVICE diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index 16be875e..13b57b62 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -55,9 +55,7 @@ def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_si A = Mean() W = A.weighting - print( - f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A}" f" on {DEVICE}." - ) + print(f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A} on {DEVICE}.") def fn_autograd(): autograd_forward_backward(model, inputs, loss_fn) diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index a106e56f..8a3acd8d 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -31,8 +31,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return ( - f"{self.__class__.__name__.replace('MatrixSampler', '')}" - f"({self.m}x{self.n}r{self.rank})" + f"{self.__class__.__name__.replace('MatrixSampler', '')}({self.m}x{self.n}r{self.rank})" ) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index e4a47642..44e15400 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -55,7 +55,7 @@ def test_nash_mtl_reset(): aggregator.reset() results = [aggregator(matrix) for matrix in matrices] - for result, expected in zip(results, expecteds): + for result, expected in zip(results, expecteds, strict=True): assert_close(result, expected) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index e1cfa16c..2461e383 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -341,7 +341,7 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch engine = Engine(model, batch_dim=batch_dim) optimizer = SGD(model.parameters(), lr=1e-7) - for i in range(n_iter): + for _ in range(n_iter): inputs, targets = make_inputs_and_targets(model, batch_size) loss_fn = make_mse_loss_fn(targets) autogram_forward_backward(model, inputs, loss_fn, engine, weighting) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index c2c1cf28..eaa09549 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -15,14 +15,14 @@ def test_single_grad_accumulation(): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] values = [ones_(shape) for shape in shapes] - input = dict(zip(keys, values)) + input = dict(zip(keys, values, strict=True)) accumulate = AccumulateGrad() output = accumulate(input) assert_tensor_dicts_are_close(output, {}) - for key, value in zip(keys, values): + for key, value in zip(keys, values, strict=True): assert_grad_close(key, value) @@ -38,12 +38,12 @@ def test_multiple_grad_accumulations(iterations: int): values = [ones_(shape) for shape in shapes] accumulate = AccumulateGrad() - for i in range(iterations): + for _ in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards - input = {key: value.clone() for key, value in zip(keys, values)} + input = {key: value.clone() for key, value in zip(keys, values, strict=True)} accumulate(input) - for key, value in zip(keys, values): + for key, value in zip(keys, values, strict=True): assert_grad_close(key, iterations * value) @@ -98,14 +98,14 @@ def test_single_jac_accumulation(): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] values = [ones_([4] + shape) for shape in shapes] - input = dict(zip(keys, values)) + input = dict(zip(keys, values, strict=True)) accumulate = AccumulateJac() output = accumulate(input) assert_tensor_dicts_are_close(output, {}) - for key, value in zip(keys, values): + for key, value in zip(keys, values, strict=True): assert_jac_close(key, value) @@ -122,12 +122,12 @@ def test_multiple_jac_accumulations(iterations: int): accumulate = AccumulateJac() - for i in range(iterations): + for _ in range(iterations): # Clone values to ensure that we accumulate values that are not ever used afterwards - input = {key: value.clone() for key, value in zip(keys, values)} + input = {key: value.clone() for key, value in zip(keys, values, strict=True)} accumulate(input) - for key, value in zip(keys, values): + for key, value in zip(keys, values, strict=True): assert_jac_close(key, iterations * value) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 8feac59e..3a5fb9a4 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -42,7 +42,7 @@ def test_jac(): jacobians = jac(outputs, inputs) assert len(jacobians) == len([a1, a2]) - for jacobian, a in zip(jacobians, [a1, a2]): + for jacobian, a in zip(jacobians, [a1, a2], strict=True): assert jacobian.shape[0] == len([y1, y2]) assert jacobian.shape[1:] == a.shape diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index c5515f40..00bda738 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -351,7 +351,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): features = [rand_(shape) @ p0 for shape in shapes] - y1 = sum([(f * p).sum() for f, p in zip(features, p1)]) + y1 = sum([(f * p).sum() for f, p in zip(features, p1, strict=True)]) y2 = (features[0] * p2).sum() mtl_backward(losses=[y1, y2], features=features) diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index a7036d7c..f4dbf7a4 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -152,7 +152,7 @@ def test_get_leaf_tensors_deep(depth: int): one = tensor_(1.0, requires_grad=True) sum_ = tensor_(0.0, requires_grad=False) - for i in range(depth): + for _ in range(depth): sum_ = sum_ + one leaves = get_leaf_tensors(tensors=[sum_], excluded=set()) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 2d7f95da..f1b98b6d 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from torch.nn import Flatten, ReLU from torch.utils._pytree import PyTree + from utils.contexts import fork_rng _T = TypeVar("_T", bound=nn.Module) @@ -47,7 +48,7 @@ def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: if isinstance(module, ShapedModule): return module.INPUT_SHAPES, module.OUTPUT_SHAPES - elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): + elif isinstance(module, nn.BatchNorm2d | nn.InstanceNorm2d): HEIGHT = 6 # Arbitrary choice WIDTH = 6 # Arbitrary choice shape = (module.num_features, HEIGHT, WIDTH) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 14ce1e43..f8b9dfe2 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -5,14 +5,14 @@ from torch.nn.functional import mse_loss from torch.utils._pytree import PyTree, tree_flatten, tree_map from torch.utils.hooks import RemovableHandle -from utils.architectures import get_in_out_shapes -from utils.contexts import fork_rng from torchjd._linalg import PSDTensor from torchjd.aggregation import Aggregator, Weighting from torchjd.autogram import Engine from torchjd.autojac import backward from torchjd.autojac._jac_to_grad import jac_to_grad +from utils.architectures import get_in_out_shapes +from utils.contexts import fork_rng def autograd_forward_backward( diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 6d8066dc..6c91a08c 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -4,6 +4,7 @@ from settings import DEVICE, DTYPE from torch import nn from torch.utils._pytree import PyTree, tree_map + from utils.architectures import get_in_out_shapes from utils.contexts import fork_rng