Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 111 additions & 12 deletions monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
from monai.metrics.utils import (
compute_voronoi_regions_fast,
do_metric_reduction,
get_edge_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand All @@ -34,6 +40,19 @@ class SurfaceDistanceMetric(CumulativeIterationMetric):

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.


The ``per_component=True`` approach computes the Surface Distance on a per-connected component basis in the ground
truth segmentation. This ensures that each component contributes equally to the final metric, regardless of its size.
Traditional Surface Distance can be dominated by large structures, but the per-component method gives a more
balanced evaluation, particularly for small or fragmented objects. This provides a granular assessment of segmentation
quality, which is especially important in cases with multiple disconnected foreground components.
Note:
- The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background),
with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction
and ground truth is B2HW[D].
- This method cannot be used with multiclass segmentation.
For more information, refer to the original paper: https://arxiv.org/abs/2410.18684

Args:
include_background: whether to include distance computation on the first channel of
the predicted output. Defaults to ``False``.
Expand All @@ -46,6 +65,7 @@ class SurfaceDistanceMetric(CumulativeIterationMetric):
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
per_component: whether to compute the Surface Distance on a per-connected component basis. Defaults to ``False``.

"""

Expand All @@ -56,13 +76,15 @@ def __init__(
distance_metric: str = "euclidean",
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
self.distance_metric = distance_metric
self.symmetric = symmetric
self.reduction = reduction
self.get_not_nans = get_not_nans
self.per_component = per_component

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
"""
Expand All @@ -88,7 +110,16 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
"""
if y_pred.dim() < 3:
raise ValueError("y_pred should have at least three dimensions.")

if self.per_component:
same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5)
binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2
same_shape = y_pred.shape == y.shape
if not (same_rank and binary_channels and same_shape):
raise ValueError(
"per_component requires matching 4D/5D binary tensors "
"(B, 2, H, W) or (B, 2, D, H, W). "
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
)
# compute (BxC) for each channel for each batch
return compute_average_surface_distance(
y_pred=y_pred,
Expand All @@ -97,6 +128,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
symmetric=self.symmetric,
distance_metric=self.distance_metric,
spacing=kwargs.get("spacing"),
per_component=self.per_component,
)

def aggregate(
Expand Down Expand Up @@ -127,6 +159,7 @@ def compute_average_surface_distance(
symmetric: bool = False,
distance_metric: str = "euclidean",
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
This function is used to compute the Average Surface Distance from `y_pred` to `y`
Expand Down Expand Up @@ -154,6 +187,7 @@ def compute_average_surface_distance(
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
per_component: whether to compute the Surface Distance on a per-connected component basis. Defaults to ``False``.
"""

if not include_background:
Expand All @@ -172,15 +206,80 @@ def compute_average_surface_distance(
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=symmetric,
class_index=c,
)
surface_distance = torch.cat(distances)
asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean()
if per_component:
pred_empty = y_pred[b, c].sum() == 0
label_empty = y[b, c].sum() == 0
if pred_empty or label_empty:
asd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan")
continue
Comment thread
coderabbitai[bot] marked this conversation as resolved.
cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy())
if cc_assignment.device != y_pred[b, c].device:
cc_assignment = cc_assignment.to(y_pred[b, c].device)
component_scores = []
for cc_id in torch.unique(cc_assignment.view(-1)):
cc_mask = cc_assignment == cc_id
coords = torch.nonzero(cc_mask, as_tuple=False)
min_corner_idx = coords.min(dim=0).values
max_corner_idx = coords.max(dim=0).values

crop_pred = (
y_pred[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y_pred.ndim == 5
else y_pred[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1
]
)

crop_label = (
y[b, c][
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y.ndim == 5
else y[b, c][min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1]
)

cc_crop_mask = (
cc_mask[
min_corner_idx[0] : max_corner_idx[0] + 1,
min_corner_idx[1] : max_corner_idx[1] + 1,
min_corner_idx[2] : max_corner_idx[2] + 1,
]
if y_pred.ndim == 5
else cc_mask[min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1]
)

pred_masked = crop_pred * cc_crop_mask
label_masked = crop_label * cc_crop_mask

_, distances, _ = get_edge_surface_distance(
pred_masked,
label_masked,
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=symmetric,
class_index=c,
)
surface_distance = torch.cat(distances)
component_scores.append(
torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean()
)
asd[b, c] = torch.nanmean(torch.stack(component_scores)) if component_scores else 0.0
else:
_, distances, _ = get_edge_surface_distance(
y_pred[b, c],
y[b, c],
distance_metric=distance_metric,
spacing=spacing_list[b],
symmetric=symmetric,
class_index=c,
)
surface_distance = torch.cat(distances)
asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean()

return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
54 changes: 54 additions & 0 deletions tests/metrics/test_surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,46 @@ def create_spherical_seg_3d(
]


TEST_CASES_CC_METRICS = []
y = torch.zeros((2, 2, 32, 32, 32), device=_device)
y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device)
TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.0], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_device)
y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device)
y_pred[0, 1, 5:10, 5:10, 5:10] = 1
y_pred[0, 0] = 1 - y_pred[0, 1]
TEST_CASES_CC_METRICS.append([[y_pred, y], [[float("nan")], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_device)
y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device)
y[0, 1, 10:15, 10:15, 10:15] = 1
y[0, 0] = 1 - y[0, 1]
y_pred[0, 1, 10:15, 10:15, 10:15] = 1
y_pred[0, 0] = 1 - y_pred[0, 1]
TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.0], [0.0]]])

y = torch.zeros((2, 2, 32, 32, 32), device=_device)
y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device)
y[0, 1, 10:15, 10:15, 10:15] = 1
y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 0] = 1 - y[0, 1]
y_pred[0, 1, 11:16, 10:15, 10:15] = 1
y_pred[0, 1, 11:16, 19:24, 20:25] = 1
y_pred[0, 0] = 1 - y_pred[0, 1]
TEST_CASES_CC_METRICS.append([[y_pred, y], [[3.7987], [0.0]]])

y = torch.zeros((2, 2, 32, 32), device=_device)
y_pred = torch.zeros((2, 2, 32, 32), device=_device)
y[0, 1, 10:15, 10:15] = 1
y[0, 1, 20:25, 20:25] = 1
y[0, 0] = 1 - y[0, 1]
y_pred[0, 1, 10:15, 10:15] = 1
y_pred[0, 1, 21:26, 19:24] = 1
y_pred[0, 0] = 1 - y_pred[0, 1]
TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.4504], [0.0]]])


class TestAllSurfaceMetrics(unittest.TestCase):

@parameterized.expand(TEST_CASES)
Expand Down Expand Up @@ -181,6 +221,20 @@ def test_nans(self, input_data):
np.testing.assert_allclose(0, result, rtol=1e-5)
np.testing.assert_allclose(0, not_nans, rtol=1e-5)

@parameterized.expand(TEST_CASES_CC_METRICS)
def test_cc_metrics(self, input_data, expected_value):
[seg_1, seg_2] = input_data
seg_1 = torch.tensor(seg_1)
seg_2 = torch.tensor(seg_2)
sd_metric = SurfaceDistanceMetric(per_component=True)
sd_metric(seg_1, seg_2)
result = sd_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

def test_channel_dimensions(self):
with self.assertRaises(ValueError):
SurfaceDistanceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144]))


if __name__ == "__main__":
unittest.main()
Loading