diff --git a/docs/source/conf.py b/docs/source/conf.py index 070c988a..68320125 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,8 +96,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: line_str = _get_line_str(obj) version_str = _get_version_str() - link = f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}" - return link + return f"https://github.com/TorchJD/torchjd/blob/{version_str}/{file_name}{line_str}" def _get_obj(_info: dict[str, str]): @@ -108,8 +107,7 @@ def _get_obj(_info: dict[str, str]): for part in full_name.split("."): obj = getattr(obj, part) # strip decorators, which would resolve to the source of the decorator - obj = inspect.unwrap(obj) - return obj + return inspect.unwrap(obj) def _get_file_name(obj) -> str | None: @@ -124,8 +122,7 @@ def _get_file_name(obj) -> str | None: def _get_line_str(obj) -> str: source, start = inspect.getsourcelines(obj) end = start + len(source) - 1 - line_str = f"#L{start}-L{end}" - return line_str + return f"#L{start}-L{end}" def _get_version_str() -> str: diff --git a/pyproject.toml b/pyproject.toml index 06071519..67a948f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,18 +129,31 @@ target-version = "py310" [tool.ruff.lint] select = [ - "E", # pycodestyle Error - "F", # Pyflakes - "W", # pycodestyle Warning - "I", # isort - "UP", # pyupgrade - "B", # flake8-bugbear + "E", # pycodestyle Error + "F", # Pyflakes + "W", # pycodestyle Warning + "I", # isort + "UP", # pyupgrade + "FBT", # flake8-boolean-trap + "B", # flake8-bugbear + "C4", # flake8-comprehensions "FIX", # flake8-fixme + "TID", # flake8-tidy-imports + "SIM", # flake8-simplify + "ARG", # flake8-unused-arguments + "RET", # flake8-return + "PYI", # flake8-pyi + "PIE", # flake8-pie + "COM", # flake8-commas + "PERF", # Perflint + "FURB", # refurb + "RUF", # Ruff-specific rules ] ignore = [ - "E501", # line-too-long (handled by the formatter) - "E402", # module-import-not-at-top-of-file + "E501", # line-too-long (handled by the formatter) + "E402", # module-import-not-at-top-of-file + "COM812", ] [tool.ruff.lint.isort] diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index d8489041..3722365e 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -2,13 +2,13 @@ from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor __all__ = [ - "compute_gramian", - "normalize", - "regularize", "Matrix", "PSDMatrix", "PSDTensor", + "compute_gramian", "is_matrix", "is_psd_matrix", "is_psd_tensor", + "normalize", + "regularize", ] diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index edc819dd..58a2af82 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -35,7 +35,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: first dimension). """ - contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim + contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim indices_source = list(range(t.ndim - contracted_dims)) indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) transposed = t.movedim(indices_source, indices_dest) @@ -70,7 +70,9 @@ def regularize(gramian: PSDMatrix, eps: float) -> PSDMatrix: """ regularization_matrix = eps * torch.eye( - gramian.shape[0], dtype=gramian.dtype, device=gramian.device + gramian.shape[0], + dtype=gramian.dtype, + device=gramian.device, ) output = gramian + regularization_matrix return cast(PSDMatrix, output) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7..1d06fdcc 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -81,6 +81,8 @@ from ._weighting_bases import GeneralizedWeighting, Weighting __all__ = [ + "IMTLG", + "MGDA", "Aggregator", "AlignedMTL", "AlignedMTLWeighting", @@ -92,14 +94,12 @@ "Flattening", "GeneralizedWeighting", "GradDrop", - "IMTLG", "IMTLGWeighting", "Krum", "KrumWeighting", + "MGDAWeighting", "Mean", "MeanWeighting", - "MGDA", - "MGDAWeighting", "PCGrad", "PCGradWeighting", "Random", diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 6935199b..2957e58a 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -21,7 +21,7 @@ def _check_is_matrix(matrix: Tensor) -> None: if not is_matrix(matrix): raise ValueError( "Parameter `matrix` should be a tensor of dimension 2. Found `matrix.shape = " - f"{matrix.shape}`." + f"{matrix.shape}`.", ) @abstractmethod @@ -59,13 +59,11 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor: weights. """ - vector = weights @ matrix - return vector + return weights @ matrix def forward(self, matrix: Matrix) -> Tensor: weights = self.weighting(matrix) - vector = self.combine(matrix, weights) - return vector + return self.combine(matrix, weights) class GramianWeightedAggregator(WeightedAggregator): diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index bf4f8dc0..ececad35 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -68,8 +68,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " - f"scale_mode={repr(self._scale_mode)})" + f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, " + f"scale_mode={self._scale_mode!r})" ) def __str__(self) -> str: @@ -101,21 +101,19 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: w = self.weighting(gramian) B = self._compute_balance_transformation(gramian, self._scale_mode) - alpha = B @ w - - return alpha + return B @ w @staticmethod def _compute_balance_transformation( - M: Tensor, scale_mode: SUPPORTED_SCALE_MODE = "min" + M: Tensor, + scale_mode: SUPPORTED_SCALE_MODE = "min", ) -> Tensor: lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig tol = torch.max(lambda_) * len(M) * torch.finfo().eps rank = sum(lambda_ > tol) if rank == 0: - identity = torch.eye(len(M), dtype=M.dtype, device=M.device) - return identity + return torch.eye(len(M), dtype=M.dtype, device=M.device) order = torch.argsort(lambda_, dim=-1, descending=True) lambda_, V = lambda_[order][:rank], V[:, order][:, :rank] @@ -130,8 +128,7 @@ def _compute_balance_transformation( scale = lambda_.mean() else: raise ValueError( - f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'." + f"Invalid scale_mode={scale_mode!r}. Expected 'min', 'median', or 'rmse'.", ) - B = scale.sqrt() * V @ sigma_inv @ V.T - return B + return scale.sqrt() * V @ sigma_inv @ V.T diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 67f94a94..d9341ec9 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -101,6 +101,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: # We are approximately on the pareto front weight_array = np.zeros(dimension) - weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) - - return weights + return torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype) diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 980b93b4..d9f13660 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -70,7 +70,7 @@ def forward(self, matrix: Matrix) -> Tensor: return length * unit_target_vector def __repr__(self) -> str: - return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" + return f"{self.__class__.__name__}(pref_vector={self._pref_vector!r})" def __str__(self) -> str: return f"ConFIG{pref_vector_to_str_suffix(self._pref_vector)}" diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index f4f062bf..3beb8c18 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -20,7 +20,7 @@ def __init__(self, weights: Tensor): self._weights = weights def __repr__(self) -> str: - return f"{self.__class__.__name__}(weights={repr(self._weights)})" + return f"{self.__class__.__name__}(weights={self._weights!r})" def __str__(self) -> str: weights_str = vector_to_str(self._weights) @@ -39,7 +39,7 @@ def __init__(self, weights: Tensor): if weights.dim() != 1: raise ValueError( "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = " - f"{weights.shape}`." + f"{weights.shape}`.", ) super().__init__() @@ -53,5 +53,5 @@ def _check_matrix_shape(self, matrix: Tensor) -> None: if matrix.shape[0] != len(self.weights): raise ValueError( f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " - f"weights). Found `matrix` with {matrix.shape[0]} rows." + f"weights). Found `matrix` with {matrix.shape[0]} rows.", ) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 4fc8cefd..82121fed 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -40,7 +40,7 @@ def __init__( self._solver: SUPPORTED_SOLVER = solver super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) + DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) # This prevents considering the computed weights as constant w.r.t. the matrix. @@ -48,8 +48,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, norm_eps=" + f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={self._solver!r})" ) def __str__(self) -> str: @@ -88,5 +88,4 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) - return w + return project_weights(u, G, self.solver) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 8db6027b..0fc0003d 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -29,5 +29,4 @@ def forward(self, generalized_gramian: PSDTensor) -> Tensor: shape = generalized_gramian.shape[:k] square_gramian = flatten(generalized_gramian) weights_vector = self.weighting(square_gramian) - weights = weights_vector.reshape(shape) - return weights + return weights_vector.reshape(shape) diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index b6ea1327..43ea26b3 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -30,7 +30,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None): if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " - f"{leak.shape}`." + f"{leak.shape}`.", ) super().__init__() @@ -64,11 +64,11 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: if self.leak is not None and n_rows != len(self.leak): raise ValueError( f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the " - f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`." + f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`.", ) def __repr__(self) -> str: - return f"{self.__class__.__name__}(f={repr(self.f)}, leak={repr(self.leak)})" + return f"{self.__class__.__name__}(f={self.f!r}, leak={self.leak!r})" def __str__(self) -> str: if self.leak is None: diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 6355158e..100729ca 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -34,9 +34,4 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: v = torch.linalg.pinv(gramian) @ d v_sum = v.sum() - if v_sum.abs() < 1e-12: - weights = torch.zeros_like(v) - else: - weights = v / v_sum - - return weights + return torch.zeros_like(v) if v_sum.abs() < 1e-12 else v / v_sum diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index bce211c6..3ef6b714 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -49,13 +49,13 @@ def __init__(self, n_byzantine: int, n_selected: int = 1): if n_byzantine < 0: raise ValueError( "Parameter `n_byzantine` should be a non-negative integer. Found `n_byzantine = " - f"{n_byzantine}`." + f"{n_byzantine}`.", ) if n_selected < 1: raise ValueError( "Parameter `n_selected` should be a positive integer. Found `n_selected = " - f"{n_selected}`." + f"{n_selected}`.", ) self.n_byzantine = n_byzantine @@ -76,20 +76,18 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: _, selected_indices = torch.topk(scores, k=self.n_selected, largest=False) one_hot_selected_indices = F.one_hot(selected_indices, num_classes=gramian.shape[0]) - weights = one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected - - return weights + return one_hot_selected_indices.sum(dim=0).to(dtype=gramian.dtype) / self.n_selected def _check_matrix_shape(self, gramian: PSDMatrix) -> None: min_rows = self.n_byzantine + 3 if gramian.shape[0] < min_rows: raise ValueError( f"Parameter `gramian` should have at least {min_rows} rows (n_byzantine + 3). Found" - f" `gramian` with {gramian.shape[0]} rows." + f" `gramian` with {gramian.shape[0]} rows.", ) if gramian.shape[0] < self.n_selected: raise ValueError( f"Parameter `gramian` should have at least {self.n_selected} rows (n_selected). " - f"Found `gramian` with {gramian.shape[0]} rows." + f"Found `gramian` with {gramian.shape[0]} rows.", ) diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index d7085e10..a56677b1 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -28,5 +28,4 @@ def forward(self, matrix: Tensor, /) -> Tensor: device = matrix.device dtype = matrix.dtype m = matrix.shape[0] - weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) - return weights + return torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 76862975..3057178f 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -84,7 +84,7 @@ def __init__( max_norm=max_norm, update_weights_every=update_weights_every, optim_niter=optim_niter, - ) + ), ) self._n_tasks = n_tasks self._max_norm = max_norm @@ -144,7 +144,7 @@ def _stop_criteria(self, gtg: np.ndarray, alpha_t: np.ndarray) -> bool: return bool( (self.alpha_param.value is None) or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3) - or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6) + or (np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value) < 1e-6), ) def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: @@ -177,8 +177,7 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: def _calc_phi_alpha_linearization(self) -> Expression: G_prvs_alpha = self.G_param @ self.prvs_alpha_param prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param - phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) - return phi_alpha + return prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param) def _init_optim_problem(self) -> None: self.alpha_param = cp.Variable(shape=(self.n_tasks,), nonneg=True) @@ -189,12 +188,10 @@ def _init_optim_problem(self) -> None: self.phi_alpha = self._calc_phi_alpha_linearization() G_alpha = self.G_param @ self.alpha_param - constraint = [] - for i in range(self.n_tasks): - constraint.append( - -cp.log(self.alpha_param[i] * self.normalization_factor_param) - cp.log(G_alpha[i]) - <= 0 - ) + constraint = [ + -cp.log(a * self.normalization_factor_param) - cp.log(G_a) <= 0 + for a, G_a in zip(self.alpha_param, G_alpha, strict=True) + ] obj = cp.Minimize(cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param) self.prob = cp.Problem(obj, constraint) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 53ef188c..9502966b 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -28,5 +28,4 @@ class RandomWeighting(Weighting[Matrix]): def forward(self, matrix: Tensor, /) -> Tensor: random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) - weights = F.softmax(random_vector, dim=-1) - return weights + return F.softmax(random_vector, dim=-1) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0d8bd5d6..13fea7dd 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -26,5 +26,4 @@ class SumWeighting(Weighting[Matrix]): def forward(self, matrix: Tensor, /) -> Tensor: device = matrix.device dtype = matrix.dtype - weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) - return weights + return torch.ones(matrix.shape[0], device=device, dtype=dtype) diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index f4e3dfc4..cd8f6dc2 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -20,7 +20,7 @@ def __init__(self, trim_number: int): if trim_number < 0: raise ValueError( "Parameter `trim_number` should be a non-negative integer. Found `trim_number` = " - f"{trim_number}`." + f"{trim_number}`.", ) self.trim_number = trim_number @@ -32,8 +32,7 @@ def forward(self, matrix: Tensor) -> Tensor: sorted_matrix, _ = torch.sort(matrix, dim=0) trimmed = torch.narrow(sorted_matrix, dim=0, start=self.trim_number, length=n_remaining) - vector = trimmed.mean(dim=0) - return vector + return trimmed.mean(dim=0) def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: min_rows = 1 + 2 * self.trim_number @@ -41,7 +40,7 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: if n_rows < min_rows: raise ValueError( f"Parameter `matrix` should be a matrix of at least {min_rows} rows " - f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`." + f"(i.e. `2 * trim_number + 1`). Found `matrix` of shape `{matrix.shape}`.", ) def __repr__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 6b8ec0f6..71986c79 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -41,7 +41,7 @@ def __init__( self._solver: SUPPORTED_SOLVER = solver super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver) + UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) # This prevents considering the computed weights as constant w.r.t. the matrix. @@ -49,8 +49,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={self._pref_vector!r}, norm_eps=" + f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={self._solver!r})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index fea943d7..caffabd9 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -8,7 +8,8 @@ def pref_vector_to_weighting( - pref_vector: Tensor | None, default: Weighting[Matrix] + pref_vector: Tensor | None, + default: Weighting[Matrix], ) -> Weighting[Matrix]: """ Returns the weighting associated to a given preference vector, with a fallback to a default @@ -17,13 +18,12 @@ def pref_vector_to_weighting( if pref_vector is None: return default - else: - if pref_vector.ndim != 1: - raise ValueError( - "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " - f"{pref_vector.ndim}`." - ) - return ConstantWeighting(pref_vector) + if pref_vector.ndim != 1: + raise ValueError( + "Parameter `pref_vector` must be a vector (1D Tensor). Found `pref_vector.ndim = " + f"{pref_vector.ndim}`.", + ) + return ConstantWeighting(pref_vector) def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: @@ -31,5 +31,4 @@ def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: if pref_vector is None: return "" - else: - return f"([{vector_to_str(pref_vector)}])" + return f"([{vector_to_str(pref_vector)}])" diff --git a/src/torchjd/aggregation/_utils/str.py b/src/torchjd/aggregation/_utils/str.py index 82a04540..24f5b06c 100644 --- a/src/torchjd/aggregation/_utils/str.py +++ b/src/torchjd/aggregation/_utils/str.py @@ -7,5 +7,4 @@ def vector_to_str(vector: Tensor) -> str: `1.23, 1., ...`. """ - weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector]) - return weights_str + return ", ".join([f"{value:.2f}".rstrip("0") for value in vector]) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 74129a36..3b037891 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -13,7 +13,7 @@ _FnOutputT = TypeVar("_FnOutputT", bound=Tensor) -class Weighting(Generic[_T], nn.Module, ABC): +class Weighting(nn.Module, ABC, Generic[_T]): r""" Abstract base class for all weighting methods. It has the role of extracting a vector of weights of dimension :math:`m` from some statistic of a matrix of dimension :math:`m \times n`, diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 3e9ddfee..42b08052 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -209,9 +209,7 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer: jacobian_computer = FunctionalJacobianComputer(module) else: jacobian_computer = AutogradJacobianComputer(module) - gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) - - return gramian_computer + return JacobianBasedGramianComputerWithCrossTerms(jacobian_computer) def _check_module_is_compatible(self, module: nn.Module) -> None: if self._batch_dim is not None: @@ -222,7 +220,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f" are {_MODULES_INCOMPATIBLE_WITH_BATCHED} (and their subclasses). The " f"recommended fix is to replace incompatible layers by something else (e.g. " f"BatchNorm by InstanceNorm). If you really can't and performance is not a " - f"priority, you may also just set `batch_dim=None` when creating the engine." + f"priority, you may also just set `batch_dim=None` when creating the engine.", ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: raise ValueError( @@ -231,7 +229,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: f" to performing in-place operations on tensors and having side-effects during " f"the forward pass. Try setting `track_running_stats` to `False`. If you really" f" can't and performance is not a priority, you may also just set " - f"`batch_dim=None` when creating the engine." + f"`batch_dim=None` when creating the engine.", ) # Currently, the type PSDMatrix is hidden from users, so Tensor is correct. @@ -278,7 +276,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: target_shape = [] if has_non_batch_dim: - target_shape = [-1] + target_shape + target_shape = [-1, *target_shape] reshaped_output = ordered_output.reshape(target_shape) # There are four different cases for the shape of reshaped_output: @@ -290,7 +288,9 @@ def compute_gramian(self, output: Tensor) -> Tensor: self._module_hook_manager.gramian_accumulation_phase.value = True try: - square_gramian = self._compute_square_gramian(reshaped_output, has_non_batch_dim) + square_gramian = self._compute_square_gramian( + reshaped_output, has_non_batch_dim=has_non_batch_dim + ) finally: # Reset everything that has a state, even if the previous call raised an exception self._module_hook_manager.gramian_accumulation_phase.value = False @@ -308,7 +308,7 @@ def compute_gramian(self, output: Tensor) -> Tensor: return gramian - def _compute_square_gramian(self, output: Tensor, has_non_batch_dim: bool) -> PSDMatrix: + def _compute_square_gramian(self, output: Tensor, *, has_non_batch_dim: bool) -> PSDMatrix: leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: @@ -333,6 +333,4 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: # If the gramian were None, then leaf_targets would be empty, so autograd.grad would # have failed. So gramian is necessarily a valid Tensor here. - gramian = cast(PSDMatrix, self._gramian_accumulator.gramian) - - return gramian + return cast(PSDMatrix, self._gramian_accumulator.gramian) diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 8c1546e0..829e5da3 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -73,5 +73,4 @@ def __call__( gramian = compute_gramian(self.summed_jacobian) del self.summed_jacobian return gramian - else: - return None + return None diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py index bfc0ed9e..12c2bfe2 100644 --- a/src/torchjd/autogram/_gramian_utils.py +++ b/src/torchjd/autogram/_gramian_utils.py @@ -75,8 +75,8 @@ def movedim(gramian: PSDTensor, half_source: list[int], half_destination: list[i # Map everything to the range [0, gramian.ndim//2[ half_ndim = gramian.ndim // 2 - half_source_ = [i if 0 <= i else i + half_ndim for i in half_source] - half_destination_ = [i if 0 <= i else i + half_ndim for i in half_destination] + half_source_ = [i if i >= 0 else i + half_ndim for i in half_source] + half_destination_ = [i if i >= 0 else i + half_ndim for i in half_destination] # Mirror the half source and the half destination and use the result to move the dimensions of # the gramian diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 6929f88e..c2e0eb1c 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -46,7 +46,11 @@ def __call__( ) -> Matrix: # This makes __call__ vmappable. return ComputeModuleJacobians.apply( - self._compute_jacobian, rg_outputs, grad_outputs, args, kwargs + self._compute_jacobian, + rg_outputs, + grad_outputs, + args, + kwargs, ) @abstractmethod @@ -110,8 +114,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...] ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] - rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) - return rg_outputs + return tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] @@ -119,8 +122,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...] # functional has a single primal which is dict(module.named_parameters()). We therefore take # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters. gradients = vjp_func(grad_outputs_j_)[0] - gradient = torch.cat([t.reshape(-1) for t in gradients.values()]) - return gradient + return torch.cat([t.reshape(-1) for t in gradients.values()]) class AutogradJacobianComputer(JacobianComputer): @@ -155,7 +157,8 @@ class ComputeModuleJacobians(torch.autograd.Function): @staticmethod def forward( compute_jacobian_fn: Callable[ - [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], Matrix + [tuple[Tensor, ...], tuple[Tensor, ...], tuple[PyTree, ...], dict[str, PyTree]], + Matrix, ], rg_outputs: tuple[Tensor, ...], grad_outputs: tuple[Tensor, ...], @@ -163,8 +166,7 @@ def forward( kwargs: dict[str, PyTree], ) -> Tensor: # There is no non-batched dimension - jacobian = compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) - return jacobian + return compute_jacobian_fn(rg_outputs, grad_outputs, args, kwargs) @staticmethod def vmap( diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index ef48b784..b86324ca 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -36,7 +36,7 @@ def __init__( ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator - self.gramian_accumulation_phase = BoolRef(False) + self.gramian_accumulation_phase = BoolRef(value=False) self._handles: list[TorchRemovableHandle] = [] # When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks @@ -79,7 +79,7 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None: class BoolRef: """Class wrapping a boolean value, acting as a reference to this boolean value.""" - def __init__(self, value: bool): + def __init__(self, *, value: bool): self.value = value def __bool__(self) -> bool: @@ -101,7 +101,7 @@ def __init__( def __call__( self, - module: nn.Module, + _: nn.Module, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], outputs: PyTree, @@ -157,11 +157,11 @@ class AutogramNode(torch.autograd.Function): @staticmethod def forward( - gramian_accumulation_phase: BoolRef, - gramian_computer: GramianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, + _: BoolRef, + __: GramianComputer, + ___: tuple[PyTree, ...], + ____: dict[str, PyTree], + _____: GramianAccumulator, *rg_tensors: Tensor, ) -> tuple[Tensor, ...]: return tuple(t.detach() for t in rg_tensors) diff --git a/src/torchjd/autojac/_accumulation.py b/src/torchjd/autojac/_accumulation.py index bafd3d62..4272da6b 100644 --- a/src/torchjd/autojac/_accumulation.py +++ b/src/torchjd/autojac/_accumulation.py @@ -27,7 +27,7 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No raise RuntimeError( f"attempting to assign a jacobian of size '{list(jac.shape)}' to a tensor of " f"size '{list(param.shape)}'. Please ensure that the tensor and each row of the" - " jacobian are the same size" + " jacobian are the same size", ) if is_tensor_with_jac(param): @@ -57,7 +57,7 @@ def _check_expects_grad(tensor: Tensor, field_name: str) -> None: if not _expects_grad(tensor): raise ValueError( f"Cannot populate the {field_name} field of a Tensor that does not satisfy:\n" - "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`." + "`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`.", ) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index e0188976..486bc2bb 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -9,6 +9,7 @@ def backward( tensors: Sequence[Tensor] | Tensor, inputs: Iterable[Tensor] | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: @@ -86,6 +87,7 @@ def backward( def _create_transform( tensors: OrderedSet[Tensor], inputs: OrderedSet[Tensor], + *, retain_graph: bool, parallel_chunk_size: int | None, ) -> Transform: @@ -98,7 +100,7 @@ def _create_transform( diag = Diagonalize(tensors) # Transform that computes the required Jacobians. - jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) + jac = Jac(tensors, inputs, chunk_size=parallel_chunk_size, retain_graph=retain_graph) # Transform that accumulates the result in the .jac field of the inputs. accumulate = AccumulateJac() diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 1c809d2d..30013fdf 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -17,6 +17,7 @@ def jac( outputs: Sequence[Tensor] | Tensor, inputs: Iterable[Tensor] | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> tuple[Tensor, ...]: @@ -136,6 +137,7 @@ def jac( def _create_transform( outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], + *, retain_graph: bool, parallel_chunk_size: int | None, ) -> Transform: @@ -146,6 +148,6 @@ def _create_transform( diag = Diagonalize(outputs) # Transform that computes the required Jacobians. - jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) + jac = Jac(outputs, inputs, chunk_size=parallel_chunk_size, retain_graph=retain_graph) return jac << diag << init diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 352e2655..b9f712e8 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -9,7 +9,10 @@ def jac_to_grad( - tensors: Iterable[Tensor], aggregator: Aggregator, retain_jac: bool = False + tensors: Iterable[Tensor], + aggregator: Aggregator, + *, + retain_jac: bool = False, ) -> None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result @@ -56,7 +59,7 @@ def jac_to_grad( if not is_tensor_with_jac(t): raise ValueError( "Some `jac` fields were not populated. Did you use `autojac.backward` or " - "`autojac.mtl_backward` before calling `jac_to_grad`?" + "`autojac.mtl_backward` before calling `jac_to_grad`?", ) tensors_.append(t) @@ -65,7 +68,7 @@ def jac_to_grad( jacobians = [t.jac for t in tensors_] - if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): + if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]): raise ValueError("All Jacobians should have the same number of rows.") if not retain_jac: @@ -73,22 +76,18 @@ def jac_to_grad( jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) + gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] - jacobian_matrix = torch.concat(jacobian_matrices, dim=1) - return jacobian_matrix + return torch.concat(jacobian_matrices, dim=1) -def _disunite_gradient( - gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] -) -> list[Tensor]: +def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) - gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] - return gradients + return [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] def _free_jacs(tensors: Iterable[TensorWithJac]) -> None: diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 831099ed..2655834a 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -21,6 +21,7 @@ def mtl_backward( features: Sequence[Tensor] | Tensor, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: @@ -113,6 +114,7 @@ def _create_transform( features: OrderedSet[Tensor], tasks_params: list[OrderedSet[Tensor]], shared_params: OrderedSet[Tensor], + *, retain_graph: bool, parallel_chunk_size: int | None, ) -> Transform: @@ -130,7 +132,7 @@ def _create_transform( features, task_params, OrderedSet([loss]), - retain_graph, + retain_graph=retain_graph, ) for task_params, loss in zip(tasks_params, losses, strict=True) ] @@ -140,7 +142,7 @@ def _create_transform( stack = Stack(task_transforms) # Transform that computes the Jacobians of the losses w.r.t. the shared parameters. - jac = Jac(features, shared_params, parallel_chunk_size, retain_graph) + jac = Jac(features, shared_params, chunk_size=parallel_chunk_size, retain_graph=retain_graph) # Transform that accumulates the result in the .jac field of the shared parameters. accumulate = AccumulateJac() @@ -152,6 +154,7 @@ def _create_task_transform( features: OrderedSet[Tensor], task_params: OrderedSet[Tensor], loss: OrderedSet[Tensor], # contains a single scalar loss + *, retain_graph: bool, ) -> Transform: # Tensors with respect to which we compute the gradients. @@ -162,7 +165,7 @@ def _create_task_transform( # Transform that computes the gradients of the loss w.r.t. the task-specific parameters and # the features. - grad = Grad(loss, to_differentiate, retain_graph) + grad = Grad(loss, to_differentiate, retain_graph=retain_graph) # Transform that accumulates the gradients w.r.t. the task-specific parameters into their # .grad fields. @@ -173,8 +176,7 @@ def _create_task_transform( # Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into # their .grad fields and backpropagates the gradient of the losses w.r.t. to the features. - backward_task = (backpropagate | accumulate) << grad << init - return backward_task + return (backpropagate | accumulate) << grad << init def _check_losses_are_scalar(losses: Iterable[Tensor]) -> None: @@ -184,7 +186,8 @@ def _check_losses_are_scalar(losses: Iterable[Tensor]) -> None: def _check_no_overlap( - shared_params: Iterable[Tensor], tasks_params: Sequence[Iterable[Tensor]] + shared_params: Iterable[Tensor], + tasks_params: Sequence[Iterable[Tensor]], ) -> None: task_param_set = {param for task_params in tasks_params for param in task_params} shared_param_set = set(shared_params) diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index 38dfb10e..5ecd772f 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -1,6 +1,7 @@ from torch import Tensor -from .._accumulation import accumulate_grads, accumulate_jacs +from torchjd.autojac._accumulation import accumulate_grads, accumulate_jacs + from ._base import TensorDict, Transform @@ -17,7 +18,7 @@ def __call__(self, gradients: TensorDict, /) -> TensorDict: accumulate_grads(gradients.keys(), gradients.values()) return {} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _: set[Tensor], /) -> set[Tensor]: return set() @@ -34,5 +35,5 @@ def __call__(self, jacobians: TensorDict, /) -> TensorDict: accumulate_jacs(jacobians.keys(), jacobians.values()) return {} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _: set[Tensor], /) -> set[Tensor]: return set() diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index db8ff2cb..55ca49d2 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -22,8 +22,6 @@ class RequirementError(ValueError): """Inappropriate set of inputs keys.""" - pass - class Transform(ABC): """ @@ -45,7 +43,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict: """Applies the transform to the input.""" @abstractmethod - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: """ Checks that the provided input_keys satisfy the transform's requirements and returns the corresponding output keys for recursion. @@ -80,10 +78,9 @@ def __call__(self, input: TensorDict, /) -> TensorDict: intermediate = self.inner(input) return self.outer(intermediate) - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: intermediate_keys = self.inner.check_keys(input_keys) - output_keys = self.outer.check_keys(intermediate_keys) - return output_keys + return self.outer.check_keys(intermediate_keys) class Conjunction(Transform): @@ -113,7 +110,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: union |= transform(tensor_dict) return union - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)] output_keys = set(output_keys_list) diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 88e5525e..ccd6fee0 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -63,16 +63,15 @@ def __init__(self, key_order: OrderedSet[Tensor]): def __call__(self, tensors: TensorDict, /) -> TensorDict: flattened_considered_values = [tensors[key].reshape([-1]) for key in self.key_order] diagonal_matrix = torch.cat(flattened_considered_values).diag() - diagonalized_tensors = { - key: diagonal_matrix[:, begin:end].reshape((-1,) + key.shape) + return { + key: diagonal_matrix[:, begin:end].reshape((-1, *key.shape)) for (begin, end), key in zip(self.indices, self.key_order, strict=True) } - return diagonalized_tensors - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: if not set(self.key_order) == input_keys: raise RequirementError( f"The input_keys must match the key_order. Found input_keys {input_keys} and" - f"key_order {self.key_order}." + f"key_order {self.key_order}.", ) return input_keys diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 3cec097d..a767f51e 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -29,6 +29,7 @@ def __init__( self, outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], + *, retain_graph: bool, create_graph: bool, ): @@ -55,16 +56,16 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, . tensor_outputs should be. """ - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: outputs = set(self.outputs) if not outputs == input_keys: raise RequirementError( f"The input_keys must match the expected outputs. Found input_keys {input_keys} and" - f"outputs {outputs}." + f"outputs {outputs}.", ) return set(self.inputs) - def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[Tensor, ...]: + def _get_vjp(self, grad_outputs: Sequence[Tensor], *, retain_graph: bool) -> tuple[Tensor, ...]: optional_grads = torch.autograd.grad( self.outputs, self.inputs, @@ -73,5 +74,4 @@ def _get_vjp(self, grad_outputs: Sequence[Tensor], retain_graph: bool) -> tuple[ create_graph=self.create_graph, allow_unused=True, ) - grads = materialize(optional_grads, inputs=self.inputs) - return grads + return materialize(optional_grads, inputs=self.inputs) diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index a52b7b15..4c1392eb 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -29,10 +29,11 @@ def __init__( self, outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], + *, retain_graph: bool = False, create_graph: bool = False, ): - super().__init__(outputs, inputs, retain_graph, create_graph) + super().__init__(outputs, inputs, retain_graph=retain_graph, create_graph=create_graph) def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: """ @@ -48,10 +49,9 @@ def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ... """ if len(self.inputs) == 0: - return tuple() + return () if len(self.outputs) == 0: return tuple(torch.zeros_like(input) for input in self.inputs) - grads = self._get_vjp(grad_outputs, self.retain_graph) - return grads + return self._get_vjp(grad_outputs, retain_graph=self.retain_graph) diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 551f8197..d66fba71 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -1,4 +1,4 @@ -from collections.abc import Set +from collections.abc import Set as AbstractSet import torch from torch import Tensor @@ -13,15 +13,15 @@ class Init(Transform): :param values: Tensors for which Gradients must be returned. """ - def __init__(self, values: Set[Tensor]): + def __init__(self, values: AbstractSet[Tensor]): self.values = values - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _: TensorDict, /) -> TensorDict: return {value: torch.ones_like(value) for value in self.values} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: if not input_keys == set(): raise RequirementError( - f"The input_keys should be the empty set. Found input_keys {input_keys}." + f"The input_keys should be the empty set. Found input_keys {input_keys}.", ) return set(self.values) diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 0783e22a..2ff367cf 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -35,11 +35,12 @@ def __init__( self, outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], + *, chunk_size: int | None, retain_graph: bool = False, create_graph: bool = False, ): - super().__init__(outputs, inputs, retain_graph, create_graph) + super().__init__(outputs, inputs, retain_graph=retain_graph, create_graph=create_graph) self.chunk_size = chunk_size def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: @@ -57,14 +58,14 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] """ if len(self.inputs) == 0: - return tuple() + return () if len(self.outputs) == 0: return tuple( [ - torch.empty((0,) + input.shape, device=input.device, dtype=input.dtype) + torch.empty((0, *input.shape), device=input.device, dtype=input.dtype) for input in self.inputs - ] + ], ) # If the jac_outputs are correct, this value should be the same for all jac_outputs. @@ -101,7 +102,8 @@ def _differentiate(self, jac_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...] def _get_jacs_chunk( - jac_outputs_chunk: list[Tensor], get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]] + jac_outputs_chunk: list[Tensor], + get_vjp: Callable[[Sequence[Tensor]], tuple[Tensor, ...]], ) -> tuple[Tensor, ...]: """ Computes the jacobian matrix chunk corresponding to the provided get_vjp function, either by @@ -115,5 +117,4 @@ def _get_jacs_chunk( grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk] gradients = get_vjp(grad_outputs) return tuple(gradient.unsqueeze(0) for gradient in gradients) - else: - return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk) + return torch.vmap(get_vjp, chunk_size=chunk_size)(jac_outputs_chunk) diff --git a/src/torchjd/autojac/_transform/_materialize.py b/src/torchjd/autojac/_transform/_materialize.py index 89100168..038565b2 100644 --- a/src/torchjd/autojac/_transform/_materialize.py +++ b/src/torchjd/autojac/_transform/_materialize.py @@ -5,7 +5,8 @@ def materialize( - optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor] + optional_tensors: Sequence[Tensor | None], + inputs: Sequence[Tensor], ) -> tuple[Tensor, ...]: """ Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index 29a6bcd2..b2e45caa 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -1,4 +1,4 @@ -from collections.abc import Set +from collections.abc import Set as AbstractSet from torch import Tensor @@ -12,18 +12,18 @@ class Select(Transform): :param keys: The keys that should be included in the returned subset. """ - def __init__(self, keys: Set[Tensor]): + def __init__(self, keys: AbstractSet[Tensor]): self.keys = keys def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: output = {key: tensor_dict[key] for key in self.keys} return type(tensor_dict)(output) - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: keys = set(self.keys) if not keys.issubset(input_keys): raise RequirementError( f"The input_keys should be a super set of the keys to select. Found input_keys " - f"{input_keys} and keys to select {keys}." + f"{input_keys} and keys to select {keys}.", ) return keys diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/autojac/_transform/_stack.py index 1a3fc2ad..328a2a7a 100644 --- a/src/torchjd/autojac/_transform/_stack.py +++ b/src/torchjd/autojac/_transform/_stack.py @@ -25,10 +25,9 @@ def __init__(self, transforms: Sequence[Transform]): def __call__(self, input: TensorDict, /) -> TensorDict: results = [transform(input) for transform in self.transforms] - result = _stack(results) - return result + return _stack(results) - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: return {key for transform in self.transforms for key in transform.check_keys(input_keys)} @@ -40,8 +39,7 @@ def _stack(gradient_dicts: list[TensorDict]) -> TensorDict: for d in gradient_dicts: union |= d unique_keys = union.keys() - result = {key: _stack_one_key(gradient_dicts, key) for key in unique_keys} - return result + return {key: _stack_one_key(gradient_dicts, key) for key in unique_keys} def _stack_one_key(gradient_dicts: list[TensorDict], input: Tensor) -> Tensor: @@ -49,5 +47,4 @@ def _stack_one_key(gradient_dicts: list[TensorDict], input: Tensor) -> Tensor: optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts] gradients = materialize(optional_gradients, [input] * len(optional_gradients)) - jacobian = torch.stack(gradients, dim=0) - return jacobian + return torch.stack(gradients, dim=0) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 87ae5068..0d029452 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -12,12 +12,13 @@ def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: if not (parallel_chunk_size is None or parallel_chunk_size > 0): raise ValueError( "`parallel_chunk_size` should be `None` or greater than `0`. (got " - f"{parallel_chunk_size})" + f"{parallel_chunk_size})", ) def as_checked_ordered_set( - tensors: Sequence[Tensor] | Tensor, variable_name: str + tensors: Sequence[Tensor] | Tensor, + variable_name: str, ) -> OrderedSet[Tensor]: if isinstance(tensors, Tensor): tensors = [tensors] @@ -42,10 +43,10 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O """ - if any([tensor.grad_fn is None for tensor in tensors]): + if any(tensor.grad_fn is None for tensor in tensors): raise ValueError("All `tensors` should have a `grad_fn`.") - if any([tensor.grad_fn is None for tensor in excluded]): + if any(tensor.grad_fn is None for tensor in excluded): raise ValueError("All `excluded` tensors should have a `grad_fn`.") accumulate_grads = _get_descendant_accumulate_grads( @@ -55,13 +56,12 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O # accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field. # They cannot be typed as such because AccumulateGrad is not public. - leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined] - - return leaves + return OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined] def _get_descendant_accumulate_grads( - roots: OrderedSet[Node], excluded_nodes: set[Node] + roots: OrderedSet[Node], + excluded_nodes: set[Node], ) -> OrderedSet[Node]: """ Gets the AccumulateGrad descendants of the specified nodes. diff --git a/tests/conftest.py b/tests/conftest.py index 5288aa1f..025e08b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ def fix_randomness() -> None: # reproducibility on GPU. We also use GPU to benchmark algorithms, and we would rather have them # use non-deterministic but faster algorithms. if DEVICE.type == "cpu": - torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(mode=True) def pytest_addoption(parser): @@ -49,7 +49,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(xfail_cuda) -def pytest_make_parametrize_id(config, val, argname): +def pytest_make_parametrize_id(config, val, argname): # noqa: ARG001 MAX_SIZE = 40 optional_string = None # Returning None means using pytest's way of making the string diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index cdf26812..283a15e4 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -238,7 +238,7 @@ def __init__(self): self.task2_head = Linear(3, 1) self.automatic_optimization = False - def training_step(self, batch, batch_idx) -> None: + def training_step(self, batch, _) -> None: input, target1, target2 = batch features = self.feature_extractor(input) @@ -256,8 +256,7 @@ def training_step(self, batch, batch_idx) -> None: opt.zero_grad() def configure_optimizers(self) -> OptimizerLRScheduler: - optimizer = Adam(self.parameters(), lr=1e-3) - return optimizer + return Adam(self.parameters(), lr=1e-3) model = Model() diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index fe844d66..11b44bb4 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -28,7 +28,11 @@ def make_fig(self) -> Figure: for i in range(len(results)): scatter = make_vector_scatter( - results[i], "black", str(self.aggregators[i]), showlegend=True, dash=True + results[i], + "black", + str(self.aggregators[i]), + showlegend=True, + dash=True, ) fig.add_trace(scatter) @@ -55,6 +59,7 @@ def make_vector_scatter( gradient: torch.Tensor, color: str, label: str, + *, showlegend: bool = False, dash: bool = False, textposition: str = "bottom center", @@ -62,32 +67,36 @@ def make_vector_scatter( text_size: float = 12, marker_size: float = 12, ) -> Scatter: - line = dict(color=color, width=line_width) + line = {"color": color, "width": line_width} if dash: line["dash"] = "dash" - scatter = go.Scatter( + return go.Scatter( x=[0, gradient[0]], y=[0, gradient[1]], mode="lines+markers+text", line=line, - marker=dict( - symbol="arrow", - color=color, - size=marker_size, - angleref="previous", - ), + marker={ + "symbol": "arrow", + "color": color, + "size": marker_size, + "angleref": "previous", + }, name=label, text=["", label], textposition=textposition, - textfont=dict(color=color, size=text_size), + textfont={"color": color, "size": text_size}, showlegend=showlegend, ) - return scatter def make_cone_scatter( - start_angle: float, opening: float, label: str, scale: float = 100.0, printable: bool = False + start_angle: float, + opening: float, + label: str, + scale: float = 100.0, + *, + printable: bool = False, ) -> Scatter: if opening < -1e-8: cone_outline = np.zeros([0, 2]) @@ -105,7 +114,7 @@ def make_cone_scatter( start_vec, # Tip of the first vector end_vec, # Tip of the second vector [0, 0], # Back to the origin to close the cone - ] + ], ) else: middle_point = angle_to_coord(middle_angle, scale) @@ -117,60 +126,56 @@ def make_cone_scatter( middle_point, # Tip of the vector in-between end_vec, # Tip of the second vector [0, 0], # Back to the origin to close the cone - ] + ], ) if printable: - fillpattern = dict( - bgcolor="white", shape="\\", fgcolor="rgba(0, 220, 0, 0.5)", size=30, solidity=0.15 - ) + fillpattern = { + "bgcolor": "white", + "shape": "\\", + "fgcolor": "rgba(0, 220, 0, 0.5)", + "size": 30, + "solidity": 0.15, + } else: fillpattern = None - cone = go.Scatter( + return go.Scatter( x=cone_outline[:, 0], y=cone_outline[:, 1], fill="toself", # Fill the area inside the polygon mode="lines", fillcolor="rgba(0, 255, 0, 0.07)", - line=dict(color="rgb(0, 220, 0)", width=2), + line={"color": "rgb(0, 220, 0)", "width": 2}, name=label, fillpattern=fillpattern, ) - return cone - def make_segment_scatter(start: torch.Tensor, end: torch.Tensor) -> Scatter: - segment = go.Scatter( + return go.Scatter( x=[start[0], end[0]], y=[start[1], end[1]], mode="lines", - line=dict( - color="rgb(150, 150, 150)", - width=2.5, - dash="longdash", - ), + line={"color": "rgb(150, 150, 150)", "width": 2.5, "dash": "longdash"}, ) - return segment - def make_polygon_scatter(points: list[torch.Tensor]) -> Scatter: - polygon = go.Scatter( + return go.Scatter( x=[point[0] for point in points], y=[point[1] for point in points], mode="lines", - line=dict( - color="rgb(100, 100, 100)", - width=1.5, - ), + line={"color": "rgb(100, 100, 100)", "width": 1.5}, ) - return polygon def make_right_angle( - vector: torch.Tensor, size: float, positive_para: bool = True, positive_orth: bool = True + vector: torch.Tensor, + size: float, + *, + positive_para: bool = True, + positive_orth: bool = True, ) -> list[torch.Tensor]: vec_para = vector / torch.linalg.norm(vector) * size vec_orth = torch.tensor([-vec_para[1], vec_para[0]]) @@ -242,10 +247,7 @@ def coord_to_angle(x: float, y: float) -> tuple[float, float]: if r == 0: raise ValueError("No angle") - elif y >= 0: - angle = np.arccos(x / r) - else: - angle = 2 * np.pi - np.arccos(x / r) + angle = np.arccos(x / r) if y >= 0 else 2 * np.pi - np.arccos(x / r) return angle, r diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index d78b7fde..0d26a1e6 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -39,7 +39,7 @@ def main() -> None: [0.0, 1.0], [1.0, -1.0], [1.0, 0.0], - ] + ], ) aggregators = [ @@ -148,7 +148,8 @@ def update_aggregators(value: list[str]) -> Figure: def make_gradient_div( - i: int, initial_gradient: torch.Tensor + i: int, + initial_gradient: torch.Tensor, ) -> tuple[html.Div, dcc.Input, dcc.Input]: x = initial_gradient[0].item() y = initial_gradient[1].item() diff --git a/tests/profiling/plot_memory_timeline.py b/tests/profiling/plot_memory_timeline.py index 0f792be1..6af5d9ae 100644 --- a/tests/profiling/plot_memory_timeline.py +++ b/tests/profiling/plot_memory_timeline.py @@ -86,7 +86,7 @@ def plot_memory_timelines(experiment: str, folders: list[str]) -> None: ax_cuda.set_ylabel("CUDA Memory (bytes)", fontsize=12) ax_cuda.set_title(f"CUDA Memory Timeline: {experiment}", fontsize=14, fontweight="bold") ax_cuda.legend(loc="best", fontsize=11) - ax_cuda.grid(True, alpha=0.3) + ax_cuda.grid(visible=True, alpha=0.3) ax_cuda.set_ylim(bottom=0) # Plot CPU memory (bottom subplot) @@ -99,7 +99,7 @@ def plot_memory_timelines(experiment: str, folders: list[str]) -> None: ax_cpu.set_ylabel("CPU Memory (bytes)", fontsize=12) ax_cpu.set_title(f"CPU Memory Timeline: {experiment}", fontsize=14, fontweight="bold") ax_cpu.legend(loc="best", fontsize=11) - ax_cpu.grid(True, alpha=0.3) + ax_cpu.grid(visible=True, alpha=0.3) ax_cpu.set_ylim(bottom=0) fig.tight_layout() diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index b143a55b..7707cb01 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -90,7 +90,10 @@ def _get_profiler_activities() -> list[ProfilerActivity]: def _save_and_print_trace( - prof: profile, method_name: str, factory: ModuleFactory, batch_size: int + prof: profile, + method_name: str, + factory: ModuleFactory, + batch_size: int, ) -> None: filename = f"{factory}-bs{batch_size}-{DEVICE.type}.json" output_dir = TRACES_DIR / method_name diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index 13b57b62..d68d67aa 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -105,7 +105,11 @@ def post_fn(): print_times("autograd", autograd_times) autograd_gramian_times = time_call( - fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs + fn_autograd_gramian, + init_fn_autograd_gramian, + pre_fn, + post_fn, + n_runs, ) print_times("autograd gramian", autograd_gramian_times) diff --git a/tests/settings.py b/tests/settings.py index b7fe2345..008080b4 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -13,7 +13,7 @@ if _device_str not in _POSSIBLE_TEST_DEVICES: raise ValueError( f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n" - f"Possible values: {_POSSIBLE_TEST_DEVICES}." + f"Possible values: {_POSSIBLE_TEST_DEVICES}.", ) if _device_str == "cuda:0" and not torch.cuda.is_available(): @@ -29,7 +29,7 @@ if _dtype_str not in _POSSIBLE_TEST_DTYPES: raise ValueError( f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n" - f"Possible values: {_POSSIBLE_TEST_DTYPES}." + f"Possible values: {_POSSIBLE_TEST_DTYPES}.", ) DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32 diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 15b69874..8c119674 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -22,7 +22,10 @@ def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None: def assert_non_conflicting( - aggregator: Aggregator, matrix: Tensor, atol: float = 4e-04, rtol: float = 4e-04 + aggregator: Aggregator, + matrix: Tensor, + atol: float = 4e-04, + rtol: float = 4e-04, ) -> None: """Tests empirically that a given `Aggregator` satisfies the non-conflicting property.""" @@ -81,7 +84,9 @@ def assert_linear_under_scaling( def assert_strongly_stationary( - aggregator: Aggregator, matrix: Tensor, threshold: float = 5e-03 + aggregator: Aggregator, + matrix: Tensor, + threshold: float = 5e-03, ) -> None: """ Tests empirically that a given `Aggregator` is strongly stationary. diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 8a3acd8d..dab23dcd 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -42,8 +42,7 @@ def __call__(self, rng: torch.Generator | None = None) -> Tensor: U = _sample_orthonormal_matrix(self.m, rng=rng) Vt = _sample_orthonormal_matrix(self.n, rng=rng) S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) - A = U[:, : self.rank] @ S @ Vt[: self.rank, :] - return A + return U[:, : self.rank] @ S @ Vt[: self.rank, :] class StrongSampler(MatrixSampler): @@ -59,7 +58,7 @@ class StrongSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 1 < m + assert m > 1 assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: @@ -68,8 +67,7 @@ def __call__(self, rng: torch.Generator | None = None) -> Tensor: U2 = _sample_semi_orthonormal_complement(U1, rng=rng) Vt = _sample_orthonormal_matrix(self.n, rng=rng) S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) - A = U2[:, : self.rank] @ S @ Vt[: self.rank, :] - return A + return U2[:, : self.rank] @ S @ Vt[: self.rank, :] class StrictlyWeakSampler(MatrixSampler): @@ -94,7 +92,7 @@ class StrictlyWeakSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 1 < m + assert m > 1 assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: @@ -110,8 +108,7 @@ def __call__(self, rng: torch.Generator | None = None) -> Tensor: U = torch.hstack([U1, U2]) Vt = _sample_orthonormal_matrix(self.n, rng=rng) S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) - A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :] - return A + return U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :] class NonWeakSampler(MatrixSampler): @@ -126,7 +123,7 @@ class NonWeakSampler(MatrixSampler): def _check_params(self, m: int, n: int, rank: int) -> None: super()._check_params(m, n, rank) - assert 0 < rank + assert rank > 0 def __call__(self, rng: torch.Generator | None = None) -> Tensor: u = torch.abs(randn_([self.m], generator=rng)) @@ -135,8 +132,7 @@ def __call__(self, rng: torch.Generator | None = None) -> Tensor: U = torch.hstack([U1, U2]) Vt = _sample_orthonormal_matrix(self.n, rng=rng) S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) - A = U[:, : self.rank] @ S @ Vt[: self.rank, :] - return A + return U[:, : self.rank] @ S @ Vt[: self.rank, :] def _sample_orthonormal_matrix(dim: int, rng: torch.Generator | None = None) -> Tensor: diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index ff75e5cf..48fa4019 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -62,7 +62,10 @@ def test_n_selected_check(n_selected: int, expectation: ExceptionContext): ], ) def test_matrix_shape_check( - n_byzantine: int, n_selected: int, n_rows: int, expectation: ExceptionContext + n_byzantine: int, + n_selected: int, + n_rows: int, + expectation: ExceptionContext, ): aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) matrix = ones_([n_rows, 5]) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index d55d87c0..57a9120c 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -52,7 +52,10 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): pc_grad_weighting = PCGradWeighting() upgrad_sum_weighting = UPGradWeighting( - ones_((2,)), norm_eps=0.0, reg_eps=0.0, solver="quadprog" + ones_((2,)), + norm_eps=0.0, + reg_eps=0.0, + solver="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 5fed3869..67f24862 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -39,7 +39,7 @@ [75.0, -666.0, 23], # adversarial row [1.0, 2.0, 3.0], [2.0, 0.0, 1.0], - ] + ], ) J_TrimmedMean = tensor( [ @@ -47,7 +47,7 @@ [1.0, -1e11], [-1e10, 1e10], [2.0, 2.0], - ] + ], ) AGGREGATOR_PARAMETRIZATIONS = [ @@ -101,7 +101,7 @@ J_base, tensor([0.0542, 0.7061, 0.7061]), marks=mark.filterwarnings("ignore::UserWarning"), - ) + ), ) except ImportError: diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 2461e383..6b893721 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -144,7 +144,9 @@ def _assert_gramian_is_equivalent_to_autograd( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): model_autograd, model_autogram = factory(), factory() engine = Engine(model_autogram, batch_dim=batch_dim) @@ -208,7 +210,9 @@ def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None]) def test_compute_gramian_with_weird_modules( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): """ Tests that compute_gramian works even with some problematic modules when batch_dim is None. It @@ -230,7 +234,9 @@ def test_compute_gramian_with_weird_modules( @mark.parametrize("batch_size", [1, 3, 32]) @mark.parametrize("batch_dim", [0, None]) def test_compute_gramian_unsupported_architectures( - factory: ModuleFactory, batch_size: int, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + batch_dim: int | None, ): """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to @@ -353,7 +359,11 @@ def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch @mark.parametrize("use_engine", [False, True]) @mark.parametrize("batch_dim", [0, None]) def test_autograd_while_modules_are_hooked( - factory: ModuleFactory, batch_size: int, use_engine: bool, batch_dim: int | None + factory: ModuleFactory, + batch_size: int, + *, + use_engine: bool, + batch_dim: int | None, ): """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd @@ -536,11 +546,11 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): input = randn_([batch_size, input_size]) engine1 = Engine(model1, batch_dim=batch_dim) - output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) + output1 = model1(input).reshape([batch_size, *non_batched_shape]).movedim(0, batch_dim) gramian1 = engine1.compute_gramian(output1) engine2 = Engine(model2, batch_dim=None) - output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) + output2 = model2(input).reshape([batch_size, *non_batched_shape]).movedim(0, batch_dim) gramian2 = engine2.compute_gramian(output2) assert_close(gramian1, gramian2) diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index ac912918..5f74df8d 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -29,8 +29,8 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]): """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" - original_matrix = randn_(original_shape + [2]) - target_matrix = original_matrix.reshape(target_shape + [2]) + original_matrix = randn_([*original_shape, 2]) + target_matrix = original_matrix.reshape([*target_shape, 2]) original_gramian = compute_gramian(original_matrix, 1) target_gramian = compute_gramian(target_matrix, 1) @@ -56,7 +56,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int] ], ) def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): - matrix = randn_(original_shape + [2]) + matrix = randn_([*original_shape, 2]) gramian = compute_gramian(matrix, 1) reshaped_gramian = reshape(gramian, target_shape) assert_is_psd_tensor(reshaped_gramian, atol=1e-04, rtol=0.0) @@ -73,7 +73,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): ], ) def test_flatten_yields_matrix(shape: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) assert is_psd_matrix(flattened_gramian) @@ -90,7 +90,7 @@ def test_flatten_yields_matrix(shape: list[int]): ], ) def test_flatten_yields_psd(shape: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) assert_is_psd_matrix(flattened_gramian, atol=1e-04, rtol=0.0) @@ -117,7 +117,7 @@ def test_flatten_yields_psd(shape: list[int]): def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" - original_matrix = randn_(shape + [2]) + original_matrix = randn_([*shape, 2]) target_matrix = original_matrix.movedim(source, destination) original_gramian = compute_gramian(original_matrix, 1) @@ -147,7 +147,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ], ) def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]): - matrix = randn_(shape + [2]) + matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) moveddim_gramian = movedim(gramian, source, destination) assert_is_psd_tensor(moveddim_gramian) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index eaa09549..8c179a89 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -97,7 +97,7 @@ def test_single_jac_accumulation(): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] - values = [ones_([4] + shape) for shape in shapes] + values = [ones_([4, *shape]) for shape in shapes] input = dict(zip(keys, values, strict=True)) accumulate = AccumulateJac() @@ -118,7 +118,7 @@ def test_multiple_jac_accumulations(iterations: int): shapes = [[], [1], [2, 3]] keys = [zeros_(shape, requires_grad=True) for shape in shapes] - values = [ones_([4] + shape) for shape in shapes] + values = [ones_([4, *shape]) for shape in shapes] accumulate = AccumulateJac() diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 5da475e6..478efc42 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -17,12 +17,11 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): def __str__(self): return "T" - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _: TensorDict, /) -> TensorDict: # Ignore the input, create a dictionary with the right keys as an output. - output_dict = {key: empty_(0) for key in self._output_keys} - return output_dict + return {key: empty_(0) for key in self._output_keys} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: # Arbitrary requirement for testing purposes. if not input_keys == self._required_keys: raise RequirementError() diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index a712dcef..470f5d6d 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -235,9 +235,7 @@ def test_equivalence_jac_grads(): grad_2_A, grad_2_b, grad_2_c = grad_dict_2[A], grad_dict_2[b], grad_dict_2[c] n_outputs = len(outputs) - batched_grad_outputs = [ - zeros_((n_outputs,) + grad_output.shape) for grad_output in grad_outputs - ] + batched_grad_outputs = [zeros_((n_outputs, *grad_output.shape)) for grad_output in grad_outputs] for i, grad_output in enumerate(grad_outputs): batched_grad_outputs[i][i] = grad_output diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index c00e43d2..e1efecf5 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -106,10 +106,16 @@ def test_retain_graph(): input = {y: eye_(2)} jac_retain_graph = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=True, ) jac_discard_graph = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=False + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=False, ) jac_retain_graph(input) @@ -140,10 +146,16 @@ def test_two_levels(): input = {z: eye_(2)} outer_jac = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + retain_graph=True, ) inner_jac = Jac( - outputs=OrderedSet([z]), inputs=OrderedSet([y]), chunk_size=None, retain_graph=True + outputs=OrderedSet([z]), + inputs=OrderedSet([y]), + chunk_size=None, + retain_graph=True, ) composed_jac = outer_jac << inner_jac jac = Jac(outputs=OrderedSet([z]), inputs=OrderedSet([a1, a2]), chunk_size=None) @@ -236,7 +248,10 @@ def test_composition_of_jacs_is_jac(): input = {z1: tensor_([1.0, 0.0]), z2: tensor_([0.0, 1.0])} outer_jac = Jac( - outputs=OrderedSet([y1, y2]), inputs=OrderedSet([a]), chunk_size=None, retain_graph=True + outputs=OrderedSet([y1, y2]), + inputs=OrderedSet([a]), + chunk_size=None, + retain_graph=True, ) inner_jac = Jac( outputs=OrderedSet([z1, z2]), @@ -291,7 +306,10 @@ def test_create_graph(): input = {y: eye_(2)} jac = Jac( - outputs=OrderedSet([y]), inputs=OrderedSet([a1, a2]), chunk_size=None, create_graph=True + outputs=OrderedSet([y]), + inputs=OrderedSet([a1, a2]), + chunk_size=None, + create_graph=True, ) jacobians = jac(input) diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index fc2cdf7a..d2d4334b 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -15,10 +15,10 @@ class FakeGradientsTransform(Transform): def __init__(self, keys: Iterable[Tensor]): self.keys = set(keys) - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _: TensorDict, /) -> TensorDict: return {key: torch.ones_like(key) for key in self.keys} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _: set[Tensor], /) -> set[Tensor]: return self.keys diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 0a1ce91d..54056422 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -48,7 +48,7 @@ def test_jac_is_populated(): @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( shape: tuple[int, int], - manually_specify_inputs: bool, + manually_specify_inputs: bool, # noqa: FBT001 chunk_size: int | None, ): """ @@ -60,10 +60,7 @@ def test_value_is_correct( input = randn_([shape[1]], requires_grad=True) output = J @ input # Note that the Jacobian of output w.r.t. input is J. - if manually_specify_inputs: - inputs = [input] - else: - inputs = None + inputs = [input] if manually_specify_inputs else None backward( [output], diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3a5fb9a4..54f67d3a 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -52,6 +52,7 @@ def test_jac(): @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( shape: tuple[int, int], + *, manually_specify_inputs: bool, chunk_size: int | None, ): @@ -64,10 +65,7 @@ def test_value_is_correct( input = randn_([shape[1]], requires_grad=True) output = J @ input # Note that the Jacobian of output w.r.t. input is J. - if manually_specify_inputs: - inputs = [input] - else: - inputs = None + inputs = [input] if manually_specify_inputs else None jacobians = jac( [output], diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 60ea6838..c1663054 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -86,7 +86,7 @@ def test_no_tensors(): @mark.parametrize("retain_jac", [True, False]) -def test_jacs_are_freed(retain_jac: bool): +def test_jacs_are_freed(*, retain_jac: bool): """Tests that jac_to_grad frees the jac fields if an only if retain_jac is False.""" aggregator = UPGrad() diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 00bda738..b244b152 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -67,6 +67,7 @@ def test_shape_is_correct(): @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( shape: tuple[int, int], + *, manually_specify_shared_params: bool, manually_specify_tasks_params: bool, chunk_size: int | None, @@ -90,15 +91,9 @@ def test_value_is_correct( y2 = p2 @ f y3 = p3 @ f - if manually_specify_shared_params: - shared_params = [p0] - else: - shared_params = None + shared_params = [p0] if manually_specify_shared_params else None - if manually_specify_tasks_params: - tasks_params = [[p1], [p2], [p3]] - else: - tasks_params = None + tasks_params = [[p1], [p2], [p3]] if manually_specify_tasks_params else None mtl_backward( losses=[y1, y2, y3], @@ -224,13 +219,13 @@ def test_multiple_params_per_task(): @mark.parametrize( "shared_params_shapes", [ - [tuple()], + [()], [(2,)], [(3, 2)], [(4, 3, 2)], - [tuple(), (2,)], + [(), (2,)], [(3, 2), (2,)], - [(4, 3, 2), (3, 2), tuple()], + [(4, 3, 2), (3, 2), ()], [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py index 6fc22512..53373822 100644 --- a/tests/unit/linalg/test_gramian.py +++ b/tests/unit/linalg/test_gramian.py @@ -57,7 +57,7 @@ def test_compute_gramian_matrix_input_0(): [ [[[1.0, 3.0], [2.0, 4.0]], [[2.0, 6.0], [4.0, 8.0]]], [[[3.0, 9.0], [6.0, 12.0]], [[4.0, 12.0], [8.0, 16.0]]], - ] + ], ) assert_close(gramian, expected) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index f1b98b6d..8d59ac73 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Generic, TypeVar +from typing import ClassVar, Generic, TypeVar import torch import torchvision @@ -48,14 +48,13 @@ def get_in_out_shapes(module: nn.Module) -> tuple[PyTree, PyTree]: if isinstance(module, ShapedModule): return module.INPUT_SHAPES, module.OUTPUT_SHAPES - elif isinstance(module, nn.BatchNorm2d | nn.InstanceNorm2d): + if isinstance(module, nn.BatchNorm2d | nn.InstanceNorm2d): HEIGHT = 6 # Arbitrary choice WIDTH = 6 # Arbitrary choice shape = (module.num_features, HEIGHT, WIDTH) return shape, shape - else: - raise ValueError("Unknown input / output shapes of module", module) + raise ValueError("Unknown input / output shapes of module", module) class OverlyNested(ShapedModule): @@ -103,8 +102,7 @@ def __init__(self): def forward(self, inputs: tuple[Tensor, Tensor]) -> Tensor: input1, input2 = inputs - output = input1 @ self.matrix1 + input2 @ self.matrix2 - return output + return input1 @ self.matrix1 + input2 @ self.matrix2 class MultiInputMultiOutput(ShapedModule): @@ -131,7 +129,7 @@ class SingleInputPyTreeOutput(ShapedModule): """Module taking a single input and returning a complex PyTree of tensors as output.""" INPUT_SHAPES = (50,) - OUTPUT_SHAPES = { + OUTPUT_SHAPES: ClassVar = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],), @@ -156,7 +154,7 @@ def forward(self, input: Tensor) -> PyTree: class PyTreeInputSingleOutput(ShapedModule): """Module taking a complex PyTree of tensors as input and returning a single output.""" - INPUT_SHAPES = { + INPUT_SHAPES: ClassVar = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } @@ -182,9 +180,7 @@ def forward(self, inputs: PyTree) -> Tensor: output3 = input3 @ self.matrix3 output4 = input4 @ self.matrix4 output5 = input5 @ self.matrix5 - output = torch.concatenate([output1, output2, output3, output4, output5], dim=1) - - return output + return torch.concatenate([output1, output2, output3, output4, output5], dim=1) class PyTreeInputPyTreeOutput(ShapedModule): @@ -193,12 +189,12 @@ class PyTreeInputPyTreeOutput(ShapedModule): output. """ - INPUT_SHAPES = { + INPUT_SHAPES: ClassVar = { "one": [((10,), [(20,), (30,)]), (12,)], "two": (14,), } - OUTPUT_SHAPES = { + OUTPUT_SHAPES: ClassVar = { "first": ((50,), [(60,), (70,)]), "second": (80,), "third": ([((90,),)],), @@ -245,8 +241,7 @@ def forward(self, input: Tensor) -> Tensor: common_input = self.relu(self.fc0(input)) branch1 = self.fc2(self.relu(self.fc1(common_input))) branch2 = self.fc3(common_input) - output = self.fc4(self.relu(branch1 + branch2)) - return output + return self.fc4(self.relu(branch1 + branch2)) class MISOBranched(ShapedModule): @@ -393,7 +388,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: - return {"one": [None, tuple()], "two": None} + return {"one": [None, ()], "two": None} class _EmptyTupleOutput(nn.Module): def __init__(self, shape: tuple[int, ...]): @@ -401,7 +396,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> tuple: - return tuple() + return () class _EmptyPytreeOutput(nn.Module): def __init__(self, shape: tuple[int, ...]): @@ -409,7 +404,7 @@ def __init__(self, shape: tuple[int, ...]): self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: - return {"one": [tuple(), tuple()], "two": [[], []]} + return {"one": [(), ()], "two": [[], []]} def __init__(self): super().__init__() @@ -529,7 +524,7 @@ def __init__(self): super().__init__() self.non_frozen = nn.Linear(50, 10) self.all_frozen = nn.Linear(50, 10) - self.all_frozen.requires_grad_(False) + self.all_frozen.requires_grad_(requires_grad=False) def forward(self, input: Tensor) -> Tensor: return self.all_frozen(input) + self.non_frozen(input**2 / 5.0) @@ -662,15 +657,14 @@ def __init__(self): def forward(self, input: Tensor) -> Tensor: _ = self.linear1(input) - output = self.linear2(input) - return output + return self.linear2(input) class Ndim0Output(ShapedModule): """Simple model whose output is a scalar.""" INPUT_SHAPES = (5,) - OUTPUT_SHAPES = tuple() + OUTPUT_SHAPES = () def __init__(self): super().__init__() @@ -808,8 +802,7 @@ def __init__(self): def forward(self, s: str, input: Tensor) -> Tensor: if s == "two": return input @ self.matrix * 2.0 - else: - return input @ self.matrix + return input @ self.matrix class WithModuleWithStringArg(ShapedModule): @@ -1030,8 +1023,7 @@ def forward(self, input: Tensor) -> Tensor: output = self.relu(self.linear1(output)) output = self.relu(self.linear2(output)) output = self.relu(self.linear3(output)) - output = self.linear4(output) - return output + return self.linear4(output) class NoFreeParam(ShapedModule): @@ -1057,8 +1049,7 @@ def forward(self, input: Tensor) -> Tensor: output = self.relu(self.linear1(output)) output = self.relu(self.linear2(output)) output = self.relu(self.linear3(output)) - output = self.linear4(output) - return output + return self.linear4(output) class Cifar10Model(ShapedModule): @@ -1116,8 +1107,7 @@ def __init__(self): def forward(self, input: Tensor) -> Tensor: features = self.body(input) - output = self.head(features) - return output + return self.head(features) class AlexNet(ShapedModule): @@ -1150,7 +1140,7 @@ class InstanceNormResNet18(ShapedModule): def __init__(self): super().__init__() self.resnet18 = torchvision.models.resnet18( - norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) + norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), ) def forward(self, input: Tensor) -> Tensor: @@ -1166,7 +1156,7 @@ class GroupNormMobileNetV3Small(ShapedModule): def __init__(self): super().__init__() self.mobile_net = torchvision.models.mobilenet_v3_small( - norm_layer=partial(nn.GroupNorm, 2, affine=True) + norm_layer=partial(nn.GroupNorm, 2, affine=True), ) def forward(self, input: Tensor) -> Tensor: @@ -1196,7 +1186,7 @@ class InstanceNormMobileNetV2(ShapedModule): def __init__(self): super().__init__() self.mobilenet = torchvision.models.mobilenet_v2( - norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True) + norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), ) def forward(self, input: Tensor) -> Tensor: diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index f8b9dfe2..93506cc4 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -71,8 +71,7 @@ def forward_pass( assert tree_map(lambda t: t.shape[1:], output) == expected_output_shapes loss_tensors = loss_fn(output) - losses = reduction(loss_tensors) - return losses + return reduction(loss_tensors) def make_mse_loss_fn(targets: PyTree) -> Callable[[PyTree], list[Tensor]]: @@ -80,13 +79,11 @@ def mse_loss_fn(outputs: PyTree) -> list[Tensor]: flat_outputs, _ = tree_flatten(outputs) flat_targets, _ = tree_flatten(targets) - loss_tensors = [ + return [ mse_loss(output, target, reduction="none") for output, target in zip(flat_outputs, flat_targets, strict=True) ] - return loss_tensors - return mse_loss_fn @@ -111,12 +108,14 @@ def reshape_raw_losses(raw_losses: Tensor) -> Tensor: if raw_losses.ndim == 1: return raw_losses.unsqueeze(1) - else: - return raw_losses.flatten(start_dim=1) + return raw_losses.flatten(start_dim=1) def compute_gramian_with_autograd( - output: Tensor, params: list[nn.Parameter], retain_graph: bool = False + output: Tensor, + params: list[nn.Parameter], + *, + retain_graph: bool = False, ) -> PSDTensor: """ Computes the Gramian of the Jacobian of the outputs with respect to the params using vmapped @@ -137,9 +136,7 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]: jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output))) jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians] - gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) - - return gramian + return sum([jacobian @ jacobian.T for jacobian in jacobian_matrices]) class CloneParams: @@ -198,7 +195,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): for module in self.model.modules(): self._restore_original_params(module) - return False # don’t suppress exceptions + return False # don't suppress exceptions def _restore_original_params(self, module: nn.Module): original_params = self._module_to_original_params.pop(module, {}) diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 6c91a08c..7988157d 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -38,6 +38,6 @@ def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: def is_leaf(s): - return isinstance(s, tuple) and all([isinstance(e, int) for e in s]) + return isinstance(s, tuple) and all(isinstance(e, int) for e in s) - return tree_map(lambda s: randn_((batch_size,) + s), tensor_shapes, is_leaf=is_leaf) + return tree_map(lambda s: randn_((batch_size, *s)), tensor_shapes, is_leaf=is_leaf)