diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 6061a56deb..ebd6d1aabb 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -27,6 +27,7 @@ def __init__( self, include_background: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN, + num_classes: int | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, ) -> None: @@ -38,6 +39,9 @@ def __init__( reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. @@ -50,5 +54,5 @@ def __init__( See also: :py:meth:`monai.metrics.meandice.compute_dice` """ - metric_fn = DiceMetric(include_background=include_background, reduction=reduction) + metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes) super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 97a26f6a4e..936ae396be 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -47,6 +47,9 @@ class DiceMetric(CumulativeIterationMetric): ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, NaN value will be set for empty ground truth cases. If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. """ @@ -56,18 +59,21 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, + num_classes: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = reduction self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty + self.num_classes = num_classes self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, get_not_nans=False, softmax=False, ignore_empty=self.ignore_empty, + num_classes=self.num_classes, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -110,20 +116,26 @@ def aggregate( def compute_dice( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + ignore_empty: bool = True, + num_classes: int | None = None, ) -> torch.Tensor: """Computes Dice score metric for a batch of predictions. Args: y_pred: input data to compute, typical segmentation model output. - It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values - should be binarized. + `y_pred` can be single-channel class indices or in the one-hot format. y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format. include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, NaN value will be set for empty ground truth cases. If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -135,13 +147,14 @@ def compute_dice( get_not_nans=False, softmax=False, ignore_empty=ignore_empty, + num_classes=num_classes, )(y_pred=y_pred, y=y) class DiceHelper: """ Compute Dice score between two tensors `y_pred` and `y`. - `y_pred` must have N channels, `y` can be single-channel class indices or in the one-hot format. + `y_pred` and `y` can be single-channel class indices or in the one-hot format. Example: @@ -170,6 +183,7 @@ def __init__( get_not_nans: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, ignore_empty: bool = True, + num_classes: int | None = None, ) -> None: """ @@ -186,6 +200,9 @@ def __init__( reduction: define mode of reduction to the metrics ignore_empty: if `True`, NaN value will be set for empty ground truth cases. If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. """ self.sigmoid = sigmoid self.reduction = reduction @@ -194,6 +211,7 @@ def __init__( self.softmax = not sigmoid if softmax is None else softmax self.activate = activate self.ignore_empty = ignore_empty + self.num_classes = num_classes def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """""" @@ -211,17 +229,23 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl """ Args: - y_pred: input predictions with shape (batch_size, num_classes, spatial_dims...). - the number of channels is inferred from ``y_pred.shape[1]``. + y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...). + the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). """ - n_pred_ch = y_pred.shape[1] - - if self.softmax: + _softmax, _sigmoid = self.softmax, self.sigmoid + if self.num_classes is None: + n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores + else: + n_pred_ch = self.num_classes + if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices + _softmax = _sigmoid = False + + if _softmax: if n_pred_ch > 1: y_pred = torch.argmax(y_pred, dim=1, keepdim=True) - elif self.sigmoid: + elif _sigmoid: if self.activate: y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 425ee59183..794e318bfc 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -207,12 +207,16 @@ def test_helper(self, input_data, _unused): result = DiceHelper(softmax=True, get_not_nans=False)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4) + num_classes = vals["y_pred"].shape[1] + vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True) + result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals) + np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4) + # DiceMetric class tests @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10]) def test_value_class(self, input_data, expected_value): # same test as for compute_dice - vals = {} - vals["y_pred"] = input_data.pop("y_pred") + vals = {"y_pred": input_data.pop("y_pred")} vals["y"] = input_data.pop("y") dice_metric = DiceMetric(**input_data) dice_metric(**vals) diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index 10cf981f02..6f91b6d3af 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -36,8 +36,6 @@ class TestHandlerMeanDice(unittest.TestCase): def test_compute(self, input_params, expected_avg, details_shape): dice_metric = MeanDice(**input_params) - # set up engine - def _val_func(engine, batch): pass @@ -71,6 +69,30 @@ def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): y = torch.ones((3, 2)) dice_metric.update([y_pred, y]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_compute_n_class(self, input_params, expected_avg, details_shape): + dice_metric = MeanDice(num_classes=2, **input_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + dice_metric.attach(engine=engine, name="mean_dice") + # test input a list of channel-first tensor + y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] + y = torch.Tensor([[[0], [1]], [[0], [1]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred + y = torch.Tensor([[[1]], [[0]]]) # class indices y + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) + if __name__ == "__main__": unittest.main()