Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Dice from classification and re-add to segmentation #2725

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f5fd1e1
a beginning of something new
SkafteNicki Sep 7, 2024
0874184
docs
SkafteNicki Sep 7, 2024
e0ea066
deprecate old implementation
SkafteNicki Sep 8, 2024
7851630
fix small mistakes in generalized dice
SkafteNicki Sep 8, 2024
35ca256
changelog
SkafteNicki Sep 9, 2024
14a77aa
initial new implementation
SkafteNicki Sep 9, 2024
f925599
Merge branch 'master' into newmetric/move_dice
SkafteNicki Sep 9, 2024
591665c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2024
c57c11f
Merge branch 'master' into newmetric/move_dice
SkafteNicki Sep 9, 2024
48f584f
Merge branch 'master' into newmetric/move_dice
Borda Sep 13, 2024
9baf6e6
Apply suggestions from code review
Borda Sep 13, 2024
9be6172
more code
SkafteNicki Sep 14, 2024
3950b00
Merge branch 'newmetric/move_dice' of https://github.com/Lightning-AI…
SkafteNicki Sep 14, 2024
88cc2c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2024
280fd92
Merge branch 'master' into newmetric/move_dice
Borda Sep 16, 2024
64d51a9
Merge branch 'master' into newmetric/move_dice
SkafteNicki Oct 15, 2024
a24ea13
Merge branch 'master' into newmetric/move_dice
SkafteNicki Oct 22, 2024
32128c4
fix
SkafteNicki Oct 22, 2024
ba0124d
update doctests
SkafteNicki Oct 22, 2024
8346d7d
doctests
SkafteNicki Oct 22, 2024
051187d
update implementations
SkafteNicki Oct 22, 2024
be0189e
Merge branch 'master' into newmetric/move_dice
SkafteNicki Oct 22, 2024
42b1824
somewhat working tests
SkafteNicki Oct 22, 2024
b7a4f00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
81dc3c3
Merge branch 'master' into newmetric/move_dice
SkafteNicki Oct 22, 2024
13038ce
fix implementation
SkafteNicki Oct 24, 2024
abc55e4
centralize inputs and fix testing
SkafteNicki Oct 24, 2024
c72786d
merge master
SkafteNicki Oct 24, 2024
7ad0069
fix typing
SkafteNicki Oct 24, 2024
4020b29
fixes
SkafteNicki Oct 24, 2024
6d5ab70
fix doctests
SkafteNicki Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Changed

- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))

### Deprecated

- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Removed

Expand Down
22 changes: 22 additions & 0 deletions docs/source/segmentation/dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Segmentation

.. include:: ../links.rst

##########
Dice Score
##########

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.DiceScore
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.dice_score
:noindex:
15 changes: 15 additions & 0 deletions src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -114,6 +115,12 @@ class Dice(Metric):

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

.. warning::
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will
be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.

Raises:
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"samples"``, ``"none"``, ``None``.
Expand Down Expand Up @@ -155,6 +162,14 @@ def __init__(
multiclass: Optional[bool] = None,
**kwargs: Any,
) -> None:
rank_zero_warn(
"The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and"
" will removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage"
" as it provides the same functionality. Additionally, we are going to re-add the `dice` metric in the"
" segmentation domain in v1.6.0 with slight modifications to functionality.",
DeprecationWarning,
)

super().__init__(**kwargs)
allowed_average = ("micro", "macro", "samples", "none", None)
if average not in allowed_average:
Expand Down
15 changes: 15 additions & 0 deletions src/torchmetrics/functional/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import Tensor

from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _input_squeeze
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod

