diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..f7d378d2 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -106,6 +106,11 @@ class Config(object): EARLY_STOPPING_PATIENCE = 0 EARLY_STOPPING_CRITERION = "val_loss" + # Dropout rate + DROPOUT_RATE = 0.0 + MC_DROPOUT = False + MC_DROPOUT_T = 100 + # Batch sizes TRAIN_BATCH_SIZE = 12 VALID_BATCH_SIZE = 35 @@ -589,6 +594,10 @@ def summary(self, additional_vars=None): "EARLY_STOPPING_PATIENCE" if self.USE_EARLY_STOPPING else "", "EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "", "", + "DROPOUT_RATE", + "MC_DROPOUT", + "MC_DROPOUT_T" if self.MC_DROPOUT else "" + "", "KERNEL_INITIALIZER", "SEED" if self.SEED else "", "" "INIT_WEIGHTS", diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 3208192a..c60dc47d 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -337,6 +337,10 @@ def inference(self, model: Model, **kwargs): workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) + + kwargs["mc_dropout"] = self._cfg.MC_DROPOUT + kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T + for scan_id in scan_ids: self._dataset_dicts = scan_to_dict_mapping[scan_id] @@ -353,6 +357,13 @@ def inference(self, model: Model, **kwargs): ) time_elapsed = time.perf_counter() - start + preds_mc_dropout = None + if isinstance(preds, dict): + if preds['preds_mc_dropout'] is not None: + preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']).transpose((1, 2, 3, 0)) + + preds = preds['preds'] + x, y, preds = self._restructure_data((x, y, preds)) input = {"x": x, "scan_id": scan_id} @@ -363,7 +374,7 @@ def inference(self, model: Model, **kwargs): } input.update(scan_params) - output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed} + output = {"y_pred": preds, "y_mc_dropout":preds_mc_dropout, "y_true": y, "time_elapsed": time_elapsed} yield input, output diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 6b8c192a..641e17c7 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -134,6 +134,7 @@ def process(self, inputs, outputs): if includes_bg: y_true = output["y_true"][..., 1:] y_pred = output["y_pred"][..., 1:] + y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_mc_dropout"][..., 1:] labels = labels[..., 1:] # if y_true.ndim == 3: # y_true = y_true[..., np.newaxis] @@ -141,6 +142,7 @@ def process(self, inputs, outputs): # labels = labels[..., np.newaxis] output["y_true"] = y_true output["y_pred"] = y_pred + output["y_mc_dropout"] = y_mc_dropout time_elapsed = output["time_elapsed"] if self.stream_evaluation: @@ -178,6 +180,9 @@ def eval_single_scan(self, input, output, labels, time_elapsed): with h5py.File(save_name, "w") as h5f: h5f.create_dataset("probs", data=output["y_pred"]) h5f.create_dataset("labels", data=labels) + h5f.create_dataset("true", data=output["y_true"]) + if output["y_mc_dropout"] is not None: + h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) def evaluate(self): """Evaluates popular medical segmentation metrics specified in config. diff --git a/medsegpy/modeling/meta_arch/unet.py b/medsegpy/modeling/meta_arch/unet.py index 07c70149..047e3a91 100644 --- a/medsegpy/modeling/meta_arch/unet.py +++ b/medsegpy/modeling/meta_arch/unet.py @@ -145,6 +145,7 @@ def build_model(self, input_tensor=None) -> Model: seed = cfg.SEED depth = cfg.DEPTH kernel_size = self.kernel_size + dropout_rate = cfg.DROPOUT_RATE self.use_attention = cfg.USE_ATTENTION self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION @@ -178,7 +179,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) # Maxpool until penultimate depth. @@ -220,7 +221,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) if self.use_deep_supervision: diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index d514653d..de0cdd2c 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -42,7 +42,9 @@ def inference_generator( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): return self.inference_generator_static( self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose @@ -57,7 +59,9 @@ def inference_generator_static( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): """Generates predictions for the input samples from a data generator and returns inputs, ground truth, and predictions. @@ -115,6 +119,8 @@ def inference_generator_static( max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, + mc_dropout=mc_dropout, + mc_dropout_T=mc_dropout_T, verbose=verbose, ) else: @@ -252,9 +258,12 @@ def _inference_generator_tf2( max_queue_size=10, workers=1, use_multiprocessing=False, + mc_dropout=False, + mc_dropout_T=100 ): """Inference generator for TensorFlow 2.""" outputs = [] + outputs_mc_dropout = [] xs = [] ys = [] with model.distribute_strategy.scope(): @@ -295,14 +304,21 @@ def _inference_generator_tf2( batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator)) # tmp_batch_outputs = predict_function(iterator) tmp_batch_outputs = model.predict(batch_x) + + + tmp_batch_outputs_mc_dropout = None + if mc_dropout: + tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(mc_dropout_T)]) + if data_handler.should_sync: context.async_wait() # noqa: F821 batch_outputs = tmp_batch_outputs # No error, now safe to assign. + batch_outputs_mc_dropout = tmp_batch_outputs_mc_dropout if batch_x_raw is not None: batch_x = batch_x_raw for batch, running in zip( - [batch_x, batch_y, batch_outputs], [xs, ys, outputs] + [batch_x, batch_y, batch_outputs, batch_outputs_mc_dropout], [xs, ys, outputs, outputs_mc_dropout] ): nest.map_structure_up_to( batch, lambda x, batch_x: x.append(batch_x), running, batch @@ -318,7 +334,11 @@ def _inference_generator_tf2( all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs) all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys) all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs) - return all_xs, all_ys, all_outputs + all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) if mc_dropout else None + + outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout} + + return all_xs, all_ys, outputs # all_xs = nest.map_structure_up_to(batch_x, concat, xs) # all_ys = nest.map_structure_up_to(batch_y, concat, ys)