diff --git a/transformcl.py b/transformcl.py index 581189e..1bd29b1 100644 --- a/transformcl.py +++ b/transformcl.py @@ -138,7 +138,9 @@ def var(cl): """ xp = array_namespace(cl) - ell = xp.arange(cl.shape[-1]) + # ell cannot be an integer here as, within the array api + # only floating-point dtypes are allowed in __truediv__ + ell = xp.arange(cl.shape[-1], dtype=xp.float64) return xp.sum((2 * ell + 1) / (4 * xp.pi) * cl, axis=-1)