Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 4 additions & 31 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ jobs:
# This reduces false positives due to rate limits
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

typing:
name: Typing correctness
code-quality:
name: Code quality (ty and ruff)
runs-on: ubuntu-latest
steps:
- name: Checkout repository
Expand All @@ -128,32 +128,5 @@ jobs:
- name: Run ty
run: uv run ty check --output-format=github

check-todos:
name: Absence of TODOs
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6

- name: Scan for TODO strings
run: |
echo "Scanning codebase for TODOs..."

git grep -nE "TODO" -- . ':(exclude).github/workflows/*' > todos_found.txt || true

if [ -s todos_found.txt ]; then
echo "❌ ERROR: Found TODOs in the following files:"
echo "-------------------------------------------"

while IFS=: read -r file line content; do
echo "::error file=$file,line=$line::TODO found at $file:$line - must be resolved before merge:%0A$content"
done < todos_found.txt

echo "-------------------------------------------"
echo "Please resolve these TODOs or track them in an issue before merging."

exit 1
else
echo "✅ No TODOs found. Codebase is clean!"
exit 0
fi
- name: Run ruff
run: uv run ruff check --output-format=github
30 changes: 5 additions & 25 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,12 @@ repos:
- id: check-docstring-first # Check a common error of defining a docstring after code.
- id: check-merge-conflict # Check for files that contain merge conflict strings.

- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.0
hooks:
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
args: [
'--ignore=E501,E203,W503,E402', # Ignore line length problems, space after colon problems, line break occurring before a binary operator problems, module level import not at top of file problems.
]

- repo: https://github.com/pycqa/isort
rev: 7.0.0
hooks:
- id: isort # Sort imports.
args: [
--multi-line=3,
--line-length=100,
--trailing-comma,
--force-grid-wrap=0,
--use-parentheses,
--ensure-newline-before-comments,
]

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
hooks:
- id: black # Format code.
args: [--line-length=100]
- id: ruff
args: [ --fix, --ignore, FIX ]
- id: ruff-format

ci:
autoupdate_commit_msg: 'chore: Update pre-commit hooks'
Expand Down
27 changes: 27 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Changelog = "https://github.com/TorchJD/torchjd/blob/main/CHANGELOG.md"

[dependency-groups]
check = [
"ruff>=0.14.14",
"ty>=0.0.14",
"pre-commit>=2.9.2", # isort doesn't work before 2.9.2
]
Expand Down Expand Up @@ -122,6 +123,32 @@ exclude_lines = [
"@overload",
]

[tool.ruff]
line-length = 100
target-version = "py310"

[tool.ruff.lint]
select = [
"E", # pycodestyle Error
"F", # Pyflakes
"W", # pycodestyle Warning
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"FIX", # flake8-fixme
]

ignore = [
"E501", # line-too-long (handled by the formatter)
"E402", # module-import-not-at-top-of-file
]

[tool.ruff.lint.isort]
combine-as-imports = true

[tool.ruff.format]
quote-style = "double"

[tool.ty.src]
include = ["src", "tests"]
exclude = ["src/torchjd/aggregation/_nash_mtl.py"]
3 changes: 1 addition & 2 deletions src/torchjd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from collections.abc import Callable
from warnings import warn as _warn

from .autojac import backward as _backward
from .autojac import mtl_backward as _mtl_backward
from .autojac import backward as _backward, mtl_backward as _mtl_backward

_deprecated_items: dict[str, tuple[str, Callable]] = {
"backward": ("autojac", _backward),
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ def __str__(self) -> str:
if self.leak is None:
leak_str = ""
else:
leak_str = f"([{', '.join(['{:.2f}'.format(l_).rstrip('0') for l_ in self.leak])}])"
leak_str = f"([{', '.join([f'{l_:.2f}'.rstrip('0') for l_ in self.leak])}])"
return f"GradDrop{leak_str}"
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
dtype = gramian.dtype

alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
for i in range(self.max_iters):
for _ in range(self.max_iters):
t = torch.argmin(gramian @ alpha)
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
e_t[t] = 1.0
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_utils/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def vector_to_str(vector: Tensor) -> str:
`1.23, 1., ...`.
"""

weights_str = ", ".join(["{:.2f}".format(value).rstrip("0") for value in vector])
weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector])
return weights_str
6 changes: 2 additions & 4 deletions src/torchjd/autogram/_gramian_accumulator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from torchjd._linalg import PSDMatrix


Expand All @@ -13,7 +11,7 @@ class GramianAccumulator:
"""

def __init__(self) -> None:
self._gramian: Optional[PSDMatrix] = None
self._gramian: PSDMatrix | None = None

def reset(self) -> None:
self._gramian = None
Expand All @@ -25,7 +23,7 @@ def accumulate_gramian(self, gramian: PSDMatrix) -> None:
self._gramian = gramian

@property
def gramian(self) -> Optional[PSDMatrix]:
def gramian(self) -> PSDMatrix | None:
"""
Get the Gramian matrix accumulated so far.

Expand Down
10 changes: 6 additions & 4 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
from typing import cast

from torch import Tensor
from torch.utils._pytree import PyTree
Expand All @@ -16,12 +16,14 @@ def __call__(
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[PSDMatrix]:
) -> PSDMatrix | None:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

@abstractmethod
def track_forward_call(self) -> None:
"""Track that the module's forward was called. Necessary in some implementations."""

@abstractmethod
def reset(self) -> None:
"""Reset state if any. Necessary in some implementations."""

Expand All @@ -40,7 +42,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
def __init__(self, jacobian_computer: JacobianComputer):
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Optional[Matrix] = None
self.summed_jacobian: Matrix | None = None

def reset(self) -> None:
self.remaining_counter = 0
Expand All @@ -55,7 +57,7 @@ def __call__(
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[PSDMatrix]:
) -> PSDMatrix | None:
"""Compute what we can for a module and optionally return the gramian if it's ready."""

jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __call__(
*rg_outputs,
)

for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs):
for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs, strict=True):
flat_outputs[idx] = output

