From 896e55d300e01893336eb26a4c6362a345fe5cdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 21 Dec 2025 17:54:07 +0100 Subject: [PATCH 1/2] test: Remove dtype arg from MatrixSampler --- tests/unit/aggregation/_matrix_samplers.py | 62 ++++++++++------------ 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index f24e2028..698e2d15 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -8,22 +8,20 @@ class MatrixSampler(ABC): - """Abstract base class for sampling matrices of a given shape, rank and dtype.""" + """Abstract base class for sampling matrices of a given shape, rank.""" - def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE): - self._check_params(m, n, rank, dtype) + def __init__(self, m: int, n: int, rank: int): + self._check_params(m, n, rank) self.m = m self.n = n self.rank = rank - self.dtype = dtype - def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None: + def _check_params(self, m: int, n: int, rank: int) -> None: """Checks that the provided __init__ parameters are acceptable.""" assert m >= 0 assert n >= 0 assert 0 <= rank <= min(m, n) - assert dtype in {torch.float32, torch.float64} @abstractmethod def __call__(self, rng: torch.Generator | None = None) -> Tensor: @@ -35,24 +33,24 @@ def __repr__(self) -> str: def __str__(self) -> str: return ( f"{self.__class__.__name__.replace('MatrixSampler', '')}" - f"({self.m}x{self.n}r{self.rank}:{str(self.dtype)[6:]})" + f"({self.m}x{self.n}r{self.rank})" ) class NormalSampler(MatrixSampler): - """Sampler for random normal matrices of shape [m, n] with provided rank and dtype.""" + """Sampler for random normal matrices of shape [m, n] with provided rank.""" def __call__(self, rng: torch.Generator | None = None) -> Tensor: - U = _sample_orthonormal_matrix(self.m, dtype=self.dtype, rng=rng) - Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng))) + U = _sample_orthonormal_matrix(self.m, dtype=DTYPE, rng=rng) + Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) A = U[:, : self.rank] @ S @ Vt[: self.rank, :] return A class StrongSampler(MatrixSampler): """ - Sampler for random strongly stationary matrices of shape [m, n] with provided rank and dtype. + Sampler for random strongly stationary matrices of shape [m, n] with provided rank. Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such that v^T A = 0. @@ -61,25 +59,24 @@ class StrongSampler(MatrixSampler): orthogonal to v. """ - def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None: - super()._check_params(m, n, rank, dtype) + def _check_params(self, m: int, n: int, rank: int) -> None: + super()._check_params(m, n, rank) assert 1 < m assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: - v = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng)) + v = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) U1 = normalize(v, dim=0).unsqueeze(1) U2 = _sample_semi_orthonormal_complement(U1, rng=rng) - Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) A = U2[:, : self.rank] @ S @ Vt[: self.rank, :] return A class StrictlyWeakSampler(MatrixSampler): """ - Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank and - dtype. + Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank. Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0, such that v^T A = 0. @@ -97,32 +94,31 @@ class StrictlyWeakSampler(MatrixSampler): stationary. """ - def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None: - super()._check_params(m, n, rank, dtype) + def _check_params(self, m: int, n: int, rank: int) -> None: + super()._check_params(m, n, rank) assert 1 < m assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: - u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng)) + u = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) split_index = randint_(1, self.m, [], generator=rng).item() shuffled_range = randperm_(self.m, generator=rng) - v = zeros_(self.m, dtype=self.dtype) + v = zeros_(self.m, dtype=DTYPE) v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0) - v_prime = zeros_(self.m, dtype=self.dtype) + v_prime = zeros_(self.m, dtype=DTYPE) v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0) U1 = torch.stack([v, v_prime]).T U2 = _sample_semi_orthonormal_complement(U1, rng=rng) U = torch.hstack([U1, U2]) - Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :] return A class NonWeakSampler(MatrixSampler): """ - Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank and - dtype. + Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank. Obtaining such a matrix is done by sampling a positive u, and by then sampling a matrix A that has u as one of its left-singular vectors, with positive singular value s. Any 0 <= v, v != 0, @@ -130,17 +126,17 @@ class NonWeakSampler(MatrixSampler): contradiction. A is thus not weakly stationary. """ - def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None: - super()._check_params(m, n, rank, dtype) + def _check_params(self, m: int, n: int, rank: int) -> None: + super()._check_params(m, n, rank) assert 0 < rank def __call__(self, rng: torch.Generator | None = None) -> Tensor: - u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng)) + u = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) U1 = normalize(u, dim=0).unsqueeze(1) U2 = _sample_semi_orthonormal_complement(U1, rng=rng) U = torch.hstack([U1, U2]) - Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) A = U[:, : self.rank] @ S @ Vt[: self.rank, :] return A From bc826a238087cb3226e1da9f3cb464983e1a1020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 21 Dec 2025 18:06:46 +0100 Subject: [PATCH 2/2] test: Minimize the number of dtype kwargs --- tests/unit/aggregation/_asserts.py | 8 ++--- tests/unit/aggregation/_matrix_samplers.py | 38 ++++++++++------------ tests/unit/aggregation/test_constant.py | 2 +- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 332bd5b8..15b69874 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -67,10 +67,10 @@ def assert_linear_under_scaling( """Tests empirically that a given `Aggregator` satisfies the linear under scaling property.""" for _ in range(n_runs): - c1 = rand_(matrix.shape[0], dtype=matrix.dtype) - c2 = rand_(matrix.shape[0], dtype=matrix.dtype) - alpha = rand_([], dtype=matrix.dtype) - beta = rand_([], dtype=matrix.dtype) + c1 = rand_(matrix.shape[0]) + c2 = rand_(matrix.shape[0]) + alpha = rand_([]) + beta = rand_([]) x1 = aggregator(torch.diag(c1) @ matrix) x2 = aggregator(torch.diag(c2) @ matrix) diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 698e2d15..a106e56f 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod import torch -from settings import DTYPE from torch import Tensor from torch.nn.functional import normalize from utils.tensors import randint_, randn_, randperm_, zeros_ @@ -41,9 +40,9 @@ class NormalSampler(MatrixSampler): """Sampler for random normal matrices of shape [m, n] with provided rank.""" def __call__(self, rng: torch.Generator | None = None) -> Tensor: - U = _sample_orthonormal_matrix(self.m, dtype=DTYPE, rng=rng) - Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) + U = _sample_orthonormal_matrix(self.m, rng=rng) + Vt = _sample_orthonormal_matrix(self.n, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) A = U[:, : self.rank] @ S @ Vt[: self.rank, :] return A @@ -65,11 +64,11 @@ def _check_params(self, m: int, n: int, rank: int) -> None: assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: - v = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) + v = torch.abs(randn_([self.m], generator=rng)) U1 = normalize(v, dim=0).unsqueeze(1) U2 = _sample_semi_orthonormal_complement(U1, rng=rng) - Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) A = U2[:, : self.rank] @ S @ Vt[: self.rank, :] return A @@ -100,18 +99,18 @@ def _check_params(self, m: int, n: int, rank: int) -> None: assert 0 < rank <= min(m - 1, n) def __call__(self, rng: torch.Generator | None = None) -> Tensor: - u = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) + u = torch.abs(randn_([self.m], generator=rng)) split_index = randint_(1, self.m, [], generator=rng).item() shuffled_range = randperm_(self.m, generator=rng) - v = zeros_(self.m, dtype=DTYPE) + v = zeros_(self.m) v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0) - v_prime = zeros_(self.m, dtype=DTYPE) + v_prime = zeros_(self.m) v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0) U1 = torch.stack([v, v_prime]).T U2 = _sample_semi_orthonormal_complement(U1, rng=rng) U = torch.hstack([U1, U2]) - Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :] return A @@ -131,22 +130,20 @@ def _check_params(self, m: int, n: int, rank: int) -> None: assert 0 < rank def __call__(self, rng: torch.Generator | None = None) -> Tensor: - u = torch.abs(randn_([self.m], dtype=DTYPE, generator=rng)) + u = torch.abs(randn_([self.m], generator=rng)) U1 = normalize(u, dim=0).unsqueeze(1) U2 = _sample_semi_orthonormal_complement(U1, rng=rng) U = torch.hstack([U1, U2]) - Vt = _sample_orthonormal_matrix(self.n, dtype=DTYPE, rng=rng) - S = torch.diag(torch.abs(randn_([self.rank], dtype=DTYPE, generator=rng))) + Vt = _sample_orthonormal_matrix(self.n, rng=rng) + S = torch.diag(torch.abs(randn_([self.rank], generator=rng))) A = U[:, : self.rank] @ S @ Vt[: self.rank, :] return A -def _sample_orthonormal_matrix( - dim: int, dtype: torch.dtype, rng: torch.Generator | None = None -) -> Tensor: +def _sample_orthonormal_matrix(dim: int, rng: torch.Generator | None = None) -> Tensor: """Uniformly samples a random orthonormal matrix of shape [dim, dim].""" - return _sample_semi_orthonormal_complement(zeros_([dim, 0], dtype=dtype), rng=rng) + return _sample_semi_orthonormal_complement(zeros_([dim, 0]), rng=rng) def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = None) -> Tensor: @@ -157,9 +154,8 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = :param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m. """ - dtype = Q.dtype m, k = Q.shape - A = randn_([m, m - k], dtype=dtype, generator=rng) + A = randn_([m, m - k], generator=rng) # project A onto the orthogonal complement of Q A_proj = A - Q @ (Q.T @ A) diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index dd860e23..984b7a1f 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -18,7 +18,7 @@ def _make_aggregator(matrix: Tensor) -> Constant: n_rows = matrix.shape[0] - weights = tensor_([1.0 / n_rows] * n_rows, dtype=matrix.dtype) + weights = tensor_([1.0 / n_rows] * n_rows) return Constant(weights)