From 1f5978ffaad0e7c881a46afc2adf1937e7e737e8 Mon Sep 17 00:00:00 2001 From: dikanquit Date: Fri, 29 May 2026 11:40:36 +0200 Subject: [PATCH] roll: skip 0-size dim and reject duplicate dims modulo raised on shape[d]=0, dup dims overwrote shrink_arg --- tinygrad/mixin/movement.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index dc03177ec9d4b..ec1998ade0e8f 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -522,9 +522,11 @@ def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) - if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape) dims, shifts = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1) if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}") + if len(set(dims)) != len(dims): raise RuntimeError(f"duplicate dims in roll: {dims=}") shrink_arg: list[tuple[sint, sint]|None] = [None] * self.ndim - for d, s in zip(dims, shifts): shrink_arg[d] = (delta:=self.shape[d]-s%self.shape[d], delta+self.shape[d]) - return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim))).shrink(tuple(shrink_arg)) + for d, s in zip(dims, shifts): + if self.shape[d] != 0: shrink_arg[d] = (delta:=self.shape[d]-s%self.shape[d], delta+self.shape[d]) + return self.repeat(*tuple(2 if i in dims and self.shape[i] != 0 else 1 for i in range(self.ndim))).shrink(tuple(shrink_arg)) # *** movement ops with expand ***