Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/torchjd/autogram/diagonal_sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def decorator(func):
class DiagonalSparseTensor(torch.Tensor):

@staticmethod
def __new__(cls, data: Tensor, v_to_p: list[int]):
def __new__(cls, data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None):
if unsqueeze_dims is None:
unsqueeze_dims = []

# At the moment, this class is not compositional, so we assert
# that the tensor we're wrapping is exactly a Tensor
assert type(data) is Tensor
Expand All @@ -39,12 +42,18 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
assert not data.requires_grad or not torch.is_grad_enabled()

shape = [data.shape[i] for i in v_to_p]
for d in unsqueeze_dims:
shape.insert(d, 1)
return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device)

def __init__(self, data: Tensor, v_to_p: list[int]):
def __init__(self, data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None):
self.contiguous_data = data # self.data cannot be used here.
self.v_to_p = v_to_p

if unsqueeze_dims is None:
unsqueeze_dims = []
self.unsqueeze_dims = unsqueeze_dims

def to_dense(self) -> Tensor:
if self.contiguous_data.ndim == 0:
return self.contiguous_data
Expand All @@ -54,10 +63,17 @@ def to_dense(self) -> Tensor:
p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij")
v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p)

shape_before_unsqueeze = [self.contiguous_data.shape[i] for i in self.v_to_p]
res = torch.zeros(
self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype
shape_before_unsqueeze,
device=self.contiguous_data.device,
dtype=self.contiguous_data.dtype,
)
res[v_indices_grid] = self.contiguous_data

for d in self.unsqueeze_dims:
res.unsqueeze_(d)

return res

@classmethod
Expand All @@ -84,15 +100,17 @@ def __repr__(self):
)


def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor:
def diagonal_sparse_tensor(
data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None
) -> Tensor:
if not all(0 <= i < data.ndim for i in v_to_p):
raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.")
if len(set(v_to_p)) != data.ndim:
raise ValueError("Every dimension in data must appear at least once in v_to_p.")
if len(v_to_p) == data.ndim:
return torch.movedim(data, (list(range(data.ndim))), v_to_p)
else:
return DiagonalSparseTensor(data, v_to_p)
return DiagonalSparseTensor(data, v_to_p, unsqueeze_dims)


# pointwise functions applied to one Tensor with `0.0 → 0`
Expand Down Expand Up @@ -194,3 +212,9 @@ def mean(t: Tensor) -> Tensor:
def sum(t: Tensor) -> Tensor:
assert isinstance(t, DiagonalSparseTensor)
return aten.sum.default(t.contiguous_data)


@implements(aten.unsqueeze.default)
def unsqueeze(t: Tensor, dim: int) -> Tensor:
assert isinstance(t, DiagonalSparseTensor)
return diagonal_sparse_tensor(t.contiguous_data, t.v_to_p, t.unsqueeze_dims + [dim])
Loading