From f5fd1e1d6f7072d91be85f5d49676907f08ec2b2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 7 Sep 2024 12:37:04 +0200 Subject: [PATCH 01/21] a beginning of something new --- CHANGELOG.md | 6 ++++ .../functional/segmentation/__init__.py | 3 +- .../functional/segmentation/dice.py | 24 +++++++++++++ src/torchmetrics/segmentation/__init__.py | 3 +- src/torchmetrics/segmentation/dice.py | 34 +++++++++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 src/torchmetrics/functional/segmentation/dice.py create mode 100644 src/torchmetrics/segmentation/dice.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b0b0022476..0dd149f3149 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) +- Added Dice metric to segmentation metrics + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) @@ -31,6 +33,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - update `InfoLM` class to dynamically set `higher_is_better` ([#2674](https://github.com/Lightning-AI/torchmetrics/pull/2674)) +### Deprecated + +- Deprecated Dice from classification metrics + ### Removed diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index 3d23192a36a..af693d4478a 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.functional.segmentation.mean_iou import mean_iou -__all__ = ["generalized_dice_score", "mean_iou"] +__all__ = ["generalized_dice_score", "mean_iou", "dice_score"] diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py new file mode 100644 index 00000000000..d2a1f113b49 --- /dev/null +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide + + +def dice_score(): + pass diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index 5b609c2c738..2cd46580c9e 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.segmentation.dice import DiceScore from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["GeneralizedDiceScore", "MeanIoU"] +__all__ = ["GeneralizedDiceScore", "MeanIoU", "DiceScore"] diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py new file mode 100644 index 00000000000..9cafcd50db6 --- /dev/null +++ b/src/torchmetrics/segmentation/dice.py @@ -0,0 +1,34 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.segmentation.generalized_dice import ( + _generalized_dice_compute, + _generalized_dice_update, + _generalized_dice_validate_args, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["DiceScore.plot"] + + +class DiceScore(Metric): + pass From 0874184ffdb3c5b059ba5cc3fe584f2bf63d2ec8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 7 Sep 2024 12:37:59 +0200 Subject: [PATCH 02/21] docs --- docs/source/segmentation/dice.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 docs/source/segmentation/dice.rst diff --git a/docs/source/segmentation/dice.rst b/docs/source/segmentation/dice.rst new file mode 100644 index 00000000000..c4db784dae2 --- /dev/null +++ b/docs/source/segmentation/dice.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Segmentation + +.. include:: ../links.rst + +########## +Dice Score +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.segmentation.DiceScore + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.dice_score + :noindex: From e0ea066d5044a8fc564385e41e4e769e8b48fab3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 8 Sep 2024 14:41:08 +0200 Subject: [PATCH 03/21] deprecate old implementation --- src/torchmetrics/classification/dice.py | 15 +++++++++++++++ .../functional/classification/dice.py | 15 +++++++++++++++ tests/unittests/test_deprecated.py | 14 ++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 39bc2acabcd..eb6b228778a 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -20,6 +20,7 @@ from torchmetrics.functional.classification.dice import _dice_compute from torchmetrics.functional.classification.stat_scores import _stat_scores_update from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -114,6 +115,12 @@ class Dice(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + .. warning:: + The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be + removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it + provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation + domain in v1.6.0 with slight modifications to functionality. + Raises: ValueError: If ``average`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"``, ``None``. @@ -155,6 +162,14 @@ def __init__( multiclass: Optional[bool] = None, **kwargs: Any, ) -> None: + rank_zero_warn( + "The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and" + " will removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage" + " as it provides the same functionality. Additionally, we are going to re-add the `dice` metric in the" + " segmentation domain in v1.6.0 with slight modifications to functionality.", + DeprecationWarning, + ) + super().__init__(**kwargs) allowed_average = ("micro", "macro", "samples", "none", None) if average not in allowed_average: diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 49d66ea9361..3aa26212fa5 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -17,6 +17,7 @@ from torch import Tensor from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod @@ -150,6 +151,12 @@ def dice( Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. + .. warning:: + The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be + removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it + provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation + domain in v1.6.0 with slight modifications to functionality. + Return: The shape of the returned tensor depends on the ``average`` parameter @@ -174,6 +181,14 @@ def dice( tensor(0.2500) """ + rank_zero_warn( + "The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will" + " removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it" + " provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation" + " domain in v1.6.0 with slight modifications to functionality.", + DeprecationWarning, + ) + allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/tests/unittests/test_deprecated.py b/tests/unittests/test_deprecated.py index f126fa06561..d4c8a0d34e4 100644 --- a/tests/unittests/test_deprecated.py +++ b/tests/unittests/test_deprecated.py @@ -1,5 +1,7 @@ import pytest import torch +from torchmetrics.classification import Dice +from torchmetrics.functional.classification import dice from torchmetrics.functional.regression import kl_divergence from torchmetrics.regression import KLDivergence @@ -14,3 +16,15 @@ def test_deprecated_kl_divergence_input_order(): with pytest.deprecated_call(match="The input order and naming in metric `KLDivergence` is set to be deprecated.*"): KLDivergence() + + +def test_deprecated_dice_from_classification(): + """Ensure that the deprecated `dice` metric from classification raises a warning.""" + preds = torch.randn(10, 2) + target = torch.randint(0, 2, (10,)) + + with pytest.deprecated_call(match="The `dice` metrics is being deprecated from the classification subpackage.*"): + dice(preds, target) + + with pytest.deprecated_call(match="The `dice` metrics is being deprecated from the classification subpackage.*"): + Dice() From 7851630d8f8b0d5acbf90c70ba72a43a6285845d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 8 Sep 2024 14:41:32 +0200 Subject: [PATCH 04/21] fix small mistakes in generalized dice --- src/torchmetrics/functional/segmentation/generalized_dice.py | 4 ++-- tests/unittests/segmentation/test_generalized_dice_score.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 47c5f30964b..e0de9f1821e 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -28,7 +28,7 @@ def _generalized_dice_validate_args( input_format: Literal["one-hot", "index"], ) -> None: """Validate the arguments of the metric.""" - if num_classes <= 0: + if not isinstance(num_classes, int) or num_classes <= 0: raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") if not isinstance(include_background, bool): raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") @@ -116,7 +116,7 @@ def generalized_dice_score( target: Ground truth values num_classes: Number of classes include_background: Whether to include the background class in the computation - per_class: Whether to compute the IoU for each class separately, else average over all classes + per_class: Whether to compute the score for each class separately, else average over all classes weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"`` input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors or ``"index"`` for index tensors diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 3f8acec842a..31e00f0e26e 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -67,7 +67,7 @@ def _reference_generalized_dice( ) @pytest.mark.parametrize("include_background", [True, False]) class TestGeneralizedDiceScore(MetricTester): - """Test class for `MeanIoU` metric.""" + """Test class for `GeneralizedDiceScore` metric.""" @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): From 35ca256ae5aa5c3230fba7d796c1d619f68c6ed7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 9 Sep 2024 08:49:58 +0200 Subject: [PATCH 05/21] changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dd149f3149..b9d3c791dd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605)) -- Added Dice metric to segmentation metrics +- Added Dice metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) ### Changed @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated -- Deprecated Dice from classification metrics +- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) ### Removed From 14a77aa725f064a5925f874b1950a7fee3c48358 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 9 Sep 2024 08:50:24 +0200 Subject: [PATCH 06/21] initial new implementation --- .../functional/segmentation/dice.py | 88 ++++++++++++++++++- src/torchmetrics/segmentation/dice.py | 62 +++++++++++-- tests/unittests/segmentation/test_dice.py | 68 ++++++++++++++ 3 files changed, 211 insertions(+), 7 deletions(-) create mode 100644 tests/unittests/segmentation/test_dice.py diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index d2a1f113b49..e72e665c328 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -13,6 +13,7 @@ # limitations under the License. import torch from torch import Tensor +from typing import Optional from typing_extensions import Literal from torchmetrics.functional.segmentation.utils import _ignore_background @@ -20,5 +21,88 @@ from torchmetrics.utilities.compute import _safe_divide -def dice_score(): - pass +def _dice_score_validate_args( + num_classes: int, + include_background: bool, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + input_format: Literal["one-hot", "index"] = "one-hot" +) -> None: + """Validate the arguments of the metric.""" + if not isinstance(num_classes, int) or num_classes <= 0: + raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.") + if not isinstance(include_background, bool): + raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.") + allowed_average = ["micro", "macro", "weighted", "none"] + if average is not None and average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.") + if input_format not in ["one-hot", "index"]: + raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") + +def _dice_score_update( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool, + input_format: Literal["one-hot", "index"] = "one-hot", +) -> Tensor: + _check_same_shape(preds, target) + if preds.ndim < 3: + raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") + + if input_format == "index": + preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1) + target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) + + if not include_background: + preds, target = _ignore_background(preds, target, num_classes) + + reduce_axis = list(range(2, preds.ndim)) + intersection = torch.sum(preds * target, dim=reduce_axis) + target_sum = torch.sum(target, dim=reduce_axis) + pred_sum = torch.sum(preds, dim=reduce_axis) + + numerator = 2 * intersection + denominator = pred_sum + target_sum + return numerator, denominator + + +def _dice_score_compute( + numerator: Tensor, + denominator: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", +) -> Tensor: + if average == "micro": + numerator = torch.sum(numerator, dim=0) + denominator = torch.sum(denominator, dim=0) + dice = _safe_divide(numerator, denominator, zero_division=1.0) + if average == "macro": + dice = torch.mean(dice) + elif average == "weighted": + weights = _safe_divide(denominator, torch.sum(denominator), zero_division=1.0) + dice = torch.sum(dice * weights) + return dice + + +def dice_score( + preds: Tensor, + target: Tensor, + num_classes: int, + include_background: bool = True, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + input_format: Literal["one-hot", "index"] = "one-hot", +) -> Tensor: + """Compute the Dice score for semantic segmentation. + + preds: Predictions from model + target: Ground truth values + num_classes: Number of classes + include_background: Whether to include the background class in the computation + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + + Returns: + The Dice score. + """ + _dice_score_validate_args(num_classes, include_background, average, input_format) + numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format) + return _dice_score_compute(numerator, denominator, average) \ No newline at end of file diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 9cafcd50db6..b333bee516f 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -17,10 +17,10 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.functional.segmentation.generalized_dice import ( - _generalized_dice_compute, - _generalized_dice_update, - _generalized_dice_validate_args, +from torchmetrics.functional.segmentation.dice import ( + _dice_score_validate_args, + _dice_score_update, + _dice_score_compute, ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -31,4 +31,56 @@ class DiceScore(Metric): - pass + r"""Compute `Dice Score`_. + + The metric can be used to evaluate the performance of image segmentation models. The Dice Score is defined as: + + ..math:: + DS = \frac{2 \sum_{i=1}^{N} t_i p_i}{\sum_{i=1}^{N} t_i + \sum_{i=1}^{N} p_i} + + where :math:`N` is the number of classes, :math:`t_i` is the target tensor, and :math:`p_i` is the prediction + tensor. In general the Dice Score can be interpreted as the overlap between the prediction and target tensors + divided by the total number of elements in the tensors. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + """ + + score: Tensor + samples: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__(self, + num_classes: int, + include_background: bool = True, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + input_format: Literal["one-hot", "index"] = "one-hot", + **kwargs: Any + ) -> None: + super().__init__(**kwargs) + _dice_score_validate_args(num_classes, include_background, average, input_format) + self.num_classes = num_classes + self.include_background = include_background + self.average = average + self.input_format = input_format + + num_classes = num_classes - 1 if include_background else num_classes + self.add_state("score", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: \ No newline at end of file diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py new file mode 100644 index 00000000000..bdc6c8ae25f --- /dev/null +++ b/tests/unittests/segmentation/test_dice.py @@ -0,0 +1,68 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +import torch + +from monai.metrics.meandice import compute_dice +from torchmetrics.functional.segmentation.dice import dice_score +from torchmetrics.segmentation.dice import DiceScore +from sklearn.metrics import jaccard_score + +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + +_inputs1 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), +) +_inputs2 = _Input( + preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), + target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), +) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) + +def sklearn_dice(*args, **kwargs): + js = jaccard_score(*args, **kwargs) + return 2 * js / (1 + js) + + +def _reference_dice_score( + preds: torch.Tensor, + target: torch.Tensor, + input_format: str, + include_background: bool = True, + reduce: bool = True, +): + pass + + +@pytest.mark.parametrize( + "preds, target, input_format", + [ + (_inputs1.preds, _inputs1.target, "one-hot"), + (_inputs2.preds, _inputs2.target, "one-hot"), + (_inputs3.preds, _inputs3.target, "index"), + ], +) +@pytest.mark.parametrize("include_background", [True, False]) +class TestGeneralizedDiceScore(MetricTester): + """Test class for `DiceScore` metric.""" \ No newline at end of file From 591665c648b6e5ef575b3761dd6c37bfcd51e97a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 06:50:45 +0000 Subject: [PATCH 07/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/dice.py | 13 ++++++++----- src/torchmetrics/segmentation/dice.py | 12 ++++++------ tests/unittests/segmentation/test_dice.py | 6 +++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index e72e665c328..e4c48ca3482 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from torch import Tensor -from typing import Optional from typing_extensions import Literal from torchmetrics.functional.segmentation.utils import _ignore_background @@ -25,7 +26,7 @@ def _dice_score_validate_args( num_classes: int, include_background: bool, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - input_format: Literal["one-hot", "index"] = "one-hot" + input_format: Literal["one-hot", "index"] = "one-hot", ) -> None: """Validate the arguments of the metric.""" if not isinstance(num_classes, int) or num_classes <= 0: @@ -38,6 +39,7 @@ def _dice_score_validate_args( if input_format not in ["one-hot", "index"]: raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") + def _dice_score_update( preds: Tensor, target: Tensor, @@ -55,7 +57,7 @@ def _dice_score_update( if not include_background: preds, target = _ignore_background(preds, target, num_classes) - + reduce_axis = list(range(2, preds.ndim)) intersection = torch.sum(preds * target, dim=reduce_axis) target_sum = torch.sum(target, dim=reduce_axis) @@ -92,7 +94,7 @@ def dice_score( input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: """Compute the Dice score for semantic segmentation. - + preds: Predictions from model target: Ground truth values num_classes: Number of classes @@ -102,7 +104,8 @@ def dice_score( Returns: The Dice score. + """ _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format) - return _dice_score_compute(numerator, denominator, average) \ No newline at end of file + return _dice_score_compute(numerator, denominator, average) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index b333bee516f..d63f85d2f02 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -32,13 +32,13 @@ class DiceScore(Metric): r"""Compute `Dice Score`_. - + The metric can be used to evaluate the performance of image segmentation models. The Dice Score is defined as: ..math:: DS = \frac{2 \sum_{i=1}^{N} t_i p_i}{\sum_{i=1}^{N} t_i + \sum_{i=1}^{N} p_i} - where :math:`N` is the number of classes, :math:`t_i` is the target tensor, and :math:`p_i` is the prediction + where :math:`N` is the number of classes, :math:`t_i` is the target tensor, and :math:`p_i` is the prediction tensor. In general the Dice Score can be interpreted as the overlap between the prediction and target tensors divided by the total number of elements in the tensors. @@ -54,7 +54,7 @@ class DiceScore(Metric): with the ``input_format`` argument. As output to ``forward`` and ``compute`` the metric returns the following output: - + """ score: Tensor @@ -65,7 +65,7 @@ class DiceScore(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - def __init__(self, + def __init__(self, num_classes: int, include_background: bool = True, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", @@ -82,5 +82,5 @@ def __init__(self, num_classes = num_classes - 1 if include_background else num_classes self.add_state("score", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: \ No newline at end of file + + def update(self, preds: Tensor, target: Tensor) -> None: diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index bdc6c8ae25f..768640c6013 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -15,11 +15,10 @@ import pytest import torch - from monai.metrics.meandice import compute_dice +from sklearn.metrics import jaccard_score from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore -from sklearn.metrics import jaccard_score from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input from unittests._helpers import seed_all @@ -40,6 +39,7 @@ target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), ) + def sklearn_dice(*args, **kwargs): js = jaccard_score(*args, **kwargs) return 2 * js / (1 + js) @@ -65,4 +65,4 @@ def _reference_dice_score( ) @pytest.mark.parametrize("include_background", [True, False]) class TestGeneralizedDiceScore(MetricTester): - """Test class for `DiceScore` metric.""" \ No newline at end of file + """Test class for `DiceScore` metric.""" From 9baf6e6c68ba26e173d4945dcc4a712ae5465738 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:15:20 +0200 Subject: [PATCH 08/21] Apply suggestions from code review --- src/torchmetrics/classification/dice.py | 6 +++--- src/torchmetrics/functional/classification/dice.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index eb6b228778a..cbc8a84987e 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -116,9 +116,9 @@ class Dice(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. .. warning:: - The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be - removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it - provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation + The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be + removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it + provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation domain in v1.6.0 with slight modifications to functionality. Raises: diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 3aa26212fa5..5c08a028572 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -152,9 +152,9 @@ def dice( than what they appear to be. .. warning:: - The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be - removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it - provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation + The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be + removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it + provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation domain in v1.6.0 with slight modifications to functionality. Return: From 9be617258fe764c83c2a1340943105fc216895e9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 14 Sep 2024 09:48:02 +0200 Subject: [PATCH 09/21] more code --- .../functional/segmentation/dice.py | 17 ++++++++++++++--- .../functional/segmentation/generalized_dice.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index e4c48ca3482..6b6a79dd44e 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -74,8 +74,8 @@ def _dice_score_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", ) -> Tensor: if average == "micro": - numerator = torch.sum(numerator, dim=0) - denominator = torch.sum(denominator, dim=0) + numerator = torch.sum(numerator, dim=1) + denominator = torch.sum(denominator, dim=1) dice = _safe_divide(numerator, denominator, zero_division=1.0) if average == "macro": dice = torch.mean(dice) @@ -95,7 +95,8 @@ def dice_score( ) -> Tensor: """Compute the Dice score for semantic segmentation. - preds: Predictions from model + Args: + preds: Predictions from model target: Ground truth values num_classes: Number of classes include_background: Whether to include the background class in the computation @@ -105,6 +106,16 @@ def dice_score( Returns: The Dice score. + Example (with one-hot encoded tensors): + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import dice_score + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> dice_score(preds, target, num_classes=5) + tensor([0.4872, 0.5000, 0.5019, 0.4891, 0.4926]) + + Example (with index tensors): + """ _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index e0de9f1821e..69a417bfdd8 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -124,7 +124,7 @@ def generalized_dice_score( Returns: The Generalized Dice Score - Example: + Example (with one-hot encoded tensors): >>> from torch import randint >>> from torchmetrics.functional.segmentation import generalized_dice_score >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction @@ -136,6 +136,19 @@ def generalized_dice_score( [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + + Example (with index tensors): + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import generalized_dice_score + >>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> generalized_dice_score(preds, target, num_classes=5) + tensor([0.4830, 0.4935, 0.5044, 0.4880]) + >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) From 88cc2c5bd6678e7d34e0f646999a2259a72abcf0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Sep 2024 07:48:29 +0000 Subject: [PATCH 10/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 69a417bfdd8..30bb9452289 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -136,7 +136,7 @@ def generalized_dice_score( [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) - + Example (with index tensors): >>> from torch import randint >>> from torchmetrics.functional.segmentation import generalized_dice_score From 32128c400b76d55a8868adbeb9a6efc3d1368b0f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 22 Oct 2024 11:24:31 +0200 Subject: [PATCH 11/21] fix --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5de93bc7be..010e7d855ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed corner case in `IoU` metric for single empty prediction tensors ([#2780](https://github.com/Lightning-AI/torchmetrics/pull/2780)) - Fixed `PSNR` calculation for integer type input images ([#2788](https://github.com/Lightning-AI/torchmetrics/pull/2788)) +--- + ## [1.4.3] - 2024-10-10 ### Fixed From ba0124d8657f0b338bbd8d9ca667b22fb0d3585d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 22 Oct 2024 11:42:05 +0200 Subject: [PATCH 12/21] update doctests --- .../functional/segmentation/dice.py | 28 +++++++++++++++++-- .../segmentation/generalized_dice.py | 14 +++++----- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 6b6a79dd44e..891ab7841b1 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -100,8 +100,10 @@ def dice_score( target: Ground truth values num_classes: Number of classes include_background: Whether to include the background class in the computation + average: The method to average the dice score. Options are ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` + or ``None``. This determines how to average the dice score across different classes. input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors - or ``"index"`` for index tensors + or ``"index"`` for index tensors Returns: The Dice score. @@ -111,10 +113,30 @@ def dice_score( >>> from torchmetrics.functional.segmentation import dice_score >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target - >>> dice_score(preds, target, num_classes=5) - tensor([0.4872, 0.5000, 0.5019, 0.4891, 0.4926]) + >>> # dice score micro averaged over all classes + >>> dice_score(preds, target, num_classes=5, average="micro") + tensor([0.4842, 0.4968, 0.5053, 0.4902]) + >>> # dice score per sample and class + >>> dice_score(preds, target, num_classes=5, average="none") + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) Example (with index tensors): + >>> from torch import randint + >>> from torchmetrics.functional.segmentation import dice_score + >>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> # dice score micro averaged over all classes + >>> dice_score(preds, target, num_classes=5, average="micro", input_format="index") + tensor([0.2031, 0.1914, 0.2500, 0.2266]) + >>> # dice score per sample and class + >>> dice_score(preds, target, num_classes=5, average="none", input_format="index") + tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069], + [0.1837, 0.2162, 0.0962, 0.2692, 0.1895], + [0.3866, 0.1348, 0.2526, 0.2301, 0.2083], + [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) """ _dice_score_validate_args(num_classes, include_background, average, input_format) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 30bb9452289..8bfc9bab18a 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -142,13 +142,13 @@ def generalized_dice_score( >>> from torchmetrics.functional.segmentation import generalized_dice_score >>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction >>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target - >>> generalized_dice_score(preds, target, num_classes=5) - tensor([0.4830, 0.4935, 0.5044, 0.4880]) - >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) - tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], - [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], - [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], - [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + >>> generalized_dice_score(preds, target, num_classes=5, input_format="index") + tensor([0.1991, 0.1971, 0.2350, 0.2216]) + >>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index") + tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069], + [0.1837, 0.2162, 0.0962, 0.2692, 0.1895], + [0.3866, 0.1348, 0.2526, 0.2301, 0.2083], + [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) From 8346d7d1acf2f2c0eb8c70f57d2d5a0f42db9225 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 22 Oct 2024 14:42:49 +0200 Subject: [PATCH 13/21] doctests --- src/torchmetrics/segmentation/dice.py | 37 +++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index d63f85d2f02..b55056dc3cc 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -55,6 +55,42 @@ class DiceScore(Metric): As output to ``forward`` and ``compute`` the metric returns the following output: + - ``gds`` (:class:`~torch.Tensor`): The dice score. If ``average`` is set to ``None`` or ``"none"`` the output + will be a tensor of shape ``(C,)`` with the dice score for each class. If ``average`` is set to + ``"micro"``, ``"macro"``, or ``"weighted"`` the output will be a scalar tensor. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation. + average: The method to average the dice score. Options are ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` + or ``None``. This determines how to average the dice score across different classes. + input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors + or ``"index"`` for index tensors + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``num_classes`` is not a positive integer + ValueError: + If ``include_background`` is not a boolean + ValueError: + If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None`` + ValueError: + If ``input_format`` is not one of ``"one-hot"`` or ``"index"`` + + Example: + >>> from torch import randint + >>> from torchmetrics.segmentation import DiceScore + >>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction + >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target + >>> dice_score = DiceScore(num_classes=5, average="micro") + >>> dice_score(preds, target) + tensor(0.4993) + >>> dice_score = DiceScore(num_classes=5, average="none") + >>> dice_score(preds, target) + tensor([0.4993, 0.5002, 0.5004, 0.4996, 0.5000]) + + """ score: Tensor @@ -84,3 +120,4 @@ def __init__(self, self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: + pass \ No newline at end of file From 051187d61ad47caf9fee7ed81b3cea199e4353d7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 22 Oct 2024 14:45:16 +0200 Subject: [PATCH 14/21] update implementations --- .../functional/segmentation/dice.py | 6 +- src/torchmetrics/segmentation/dice.py | 78 +++++++++++++++---- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 891ab7841b1..3d73c89640b 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -58,7 +58,7 @@ def _dice_score_update( if not include_background: preds, target = _ignore_background(preds, target, num_classes) - reduce_axis = list(range(2, preds.ndim)) + reduce_axis = list(range(2, target.ndim)) intersection = torch.sum(preds * target, dim=reduce_axis) target_sum = torch.sum(target, dim=reduce_axis) pred_sum = torch.sum(preds, dim=reduce_axis) @@ -74,8 +74,8 @@ def _dice_score_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", ) -> Tensor: if average == "micro": - numerator = torch.sum(numerator, dim=1) - denominator = torch.sum(denominator, dim=1) + numerator = torch.sum(numerator, dim=-1) + denominator = torch.sum(denominator, dim=-1) dice = _safe_divide(numerator, denominator, zero_division=1.0) if average == "macro": dice = torch.mean(dice) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index b55056dc3cc..7dc14e5c914 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.segmentation.dice import ( - _dice_score_validate_args, - _dice_score_update, _dice_score_compute, + _dice_score_update, + _dice_score_validate_args, ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -55,7 +55,7 @@ class DiceScore(Metric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``gds`` (:class:`~torch.Tensor`): The dice score. If ``average`` is set to ``None`` or ``"none"`` the output + - ``gds`` (:class:`~torch.Tensor`): The dice score. If ``average`` is set to ``None`` or ``"none"`` the output will be a tensor of shape ``(C,)`` with the dice score for each class. If ``average`` is set to ``"micro"``, ``"macro"``, or ``"weighted"`` the output will be a scalar tensor. @@ -90,23 +90,24 @@ class DiceScore(Metric): >>> dice_score(preds, target) tensor([0.4993, 0.5002, 0.5004, 0.4996, 0.5000]) - """ - score: Tensor - samples: Tensor full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - def __init__(self, + numerator: List[Tensor] + denominator: List[Tensor] + + def __init__( + self, num_classes: int, include_background: bool = True, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", input_format: Literal["one-hot", "index"] = "one-hot", - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__(**kwargs) _dice_score_validate_args(num_classes, include_background, average, input_format) @@ -115,9 +116,60 @@ def __init__(self, self.average = average self.input_format = input_format - num_classes = num_classes - 1 if include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") + num_classes = num_classes - 1 if not include_background else num_classes + self.add_state("numerator", torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("denominator", torch.zeros(num_classes), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: - pass \ No newline at end of file + """Update the state with new data.""" + numerator, denominator = _dice_score_update( + preds, target, self.num_classes, self.include_background, self.input_format + ) + self.numerator += numerator.sum(dim=0) + self.denominator += denominator.sum(dim=0) + + def compute(self) -> Tensor: + """Computes the Dice Score.""" + return _dice_score_compute(self.numerator, self.denominator, self.average) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.segmentation import DiceScore + >>> metric = DiceScore(num_classes=3) + >>> metric.update(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.segmentation import DiceScore + >>> metric = DiceScore(num_classes=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append( + ... metric(torch.randint(0, 2, (10, 3, 128, 128)), torch.randint(0, 2, (10, 3, 128, 128))) + ... ) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) From 42b18241afad125c55299d6eb88c56f4c41515f1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 22 Oct 2024 15:11:26 +0200 Subject: [PATCH 15/21] somewhat working tests --- tests/unittests/segmentation/test_dice.py | 68 +++++++++++++++++-- .../test_generalized_dice_score.py | 2 +- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index 768640c6013..63236d31c63 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -15,8 +15,8 @@ import pytest import torch +from sklearn.metrics import f1_score from monai.metrics.meandice import compute_dice -from sklearn.metrics import jaccard_score from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore @@ -40,19 +40,30 @@ ) -def sklearn_dice(*args, **kwargs): - js = jaccard_score(*args, **kwargs) - return 2 * js / (1 + js) - - def _reference_dice_score( preds: torch.Tensor, target: torch.Tensor, input_format: str, include_background: bool = True, + average: str = "micro", reduce: bool = True, ): - pass + """Calculate reference metric for dice score""" + import pdb + pdb.set_trace() + if input_format == "one-hot": + preds = preds.argmax(dim=1) + target = target.argmax(dim=1) + preds = preds.cpu().numpy() + target = target.cpu().numpy() + + labels = list(range(1, NUM_CLASSES) if not include_background else range(NUM_CLASSES)) + if reduce: + return f1_score(target.flatten(), preds.flatten(), average=average, labels=labels) + import pdb + pdb.set_trace() + val = [f1_score(t, p, average=average, labels=labels) for t, p in zip(target, preds)] + return val @pytest.mark.parametrize( @@ -64,5 +75,48 @@ def _reference_dice_score( ], ) @pytest.mark.parametrize("include_background", [True, False]) +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", "none"]) class TestGeneralizedDiceScore(MetricTester): """Test class for `DiceScore` metric.""" + + # @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + # def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): + # """Test class implementation of metric.""" + # self.run_class_metric_test( + # ddp=ddp, + # preds=preds, + # target=target, + # metric_class=GeneralizedDiceScore, + # reference_metric=partial( + # _reference_generalized_dice, + # input_format=input_format, + # include_background=include_background, + # reduce=True, + # ), + # metric_args={ + # "num_classes": NUM_CLASSES, + # "include_background": include_background, + # "input_format": input_format, + # }, + # ) + + def test_generalized_dice_functional(self, preds, target, input_format, include_background, average): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=dice_score, + reference_metric=partial( + _reference_dice_score, + input_format=input_format, + include_background=include_background, + average=average, + reduce=False, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "average": average, + "input_format": input_format, + }, + ) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 5347b6109b1..7958435bc64 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -47,7 +47,7 @@ def _reference_generalized_dice( include_background: bool = True, reduce: bool = True, ): - """Calculate reference metric for `MeanIoU`.""" + """Calculate reference metric for generalized dice metric.""" if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) From b7a4f0009fb1670931376758b6781f9e3a341192 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:11:57 +0000 Subject: [PATCH 16/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/segmentation/test_dice.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index 63236d31c63..5727e48ff14 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -15,8 +15,8 @@ import pytest import torch -from sklearn.metrics import f1_score from monai.metrics.meandice import compute_dice +from sklearn.metrics import f1_score from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore @@ -48,8 +48,9 @@ def _reference_dice_score( average: str = "micro", reduce: bool = True, ): - """Calculate reference metric for dice score""" + """Calculate reference metric for dice score.""" import pdb + pdb.set_trace() if input_format == "one-hot": preds = preds.argmax(dim=1) @@ -61,6 +62,7 @@ def _reference_dice_score( if reduce: return f1_score(target.flatten(), preds.flatten(), average=average, labels=labels) import pdb + pdb.set_trace() val = [f1_score(t, p, average=average, labels=labels) for t, p in zip(target, preds)] return val From 13038ce836b16f8a29391d5378ff4412d85d03d7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 24 Oct 2024 12:00:53 +0200 Subject: [PATCH 17/21] fix implementation --- .../functional/segmentation/dice.py | 20 ++++++++------ src/torchmetrics/segmentation/dice.py | 26 +++++++++++++------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 3d73c89640b..c4b033fb5ec 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -47,6 +47,7 @@ def _dice_score_update( include_background: bool, input_format: Literal["one-hot", "index"] = "one-hot", ) -> Tensor: + """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.") @@ -56,7 +57,7 @@ def _dice_score_update( target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1) if not include_background: - preds, target = _ignore_background(preds, target, num_classes) + preds, target = _ignore_background(preds, target) reduce_axis = list(range(2, target.ndim)) intersection = torch.sum(preds * target, dim=reduce_axis) @@ -65,23 +66,26 @@ def _dice_score_update( numerator = 2 * intersection denominator = pred_sum + target_sum - return numerator, denominator + support = target_sum + return numerator, denominator, support def _dice_score_compute( numerator: Tensor, denominator: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + support: Optional[Tensor] = None, ) -> Tensor: + """Compute the Dice score from the numerator and denominator.""" if average == "micro": numerator = torch.sum(numerator, dim=-1) denominator = torch.sum(denominator, dim=-1) dice = _safe_divide(numerator, denominator, zero_division=1.0) if average == "macro": - dice = torch.mean(dice) - elif average == "weighted": - weights = _safe_divide(denominator, torch.sum(denominator), zero_division=1.0) - dice = torch.sum(dice * weights) + dice = torch.mean(dice, dim=-1) + elif average == "weighted" and support is not None: + weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=1.0) + dice = torch.sum(dice * weights, dim=-1) return dice @@ -140,5 +144,5 @@ def dice_score( """ _dice_score_validate_args(num_classes, include_background, average, input_format) - numerator, denominator = _dice_score_update(preds, target, num_classes, include_background, input_format) - return _dice_score_compute(numerator, denominator, average) + numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) + return _dice_score_compute(numerator, denominator, average, support=support) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 7dc14e5c914..ef8f4e1db74 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, List, Optional, Sequence, Union -import torch from torch import Tensor from typing_extensions import Literal @@ -23,6 +22,7 @@ _dice_score_validate_args, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -57,7 +57,8 @@ class DiceScore(Metric): - ``gds`` (:class:`~torch.Tensor`): The dice score. If ``average`` is set to ``None`` or ``"none"`` the output will be a tensor of shape ``(C,)`` with the dice score for each class. If ``average`` is set to - ``"micro"``, ``"macro"``, or ``"weighted"`` the output will be a scalar tensor. + ``"micro"``, ``"macro"``, or ``"weighted"`` the output will be a scalar tensor. The score is an average over + all samples. Args: num_classes: The number of classes in the segmentation problem. @@ -100,6 +101,7 @@ class DiceScore(Metric): numerator: List[Tensor] denominator: List[Tensor] + support: List[Tensor] def __init__( self, @@ -117,20 +119,28 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("numerator", torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("denominator", torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("numerator", [], dist_reduce_fx="cat") + self.add_state("denominator", [], dist_reduce_fx="cat") + self.add_state("support", [], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" - numerator, denominator = _dice_score_update( + numerator, denominator, support = _dice_score_update( preds, target, self.num_classes, self.include_background, self.input_format ) - self.numerator += numerator.sum(dim=0) - self.denominator += denominator.sum(dim=0) + self.numerator.append(numerator) + self.denominator.append(denominator) + if self.average == "weighted": + self.support.append(support) def compute(self) -> Tensor: """Computes the Dice Score.""" - return _dice_score_compute(self.numerator, self.denominator, self.average) + return _dice_score_compute( + dim_zero_cat(self.numerator), + dim_zero_cat(self.denominator), + self.average, + support=dim_zero_cat(self.support) if self.average == "weighted" else None, + ).mean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. From abc55e4b1be5c5ec1cdd192b1740fb9efd717e88 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 24 Oct 2024 12:02:52 +0200 Subject: [PATCH 18/21] centralize inputs and fix testing --- tests/unittests/segmentation/inputs.py | 22 ++++-- tests/unittests/segmentation/test_dice.py | 75 +++++++------------ .../test_generalized_dice_score.py | 16 +--- tests/unittests/segmentation/test_mean_iou.py | 16 +--- 4 files changed, 48 insertions(+), 81 deletions(-) diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py index 996b8364e9c..b773ba29ebd 100644 --- a/tests/unittests/segmentation/inputs.py +++ b/tests/unittests/segmentation/inputs.py @@ -13,16 +13,24 @@ # limitations under the License. __all__ = ["_Input"] -from typing import NamedTuple - -from torch import Tensor +import torch +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input from unittests._helpers import seed_all seed_all(42) +to_one_hot = lambda x: torch.nn.functional.one_hot(x, NUM_CLASSES).permute(0, 1, 4, 2, 3) -# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels -class _Input(NamedTuple): - preds: Tensor - target: Tensor +_inputs1 = _Input( + preds=to_one_hot(torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 16, 16))), + target=to_one_hot(torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 16, 16))), +) +_inputs2 = _Input( + preds=to_one_hot(torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32))), + target=to_one_hot(torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32))), +) +_inputs3 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), +) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index 5727e48ff14..d5bfc08b4ae 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -15,30 +15,17 @@ import pytest import torch -from monai.metrics.meandice import compute_dice from sklearn.metrics import f1_score from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore -from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 seed_all(42) -_inputs1 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), -) -_inputs2 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), -) -_inputs3 = _Input( - preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), - target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), -) - def _reference_dice_score( preds: torch.Tensor, @@ -49,9 +36,6 @@ def _reference_dice_score( reduce: bool = True, ): """Calculate reference metric for dice score.""" - import pdb - - pdb.set_trace() if input_format == "one-hot": preds = preds.argmax(dim=1) target = target.argmax(dim=1) @@ -59,12 +43,9 @@ def _reference_dice_score( target = target.cpu().numpy() labels = list(range(1, NUM_CLASSES) if not include_background else range(NUM_CLASSES)) + val = [f1_score(t.flatten(), p.flatten(), average=average, labels=labels) for t, p in zip(target, preds)] if reduce: - return f1_score(target.flatten(), preds.flatten(), average=average, labels=labels) - import pdb - - pdb.set_trace() - val = [f1_score(t, p, average=average, labels=labels) for t, p in zip(target, preds)] + val = torch.tensor(val).mean(dim=0) return val @@ -77,32 +58,34 @@ def _reference_dice_score( ], ) @pytest.mark.parametrize("include_background", [True, False]) -@pytest.mark.parametrize("average", ["micro", "macro", "weighted", "none"]) -class TestGeneralizedDiceScore(MetricTester): +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +class TestDiceScore(MetricTester): """Test class for `DiceScore` metric.""" - # @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - # def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): - # """Test class implementation of metric.""" - # self.run_class_metric_test( - # ddp=ddp, - # preds=preds, - # target=target, - # metric_class=GeneralizedDiceScore, - # reference_metric=partial( - # _reference_generalized_dice, - # input_format=input_format, - # include_background=include_background, - # reduce=True, - # ), - # metric_args={ - # "num_classes": NUM_CLASSES, - # "include_background": include_background, - # "input_format": input_format, - # }, - # ) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_dice_score_class(self, preds, target, input_format, include_background, average, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=DiceScore, + reference_metric=partial( + _reference_dice_score, + input_format=input_format, + include_background=include_background, + average=average, + reduce=True, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "include_background": include_background, + "average": average, + "input_format": input_format, + }, + ) - def test_generalized_dice_functional(self, preds, target, input_format, include_background, average): + def test_dice_score_functional(self, preds, target, input_format, include_background, average): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 7958435bc64..02c7ec1a859 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -20,25 +20,13 @@ from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore -from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 seed_all(42) -_inputs1 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), -) -_inputs2 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), -) -_inputs3 = _Input( - preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), - target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), -) - def _reference_generalized_dice( preds: torch.Tensor, diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 68c2b060a9e..8c21d5c70c3 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -20,21 +20,9 @@ from torchmetrics.functional.segmentation.mean_iou import mean_iou from torchmetrics.segmentation.mean_iou import MeanIoU -from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input +from unittests import NUM_CLASSES from unittests._helpers.testers import MetricTester - -_inputs1 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 16)), -) -_inputs2 = _Input( - preds=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), - target=torch.randint(0, 2, (NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, 32, 32)), -) -_inputs3 = _Input( - preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), - target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), -) +from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 def _reference_mean_iou( From 7ad0069ad5b8a998b3e3e5ccff2a3c8d7f9dc6fb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 24 Oct 2024 12:10:32 +0200 Subject: [PATCH 19/21] fix typing --- src/torchmetrics/functional/segmentation/dice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index c4b033fb5ec..87b3b699fc0 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor @@ -46,7 +46,7 @@ def _dice_score_update( num_classes: int, include_background: bool, input_format: Literal["one-hot", "index"] = "one-hot", -) -> Tensor: +) -> Tuple[Tensor, Tensor, Tensor]: """Update the state with the current prediction and target.""" _check_same_shape(preds, target) if preds.ndim < 3: From 4020b29ac0eaa747b2d1023d6e3d464577d5b4ed Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 24 Oct 2024 12:13:21 +0200 Subject: [PATCH 20/21] fixes --- src/torchmetrics/classification/dice.py | 4 ++-- src/torchmetrics/functional/classification/dice.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index cbc8a84987e..59601104e9e 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -116,8 +116,8 @@ class Dice(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. .. warning:: - The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be - removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it + The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will + be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation domain in v1.6.0 with slight modifications to functionality. diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 5c08a028572..845ed7162d7 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -152,8 +152,8 @@ def dice( than what they appear to be. .. warning:: - The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will be - removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it + The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will + be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation domain in v1.6.0 with slight modifications to functionality. From 6d5ab7046ce77701fa1e1c4c27db9ea27ea0088e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 24 Oct 2024 12:19:00 +0200 Subject: [PATCH 21/21] fix doctests --- src/torchmetrics/segmentation/dice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index ef8f4e1db74..fc8cadd8c3a 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -86,10 +86,10 @@ class DiceScore(Metric): >>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target >>> dice_score = DiceScore(num_classes=5, average="micro") >>> dice_score(preds, target) - tensor(0.4993) + tensor(0.4941) >>> dice_score = DiceScore(num_classes=5, average="none") >>> dice_score(preds, target) - tensor([0.4993, 0.5002, 0.5004, 0.4996, 0.5000]) + tensor([0.4860, 0.4999, 0.5014, 0.4885, 0.4915]) """