Skip to content
Merged
30 changes: 22 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,32 @@ target-version = "py310"

[tool.ruff.lint]
select = [
"E", # pycodestyle Error
"F", # Pyflakes
"W", # pycodestyle Warning
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"E", # pycodestyle Error
"F", # Pyflakes
"W", # pycodestyle Warning
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"FIX", # flake8-fixme
"TID", # flake8-tidy-imports
"SIM", # flake8-simplify
"RET", # flake8-return
"PYI", # flake8-pyi
"PIE", # flake8-pie
"COM", # flake8-commas
"PERF", # Perflint
"FURB", # refurb
"RUF", # Ruff-specific rules
]

ignore = [
"E501", # line-too-long (handled by the formatter)
"E402", # module-import-not-at-top-of-file
"E501", # line-too-long (handled by the formatter)
"E402", # module-import-not-at-top-of-file
"RUF022", # __all__ not sorted
"RUF010", # Use explicit conversion flag
"RUF012", # Mutable default value for class attribute (a bit tedious to fix)
"RET504", # Unnecessary assignment return statement
]

[tool.ruff.lint.isort]
Expand Down
6 changes: 4 additions & 2 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
Expand Down Expand Up @@ -70,7 +70,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""

regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
gramian.shape[0],
dtype=gramian.dtype,
device=gramian.device,
)
output = gramian + regularization_matrix
return cast(PSDMatrix, output)
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None:
if not is_matrix(matrix):
raise ValueError(
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
f"{matrix.shape}`."
f"{matrix.shape}`.",
)

@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

@staticmethod
def _compute_balance_transformation(
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
M: Tensor,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
Expand All @@ -130,7 +131,7 @@ def _compute_balance_transformation(
scale = lambda_.mean()
else:
raise ValueError(
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.",
)

B = scale.sqrt() * V @ sigma_inv @ V.T
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, weights: Tensor):
if weights.dim() != 1:
raise ValueError(
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
f"{weights.shape}`."
f"{weights.shape}`.",
)

super().__init__()
Expand All @@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
if matrix.shape[0] != len(self.weights):
raise ValueError(
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
f"weights). Found `matrix` with {matrix.shape[0]} rows."
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
)
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
f"{leak.shape}`."
f"{leak.shape}`.",
)

super().__init__()
Expand Down Expand Up @@ -64,7 +64,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
if self.leak is not None and n_rows != len(self.leak):
raise ValueError(
f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the "
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`."
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.",
)

def __repr__(self) -> str:
Expand Down
5 changes: 1 addition & 4 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
weights = torch.zeros_like(v)
else:
weights = v / v_sum
weights = torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum

return weights
8 changes: 4 additions & 4 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
if n_byzantine < 0:
raise ValueError(
"Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
f"{n_byzantine}`."
f"{n_byzantine}`.",
)

if n_selected < 1:
raise ValueError(
"Parameter `n_selected` should be a positive integer. Found `n_selected = "
f"{n_selected}`."
f"{n_selected}`.",
)

self.n_byzantine = n_byzantine
Expand Down Expand Up @@ -85,11 +85,11 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
if gramian.shape[0] < min_rows:
raise ValueError(
f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found"
f" `gramian` with {gramian.shape[0]} rows."
f" `gramian` with {gramian.shape[0]} rows.",
)

if gramian.shape[0] < self.n_selected:
raise ValueError(
f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). "
f"Found `gramian` with {gramian.shape[0]} rows."
f"Found `gramian` with {gramian.shape[0]} rows.",
)
14 changes: 6 additions & 8 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
max_norm=max_norm,
update_weights_every=update_weights_every,
optim_niter=optim_niter,
)
),
)
self._n_tasks = n_tasks
self._max_norm = max_norm
Expand Down Expand Up @@ -144,7 +144,7 @@ def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool:
return bool(
(self.alpha_param.value is None)
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6)
or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6),
)

def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -189,12 +189,10 @@ def _init_optim_problem(self) -> None:
self.phi_alpha = self._calc_phi_alpha_linearization()

G_alpha = self.G_param @ self.alpha_param
constraint = []
for i in range(self.n_tasks):
constraint.append(
-cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i])
<= 0
)
constraint = [
-cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0
for a, G_a in zip(self.alpha_param, G_alpha, strict=True)
]
obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param)
self.prob = cp.Problem(obj, constraint)

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, trim_number: int):
if trim_number < 0:
raise ValueError(
"Parameter `trim_number` should be a non-negative integer. Found `trim_number` = "
f"{trim_number}`."
f"{trim_number}`.",
)
self.trim_number = trim_number