return tree_unflatten(flat_outputs, output_spec)
Expand Down
3 changes: 1 addition & 2 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Iterable
from collections.abc import Iterable, Sequence

from torch import Tensor

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _disunite_gradient(
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)]
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)]
return gradients


Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _create_transform(
OrderedSet([loss]),
retain_graph,
)
for task_params, loss in zip(tasks_params, losses)
for task_params, loss in zip(tasks_params, losses, strict=True)
]

# Transform that stacks the gradients of the losses w.r.t. the shared representations into a
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_diagonalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict:
diagonal_matrix = torch.cat(flattened_considered_values).diag()
diagonalized_tensors = {
key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape)
for (begin, end), key in zip(self.indices, self.key_order)
for (begin, end), key in zip(self.indices, self.key_order, strict=True)
}
return diagonalized_tensors

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict:
tensor_outputs = [tensors[output] for output in self.outputs]

differentiated_tuple = self._differentiate(tensor_outputs)
new_differentiations = dict(zip(self.inputs, differentiated_tuple))
new_differentiations = dict(zip(self.inputs, differentiated_tuple, strict=True))
return type(tensors)(new_differentiations)

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def materialize(
"""

tensors = []
for optional_tensor, input in zip(optional_tensors, inputs):
for optional_tensor, input in zip(optional_tensors, inputs, strict=True):
if optional_tensor is None:
tensors.append(torch.zeros_like(input))
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def pytest_make_parametrize_id(config, val, argname):
MAX_SIZE = 40
optional_string = None # Returning None means using pytest's way of making the string

if isinstance(val, (Aggregator, ModuleFactory, Weighting)):
if isinstance(val, Aggregator | ModuleFactory | Weighting):
optional_string = str(val)
elif isinstance(val, Tensor):
optional_string = "T" + str(list(val.shape)) # T to indicate that it's a tensor
elif isinstance(val, (tuple, list, set)) and len(val) < 20:
elif isinstance(val, tuple | list | set) and len(val) < 20:
optional_string = str(val)
elif isinstance(val, RaisesExc):
optional_string = " or ".join([f"{exc.__name__}" for exc in val.expected_exceptions])
Expand Down
2 changes: 1 addition & 1 deletion tests/doc/test_autogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_engine():
# Create the engine before the backward pass, and only once.
engine = Engine(model, batch_dim=0)

for input, target in zip(inputs, targets):
for input, target in zip(inputs, targets, strict=True):
output = model(input).squeeze(dim=1) # shape: [16]
losses = criterion(output, target) # shape: [16]

Expand Down
18 changes: 9 additions & 9 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_amp():
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False):
with torch.autocast(device_type="cpu", dtype=torch.float16):
features = shared_module(input)
output1 = task1_module(features)
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_iwmtl():
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False):
features = shared_module(input) # shape: [16, 3]
out1 = task1_module(features).squeeze(1) # shape: [16]
out2 = task2_module(features).squeeze(1) # shape: [16]
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_autograd():
params = model.parameters()
optimizer = SGD(params, lr=0.1)

for x, y in zip(X, Y):
for x, y in zip(X, Y, strict=False):
y_hat = model(x).squeeze(dim=1) # shape: [16]
loss = loss_fn(y_hat, y) # shape: [] (scalar)
loss.backward()
Expand All @@ -165,7 +165,7 @@ def test_autojac():
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

for x, y in zip(X, Y):
for x, y in zip(X, Y, strict=False):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
backward(losses)
Expand All @@ -192,7 +192,7 @@ def test_autogram():
weighting = UPGradWeighting()
engine = Engine(model, batch_dim=0)

for x, y in zip(X, Y):
for x, y in zip(X, Y, strict=False):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
gramian = engine.compute_gramian(losses) # shape: [16, 16]
Expand Down Expand Up @@ -318,7 +318,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_mtl():
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
for input, target1, target2 in zip(inputs, task1_targets, task2_targets, strict=False):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
Expand Down Expand Up @@ -392,7 +392,7 @@ def test_partial_jd():
params = model.parameters()
optimizer = SGD(params, lr=0.1)

for x, y in zip(X, Y):
for x, y in zip(X, Y, strict=False):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
gramian = engine.compute_gramian(losses)
Expand All @@ -417,7 +417,7 @@ def test_rnn():
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.

for input, target in zip(inputs, targets):
for input, target in zip(inputs, targets, strict=False):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

Expand Down
2 changes: 1 addition & 1 deletion tests/plots/interactive_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from dash import Dash, Input, Output, callback, dcc, html
from plotly.graph_objs import Figure
from plots._utils import Plotter, angle_to_coord, coord_to_angle

from plots._utils import Plotter, angle_to_coord, coord_to_angle
from torchjd.aggregation import (
IMTLG,
MGDA,
Expand Down
Loading