diff --git a/src/torchjd/autogram/diagonal_sparse_tensor.py b/src/torchjd/autogram/diagonal_sparse_tensor.py index c12134db..93af6585 100644 --- a/src/torchjd/autogram/diagonal_sparse_tensor.py +++ b/src/torchjd/autogram/diagonal_sparse_tensor.py @@ -23,7 +23,10 @@ def decorator(func): class DiagonalSparseTensor(torch.Tensor): @staticmethod - def __new__(cls, data: Tensor, v_to_p: list[int]): + def __new__(cls, data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None): + if unsqueeze_dims is None: + unsqueeze_dims = [] + # At the moment, this class is not compositional, so we assert # that the tensor we're wrapping is exactly a Tensor assert type(data) is Tensor @@ -39,12 +42,18 @@ def __new__(cls, data: Tensor, v_to_p: list[int]): assert not data.requires_grad or not torch.is_grad_enabled() shape = [data.shape[i] for i in v_to_p] + for d in unsqueeze_dims: + shape.insert(d, 1) return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device) - def __init__(self, data: Tensor, v_to_p: list[int]): + def __init__(self, data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None): self.contiguous_data = data # self.data cannot be used here. self.v_to_p = v_to_p + if unsqueeze_dims is None: + unsqueeze_dims = [] + self.unsqueeze_dims = unsqueeze_dims + def to_dense(self) -> Tensor: if self.contiguous_data.ndim == 0: return self.contiguous_data @@ -54,10 +63,17 @@ def to_dense(self) -> Tensor: p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") v_indices_grid = tuple(p_indices_grid[i] for i in self.v_to_p) + shape_before_unsqueeze = [self.contiguous_data.shape[i] for i in self.v_to_p] res = torch.zeros( - self.shape, device=self.contiguous_data.device, dtype=self.contiguous_data.dtype + shape_before_unsqueeze, + device=self.contiguous_data.device, + dtype=self.contiguous_data.dtype, ) res[v_indices_grid] = self.contiguous_data + + for d in self.unsqueeze_dims: + res.unsqueeze_(d) + return res @classmethod @@ -84,7 +100,9 @@ def __repr__(self): ) -def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: +def diagonal_sparse_tensor( + data: Tensor, v_to_p: list[int], unsqueeze_dims: list[int] | None = None +) -> Tensor: if not all(0 <= i < data.ndim for i in v_to_p): raise ValueError(f"Elements in v_to_p map to dimensions in data. Found {v_to_p}.") if len(set(v_to_p)) != data.ndim: @@ -92,7 +110,7 @@ def diagonal_sparse_tensor(data: Tensor, v_to_p: list[int]) -> Tensor: if len(v_to_p) == data.ndim: return torch.movedim(data, (list(range(data.ndim))), v_to_p) else: - return DiagonalSparseTensor(data, v_to_p) + return DiagonalSparseTensor(data, v_to_p, unsqueeze_dims) # pointwise functions applied to one Tensor with `0.0 → 0` @@ -194,3 +212,9 @@ def mean(t: Tensor) -> Tensor: def sum(t: Tensor) -> Tensor: assert isinstance(t, DiagonalSparseTensor) return aten.sum.default(t.contiguous_data) + + +@implements(aten.unsqueeze.default) +def unsqueeze(t: Tensor, dim: int) -> Tensor: + assert isinstance(t, DiagonalSparseTensor) + return diagonal_sparse_tensor(t.contiguous_data, t.v_to_p, t.unsqueeze_dims + [dim])