Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Segmentation] Added generalized dice score metric #1090

Merged
merged 80 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
d48dd72
Added generalized dice score metric
jlcsilva Jun 15, 2022
fa3612f
Merge branch 'master' into generalized_dice_score
SkafteNicki Jun 20, 2022
1af04cc
Merge branch 'master' into generalized_dice_score
Borda Jun 20, 2022
f3eda70
Merge branch 'master' into generalized_dice_score
justusschock Jun 28, 2022
d6dfe99
Apply suggestions from code review
Borda Jun 29, 2022
47421a3
Merge branch 'master' into generalized_dice_score
Borda Jun 29, 2022
a5985c7
move
Borda Jun 29, 2022
f541731
Merge branch 'master' into generalized_dice_score
Borda Jul 11, 2022
5d946e6
Merge branch 'master' into generalized_dice_score
stancld Oct 5, 2022
6266251
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2022
fe8be53
fix integration testing
SkafteNicki Nov 5, 2022
c49d6c1
revert
SkafteNicki Nov 6, 2022
180b652
Merge branch 'master' into generalized_dice_score
SkafteNicki Nov 6, 2022
46c7e1f
revert some more
SkafteNicki Nov 6, 2022
3238522
Merge branch 'generalized_dice_score' of https://github.com/jlcsilva/…
SkafteNicki Nov 6, 2022
c23fc00
changelog
SkafteNicki Nov 6, 2022
1d3ac53
deprecate dice
SkafteNicki Nov 6, 2022
427a3ff
transfer to new format
SkafteNicki Nov 6, 2022
f14047e
Merge branch 'master' into generalized_dice_score
justusschock Jan 9, 2023
54afa44
missing import
justusschock Jan 9, 2023
7aa422e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2023
e6f67b0
missing import
justusschock Jan 9, 2023
0a931fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2023
45a486c
Merge branch 'master' into generalized_dice_score
Borda Feb 20, 2023
e10b914
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
0e3ab5d
Apply suggestions from code review
Borda Feb 20, 2023
5e7fe1c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
fc93b89
Merge branch 'master' into generalized_dice_score
Borda Feb 20, 2023
783845c
Merge branch 'master' into generalized_dice_score
Borda Feb 28, 2023
653a9e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2023
d6cde19
Literal
Borda Feb 28, 2023
0541aed
Merge branch 'master' into generalized_dice_score
Borda Mar 6, 2023
e94580d
Merge branch 'master' into generalized_dice_score
Borda Mar 21, 2023
a5a34b9
Merge branch 'master' into generalized_dice_score
Borda Mar 21, 2023
579bad1
Merge branch 'master' into generalized_dice_score
Borda Mar 21, 2023
4b99b1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2023
60328f6
Merge branch 'master' into generalized_dice_score
Borda Mar 31, 2023
0357adb
Merge branch 'master' into generalized_dice_score
Borda Apr 17, 2023
0435c7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
34c8044
Merge branch 'master' into generalized_dice_score
Borda Apr 17, 2023
b879c3d
Merge branch 'master' into generalized_dice_score
Borda Apr 17, 2023
d8787e3
Merge branch 'master' into generalized_dice_score
Borda Apr 26, 2023
c02c670
Merge branch 'master' into generalized_dice_score
Borda May 9, 2023
534a81d
Merge branch 'master' into generalized_dice_score
Borda May 15, 2023
05a1c37
Merge branch 'master' into generalized_dice_score
Borda May 17, 2023
5434095
Merge branch 'master' into generalized_dice_score
Borda May 23, 2023
4631b65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
6b43738
Merge branch 'master' into generalized_dice_score
Borda Jun 15, 2023
6e29777
Merge branch 'master' into generalized_dice_score
Borda Jun 30, 2023
628696d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
ce3df53
Merge branch 'master' into generalized_dice_score
Borda Jul 3, 2023
206ec2b
Merge branch 'master' into generalized_dice_score
Borda Jul 6, 2023
bc4683b
Merge branch 'master' into generalized_dice_score
Borda Aug 7, 2023
ef708bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
ea15142
Merge branch 'master' into generalized_dice_score
Borda Aug 7, 2023
edae668
Merge branch 'master' into generalized_dice_score
Borda Aug 18, 2023
6c4d9dc
Merge branch 'master' into generalized_dice_score
Borda Aug 23, 2023
5b1a941
Merge branch 'master' into generalized_dice_score
Borda Aug 25, 2023
86464ec
Merge branch 'master' into generalized_dice_score
Borda Jan 9, 2024
4c8234e
Merge branch 'master' into generalized_dice_score
Borda Jan 12, 2024
5eaca9c
Merge branch 'master' into generalized_dice_score
Borda Feb 15, 2024
4c6fae7
Merge branch 'master' into generalized_dice_score
SkafteNicki Apr 12, 2024
32e6d1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
978054d
move around files
SkafteNicki Apr 12, 2024
70fabaf
revert some changes from classification
SkafteNicki Apr 12, 2024
bd29835
move content in init files
SkafteNicki Apr 12, 2024
09b03d4
Merge branch 'generalized_dice_score' of https://github.com/jlcsilva/…
SkafteNicki Apr 12, 2024
05ca6e9
implementation
SkafteNicki Apr 22, 2024
75235d8
more implementation
SkafteNicki Apr 22, 2024
23df55f
docstrings + fixing of testing framework
SkafteNicki Apr 22, 2024
7726927
links
SkafteNicki Apr 23, 2024
547f890
fix remaining tests
SkafteNicki Apr 23, 2024
77f151e
revert some changes
SkafteNicki Apr 23, 2024
f30c37b
Merge branch 'master' into generalized_dice_score
SkafteNicki Apr 23, 2024
5ba5d7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
16eaf2e
revert changelog
SkafteNicki Apr 23, 2024
57e26de
Merge branch 'generalized_dice_score' of https://github.com/jlcsilva/…
SkafteNicki Apr 23, 2024
4286283
fixes to docs
SkafteNicki Apr 23, 2024
dc72724
mypy
Borda Apr 23, 2024
bd375c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `GeneralizedDiceScore` to segmentation package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090))


- Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217))


Expand All @@ -34,7 +37,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

-


### Fixed
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ covers the following domains:
- Multimodal (Image-Text)
- Nominal
- Regression
- Segmentation
- Text

Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ Or directly from conda

retrieval/*

.. toctree::
:maxdepth: 2
:name: segmentation
:caption: Segmentation
:glob:

segmentation/*

.. toctree::
:maxdepth: 2
:name: text
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
22 changes: 22 additions & 0 deletions docs/source/segmentation/generalized_dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Generalized Dice Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

######################
Generalized Dice Score
######################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.GeneralizedDiceScore
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.generalized_dice_score
:noindex:
1 change: 1 addition & 0 deletions src/torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"confusion_matrix",
"multiclass_confusion_matrix",
"multilabel_confusion_matrix",
"generalized_dice_score",
"dice",
"exact_match",
"multiclass_exact_match",
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score

__all__ = ["generalized_dice_score"]
138 changes: 138 additions & 0 deletions src/torchmetrics/functional/segmentation/generalized_dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide


def _generalized_dice_validate_args(
num_classes: int,
include_background: bool,
per_class: bool,
weight_type: Literal["square", "simple", "linear"],
) -> None:
"""Validate the arguments of the metric."""
if num_classes <= 0:
raise ValueError(f"Expected argument `num_classes` must be a positive integer, but got {num_classes}.")
if not isinstance(include_background, bool):
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
if not isinstance(per_class, bool):
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
if weight_type not in ["square", "simple", "linear"]:
raise ValueError(
f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}."
)


def _generalized_dice_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool,
weight_type: Literal["square", "simple", "linear"] = "square",
) -> Tensor:
"""Update the state with the current prediction and target."""
_check_same_shape(preds, target)
if preds.ndim < 3:
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")

if (preds.bool() != preds).any(): # preds is an index tensor
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if (target.bool() != target).any(): # target is an index tensor
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
preds, target = _ignore_background(preds, target)

reduce_axis = list(range(2, target.ndim))
intersection = torch.sum(preds * target, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
cardinality = target_sum + pred_sum

if weight_type == "simple":
weights = 1.0 / target_sum
elif weight_type == "linear":
weights = torch.ones_like(target_sum)
elif weight_type == "square":
weights = 1.0 / (target_sum**2)
else:
raise ValueError(
f"Expected argument `weight_type` to be one of 'simple', 'linear', 'square', but got {weight_type}."
)

w_shape = weights.shape
weights_flatten = weights.flatten()
infs = torch.isinf(weights_flatten)
weights_flatten[infs] = 0
w_max = torch.max(weights, 0).values.repeat(w_shape[0], 1).T.flatten()
weights_flatten[infs] = w_max[infs]
weights = weights_flatten.reshape(w_shape)

numerator = 2.0 * intersection * weights
denominator = cardinality * weights
return numerator, denominator # type:ignore[return-value]


def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: bool = True) -> Tensor:
"""Compute the generalized dice score."""
if not per_class:
numerator = torch.sum(numerator, 1)
denominator = torch.sum(denominator, 1)
return _safe_divide(numerator, denominator)


def generalized_dice_score(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
weight_type: Literal["square", "simple", "linear"] = "square",
) -> Tensor:
"""Compute the Generalized Dice Score for semantic segmentation.

Args:
preds: Predictions from model
target: Ground truth values
num_classes: Number of classes
include_background: Whether to include the background class in the computation
per_class: Whether to compute the IoU for each class separately, else average over all classes
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``

Returns:
The Generalized Dice Score

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.segmentation import generalized_dice_score
>>> preds = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 prediction
>>> target = torch.randint(0, 2, (4, 5, 16, 16)) # 4 samples, 5 classes, 16x16 target
>>> generalized_dice_score(preds, target, num_classes=5)
tensor([0.4830, 0.4935, 0.5044, 0.4880])
>>> generalized_dice_score(preds, target, num_classes=5, per_class=True)
tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500],
[0.4571, 0.4980, 0.5191, 0.4380, 0.5649],
[0.5428, 0.4904, 0.5358, 0.4830, 0.4724],
[0.4715, 0.4925, 0.4797, 0.5267, 0.4788]])

"""
_generalized_dice_validate_args(num_classes, include_background, per_class, weight_type)
numerator, denominator = _generalized_dice_update(preds, target, num_classes, include_background, weight_type)
return _generalized_dice_compute(numerator, denominator, per_class)
7 changes: 7 additions & 0 deletions src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from torchmetrics.utilities.imports import _SCIPY_AVAILABLE


def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Ignore the background class in the computation."""
preds = preds[:, 1:] if preds.shape[1] > 1 else preds
target = target[:, 1:] if target.shape[1] > 1 else target
return preds, target


def check_if_binarized(x: Tensor) -> None:
"""Check if the input is binarized.

Expand Down
17 changes: 17 additions & 0 deletions src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore

__all__ = ["GeneralizedDiceScore"]
Loading
Loading