Expand Down Expand Up @@ -150,6 +151,12 @@ def dice(
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be.

.. warning::
The ``dice`` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will
be removed in v1.7.0. Please instead consider using ``f1score`` metric from the classification subpackage as it
provides the same functionality. Additionally, we are going to re-add the ``dice`` metric in the segmentation
domain in v1.6.0 with slight modifications to functionality.

Return:
The shape of the returned tensor depends on the ``average`` parameter

Expand All @@ -174,6 +181,14 @@ def dice(
tensor(0.2500)

"""
rank_zero_warn(
"The `dice` metrics is being deprecated from the classification subpackage in v1.6.0 of torchmetrics and will"
" removed in v1.7.0. Please instead consider using `f1score` metric from the classification subpackage as it"
" provides the same functionality. Additionally, we are going to re-add the `dice` metric in the segmentation"
" domain in v1.6.0 with slight modifications to functionality.",
DeprecationWarning,
)

allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.segmentation.dice import dice_score
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance
from torchmetrics.functional.segmentation.mean_iou import mean_iou

__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"]
__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance", "dice_score"]
148 changes: 148 additions & 0 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide


def _dice_score_validate_args(
num_classes: int,
include_background: bool,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> None:
"""Validate the arguments of the metric."""
if not isinstance(num_classes, int) or num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
allowed_average = ["micro", "macro", "weighted", "none"]
if average is not None and average not in allowed_average:
raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")


def _dice_score_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool,
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tuple[Tensor, Tensor, Tensor]:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

reduce_axis = list(range(2, target.ndim))
intersection = torch.sum(preds * target, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)

numerator = 2 * intersection
denominator = pred_sum + target_sum
support = target_sum
return numerator, denominator, support


def _dice_score_compute(
numerator: Tensor,
denominator: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
support: Optional[Tensor] = None,
) -> Tensor:
"""Compute the Dice score from the numerator and denominator."""
if average == "micro":
numerator = torch.sum(numerator, dim=-1)
denominator = torch.sum(denominator, dim=-1)
dice = _safe_divide(numerator, denominator, zero_division=1.0)
if average == "macro":
dice = torch.mean(dice, dim=-1)
elif average == "weighted" and support is not None:
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=1.0)
dice = torch.sum(dice * weights, dim=-1)
return dice


def dice_score(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = True,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
) -> Tensor:
"""Compute the Dice score for semantic segmentation.

Args:
preds: Predictions from model
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
average: The method to average the dice score. Options are ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``
or ``None``. This determines how to average the dice score across different classes.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors

Returns:
The Dice score.

Example (with one-hot encoded tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import dice_score
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> # dice score micro averaged over all classes
>>> dice_score(preds, target, num_classes=5, average="micro")
tensor([0.4842, 0.4968, 0.5053, 0.4902])
>>> # dice score per sample and class
>>> dice_score(preds, target, num_classes=5, average="none")
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])

Example (with index tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import dice_score
>>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> # dice score micro averaged over all classes
>>> dice_score(preds, target, num_classes=5, average="micro", input_format="index")
tensor([0.2031, 0.1914, 0.2500, 0.2266])
>>> # dice score per sample and class
>>> dice_score(preds, target, num_classes=5, average="none", input_format="index")
tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069],
[0.1837, 0.2162, 0.0962, 0.2692, 0.1895],
[0.3866, 0.1348, 0.2526, 0.2301, 0.2083],
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])

"""
_dice_score_validate_args(num_classes, include_background, average, input_format)
numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
return _dice_score_compute(numerator, denominator, average, support=support)
19 changes: 16 additions & 3 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _generalized_dice_validate_args(
input_format: Literal["one-hot", "index"],
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
if not isinstance(num_classes, int) or num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
Expand Down Expand Up @@ -116,15 +116,15 @@ def generalized_dice_score(
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
per_class: Whether to compute the score for each class separately, else average over all classes
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors

Returns:
The Generalized Dice Score

Example:
Example (with one-hot encoded tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
Expand All @@ -137,6 +137,19 @@ def generalized_dice_score(
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])

Example (with index tensors):
>>> from torch import randint
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = randint(0, 5, (4, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> generalized_dice_score(preds, target, num_classes=5, input_format="index")
tensor([0.1991, 0.1971, 0.2350, 0.2216])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index")
tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069],
[0.1837, 0.2162, 0.0962, 0.2692, 0.1895],
[0.3866, 0.1348, 0.2526, 0.2301, 0.2083],
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])

"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format)
numerator, denominator = _generalized_dice_update(
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.segmentation.dice import DiceScore
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance
from torchmetrics.segmentation.mean_iou import MeanIoU

__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"]
__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance", "DiceScore"]
Loading
Loading