Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
line_str = _get_line_str(obj)
version_str = _get_version_str()

link = f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}"
return link
return f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}"


def _get_obj(_info: dict[str, str]):
Expand All @@ -108,8 +107,7 @@ def _get_obj(_info: dict[str, str]):
for part in full_name.split("."):
obj = getattr(obj, part)
# strip decorators, which would resolve to the source of the decorator
obj = inspect.unwrap(obj)
return obj
return inspect.unwrap(obj)


def _get_file_name(obj) -> str | None:
Expand All @@ -124,8 +122,7 @@ def _get_file_name(obj) -> str | None:
def _get_line_str(obj) -> str:
source, start = inspect.getsourcelines(obj)
end = start + len(source) - 1
line_str = f"#L{start}-L{end}"
return line_str
return f"#L{start}-L{end}"


def _get_version_str() -> str:
Expand Down
29 changes: 21 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,31 @@ target-version = "py310"

[tool.ruff.lint]
select = [
"E", # pycodestyle Error
"F", # Pyflakes
"W", # pycodestyle Warning
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"E", # pycodestyle Error
"F", # Pyflakes
"W", # pycodestyle Warning
"I", # isort
"UP", # pyupgrade
"FBT", # flake8-boolean-trap
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"FIX", # flake8-fixme
"TID", # flake8-tidy-imports
"SIM", # flake8-simplify
"ARG", # flake8-unused-arguments
"RET", # flake8-return
"PYI", # flake8-pyi
"PIE", # flake8-pie
"COM", # flake8-commas
"PERF", # Perflint
"FURB", # refurb
"RUF", # Ruff-specific rules
]

ignore = [
"E501", # line-too-long (handled by the formatter)
"E402", # module-import-not-at-top-of-file
"E501", # line-too-long (handled by the formatter)
"E402", # module-import-not-at-top-of-file
"COM812",
]

[tool.ruff.lint.isort]
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor

__all__ = [
"compute_gramian",
"normalize",
"regularize",
"Matrix",
"PSDMatrix",
"PSDTensor",
"compute_gramian",
"is_matrix",
"is_psd_matrix",
"is_psd_tensor",
"normalize",
"regularize",
]
6 changes: 4 additions & 2 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
Expand Down Expand Up @@ -70,7 +70,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
"""

regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
gramian.shape[0],
dtype=gramian.dtype,
device=gramian.device,
)
output = gramian + regularization_matrix
return cast(PSDMatrix, output)
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
from ._weighting_bases import GeneralizedWeighting, Weighting

__all__ = [
"IMTLG",
"MGDA",
"Aggregator",
"AlignedMTL",
"AlignedMTLWeighting",
Expand All @@ -92,14 +94,12 @@
"Flattening",
"GeneralizedWeighting",
"GradDrop",
"IMTLG",
"IMTLGWeighting",
"Krum",
"KrumWeighting",
"MGDAWeighting",
"Mean",
"MeanWeighting",
"MGDA",
"MGDAWeighting",
"PCGrad",
"PCGradWeighting",
"Random",
Expand Down
8 changes: 3 additions & 5 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None:
if not is_matrix(matrix):
raise ValueError(
"Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = "
f"{matrix.shape}`."
f"{matrix.shape}`.",
)

@abstractmethod
Expand Down Expand Up @@ -59,13 +59,11 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor:
weights.
"""

vector = weights @ matrix
return vector
return weights @ matrix

def forward(self, matrix: Matrix) -> Tensor:
weights = self.weighting(matrix)
vector = self.combine(matrix, weights)
return vector
return self.combine(matrix, weights)


class GramianWeightedAggregator(WeightedAggregator):
Expand Down
19 changes: 8 additions & 11 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __init__(

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"scale_mode={repr(self._scale_mode)})"
f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, "
f"scale_mode={self._scale_mode!r})"
)

def __str__(self) -> str:
Expand Down Expand Up @@ -101,21 +101,19 @@ def __init__(
def forward(self, gramian: PSDMatrix, /) -> Tensor:
w = self.weighting(gramian)
B = self._compute_balance_transformation(gramian, self._scale_mode)
alpha = B @ w

return alpha
return B @ w

@staticmethod
def _compute_balance_transformation(
M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min"
M: Tensor,
scale_mode: SUPPORTED_SCALE_MODE = "min",
) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)