Expand All @@ -41,7 +41,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
if n_rows < min_rows:
raise ValueError(
f"Parameter `matrix` should be a matrix of at least {min_rows} rows "
f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`."
f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`.",
)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
Expand Down
19 changes: 9 additions & 10 deletions src/torchjd/aggregation/_utils/pref_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


def pref_vector_to_weighting(
pref_vector: Tensor | None, default: Weighting[Matrix]
pref_vector: Tensor | None,
default: Weighting[Matrix],
) -> Weighting[Matrix]:
"""
Returns the weighting associated to a given preference vector, with a fallback to a default
Expand All @@ -17,19 +18,17 @@ def pref_vector_to_weighting(

if pref_vector is None:
return default
else:
if pref_vector.ndim != 1:
raise ValueError(
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
f"{pref_vector.ndim}`."
)
return ConstantWeighting(pref_vector)
if pref_vector.ndim != 1:
raise ValueError(
"Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = "
f"{pref_vector.ndim}`.",
)
return ConstantWeighting(pref_vector)


def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""

if pref_vector is None:
return ""
else:
return f"([{vector_to_str(pref_vector)}])"
return f"([{vector_to_str(pref_vector)}])"
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_FnOutputT = TypeVar("_FnOutputT", bound=Tensor)


class Weighting(Generic[_T], nn.Module, ABC):
class Weighting(nn.Module, ABC, Generic[_T]):
r"""
Abstract base class for all weighting methods. It has the role of extracting a vector of weights
of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`,
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None:
f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The "
f"recommended fix is to replace incompatible layers by something else (e.g. "
f"BatchNorm by InstanceNorm). If you really can't and performance is not a "
f"priority, you may also just set `batch_dim=None` when creating the engine."
f"priority, you may also just set `batch_dim=None` when creating the engine.",
)
if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats:
raise ValueError(
Expand All @@ -231,7 +231,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None:
f" to performing in-place operations on tensors and having side-effects during "
f"the forward pass. Try setting `track_running_stats` to `False`. If you really"
f" can't and performance is not a priority, you may also just set "
f"`batch_dim=None` when creating the engine."
f"`batch_dim=None` when creating the engine.",
)

# Currently, the type PSDMatrix is hidden from users, so Tensor is correct.
Expand Down Expand Up @@ -278,7 +278,7 @@ def compute_gramian(self, output: Tensor) -> Tensor:
target_shape = []

if has_non_batch_dim:
target_shape = [-1] + target_shape
target_shape = [-1, *target_shape]

reshaped_output = ordered_output.reshape(target_shape)
# There are four different cases for the shape of reshaped_output:
Expand Down
3 changes: 1 addition & 2 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,4 @@ def __call__(
gramian = compute_gramian(self.summed_jacobian)
del self.summed_jacobian
return gramian
else:
return None
return None
4 changes: 2 additions & 2 deletions src/torchjd/autogram/_gramian_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def movedim(gramian: PSDTensor, half_source: list[int], half_destination: list[i

# Map everything to the range [0, gramian.ndim//2[
half_ndim = gramian.ndim // 2
half_source_ = [i if 0 <= i else i + half_ndim for i in half_source]
half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination]
half_source_ = [i if i >= 0 else i + half_ndim for i in half_source]
half_destination_ = [i if i >= 0 else i + half_ndim for i in half_destination]

# Mirror the half source and the half destination and use the result to move the dimensions of
# the gramian
Expand Down
9 changes: 7 additions & 2 deletions src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def __call__(
) -> Matrix:
# This makes __call__ vmappable.
return ComputeModuleJacobians.apply(
self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs
self._compute_jacobian,
rg_outputs,
grad_outputs,
args,
kwargs,
)

@abstractmethod
Expand Down Expand Up @@ -155,7 +159,8 @@ class ComputeModuleJacobians(torch.autograd.Function):
@staticmethod
def forward(
compute_jacobian_fn: Callable[
[tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Matrix
[tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]],
Matrix,
],
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
raise RuntimeError(
f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of "
f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the"
" jacobian are the same size"
" jacobian are the same size",
)

if is_tensor_with_jac(param):
Expand Down Expand Up @@ -57,7 +57,7 @@ def _check_expects_grad(tensor: Tensor, field_name: str) -> None:
if not _expects_grad(tensor):
raise ValueError(
f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n"
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`.",
)


Expand Down
Loading