Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/unit/aggregation/_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def assert_linear_under_scaling(
"""Tests empirically that a given `Aggregator` satisfies the linear under scaling property."""

for _ in range(n_runs):
c1 = rand_(matrix.shape[0], dtype=matrix.dtype)
c2 = rand_(matrix.shape[0], dtype=matrix.dtype)
alpha = rand_([], dtype=matrix.dtype)
beta = rand_([], dtype=matrix.dtype)
c1 = rand_(matrix.shape[0])
c2 = rand_(matrix.shape[0])
alpha = rand_([])
beta = rand_([])

x1 = aggregator(torch.diag(c1) @ matrix)
x2 = aggregator(torch.diag(c2) @ matrix)
Expand Down
72 changes: 32 additions & 40 deletions tests/unit/aggregation/_matrix_samplers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
from abc import ABC, abstractmethod

import torch
from settings import DTYPE
from torch import Tensor
from torch.nn.functional import normalize
from utils.tensors import randint_, randn_, randperm_, zeros_


class MatrixSampler(ABC):
"""Abstract base class for sampling matrices of a given shape, rank and dtype."""
"""Abstract base class for sampling matrices of a given shape, rank."""

def __init__(self, m: int, n: int, rank: int, dtype: torch.dtype = DTYPE):
self._check_params(m, n, rank, dtype)
def __init__(self, m: int, n: int, rank: int):
self._check_params(m, n, rank)
self.m = m
self.n = n
self.rank = rank
self.dtype = dtype

def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
def _check_params(self, m: int, n: int, rank: int) -> None:
"""Checks that the provided __init__ parameters are acceptable."""

assert m >= 0
assert n >= 0
assert 0 <= rank <= min(m, n)
assert dtype in {torch.float32, torch.float64}

@abstractmethod
def __call__(self, rng: torch.Generator | None = None) -> Tensor:
Expand All @@ -35,24 +32,24 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return (
f"{self.__class__.__name__.replace('MatrixSampler', '')}"
f"({self.m}x{self.n}r{self.rank}:{str(self.dtype)[6:]})"
f"({self.m}x{self.n}r{self.rank})"
)


class NormalSampler(MatrixSampler):
"""Sampler for random normal matrices of shape [m, n] with provided rank and dtype."""
"""Sampler for random normal matrices of shape [m, n] with provided rank."""

def __call__(self, rng: torch.Generator | None = None) -> Tensor:
U = _sample_orthonormal_matrix(self.m, dtype=self.dtype, rng=rng)
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
U = _sample_orthonormal_matrix(self.m, rng=rng)
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
return A


class StrongSampler(MatrixSampler):
"""
Sampler for random strongly stationary matrices of shape [m, n] with provided rank and dtype.
Sampler for random strongly stationary matrices of shape [m, n] with provided rank.

Definition: A matrix A is said to be strongly stationary if there exists a vector 0 < v such
that v^T A = 0.
Expand All @@ -61,25 +58,24 @@ class StrongSampler(MatrixSampler):
orthogonal to v.
"""

def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
super()._check_params(m, n, rank, dtype)
def _check_params(self, m: int, n: int, rank: int) -> None:
super()._check_params(m, n, rank)
assert 1 < m
assert 0 < rank <= min(m - 1, n)

def __call__(self, rng: torch.Generator | None = None) -> Tensor:
v = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
v = torch.abs(randn_([self.m], generator=rng))
U1 = normalize(v, dim=0).unsqueeze(1)
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
A = U2[:, : self.rank] @ S @ Vt[: self.rank, :]
return A


class StrictlyWeakSampler(MatrixSampler):
"""
Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank and
dtype.
Sampler for random strictly weakly stationary matrices of shape [m, n] with provided rank.

Definition: A matrix A is said to be weakly stationary if there exists a vector 0 <= v, v != 0,
such that v^T A = 0.
Expand All @@ -97,60 +93,57 @@ class StrictlyWeakSampler(MatrixSampler):
stationary.
"""

def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
super()._check_params(m, n, rank, dtype)
def _check_params(self, m: int, n: int, rank: int) -> None:
super()._check_params(m, n, rank)
assert 1 < m
assert 0 < rank <= min(m - 1, n)

def __call__(self, rng: torch.Generator | None = None) -> Tensor:
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
u = torch.abs(randn_([self.m], generator=rng))
split_index = randint_(1, self.m, [], generator=rng).item()
shuffled_range = randperm_(self.m, generator=rng)
v = zeros_(self.m, dtype=self.dtype)
v = zeros_(self.m)
v[shuffled_range[:split_index]] = normalize(u[shuffled_range[:split_index]], dim=0)
v_prime = zeros_(self.m, dtype=self.dtype)
v_prime = zeros_(self.m)
v_prime[shuffled_range[split_index:]] = normalize(u[shuffled_range[split_index:]], dim=0)
U1 = torch.stack([v, v_prime]).T
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
U = torch.hstack([U1, U2])
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
A = U[:, 1 : self.rank + 1] @ S @ Vt[: self.rank, :]
return A


class NonWeakSampler(MatrixSampler):
"""
Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank and
dtype.
Sampler for a random non weakly-stationary matrices of shape [m, n] with provided rank.

Obtaining such a matrix is done by sampling a positive u, and by then sampling a matrix A that
has u as one of its left-singular vectors, with positive singular value s. Any 0 <= v, v != 0,
satisfies v^T A != 0. Otherwise, we would have 0 = v^T A A^T u = s v^T u > 0, which is a
contradiction. A is thus not weakly stationary.
"""

def _check_params(self, m: int, n: int, rank: int, dtype: torch.dtype) -> None:
super()._check_params(m, n, rank, dtype)
def _check_params(self, m: int, n: int, rank: int) -> None:
super()._check_params(m, n, rank)
assert 0 < rank

def __call__(self, rng: torch.Generator | None = None) -> Tensor:
u = torch.abs(randn_([self.m], dtype=self.dtype, generator=rng))
u = torch.abs(randn_([self.m], generator=rng))
U1 = normalize(u, dim=0).unsqueeze(1)
U2 = _sample_semi_orthonormal_complement(U1, rng=rng)
U = torch.hstack([U1, U2])
Vt = _sample_orthonormal_matrix(self.n, dtype=self.dtype, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], dtype=self.dtype, generator=rng)))
Vt = _sample_orthonormal_matrix(self.n, rng=rng)
S = torch.diag(torch.abs(randn_([self.rank], generator=rng)))
A = U[:, : self.rank] @ S @ Vt[: self.rank, :]
return A


