From 84c78cd1bbf6eb4eae5348b2eb4a5062b6b008b7 Mon Sep 17 00:00:00 2001 From: Anthony Gatti Date: Sun, 24 Apr 2022 18:42:59 -0700 Subject: [PATCH 1/2] fixing path error in GETTING_STARTED.MD --- GETTING_STARTED.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 5125d88a..d69b5b40 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -76,7 +76,7 @@ machineB.save() ### Training & Evaluation in Command Line -We provide a script in "medsegpy/train_net.py", that is made to train +We provide a script in "tools/train_net.py", that is made to train all the configs provided in medsegpy. You may want to use it as a reference to write your own training script for new research. From 73f66f598cc7a6f7d76dcb532aa302f57e3000ac Mon Sep 17 00:00:00 2001 From: Anthony Gatti Date: Wed, 13 Jul 2022 22:21:40 -0700 Subject: [PATCH 2/2] Add softmax, multiple losses, & single class loss - `config.py` updated to take a list called `LOSS_METRICS`. This is a list of extra loss metrics to run during train/val steps. Each item in the list is structured as: [[loss_type, activation], weights]. The first item is a list that mimics the current `cfg.LOSS` input into `build_loss` and the second item is the weights. - `losses.py` now explicitly includes "softmax" version of "avg_dice_no_reduce" `("avg_dice_no_reduce", "softmax")` which calls `DiceLoss`. Though, this might be redundant, becuase it does the exact same thing as calling `("avg_dice_no_reduce", "sigmoid")`. From this standpoint, it might be useful to break `LOSS` into the actual loss part (`'avg_dice_no_reduce'`) and the activation (`'softmax'`/`'sigmoid'`). The current implementation was a bit confusing to me - I thought I had to pass the `softmax` as the activation to `DiceLoss` which caused issues. Breaking it up into these parts would be clearer. - `losses.py` got a new function that creates a one-hot-encoded set of `weights` if an integer is inputted instead of a list of weights. This is useful it the goal is to just find the loss of a single tissue/class during training. - `build_loss` was updated to enable building the additional loss functions mentioned in `config.py` above. - `trainer.py` was updated so that it builds a list of the loss metrics mentioned in `config.py` and enabled by the updates to `build_loss`. - `reduce_tensor` in `utils.py` was updated so that if the reduce was `'none'` and the weights were one-hot encoded, it scales the weights value so that it makes sense. Otherwise, when keras does its built in reduce it averages over all of the zero dims making the loss seem artificially low (by a factor of the number of categories/classes in the loss). --- medsegpy/config.py | 4 ++++ medsegpy/engine/trainer.py | 9 ++++++- medsegpy/loss/utils.py | 4 ++++ medsegpy/losses.py | 48 ++++++++++++++++++++++++++++++++++---- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..3809d42d 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -70,6 +70,10 @@ class Config(object): # Class name for robust loss computation ROBUST_LOSS_NAME = "" ROBUST_LOSS_STEP_SIZE = 1e-1 + # Additonal loss functions to run during training + # [[(id_1, output_mode_1), class_weights_1], + # [(id_2, output_mode_2), class_weights_2] ... ] + LOSS_METRICS = [] # PIDS to include, None = all pids PIDS = None diff --git a/medsegpy/engine/trainer.py b/medsegpy/engine/trainer.py index e7948ece..690121d9 100644 --- a/medsegpy/engine/trainer.py +++ b/medsegpy/engine/trainer.py @@ -176,7 +176,14 @@ def _train_model(self): # TODO: Add more options for metrics. optimizer = solver.build_optimizer(cfg) loss_func = self.build_loss() - metrics = [lr_callback(optimizer), dice_loss] + + loss_metrics = [] + if len(cfg.LOSS_METRICS) > 0: + for loss_idx, loss_metric in enumerate(cfg.LOSS_METRICS): + new_metric = build_loss(cfg, build_additional_metric=True, additional_metric=loss_metric) + new_metric.name = f'{loss_metric[0][0]}_{loss_idx}' + loss_metrics.append(new_metric) + metrics = [lr_callback(optimizer), dice_loss] + loss_metrics callbacks = self.build_callbacks() if isinstance(loss_func, kc.Callback): diff --git a/medsegpy/loss/utils.py b/medsegpy/loss/utils.py index c2f5d69d..5406d3d3 100644 --- a/medsegpy/loss/utils.py +++ b/medsegpy/loss/utils.py @@ -65,6 +65,10 @@ def reduce_tensor(x, reduction="mean", axis=None, weights=None): use_weights = weights is not None if use_weights: x *= weights + if (reduction in ("none", None)) and (len(tf.where(weights==0)) == (len(weights) - 1)): + # if one of the weights = 1 and rest = 0, then only want loss of that single value + # need to scale by factor len(weights) because final reduction is a mean + return x * len(weights) if reduction == "mean" and use_weights: ndim = K.ndim(x) diff --git a/medsegpy/losses.py b/medsegpy/losses.py index 7ee28f02..ff0d43a0 100755 --- a/medsegpy/losses.py +++ b/medsegpy/losses.py @@ -17,6 +17,7 @@ AVG_DICE_LOSS = ("avg_dice", "sigmoid") AVG_DICE_LOSS_SOFTMAX = ("avg_dice", "softmax") AVG_DICE_NO_REDUCE = ("avg_dice_no_reduce", "sigmoid") +AVG_DICE_NO_REDUCE_SOFTMAX = ("avg_dice_no_reduce", "softmax") WEIGHTED_CROSS_ENTROPY_LOSS = ("weighted_cross_entropy", "softmax") WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS = ("weighted_cross_entropy_sigmoid", "sigmoid") @@ -36,6 +37,7 @@ "AVG_DICE_LOSS", "AVG_DICE_LOSS_SOFTMAX", "AVG_DICE_NO_REDUCE", + "AVG_DICE_NO_REDUCE_SOFTMAX", "WEIGHTED_CROSS_ENTROPY_LOSS", "WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS", "BINARY_CROSS_ENTROPY_LOSS", @@ -46,11 +48,30 @@ ] -def build_loss(cfg): - loss = cfg.LOSS +def build_loss(cfg, build_additional_metric=False, additional_metric: list = None): + if build_additional_metric is False: + loss = cfg.LOSS + robust_loss_cls = cfg.ROBUST_LOSS_NAME + robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE + class_weights = cfg.CLASS_WEIGHTS + elif build_additional_metric is True: + loss = additional_metric[0] + # yaml giving trouble importing list of tuples - need to conver manually? + if type(loss) == list: + loss = tuple(loss) + class_weights = additional_metric[1] + # not supporting robust loss for additional metrics (for now). + robust_loss_cls = False + robust_step_size = None + num_classes = len(cfg.CATEGORIES) - robust_loss_cls = cfg.ROBUST_LOSS_NAME - robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE + + # allow config to specify weights as integer indicating we only want + # to test one of the classes. + if type(class_weights) in (list, tuple): + pass + elif type(class_weights) is int: + class_weights = get_class_weights_from_int(class_weights, num_classes) if robust_loss_cls: reduction = "class" @@ -64,7 +85,7 @@ def build_loss(cfg): pass loss = get_training_loss( loss, - weights=cfg.CLASS_WEIGHTS, + weights=class_weights, # Remove computation on the background class. remove_background=cfg.INCLUDE_BACKGROUND, reduce=reduction, @@ -79,6 +100,12 @@ def build_loss(cfg): else: raise ValueError(f"{robust_loss_cls} not supported") +def get_class_weights_from_int(label, num_classes): + """Returns class_weights for an integer label.""" + class_weights = [0] * num_classes + class_weights[label] = 1 + return class_weights + # TODO (arjundd): Add ability to exclude specific indices from loss function. def get_training_loss_from_str(loss_str: str): @@ -91,6 +118,8 @@ def get_training_loss_from_str(loss_str: str): return AVG_DICE_LOSS elif loss_str == "AVG_DICE_NO_REDUCE": return AVG_DICE_NO_REDUCE + elif loss_str == "AVG_DICE_NO_REDUCE_SOFTMAX": + return AVG_DICE_NO_REDUCE_SOFTMAX elif loss_str == "WEIGHTED_CROSS_ENTROPY_LOSS": return WEIGHTED_CROSS_ENTROPY_LOSS elif loss_str == "WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS": @@ -134,6 +163,15 @@ def get_training_loss(loss, **kwargs): kwargs.pop("reduce", None) kwargs["reduction"] = "none" return DiceLoss(**kwargs) + elif loss == AVG_DICE_NO_REDUCE_SOFTMAX: + # Below is actually the same as the above, we could/should amalgamate? + kwargs.pop("reduce", None) + kwargs["reduction"] = "none" + # we don't need to add the softmax activation here - + # it should already be added here: + # (https://github.com/ad12/MedSegPy/blob/0c316baaf040c22d562940a198a0e48eef2d36a8/medsegpy/modeling/meta_arch/unet.py#L152) + # kwargs["activation"] = "softmax" + return DiceLoss(**kwargs) else: raise ValueError("Loss type not supported")