Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class Config(object):
EARLY_STOPPING_PATIENCE = 0
EARLY_STOPPING_CRITERION = "val_loss"

# Dropout rate
DROPOUT_RATE = 0.0
MC_DROPOUT = False
MC_DROPOUT_T = 100
Comment on lines +111 to +112
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment indicating what MC_DROPOUT and MC_DROPOUT_T are referring to


# Batch sizes
TRAIN_BATCH_SIZE = 12
VALID_BATCH_SIZE = 35
Expand Down Expand Up @@ -589,6 +594,10 @@ def summary(self, additional_vars=None):
"EARLY_STOPPING_PATIENCE" if self.USE_EARLY_STOPPING else "",
"EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "",
"",
"DROPOUT_RATE",
"MC_DROPOUT",
"MC_DROPOUT_T" if self.MC_DROPOUT else ""
"",
"KERNEL_INITIALIZER",
"SEED" if self.SEED else "",
"" "INIT_WEIGHTS",
Expand Down
13 changes: 12 additions & 1 deletion medsegpy/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def inference(self, model: Model, **kwargs):

workers = kwargs.pop("workers", self._cfg.NUM_WORKERS)
use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1)

kwargs["mc_dropout"] = self._cfg.MC_DROPOUT
kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T

for scan_id in scan_ids:
self._dataset_dicts = scan_to_dict_mapping[scan_id]

Expand All @@ -353,6 +357,13 @@ def inference(self, model: Model, **kwargs):
)
time_elapsed = time.perf_counter() - start

preds_mc_dropout = None
if isinstance(preds, dict):
if preds['preds_mc_dropout'] is not None:
preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']).transpose((1, 2, 3, 0))

preds = preds['preds']

x, y, preds = self._restructure_data((x, y, preds))

input = {"x": x, "scan_id": scan_id}
Expand All @@ -363,7 +374,7 @@ def inference(self, model: Model, **kwargs):
}
input.update(scan_params)

output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed}
output = {"y_pred": preds, "y_mc_dropout":preds_mc_dropout, "y_true": y, "time_elapsed": time_elapsed}

yield input, output

Expand Down
5 changes: 5 additions & 0 deletions medsegpy/evaluation/sem_seg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ def process(self, inputs, outputs):
if includes_bg:
y_true = output["y_true"][..., 1:]
y_pred = output["y_pred"][..., 1:]
y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_mc_dropout"][..., 1:]
labels = labels[..., 1:]
# if y_true.ndim == 3:
# y_true = y_true[..., np.newaxis]
# y_pred = y_pred[..., np.newaxis]
# labels = labels[..., np.newaxis]
output["y_true"] = y_true
output["y_pred"] = y_pred
output["y_mc_dropout"] = y_mc_dropout

time_elapsed = output["time_elapsed"]
if self.stream_evaluation:
Expand Down Expand Up @@ -178,6 +180,9 @@ def eval_single_scan(self, input, output, labels, time_elapsed):
with h5py.File(save_name, "w") as h5f:
h5f.create_dataset("probs", data=output["y_pred"])
h5f.create_dataset("labels", data=labels)
h5f.create_dataset("true", data=output["y_true"])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would avoid saving y_true - it should be easily accessible from your input data hdf5 file and duplicating it here would use up more disk space, which is limited.

if output["y_mc_dropout"] is not None:
h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"])

def evaluate(self):
"""Evaluates popular medical segmentation metrics specified in config.
Expand Down
5 changes: 3 additions & 2 deletions medsegpy/modeling/meta_arch/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def build_model(self, input_tensor=None) -> Model:
seed = cfg.SEED
depth = cfg.DEPTH
kernel_size = self.kernel_size
dropout_rate = cfg.DROPOUT_RATE
self.use_attention = cfg.USE_ATTENTION
self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION

Expand Down Expand Up @@ -178,7 +179,7 @@ def build_model(self, input_tensor=None) -> Model:
num_conv=2,
activation="relu",
kernel_initializer=kernel_initializer,
dropout=0.0,
dropout=dropout_rate,
)

# Maxpool until penultimate depth.
Expand Down Expand Up @@ -220,7 +221,7 @@ def build_model(self, input_tensor=None) -> Model:
num_conv=2,
activation="relu",
kernel_initializer=kernel_initializer,
dropout=0.0,
dropout=dropout_rate,
)

if self.use_deep_supervision:
Expand Down
28 changes: 24 additions & 4 deletions medsegpy/modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def inference_generator(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
mc_dropout=False,
mc_dropout_T=100,
Comment on lines +45 to +46
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like these are not used - if that's the case, delete

verbose=0
):
return self.inference_generator_static(
self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose
Expand All @@ -57,7 +59,9 @@ def inference_generator_static(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
mc_dropout=False,
mc_dropout_T=100,
verbose=0
):
"""Generates predictions for the input samples from a data generator
and returns inputs, ground truth, and predictions.
Expand Down Expand Up @@ -115,6 +119,8 @@ def inference_generator_static(
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
mc_dropout=mc_dropout,
mc_dropout_T=mc_dropout_T,
verbose=verbose,
)
else:
Expand Down Expand Up @@ -252,9 +258,12 @@ def _inference_generator_tf2(
max_queue_size=10,
workers=1,
use_multiprocessing=False,
mc_dropout=False,
mc_dropout_T=100
):
"""Inference generator for TensorFlow 2."""
outputs = []
outputs_mc_dropout = []
xs = []
ys = []
with model.distribute_strategy.scope():
Expand Down Expand Up @@ -295,14 +304,21 @@ def _inference_generator_tf2(
batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator))
# tmp_batch_outputs = predict_function(iterator)
tmp_batch_outputs = model.predict(batch_x)


tmp_batch_outputs_mc_dropout = None
if mc_dropout:
tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(mc_dropout_T)])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you're trying to do here, but it will not be reproducible, which is necessary if we are to add this in the inference loop. There is no random seed being set, so the features that are dropped out will be different if you run inference on the same example twice.

im not sure exactly how to account for this, potentially setting a random seed. Write a unittest to verify that this does in fact produce identical inputs when run twice.


if data_handler.should_sync:
context.async_wait() # noqa: F821
batch_outputs = tmp_batch_outputs # No error, now safe to assign.
batch_outputs_mc_dropout = tmp_batch_outputs_mc_dropout

if batch_x_raw is not None:
batch_x = batch_x_raw
for batch, running in zip(
[batch_x, batch_y, batch_outputs], [xs, ys, outputs]
[batch_x, batch_y, batch_outputs, batch_outputs_mc_dropout], [xs, ys, outputs, outputs_mc_dropout]
):
nest.map_structure_up_to(
batch, lambda x, batch_x: x.append(batch_x), running, batch
Expand All @@ -318,7 +334,11 @@ def _inference_generator_tf2(
all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs)
all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys)
all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs)
return all_xs, all_ys, all_outputs
all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) if mc_dropout else None

outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout}

return all_xs, all_ys, outputs

# all_xs = nest.map_structure_up_to(batch_x, concat, xs)
# all_ys = nest.map_structure_up_to(batch_y, concat, ys)
Expand Down