if rank == 0:
identity = torch.eye(len(M), dtype=M.dtype, device=M.device)
return identity
return torch.eye(len(M), dtype=M.dtype, device=M.device)

order = torch.argsort(lambda_, dim=-1, descending=True)
lambda_, V = lambda_[order][:rank], V[:, order][:, :rank]
Expand All @@ -130,8 +128,7 @@ def _compute_balance_transformation(
scale = lambda_.mean()
else:
raise ValueError(
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'."
f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.",
)

B = scale.sqrt() * V @ sigma_inv @ V.T
return B
return scale.sqrt() * V @ sigma_inv @ V.T
4 changes: 1 addition & 3 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
# We are approximately on the pareto front
weight_array = np.zeros(dimension)

weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights
return torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, matrix: Matrix) -> Tensor:
return length * unit_target_vector

def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
return f"{self.__class__.__name__}(pref_vector={self._pref_vector!r})"

def __str__(self) -> str:
return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}"
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, weights: Tensor):
self._weights = weights

def __repr__(self) -> str:
return f"{self.__class__.__name__}(weights={repr(self._weights)})"
return f"{self.__class__.__name__}(weights={self._weights!r})"

def __str__(self) -> str:
weights_str = vector_to_str(self._weights)
Expand All @@ -39,7 +39,7 @@ def __init__(self, weights: Tensor):
if weights.dim() != 1:
raise ValueError(
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
f"{weights.shape}`."
f"{weights.shape}`.",
)

super().__init__()
Expand All @@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:
if matrix.shape[0] != len(self.weights):
raise ValueError(
f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified "
f"weights). Found `matrix` with {matrix.shape[0]} rows."
f"weights). Found `matrix` with {matrix.shape[0]} rows.",
)
9 changes: 4 additions & 5 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def __init__(
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver)
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={self._solver!r})"
)

def __str__(self) -> str:
Expand Down Expand Up @@ -88,5 +88,4 @@ def __init__(
def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
return w
return project_weights(u, G, self.solver)
3 changes: 1 addition & 2 deletions src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,4 @@ def forward(self, generalized_gramian: PSDTensor) -> Tensor:
shape = generalized_gramian.shape[:k]
square_gramian = flatten(generalized_gramian)
weights_vector = self.weighting(square_gramian)
weights = weights_vector.reshape(shape)
return weights
return weights_vector.reshape(shape)
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
f"{leak.shape}`."
f"{leak.shape}`.",
)

super().__init__()
Expand Down Expand Up @@ -64,11 +64,11 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None:
if self.leak is not None and n_rows != len(self.leak):
raise ValueError(
f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the "
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`."
f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.",
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})"
return f"{self.__class__.__name__}(f={self.f!r}, leak={self.leak!r})"

def __str__(self) -> str:
if self.leak is None:
Expand Down
7 changes: 1 addition & 6 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
weights = torch.zeros_like(v)
else:
weights = v / v_sum

return weights
return torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum
12 changes: 5 additions & 7 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1):
if n_byzantine < 0:
raise ValueError(
"Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = "
f"{n_byzantine}`."
f"{n_byzantine}`.",
)

if n_selected < 1:
raise ValueError(
"Parameter `n_selected` should be a positive integer. Found `n_selected = "
f"{n_selected}`."
f"{n_selected}`.",
)

self.n_byzantine = n_byzantine
Expand All @@ -76,20 +76,18 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

_, selected_indices = torch.topk(scores, k=self.n_selected, largest=False)
one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0])
weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected

return weights
return one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected

def _check_matrix_shape(self, gramian: PSDMatrix) -> None:
min_rows = self.n_byzantine + 3
if gramian.shape[0] < min_rows:
raise ValueError(
f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found"
f" `gramian` with {gramian.shape[0]} rows."
f" `gramian` with {gramian.shape[0]} rows.",
)

if gramian.shape[0] < self.n_selected:
raise ValueError(
f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). "
f"Found `gramian` with {gramian.shape[0]} rows."
f"Found `gramian` with {gramian.shape[0]} rows.",
)
3 changes: 1 addition & 2 deletions src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,4 @@ def forward(self, matrix: Tensor, /) -> Tensor:
device = matrix.device
dtype = matrix.dtype
m = matrix.shape[0]
weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
return weights
return torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype)
Loading