diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index dc03177ec9d4b..ec1998ade0e8f 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -522,9 +522,11 @@ def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) - if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape) dims, shifts = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1) if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}") + if len(set(dims)) != len(dims): raise RuntimeError(f"duplicate dims in roll: {dims=}") shrink_arg: list[tuple[sint, sint]|None] = [None] * self.ndim - for d, s in zip(dims, shifts): shrink_arg[d] = (delta:=self.shape[d]-s%self.shape[d], delta+self.shape[d]) - return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim))).shrink(tuple(shrink_arg)) + for d, s in zip(dims, shifts): + if self.shape[d] != 0: shrink_arg[d] = (delta:=self.shape[d]-s%self.shape[d], delta+self.shape[d]) + return self.repeat(*tuple(2 if i in dims and self.shape[i] != 0 else 1 for i in range(self.ndim))).shrink(tuple(shrink_arg)) # *** movement ops with expand ***