Skip to content

Commit

Permalink
6409 support class indices y_pred DiceHelper (#6412)
Browse files Browse the repository at this point in the history
Fixes #6409


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Apr 21, 2023
1 parent 4404075 commit 5312a68
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
6 changes: 5 additions & 1 deletion monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
num_classes: int | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
) -> None:
Expand All @@ -38,6 +39,9 @@ def __init__(
reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
Expand All @@ -50,5 +54,5 @@ def __init__(
See also:
:py:meth:`monai.metrics.meandice.compute_dice`
"""
metric_fn = DiceMetric(include_background=include_background, reduction=reduction)
metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
44 changes: 34 additions & 10 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class DiceMetric(CumulativeIterationMetric):
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
"""

Expand All @@ -56,18 +59,21 @@ def __init__(
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
softmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -110,20 +116,26 @@ def aggregate(


def compute_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> torch.Tensor:
"""Computes Dice score metric for a batch of predictions.
Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
`y_pred` can be single-channel class indices or in the one-hot format.
y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
include_background: whether to skip Dice computation on the first channel of
the predicted output. Defaults to True.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -135,13 +147,14 @@ def compute_dice(
get_not_nans=False,
softmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
)(y_pred=y_pred, y=y)


class DiceHelper:
"""
Compute Dice score between two tensors `y_pred` and `y`.
`y_pred` must have N channels, `y` can be single-channel class indices or in the one-hot format.
`y_pred` and `y` can be single-channel class indices or in the one-hot format.
Example:
Expand Down Expand Up @@ -170,6 +183,7 @@ def __init__(
get_not_nans: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
ignore_empty: bool = True,
num_classes: int | None = None,
) -> None:
"""
Expand All @@ -186,6 +200,9 @@ def __init__(
reduction: define mode of reduction to the metrics
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
"""
self.sigmoid = sigmoid
self.reduction = reduction
Expand All @@ -194,6 +211,7 @@ def __init__(
self.softmax = not sigmoid if softmax is None else softmax
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""""""
Expand All @@ -211,17 +229,23 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
"""
Args:
y_pred: input predictions with shape (batch_size, num_classes, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]``.
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
"""
n_pred_ch = y_pred.shape[1]

if self.softmax:
_softmax, _sigmoid = self.softmax, self.sigmoid
if self.num_classes is None:
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
else:
n_pred_ch = self.num_classes
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
_softmax = _sigmoid = False

if _softmax:
if n_pred_ch > 1:
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)

elif self.sigmoid:
elif _sigmoid:
if self.activate:
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5
Expand Down
8 changes: 6 additions & 2 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,16 @@ def test_helper(self, input_data, _unused):
result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)

num_classes = vals["y_pred"].shape[1]
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)

# DiceMetric class tests
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10])
def test_value_class(self, input_data, expected_value):
# same test as for compute_dice
vals = {}
vals["y_pred"] = input_data.pop("y_pred")
vals = {"y_pred": input_data.pop("y_pred")}
vals["y"] = input_data.pop("y")
dice_metric = DiceMetric(**input_data)
dice_metric(**vals)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_handler_mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class TestHandlerMeanDice(unittest.TestCase):
def test_compute(self, input_params, expected_avg, details_shape):
dice_metric = MeanDice(**input_params)

# set up engine

def _val_func(engine, batch):
pass

Expand Down Expand Up @@ -71,6 +69,30 @@ def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):
y = torch.ones((3, 2))
dice_metric.update([y_pred, y])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_compute_n_class(self, input_params, expected_avg, details_shape):
dice_metric = MeanDice(num_classes=2, **input_params)

def _val_func(engine, batch):
pass

engine = Engine(_val_func)
dice_metric.attach(engine=engine, name="mean_dice")
# test input a list of channel-first tensor
y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])]
y = torch.Tensor([[[0], [1]], [[0], [1]]])
engine.state.output = {"pred": y_pred, "label": y}
engine.fire_event(Events.ITERATION_COMPLETED)

y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred
y = torch.Tensor([[[1]], [[0]]]) # class indices y
engine.state.output = {"pred": y_pred, "label": y}
engine.fire_event(Events.ITERATION_COMPLETED)

engine.fire_event(Events.EPOCH_COMPLETED)
assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5312a68

Please sign in to comment.