diff --git a/metrics.py b/metrics.py index d53c716..82f0c9e 100644 --- a/metrics.py +++ b/metrics.py @@ -6,8 +6,9 @@ import numpy as np def compute_metrics(x): - sx = np.sort(-x, axis=1) - d = np.diag(-x) + x = -np.concatenate(x, axis=0) + sx = np.sort(x, axis=1) + d = np.diag(x) d = d[:, np.newaxis] ind = sx - d ind = np.where(ind == 0)