diff --git a/Baseline/process.py b/Baseline/process.py index f2211d1..bafbc1c 100644 --- a/Baseline/process.py +++ b/Baseline/process.py @@ -62,17 +62,15 @@ def __init__(self): if not output_path.exists(): output_path.mkdir() - #self._input_path = Path("/input/images/brain-mri/") self._segmentation_output_path = Path("/output/images/white-matter-multiple-sclerosis-lesion-segmentation/") self._uncertainty_output_path = Path("/output/images/white-matter-multiple-sclerosis-lesion-uncertainty-map/") - #self._segmentation_output_path = Path("/output/segmentation/") - #self._uncertainty_output_path = Path("/output/uncertainty/") - self.device = get_default_device() K = 3 models = [] + + # TODO: change to your model for i in range(K): models.append(UNet( spatial_dims=3, @@ -82,6 +80,8 @@ def __init__(self): strides=(2, 2, 2, 2), num_res_units=0).to(self.device) ) + self.th = 0.35 + # -------------------------------------- for i, model in enumerate(models): model.load_state_dict(torch.load('./model'+str(i+1)+'.pth', map_location=self.device)) @@ -89,12 +89,12 @@ def __init__(self): self.models = models self.act = torch.nn.Softmax(dim=1) - self.th = 0.35 self.roi_size = (96, 96, 96) self.sw_batch_size = 4 def process_case(self, *, idx, case): + """ Please do not change """ # Load and test the image for this case input_image, input_image_file_path = self._load_input_image(case=case) @@ -129,22 +129,20 @@ def process_case(self, *, idx, case): def predict(self, *, input_image: SimpleITK.Image) -> SimpleITK.Image: - + """ Inference of a single file """ image = SimpleITK.GetArrayFromImage(input_image) image = np.transpose(np.array(image)) - - - # The image must be normalized as that is what we did with monai for training of the model - # only normalize non-zero values (i.e. not the background) + + # TODO: change to preprocessing specific to your model non_zeros = image != 0 mu = np.mean(image[non_zeros]) sigma = np.std(image[non_zeros]) image[non_zeros] = (image[non_zeros] - mu) / sigma + # ---------------------------------------------- + # run inference for each model in ensemble with torch.no_grad(): - image = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(image).to(self.device), axis=0), axis=0) - all_outputs = [] for model in self.models: outputs = sliding_window_inference(image, self.roi_size, self.sw_batch_size, model, mode='gaussian') @@ -153,22 +151,29 @@ def predict(self, *, input_image: SimpleITK.Image) -> SimpleITK.Image: all_outputs.append(outputs) all_outputs = np.asarray(all_outputs) + # apply probability threshold to your model to generate binary segmentation mask seg = np.mean(all_outputs, axis=0) seg[seg>self.th]=1 seg[seg<=self.th]=0 seg = np.squeeze(seg) + + # TODO: apply post-processing to the models outputs + # removes all connected components with less than 10 voxels seg = remove_connected_components(seg) + # ------------------------------------------------ - uncs = ensemble_uncertainties_classification( np.concatenate( (np.expand_dims(all_outputs, axis=-1), np.expand_dims(1.-all_outputs, axis=-1)), axis=-1) ) + # TODO: change to your proposed uncertainty measure + uncs = ensemble_uncertainties_classification( + np.concatenate((np.expand_dims(all_outputs, axis=-1), np.expand_dims(1.-all_outputs, axis=-1)), axis=-1) + ) unc_rmi = uncs["reverse_mutual_information"] + # ------------------------------------------------- + # convert 3D numpy.ndarrays to the format required by evaluation system out_seg = SimpleITK.GetImageFromArray(seg) out_unc = SimpleITK.GetImageFromArray(unc_rmi) return out_seg, out_unc - - - if __name__ == "__main__": Baseline().process()