diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index d24a7d9d60c4b..69ff480eeabf5 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -790,7 +790,7 @@ def argmax(self, axis=None, keepdim=False) -> Self: print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1. ``` """ - if axis is None: return self.flatten().argmax(0) + if axis is None: return self.flatten().argmax(0).reshape((1,)*self.ndim if keepdim else ()) axis = self._resolve_dim(axis) m = self.eq(self.max(axis=axis, keepdim=True)) idx = m * type(self).arange(self.shape[axis], 0, -1, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))