diff --git a/pytorch_neat/recurrent_net.py b/pytorch_neat/recurrent_net.py index 1bf811b..413971e 100644 --- a/pytorch_neat/recurrent_net.py +++ b/pytorch_neat/recurrent_net.py @@ -33,7 +33,7 @@ def dense_from_coo(shape, conns, dtype=torch.float64): idxs, weights = conns if len(idxs) == 0: return mat - rows, cols = np.array(idxs).transpose() + rows, cols = np.array(idxs, dtype=np.int64).transpose() mat[torch.tensor(rows), torch.tensor(cols)] = torch.tensor( weights, dtype=dtype) return mat