diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index dc03177ec9d4b..4e750f7fb22ae 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -329,6 +329,7 @@ def flatten(self, start_dim=0, end_dim=-1) -> Self: ``` """ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) + if start_dim > end_dim: raise RuntimeError(f"flatten: {start_dim=} > {end_dim=}") return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :]) def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self: