From bc6c867185639e46323e2e1527e4d69e9fb34490 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Wed, 13 May 2026 14:55:42 -0400 Subject: [PATCH 1/3] Adding per component support to Surface Distance metric Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/surface_distance.py | 124 ++++++++++++++++++++++--- tests/metrics/test_surface_distance.py | 54 +++++++++++ 2 files changed, 166 insertions(+), 12 deletions(-) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 3cb336d6a0..14fe131661 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -17,7 +17,13 @@ import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.metrics.utils import ( + compute_voronoi_regions_fast, + do_metric_reduction, + get_edge_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -34,6 +40,19 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + The ``per_component=True`` approach computes the Surface Distance on a per-connected component basis in the ground + truth segmentation. This ensures that each component contributes equally to the final metric, regardless of its size. + Traditional Surface Distance can be dominated by large structures, but the per-component method gives a more + balanced evaluation, particularly for small or fragmented objects. This provides a granular assessment of segmentation + quality, which is especially important in cases with multiple disconnected foreground components. + Note: + - The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background), + with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction + and ground truth is B2HW[D]. + - This method cannot be used with multiclass segmentation. + For more information, refer to the original paper: https://arxiv.org/abs/2410.18684 + Args: include_background: whether to include distance computation on the first channel of the predicted output. Defaults to ``False``. @@ -46,6 +65,7 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + per_component: whether to compute the Surface Distance on a per-connected component basis. Defaults to ``False``. """ @@ -56,6 +76,7 @@ def __init__( distance_metric: str = "euclidean", reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + per_component: bool = False, ) -> None: super().__init__() self.include_background = include_background @@ -63,6 +84,7 @@ def __init__( self.symmetric = symmetric self.reduction = reduction self.get_not_nans = get_not_nans + self.per_component = per_component def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -88,7 +110,17 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) """ if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") - + if self.per_component: + if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: + same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) + binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 + same_shape = y_pred.shape == y.shape + if not (same_rank and binary_channels and same_shape): + raise ValueError( + "per_component requires matching 4D/5D binary tensors " + "(B, 2, H, W) or (B, 2, D, H, W). " + f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -97,6 +129,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) symmetric=self.symmetric, distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), + per_component=self.per_component, ) def aggregate( @@ -127,6 +160,7 @@ def compute_average_surface_distance( symmetric: bool = False, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + per_component: bool = False, ) -> torch.Tensor: """ This function is used to compute the Average Surface Distance from `y_pred` to `y` @@ -154,6 +188,7 @@ def compute_average_surface_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + per_component: whether to compute the Surface Distance on a per-connected component basis. Defaults to ``False``. """ if not include_background: @@ -172,15 +207,80 @@ def compute_average_surface_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], - distance_metric=distance_metric, - spacing=spacing_list[b], - symmetric=symmetric, - class_index=c, - ) - surface_distance = torch.cat(distances) - asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() + if per_component: + pred_empty = y_pred[b, c].sum() == 0 + label_empty = y[b, c].sum() == 0 + if pred_empty and label_empty: + asd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan") + continue + cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy()) + if cc_assignment.device != y_pred[b, c].device: + cc_assignment = cc_assignment.to(y_pred[b, c].device) + component_scores = [] + for cc_id in torch.unique(cc_assignment.view(-1)): + cc_mask = cc_assignment == cc_id + coords = torch.nonzero(cc_mask, as_tuple=False) + min_corner_idx = coords.min(dim=0).values + max_corner_idx = coords.max(dim=0).values + + crop_pred = ( + y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else y_pred[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1 + ] + ) + + crop_label = ( + y[b, c][ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y.ndim == 5 + else y[b, c][min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + cc_crop_mask = ( + cc_mask[ + min_corner_idx[0] : max_corner_idx[0] + 1, + min_corner_idx[1] : max_corner_idx[1] + 1, + min_corner_idx[2] : max_corner_idx[2] + 1, + ] + if y_pred.ndim == 5 + else cc_mask[min_corner_idx[0] : max_corner_idx[0] + 1, min_corner_idx[1] : max_corner_idx[1] + 1] + ) + + pred_masked = crop_pred * cc_crop_mask + label_masked = crop_label * cc_crop_mask + + _, distances, _ = get_edge_surface_distance( + pred_masked, + label_masked, + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=symmetric, + class_index=c, + ) + surface_distance = torch.cat(distances) + component_scores.append( + torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() + ) + asd[b, c] = torch.nanmean(torch.stack(component_scores)) if component_scores else 0.0 + else: + _, distances, _ = get_edge_surface_distance( + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + symmetric=symmetric, + class_index=c, + ) + surface_distance = torch.cat(distances) + asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/tests/metrics/test_surface_distance.py b/tests/metrics/test_surface_distance.py index 85db389f80..115afd372d 100644 --- a/tests/metrics/test_surface_distance.py +++ b/tests/metrics/test_surface_distance.py @@ -141,6 +141,46 @@ def create_spherical_seg_3d( ] +TEST_CASES_CC_METRICS = [] +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat[0, 1, 5:10, 5:10, 5:10] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[float("inf")], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15, 10:15] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y[0, 1, 10:15, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 11:16, 10:15, 10:15] = 1 +y_hat[0, 1, 11:16, 19:24, 20:25] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[3.6829], [0.0]]]) + +y = torch.zeros((2, 2, 32, 32), device=_device) +y_hat = torch.zeros((2, 2, 32, 32), device=_device) +y[0, 1, 10:15, 10:15] = 1 +y[0, 1, 20:25, 20:25] = 1 +y[0, 0] = 1 - y[0, 1] +y_hat[0, 1, 10:15, 10:15] = 1 +y_hat[0, 1, 21:26, 19:24] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] +TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.4504], [0.0]]]) + + class TestAllSurfaceMetrics(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -181,6 +221,20 @@ def test_nans(self, input_data): np.testing.assert_allclose(0, result, rtol=1e-5) np.testing.assert_allclose(0, not_nans, rtol=1e-5) + @parameterized.expand(TEST_CASES_CC_METRICS) + def test_cc_metrics(self, input_data, expected_value): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + hd_metric = SurfaceDistanceMetric(per_component=True) + hd_metric(seg_1, seg_2) + result = hd_metric.aggregate(reduction="none") + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + def test_channel_dimensions(self): + with self.assertRaises(ValueError): + SurfaceDistanceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144])) + if __name__ == "__main__": unittest.main() From a34ca8aef155ff8c82513c969c9ffcbc01cd9340 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Wed, 13 May 2026 14:58:53 -0400 Subject: [PATCH 2/3] Modifying variable name Signed-off-by: Vijay Vignesh Prasad Rao --- tests/metrics/test_surface_distance.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/metrics/test_surface_distance.py b/tests/metrics/test_surface_distance.py index 115afd372d..a24818a0da 100644 --- a/tests/metrics/test_surface_distance.py +++ b/tests/metrics/test_surface_distance.py @@ -226,9 +226,9 @@ def test_cc_metrics(self, input_data, expected_value): [seg_1, seg_2] = input_data seg_1 = torch.tensor(seg_1) seg_2 = torch.tensor(seg_2) - hd_metric = SurfaceDistanceMetric(per_component=True) - hd_metric(seg_1, seg_2) - result = hd_metric.aggregate(reduction="none") + sd_metric = SurfaceDistanceMetric(per_component=True) + sd_metric(seg_1, seg_2) + result = sd_metric.aggregate(reduction="none") np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) def test_channel_dimensions(self): From a3cf84aed60fddf37ac3336f3cdb646aba9feb8a Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Wed, 13 May 2026 15:22:54 -0400 Subject: [PATCH 3/3] Resolving coderabbitai bugs Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/surface_distance.py | 21 +++++++------- tests/metrics/test_surface_distance.py | 40 +++++++++++++------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 14fe131661..4c505c588d 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -111,16 +111,15 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") if self.per_component: - if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2: - same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) - binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 - same_shape = y_pred.shape == y.shape - if not (same_rank and binary_channels and same_shape): - raise ValueError( - "per_component requires matching 4D/5D binary tensors " - "(B, 2, H, W) or (B, 2, D, H, W). " - f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." - ) + same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5) + binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2 + same_shape = y_pred.shape == y.shape + if not (same_rank and binary_channels and same_shape): + raise ValueError( + "per_component requires matching 4D/5D binary tensors " + "(B, 2, H, W) or (B, 2, D, H, W). " + f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -210,7 +209,7 @@ def compute_average_surface_distance( if per_component: pred_empty = y_pred[b, c].sum() == 0 label_empty = y[b, c].sum() == 0 - if pred_empty and label_empty: + if pred_empty or label_empty: asd[b, c] = 0.0 if (pred_empty and label_empty) else float("nan") continue cc_assignment = compute_voronoi_regions_fast(y[b, c].cpu().numpy()) diff --git a/tests/metrics/test_surface_distance.py b/tests/metrics/test_surface_distance.py index a24818a0da..91096f10e1 100644 --- a/tests/metrics/test_surface_distance.py +++ b/tests/metrics/test_surface_distance.py @@ -143,42 +143,42 @@ def create_spherical_seg_3d( TEST_CASES_CC_METRICS = [] y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) -TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) +y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device) +TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.0], [0.0]]]) y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat[0, 1, 5:10, 5:10, 5:10] = 1 -y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[float("inf")], [0.0]]]) +y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_pred[0, 1, 5:10, 5:10, 5:10] = 1 +y_pred[0, 0] = 1 - y_pred[0, 1] +TEST_CASES_CC_METRICS.append([[y_pred, y], [[float("nan")], [0.0]]]) y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device) y[0, 1, 10:15, 10:15, 10:15] = 1 y[0, 0] = 1 - y[0, 1] -y_hat[0, 1, 10:15, 10:15, 10:15] = 1 -y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.0], [0.0]]]) +y_pred[0, 1, 10:15, 10:15, 10:15] = 1 +y_pred[0, 0] = 1 - y_pred[0, 1] +TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.0], [0.0]]]) y = torch.zeros((2, 2, 32, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32, 32), device=_device) +y_pred = torch.zeros((2, 2, 32, 32, 32), device=_device) y[0, 1, 10:15, 10:15, 10:15] = 1 y[0, 1, 20:25, 20:25, 20:25] = 1 y[0, 0] = 1 - y[0, 1] -y_hat[0, 1, 11:16, 10:15, 10:15] = 1 -y_hat[0, 1, 11:16, 19:24, 20:25] = 1 -y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[3.6829], [0.0]]]) +y_pred[0, 1, 11:16, 10:15, 10:15] = 1 +y_pred[0, 1, 11:16, 19:24, 20:25] = 1 +y_pred[0, 0] = 1 - y_pred[0, 1] +TEST_CASES_CC_METRICS.append([[y_pred, y], [[3.7987], [0.0]]]) y = torch.zeros((2, 2, 32, 32), device=_device) -y_hat = torch.zeros((2, 2, 32, 32), device=_device) +y_pred = torch.zeros((2, 2, 32, 32), device=_device) y[0, 1, 10:15, 10:15] = 1 y[0, 1, 20:25, 20:25] = 1 y[0, 0] = 1 - y[0, 1] -y_hat[0, 1, 10:15, 10:15] = 1 -y_hat[0, 1, 21:26, 19:24] = 1 -y_hat[0, 0] = 1 - y_hat[0, 1] -TEST_CASES_CC_METRICS.append([[y, y_hat], [[0.4504], [0.0]]]) +y_pred[0, 1, 10:15, 10:15] = 1 +y_pred[0, 1, 21:26, 19:24] = 1 +y_pred[0, 0] = 1 - y_pred[0, 1] +TEST_CASES_CC_METRICS.append([[y_pred, y], [[0.4504], [0.0]]]) class TestAllSurfaceMetrics(unittest.TestCase):