def _sample_orthonormal_matrix(
dim: int, dtype: torch.dtype, rng: torch.Generator | None = None
) -> Tensor:
def _sample_orthonormal_matrix(dim: int, rng: torch.Generator | None = None) -> Tensor:
"""Uniformly samples a random orthonormal matrix of shape [dim, dim]."""

return _sample_semi_orthonormal_complement(zeros_([dim, 0], dtype=dtype), rng=rng)
return _sample_semi_orthonormal_complement(zeros_([dim, 0]), rng=rng)


def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None = None) -> Tensor:
Expand All @@ -161,9 +154,8 @@ def _sample_semi_orthonormal_complement(Q: Tensor, rng: torch.Generator | None =
:param Q: A semi-orthonormal matrix (i.e. Q^T Q = I) of shape [m, k], with k <= m.
"""

dtype = Q.dtype
m, k = Q.shape
A = randn_([m, m - k], dtype=dtype, generator=rng)
A = randn_([m, m - k], generator=rng)

# project A onto the orthogonal complement of Q
A_proj = A - Q @ (Q.T @ A)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def _make_aggregator(matrix: Tensor) -> Constant:
n_rows = matrix.shape[0]
weights = tensor_([1.0 / n_rows] * n_rows, dtype=matrix.dtype)
weights = tensor_([1.0 / n_rows] * n_rows)
return Constant(weights)


Expand Down
Loading