Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4103 enhances surface Dice to use subvoxel borders #6681

Merged
merged 13 commits into from
Jul 5, 2023
79 changes: 46 additions & 33 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SurfaceDiceMetric(CumulativeIterationMetric):
Computes the Normalized Surface Dice (NSD) for each batch sample and class of
predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is
Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's
mplementation https://github.com/deepmind/surface-distance.
In this implementation, the length/area of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.
Expand Down Expand Up @@ -86,7 +87,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
y: Reference segmentation.
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric.
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
Expand All @@ -96,6 +97,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.


Returns:
Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch
Expand All @@ -108,6 +111,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
include_background=self.include_background,
distance_metric=self.distance_metric,
spacing=kwargs.get("spacing"),
use_subvoxels=kwargs.get("use_subvoxels", False),
)

def aggregate(
Expand Down Expand Up @@ -141,13 +145,14 @@ def compute_surface_dice(
include_background: bool = False,
distance_metric: str = "euclidean",
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
use_subvoxels: bool = False,
) -> torch.Tensor:
r"""
This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
:math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation
boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the
reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in
pixels. The NSD is bounded between 0 and 1.
reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation
in pixels. The NSD is bounded between 0 and 1.

This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`.
The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function:
Expand All @@ -159,24 +164,23 @@ def compute_surface_dice(
:label: nsd

with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor
distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation
boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference
segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
:math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the
acceptable distance :math:`\tau_c`:

.. math::
\mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}.


In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value
will be returned for this class. In the case of a class being present in only one of predicted segmentation or
reference segmentation, the class NSD will be 0.
In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation,
a nan value will be returned for this class. In the case of a class being present in only one of predicted
segmentation or reference segmentation, the class NSD will be 0.

This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
The computation of boundaries follows DeepMind's implementation
https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation
boundary is interpreted as the number of its edge pixels.

Args:
y_pred: Predicted segmentation, typically segmentation model output.
Expand All @@ -198,6 +202,7 @@ def compute_surface_dice(
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.

Raises:
ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
Expand Down Expand Up @@ -227,11 +232,6 @@ def compute_surface_dice(
f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
)

if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y):
raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).")
if torch.any(y_pred > 1) or torch.any(y > 1):
raise ValueError("y_pred and y should be one-hot encoded.")

y = y.float()
y_pred = y_pred.float()

Expand All @@ -254,24 +254,37 @@ def compute_surface_dice(
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False)
if not use_subvoxels:
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True)
distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
distances_gt_pred <= class_thresholds[c]
)
else:
_spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim
areas_pred: np.ndarray
areas_gt: np.ndarray
edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore
y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore
)
dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b])
dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b])
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
boundary_complete = areas_gt.sum() + areas_pred.sum()
gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_correct = gt_true + pred_true
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
distances_gt_pred <= class_thresholds[c]
)

if boundary_complete == 0:
# the class is neither present in the prediction, nor in the reference segmentation
nsd[b, c] = np.nan
Expand Down
Loading