From 52791a38edb145fdfa0e195b8e0037b7622b567b Mon Sep 17 00:00:00 2001 From: Isabel Gallegos Date: Wed, 2 Feb 2022 16:50:23 -0800 Subject: [PATCH 1/7] Add configurable dropout rate --- medsegpy/config.py | 5 +++++ medsegpy/modeling/meta_arch/unet.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..39b9e7fc 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -106,6 +106,9 @@ class Config(object): EARLY_STOPPING_PATIENCE = 0 EARLY_STOPPING_CRITERION = "val_loss" + # Dropout rate + DROPOUT_RATE = 0.0 + # Batch sizes TRAIN_BATCH_SIZE = 12 VALID_BATCH_SIZE = 35 @@ -589,6 +592,8 @@ def summary(self, additional_vars=None): "EARLY_STOPPING_PATIENCE" if self.USE_EARLY_STOPPING else "", "EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "", "", + "DROPOUT_RATE", + "", "KERNEL_INITIALIZER", "SEED" if self.SEED else "", "" "INIT_WEIGHTS", diff --git a/medsegpy/modeling/meta_arch/unet.py b/medsegpy/modeling/meta_arch/unet.py index 07c70149..047e3a91 100644 --- a/medsegpy/modeling/meta_arch/unet.py +++ b/medsegpy/modeling/meta_arch/unet.py @@ -145,6 +145,7 @@ def build_model(self, input_tensor=None) -> Model: seed = cfg.SEED depth = cfg.DEPTH kernel_size = self.kernel_size + dropout_rate = cfg.DROPOUT_RATE self.use_attention = cfg.USE_ATTENTION self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION @@ -178,7 +179,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) # Maxpool until penultimate depth. @@ -220,7 +221,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) if self.use_deep_supervision: From 848070b5d3770c6118c0d404ce1139f9a4bc40b6 Mon Sep 17 00:00:00 2001 From: Isabel Gallegos Date: Wed, 2 Feb 2022 21:11:11 -0800 Subject: [PATCH 2/7] Add MC dropout to model inference --- medsegpy/modeling/model.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index d514653d..17e16fcf 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -255,6 +255,7 @@ def _inference_generator_tf2( ): """Inference generator for TensorFlow 2.""" outputs = [] + outputs_mc_dropout = [] xs = [] ys = [] with model.distribute_strategy.scope(): @@ -295,14 +296,21 @@ def _inference_generator_tf2( batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator)) # tmp_batch_outputs = predict_function(iterator) tmp_batch_outputs = model.predict(batch_x) + + + tmp_batch_outputs_mc_dropout = None + if True: #mc_dropout: + tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(100)]) # TODO: 100 -> T + if data_handler.should_sync: context.async_wait() # noqa: F821 batch_outputs = tmp_batch_outputs # No error, now safe to assign. + batch_outputs_mc_dropout = tmp_batch_outputs_mc_dropout if batch_x_raw is not None: batch_x = batch_x_raw for batch, running in zip( - [batch_x, batch_y, batch_outputs], [xs, ys, outputs] + [batch_x, batch_y, batch_outputs, batch_outputs_mc_dropout], [xs, ys, outputs, outputs_mc_dropout] ): nest.map_structure_up_to( batch, lambda x, batch_x: x.append(batch_x), running, batch @@ -318,7 +326,11 @@ def _inference_generator_tf2( all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs) all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys) all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs) - return all_xs, all_ys, all_outputs + all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) + + outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout} + + return all_xs, all_ys, outputs # all_xs = nest.map_structure_up_to(batch_x, concat, xs) # all_ys = nest.map_structure_up_to(batch_y, concat, ys) From 70bec9a707c58fa03a89da5a37fc90d08eb414e1 Mon Sep 17 00:00:00 2001 From: Isabel Gallegos Date: Wed, 2 Feb 2022 22:05:34 -0800 Subject: [PATCH 3/7] Save mc dropout output --- medsegpy/data/data_loader.py | 9 ++++++++- medsegpy/evaluation/sem_seg_evaluation.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 3208192a..7883e2f1 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -353,6 +353,13 @@ def inference(self, model: Model, **kwargs): ) time_elapsed = time.perf_counter() - start + preds_mc_dropout = None + if isinstance(preds, dict): + preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']) + preds_mc_dropout = preds_mc_dropout.transpose((1, 2, 3, 0)) + + preds = preds['preds'] + x, y, preds = self._restructure_data((x, y, preds)) input = {"x": x, "scan_id": scan_id} @@ -363,7 +370,7 @@ def inference(self, model: Model, **kwargs): } input.update(scan_params) - output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed} + output = {"y_pred": preds, "y_mc_dropout":preds_mc_dropout, "y_true": y, "time_elapsed": time_elapsed} yield input, output diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 6b8c192a..5327021e 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -134,6 +134,7 @@ def process(self, inputs, outputs): if includes_bg: y_true = output["y_true"][..., 1:] y_pred = output["y_pred"][..., 1:] + y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_pred"][..., 1:] labels = labels[..., 1:] # if y_true.ndim == 3: # y_true = y_true[..., np.newaxis] @@ -141,6 +142,7 @@ def process(self, inputs, outputs): # labels = labels[..., np.newaxis] output["y_true"] = y_true output["y_pred"] = y_pred + output["y_mc_dropout"] = y_mc_dropout time_elapsed = output["time_elapsed"] if self.stream_evaluation: @@ -178,6 +180,7 @@ def eval_single_scan(self, input, output, labels, time_elapsed): with h5py.File(save_name, "w") as h5f: h5f.create_dataset("probs", data=output["y_pred"]) h5f.create_dataset("labels", data=labels) + h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) def evaluate(self): """Evaluates popular medical segmentation metrics specified in config. From 4d45970682cc6cf49a38d14ee3e63fa37b36a890 Mon Sep 17 00:00:00 2001 From: Isabel Gallegos Date: Wed, 2 Feb 2022 22:58:28 -0800 Subject: [PATCH 4/7] Make mc dropout configurable --- medsegpy/config.py | 4 ++++ medsegpy/data/data_loader.py | 8 ++++++-- medsegpy/evaluation/sem_seg_evaluation.py | 3 ++- medsegpy/modeling/model.py | 16 +++++++++++----- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/medsegpy/config.py b/medsegpy/config.py index 39b9e7fc..f7d378d2 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -108,6 +108,8 @@ class Config(object): # Dropout rate DROPOUT_RATE = 0.0 + MC_DROPOUT = False + MC_DROPOUT_T = 100 # Batch sizes TRAIN_BATCH_SIZE = 12 @@ -593,6 +595,8 @@ def summary(self, additional_vars=None): "EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "", "", "DROPOUT_RATE", + "MC_DROPOUT", + "MC_DROPOUT_T" if self.MC_DROPOUT else "" "", "KERNEL_INITIALIZER", "SEED" if self.SEED else "", diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 7883e2f1..3370c0f2 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -337,6 +337,10 @@ def inference(self, model: Model, **kwargs): workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) + + kwargs["mc_dropout"] = self._cfg.MC_DROPOUT + kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T + for scan_id in scan_ids: self._dataset_dicts = scan_to_dict_mapping[scan_id] @@ -355,8 +359,8 @@ def inference(self, model: Model, **kwargs): preds_mc_dropout = None if isinstance(preds, dict): - preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']) - preds_mc_dropout = preds_mc_dropout.transpose((1, 2, 3, 0)) + if preds['preds_mc_dropout'] is not None: + preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']).transpose((1, 2, 3, 0)) preds = preds['preds'] diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 5327021e..493143ad 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -180,7 +180,8 @@ def eval_single_scan(self, input, output, labels, time_elapsed): with h5py.File(save_name, "w") as h5f: h5f.create_dataset("probs", data=output["y_pred"]) h5f.create_dataset("labels", data=labels) - h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) + if output["y_mc_dropout"] is not None: + h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) def evaluate(self): """Evaluates popular medical segmentation metrics specified in config. diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index 17e16fcf..944d12f0 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -42,7 +42,9 @@ def inference_generator( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): return self.inference_generator_static( self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose @@ -57,7 +59,9 @@ def inference_generator_static( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): """Generates predictions for the input samples from a data generator and returns inputs, ground truth, and predictions. @@ -252,6 +256,8 @@ def _inference_generator_tf2( max_queue_size=10, workers=1, use_multiprocessing=False, + mc_dropout=False, + mc_dropout_T=100 ): """Inference generator for TensorFlow 2.""" outputs = [] @@ -299,8 +305,8 @@ def _inference_generator_tf2( tmp_batch_outputs_mc_dropout = None - if True: #mc_dropout: - tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(100)]) # TODO: 100 -> T + if mc_dropout: + tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(mc_dropout_T)]) if data_handler.should_sync: context.async_wait() # noqa: F821 @@ -326,7 +332,7 @@ def _inference_generator_tf2( all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs) all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys) all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs) - all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) + all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) if mc_dropout else None outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout} From acffc062d1c16d6b5ac25cfe60065b1270c86d8b Mon Sep 17 00:00:00 2001 From: i-gallegos Date: Thu, 3 Feb 2022 16:45:32 -0800 Subject: [PATCH 5/7] Fix mc dropout bug --- medsegpy/data/data_loader.py | 2 +- medsegpy/modeling/model.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 3370c0f2..c60dc47d 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -337,7 +337,7 @@ def inference(self, model: Model, **kwargs): workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) - + kwargs["mc_dropout"] = self._cfg.MC_DROPOUT kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index 944d12f0..de0cdd2c 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -119,6 +119,8 @@ def inference_generator_static( max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, + mc_dropout=mc_dropout, + mc_dropout_T=mc_dropout_T, verbose=verbose, ) else: From b6878377af79b0c0939d187605d3525bb852bc03 Mon Sep 17 00:00:00 2001 From: i-gallegos Date: Fri, 11 Feb 2022 01:25:04 -0800 Subject: [PATCH 6/7] Save y_true --- medsegpy/evaluation/sem_seg_evaluation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 493143ad..4994ab3f 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -180,6 +180,7 @@ def eval_single_scan(self, input, output, labels, time_elapsed): with h5py.File(save_name, "w") as h5f: h5f.create_dataset("probs", data=output["y_pred"]) h5f.create_dataset("labels", data=labels) + h5f.create_dataset("true", data=output["y_true"]) if output["y_mc_dropout"] is not None: h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) From 2937be59a0e67bb0c922011d9d3fd78472d86dcd Mon Sep 17 00:00:00 2001 From: i-gallegos Date: Tue, 17 May 2022 13:25:21 -0700 Subject: [PATCH 7/7] Bug fix --- medsegpy/evaluation/sem_seg_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 4994ab3f..641e17c7 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -134,7 +134,7 @@ def process(self, inputs, outputs): if includes_bg: y_true = output["y_true"][..., 1:] y_pred = output["y_pred"][..., 1:] - y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_pred"][..., 1:] + y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_mc_dropout"][..., 1:] labels = labels[..., 1:] # if y_true.ndim == 3: # y_true = y_true[..., np.newaxis]