Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend SurfaceDiceMetric for 3D images #6549

Merged
merged 6 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@

class SurfaceDiceMetric(CumulativeIterationMetric):
"""
Computes the Normalized Surface Distance (NSD) for each batch sample and class of
Computes the Normalized Surface Dice (NSD) for each batch sample and class of
predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation
https://github.com/deepmind/surface-distance.
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.

The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`.

Expand Down Expand Up @@ -79,9 +83,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
r"""
Args:
y_pred: Predicted segmentation, typically segmentation model output.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
y: Reference segmentation.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
if ``distance_metric`` is set to ``"euclidean"``.
Expand Down Expand Up @@ -168,17 +172,17 @@ def compute_surface_dice(
will be returned for this class. In the case of a class being present in only one of predicted segmentation or
reference segmentation, the class NSD will be 0.

This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images.
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).

Args:
y_pred: Predicted segmentation, typically segmentation model output.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
y: Reference segmentation.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
class_thresholds: List of class-specific thresholds.
The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.
Each threshold needs to be a finite, non-negative number.
Expand Down Expand Up @@ -215,8 +219,8 @@ def compute_surface_dice(
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
raise ValueError("y_pred and y must be PyTorch Tensor.")

if y_pred.ndimension() != 4 or y.ndimension() != 4:
raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].")
if y_pred.ndimension() not in (4, 5) or y.ndimension() not in (4, 5):
raise ValueError("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].")

if y_pred.shape != y.shape:
raise ValueError(
Expand Down
51 changes: 49 additions & 2 deletions tests/test_surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,53 @@ def test_tolerance_euclidean_distance(self):
np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))

def test_tolerance_euclidean_distance_3d(self):
batch_size = 2
n_class = 2
predictions = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)
labels = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)
predictions[0, :, :, 20:] = 1
labels[0, :, :, 30:] = 1 # offset by 10
predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 4, 1, 2, 3)
labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 4, 1, 2, 3)

sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)
res0 = sd0(predictions_hot, labels_hot)
agg0 = sd0.aggregate() # aggregation: nanmean across image then nanmean across batch
sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True)
res0_nans = sd0_nans(predictions_hot, labels_hot)
agg0_nans, not_nans = sd0_nans.aggregate()

np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())
np.testing.assert_equal(res0.device, predictions.device)
np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())
np.testing.assert_equal(agg0.device, predictions.device)

res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot)
res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot)
res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot)

for res in [res0, res1, res10, res11]:
assert res.shape == torch.Size([2, 2])

assert res0[0, 0] < res1[0, 0] < res10[0, 0]
assert res0[0, 1] < res1[0, 1] < res10[0, 1]
np.testing.assert_array_equal(res10.cpu(), res11.cpu())

expected_res0 = np.zeros((batch_size, n_class))
expected_res0[0, 1] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (
200 * 110 * 4 + (58 + 48) * 200 * 2 + (58 + 48) * 108 * 2
)
expected_res0[0, 0] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (
200 * 110 * 4 + (28 + 18) * 200 * 2 + (28 + 18) * 108 * 2
)
expected_res0[1, 0] = 1
expected_res0[1, 1] = np.nan
for b, c in np.ndindex(batch_size, n_class):
np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())
np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))

def test_tolerance_all_distances(self):
batch_size = 1
n_class = 2
Expand Down Expand Up @@ -262,10 +309,10 @@ def test_asserts(self):
# wrong dimensions
with self.assertRaises(ValueError) as context:
SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions, labels_hot)
self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception))
self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception))
with self.assertRaises(ValueError) as context:
SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels)
self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception))
self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception))

# mismatch of shape of input tensors
input_bad_shape = torch.clone(predictions_hot)
Expand Down