Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions Baseline/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,15 @@ def __init__(self):
if not output_path.exists():
output_path.mkdir()

#self._input_path = Path("/input/images/brain-mri/")
self._segmentation_output_path = Path("/output/images/white-matter-multiple-sclerosis-lesion-segmentation/")
self._uncertainty_output_path = Path("/output/images/white-matter-multiple-sclerosis-lesion-uncertainty-map/")

#self._segmentation_output_path = Path("/output/segmentation/")
#self._uncertainty_output_path = Path("/output/uncertainty/")

self.device = get_default_device()

K = 3
models = []

# TODO: change to your model
for i in range(K):
models.append(UNet(
spatial_dims=3,
Expand All @@ -82,19 +80,21 @@ def __init__(self):
strides=(2, 2, 2, 2),
num_res_units=0).to(self.device)
)
self.th = 0.35
# --------------------------------------

for i, model in enumerate(models):
model.load_state_dict(torch.load('./model'+str(i+1)+'.pth', map_location=self.device))
model.eval()

self.models = models
self.act = torch.nn.Softmax(dim=1)
self.th = 0.35
self.roi_size = (96, 96, 96)
self.sw_batch_size = 4


def process_case(self, *, idx, case):
""" Please do not change """
# Load and test the image for this case
input_image, input_image_file_path = self._load_input_image(case=case)

Expand Down Expand Up @@ -129,22 +129,20 @@ def process_case(self, *, idx, case):


def predict(self, *, input_image: SimpleITK.Image) -> SimpleITK.Image:

""" Inference of a single file """
image = SimpleITK.GetArrayFromImage(input_image)
image = np.transpose(np.array(image))


# The image must be normalized as that is what we did with monai for training of the model
# only normalize non-zero values (i.e. not the background)

# TODO: change to preprocessing specific to your model
non_zeros = image != 0
mu = np.mean(image[non_zeros])
sigma = np.std(image[non_zeros])
image[non_zeros] = (image[non_zeros] - mu) / sigma
# ----------------------------------------------

# run inference for each model in ensemble
with torch.no_grad():

image = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(image).to(self.device), axis=0), axis=0)

all_outputs = []
for model in self.models:
outputs = sliding_window_inference(image, self.roi_size, self.sw_batch_size, model, mode='gaussian')
Expand All @@ -153,22 +151,29 @@ def predict(self, *, input_image: SimpleITK.Image) -> SimpleITK.Image:
all_outputs.append(outputs)
all_outputs = np.asarray(all_outputs)

# apply probability threshold to your model to generate binary segmentation mask
seg = np.mean(all_outputs, axis=0)
seg[seg>self.th]=1
seg[seg<=self.th]=0
seg = np.squeeze(seg)

# TODO: apply post-processing to the models outputs
# removes all connected components with less than 10 voxels
seg = remove_connected_components(seg)
# ------------------------------------------------

uncs = ensemble_uncertainties_classification( np.concatenate( (np.expand_dims(all_outputs, axis=-1), np.expand_dims(1.-all_outputs, axis=-1)), axis=-1) )
# TODO: change to your proposed uncertainty measure
uncs = ensemble_uncertainties_classification(
np.concatenate((np.expand_dims(all_outputs, axis=-1), np.expand_dims(1.-all_outputs, axis=-1)), axis=-1)
)
unc_rmi = uncs["reverse_mutual_information"]
# -------------------------------------------------

# convert 3D numpy.ndarrays to the format required by evaluation system
out_seg = SimpleITK.GetImageFromArray(seg)
out_unc = SimpleITK.GetImageFromArray(unc_rmi)
return out_seg, out_unc





if __name__ == "__main__":
Baseline().process()