Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions mobie/tables/metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import json
import multiprocessing
import os

import luigi
import numpy as np
import pandas as pd
from cluster_tools.evaluation import ObjectIouWorkflow, ObjectViWorkflow
from pybdv.metadata import get_data_path

from ..config import write_global_config
from ..metadata import load_image_dict

METRICS = {
'iou': ObjectIouWorkflow,
'voi': ObjectViWorkflow
}


def read_metrics(path, metric):
with open(path) as f:
scores = json.load(f)

if metric == 'iou': # iou just has a single metric per object
data = np.array([
[int(label_id), score] for label_id, score in scores.items()
])
columns = ['label_id', 'iou-score']
else: # voi has two metrics per object
data = np.array([
[int(label_id), score[0], score[1]] for label_id, score in scores.items()
])
columns = ['label_id', 'voi-split-score', 'voi-merge-score']

return data, columns


def write_metric_table(table_path, data, columns):
if os.path.exists(table_path):
df = pd.read_csv(table_path, sep='\t')

label_ids = data[:, 0]
if not np.allclose(label_ids, df['label_id'].values):
raise RuntimeError("Label ids in metrics table disagree")

for col_id, col_name in enumerate(columns[1:], 1):
if col_name in df.columns:
df[col_name] = data[:, 1]
else:
merge_data = np.concatenate([data[:, 0:1], data[:, col_id:col_id+1]], axis=1)
df = df.merge(pd.DataFrame(merge_data, columns=['label_id', col_name]))
else:
df = pd.DataFrame(data, columns=columns)
df.to_csv(table_path, index=False, sep='\t')


def compute_metrics_table(
dataset_folder,
seg_name,
gt_name,
metric='iou',
scale=0,
tmp_folder=None,
target='local',
max_jobs=multiprocessing.cpu_count()
):
"""
"""

if metric not in METRICS:
msg = f"Metric {metric} is not supported. Only {list(METRICS.keys())} are supported."
raise ValueError(msg)
task = METRICS[metric]

image_folder = os.path.join(dataset_folder, 'images')
image_dict = load_image_dict(os.path.join(image_folder, 'images.json'))

seg_entry = image_dict[seg_name]
seg_path = os.path.join(image_folder, seg_entry['storage']['local'])
seg_path = get_data_path(seg_path, return_absolute_path=True)

gt_entry = image_dict[gt_name]
gt_path = os.path.join(image_folder, gt_entry['storage']['local'])
gt_path = get_data_path(gt_path, return_absolute_path=True)

tmp_folder = f'tmp_metrics_{seg_name}_{gt_name}' if tmp_folder is None else tmp_folder
config_dir = os.path.join(tmp_folder, 'configs')
write_global_config(config_dir)

key = f'setup0/timepoint0/s{scale}'
out_path = os.path.join(tmp_folder, f'scores_{metric}.json')
t = task(tmp_folder=tmp_folder, config_dir=config_dir,
target=target, max_jobs=max_jobs,
seg_path=seg_path, seg_key=key,
gt_path=gt_path, gt_key=key,
output_path=out_path)
if not luigi.build([t], local_scheduler=True):
raise RuntimeError("Computing metrics failed.")

data, columns = read_metrics(out_path, metric)

table_dir = os.path.join(dataset_folder, seg_entry['tableFolder'])
table_path = os.path.join(table_dir, 'metrics.csv')

write_metric_table(table_path, data, columns)
Empty file added test/tables/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions test/tables/test_metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import unittest
from glob import glob
from shutil import rmtree

import h5py
import numpy as np
import pandas as pd
from mobie import initialize_dataset, add_segmentation
from mobie.metadata import load_image_dict


# TODO use better test data that has at least some overlap between segmentation and ground-truth
class TestMetricsTable(unittest.TestCase):
project_folder = './data'
dataset_folder = './data/test'
gt_name = 'gt'
seg_name = 'seg'
n_segments = 32

def setUp(self):
os.makedirs(self.project_folder, exist_ok=True)
shape = (256, 256, 256)
chunks = (64, 64, 64)
tmp_data = './data/data.h5'
with h5py.File(tmp_data, 'w') as f:
f.create_dataset('raw', data=np.random.rand(*shape), chunks=chunks)
f.create_dataset('gt', data=np.random.randint(0, 25, size=shape).astype('uint64'), chunks=chunks)
f.create_dataset('seg', data=np.random.randint(0, self.n_segments, size=shape).astype('uint64'),
chunks=chunks)

initialize_dataset(tmp_data, 'raw', self.project_folder, 'test', 'raw',
resolution=(1, 1, 1), chunks=chunks, scale_factors=[[2, 2, 2]])
add_segmentation(tmp_data, 'gt', self.project_folder, 'test', self.gt_name,
resolution=(1, 1, 1), chunks=chunks, scale_factors=[[2, 2, 2]])
add_segmentation(tmp_data, 'seg', self.project_folder, 'test', self.seg_name,
resolution=(1, 1, 1), chunks=chunks, scale_factors=[[2, 2, 2]])

def tearDown(self):
rmtree(self.project_folder)
tmp_folders = glob('tmp*')
for tmp_folder in tmp_folders:
rmtree(tmp_folder)

def _load_table(self):
image_folder = os.path.join(self.dataset_folder, 'images')
image_dict = load_image_dict(os.path.join(image_folder, 'images.json'))

seg_entry = image_dict[self.seg_name]
table_dir = os.path.join(self.dataset_folder, seg_entry['tableFolder'])
table_path = os.path.join(table_dir, 'metrics.csv')

return pd.read_csv(table_path, sep='\t')

def test_iou(self):
from mobie.tables.metrics_table import compute_metrics_table
compute_metrics_table(
self.dataset_folder, self.seg_name, self.gt_name,
metric='iou'
)
table = self._load_table()
self.assertIn('iou-score', table.columns)
scores = table['iou-score'].values
self.assertTrue(0 <= np.min(scores) <= 1)
self.assertTrue(0 <= np.max(scores) <= 1)

def test_voi(self):
from mobie.tables.metrics_table import compute_metrics_table
compute_metrics_table(
self.dataset_folder, self.seg_name, self.gt_name,
metric='voi'
)
table = self._load_table()
for score_name in ('voi-split-score', 'voi-merge-score'):
self.assertIn(score_name, table.columns)
scores = table[score_name].values
self.assertTrue(0 <= np.min(scores))
self.assertTrue(0 <= np.max(scores))


if __name__ == '__main__':
unittest.main()