Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit -- Adding calibration loss specific to segmentation #7819

Merged
merged 54 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8fbec82
Initial commit -- Adding calibration loss specific to segmentation
Bala93 Jun 2, 2024
23b897b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2024
b2ec62b
Update __init__.py
Bala93 Jun 2, 2024
93ee114
Update segcalib.py
Bala93 Jun 2, 2024
42e732b
Update segcalib.py
Bala93 Jun 2, 2024
187053d
Update segcalib.py
Bala93 Jun 2, 2024
1d27ec5
Update segcalib.py
Bala93 Jun 2, 2024
d499134
Update segcalib.py
Bala93 Jun 3, 2024
1e3f755
Update segcalib.py
Bala93 Jun 3, 2024
9dedfba
Update segcalib.py
Bala93 Jun 4, 2024
59959ce
Update monai/losses/segcalib.py
Bala93 Jun 14, 2024
cf1d044
Update monai/losses/segcalib.py
Bala93 Jun 14, 2024
0926851
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
5317706
Update segcalib.py
Bala93 Jun 15, 2024
3155433
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2024
7c121a0
Add specific to gaussian for both 2d and 3d
Bala93 Aug 3, 2024
24efd85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
0067953
Merge branch 'Project-MONAI:dev' into model-calibration
Bala93 Aug 3, 2024
dccde47
Add mean loss and resolve formatting
Bala93 Aug 3, 2024
44e8065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
57686d7
Merge branch 'dev' into model-calibration
Bala93 Aug 3, 2024
5cd9a33
Update segcalib.py
Bala93 Aug 3, 2024
b547c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
42a0215
Update segcalib.py
Bala93 Aug 3, 2024
7e36ca1
Update segcalib.py
Bala93 Aug 3, 2024
6dbd53d
Update segcalib.py
Bala93 Aug 3, 2024
354056c
Update segcalib.py
Bala93 Aug 4, 2024
7eb911f
Update segcalib.py
Bala93 Aug 4, 2024
0b1209b
Update segcalib.py
Bala93 Aug 4, 2024
035c92e
Update segcalib.py
Bala93 Aug 4, 2024
c1de5f1
Rename segcalib.py to nacl_loss.py
Bala93 Aug 5, 2024
91dd1b9
Update __init__.py
Bala93 Aug 5, 2024
9702c02
Update test_nacl_loss.py
Bala93 Aug 5, 2024
4462379
Update nacl_loss.py
Bala93 Aug 5, 2024
c4f8283
Update test_nacl_loss.py
Bala93 Aug 5, 2024
bc6b995
Update test_nacl_loss.py
Bala93 Aug 5, 2024
51e15fe
Added missing parameters in doc
Bala93 Aug 5, 2024
3a00aec
Formatting check with monai
Bala93 Aug 5, 2024
818b42b
Update test_nacl_loss.py
Bala93 Aug 5, 2024
6647708
Added mypy fixes
Bala93 Aug 5, 2024
7e579dd
DCO Remediation Commit for bala93 <[email protected]>
Bala93 Aug 5, 2024
4f8abf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
b72e478
Update docs/source/losses.rst
Bala93 Aug 6, 2024
747681d
* Include test cases covering more cases
Bala93 Aug 7, 2024
3b15554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
877139c
Update monai/losses/nacl_loss.py
Bala93 Aug 7, 2024
4679456
Update monai/losses/nacl_loss.py
Bala93 Aug 7, 2024
7c5217e
* Add docstring with better explanations
Bala93 Aug 7, 2024
d33f435
* Maintain the dimension consistency.
Bala93 Aug 7, 2024
7deb2cc
Update nacl_loss.py
Bala93 Aug 7, 2024
91ce50b
Update nacl_loss.py
Bala93 Aug 7, 2024
7f87e0c
Merge branch 'model-calibration' of https://github.com/Bala93/MONAI i…
Bala93 Aug 7, 2024
0e880a8
Modify docstring
Bala93 Aug 7, 2024
db9daeb
Merge branch 'dev' into model-calibration
KumoLiu Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ Segmentation Losses
.. autoclass:: SoftDiceclDiceLoss
:members:

`NACLLoss`
~~~~~~~~~~
.. autoclass:: NACLLoss
:members:

Registration Losses
-------------------

Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .nacl_loss import NACLLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
Expand Down
139 changes: 139 additions & 0 deletions monai/losses/nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.networks.layers import GaussianFilter, MeanFilter


class NACLLoss(_Loss):
"""
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
to match a soft class proportion of surrounding pixel.

Murugesan, Balamurali, et al.
"Trust your neighbours: Penalty-based constraints for model calibration."
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
https://arxiv.org/abs/2303.06268

Murugesan, Balamurali, et al.
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
https://arxiv.org/abs/2401.14487
"""

def __init__(
self,
classes: int,
dim: int,
kernel_size: int = 3,
kernel_ops: str = "mean",
distance_type: str = "l1",
alpha: float = 0.1,
sigma: float = 1.0,
) -> None:
"""
Args:
classes: number of classes
dim: dimension of data (supports 2d and 3d)
kernel_size: size of the spatial kernel
distance_type: l1/l2 distance between spatial kernel and predicted logits
alpha: weightage between cross entropy and logit constraint
sigma: sigma of gaussian
"""

super().__init__()

if kernel_ops not in ["mean", "gaussian"]:
raise ValueError("Kernel ops must be either mean or gaussian")

if dim not in [2, 3]:
raise ValueError(f"Support 2d and 3d, got dim={dim}.")

if distance_type not in ["l1", "l2"]:
raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}")

self.nc = classes
self.dim = dim
self.cross_entropy = nn.CrossEntropyLoss()
self.distance_type = distance_type
self.alpha = alpha
self.ks = kernel_size
self.svls_layer: Any

if kernel_ops == "mean":
self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)
self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)
if kernel_ops == "gaussian":
self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)

def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts the mask to one hot represenation and is smoothened with the selected spatial filter.

Args:
mask: the shape should be BH[WD].

Returns:
torch.Tensor: the shape would be BNH[WD], N being number of classes.
"""
rmask: torch.Tensor

if self.dim == 2:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
rmask = self.svls_layer(oh_labels)

if self.dim == 3:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
rmask = self.svls_layer(oh_labels)

return rmask

def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.

Args:
inputs: the shape should be BNH[WD], where N is the number of classes.
targets: the shape should be BH[WD].

Returns:
torch.Tensor: value of the loss.

Example:
>>> import torch
>>> from monai.losses import NACLLoss
>>> B, N, H, W = 8, 3, 64, 64
>>> input = torch.rand(B, N, H, W)
>>> target = torch.randint(0, N, (B, H, W))
>>> criterion = NACLLoss(classes = N, dim = 2)
>>> loss = criterion(input, target)
"""

loss_ce = self.cross_entropy(inputs, targets)

utargets = self.get_constr_target(targets)

if self.distance_type == "l1":
loss_conf = utargets.sub(inputs).abs_().mean()
elif self.distance_type == "l2":
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()

loss: torch.Tensor = loss_ce + self.alpha * loss_conf

return loss
166 changes: 166 additions & 0 deletions tests/test_nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import NACLLoss

inputs = torch.tensor(
[
[
[
[0.1498, 0.1158, 0.3996, 0.3730],
[0.2155, 0.1585, 0.8541, 0.8579],
[0.6640, 0.2424, 0.0774, 0.0324],
[0.0580, 0.2180, 0.3447, 0.8722],
],
[
[0.3908, 0.9366, 0.1779, 0.1003],
[0.9630, 0.6118, 0.4405, 0.7916],
[0.5782, 0.9515, 0.4088, 0.3946],
[0.7860, 0.3910, 0.0324, 0.9568],
],
[
[0.0759, 0.0238, 0.5570, 0.1691],
[0.2703, 0.7722, 0.1611, 0.6431],
[0.8051, 0.6596, 0.4121, 0.1125],
[0.5283, 0.6746, 0.5528, 0.7913],
],
]
]
)
targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]])

TEST_CASES = [
[{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442],
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433],
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469],
[{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269],
[{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790],
[
{"classes": 3, "dim": 3, "kernel_ops": "gaussian"},
{
"inputs": torch.tensor(
[
[
[
[
[0.5977, 0.2767, 0.0591, 0.1675],
[0.4835, 0.3778, 0.8406, 0.3065],
[0.6047, 0.2860, 0.9742, 0.2013],
[0.9128, 0.8368, 0.6711, 0.4384],
],
[
[0.9797, 0.1863, 0.5584, 0.6652],
[0.2272, 0.2004, 0.7914, 0.4224],
[0.5097, 0.8818, 0.2581, 0.3495],
[0.1054, 0.5483, 0.3732, 0.3587],
],
[
[0.3060, 0.7066, 0.7922, 0.4689],
[0.1733, 0.8902, 0.6704, 0.2037],
[0.8656, 0.5561, 0.2701, 0.0092],
[0.1866, 0.7714, 0.6424, 0.9791],
],
[
[0.5067, 0.3829, 0.6156, 0.8985],
[0.5192, 0.8347, 0.2098, 0.2260],
[0.8887, 0.3944, 0.6400, 0.5345],
[0.1207, 0.3763, 0.5282, 0.7741],
],
],
[
[
[0.8499, 0.4759, 0.1964, 0.5701],
[0.3190, 0.1238, 0.2368, 0.9517],
[0.0797, 0.6185, 0.0135, 0.8672],
[0.4116, 0.1683, 0.1355, 0.0545],
],
[
[0.7533, 0.2658, 0.5955, 0.4498],
[0.9500, 0.2317, 0.2825, 0.9763],
[0.1493, 0.1558, 0.3743, 0.8723],
[0.1723, 0.7980, 0.8816, 0.0133],
],
[
[0.8426, 0.2666, 0.2077, 0.3161],
[0.1725, 0.8414, 0.1515, 0.2825],
[0.4882, 0.5159, 0.4120, 0.1585],
[0.2551, 0.9073, 0.7691, 0.9898],
],
[
[0.4633, 0.8717, 0.8537, 0.2899],
[0.3693, 0.7953, 0.1183, 0.4596],
[0.0087, 0.7925, 0.0989, 0.8385],
[0.8261, 0.6920, 0.7069, 0.4464],
],
],
[
[
[0.0110, 0.1608, 0.4814, 0.6317],
[0.0194, 0.9669, 0.3259, 0.0028],
[0.5674, 0.8286, 0.0306, 0.5309],
[0.3973, 0.8183, 0.0238, 0.1934],
],
[
[0.8947, 0.6629, 0.9439, 0.8905],
[0.0072, 0.1697, 0.4634, 0.0201],
[0.7184, 0.2424, 0.0820, 0.7504],
[0.3937, 0.1424, 0.4463, 0.5779],
],
[
[0.4123, 0.6227, 0.0523, 0.8826],
[0.0051, 0.0353, 0.3662, 0.7697],
[0.4867, 0.8986, 0.2510, 0.5316],
[0.1856, 0.2634, 0.9140, 0.9725],
],
[
[0.2041, 0.4248, 0.2371, 0.7256],
[0.2168, 0.5380, 0.4538, 0.7007],
[0.9013, 0.2623, 0.0739, 0.2998],
[0.1366, 0.5590, 0.2952, 0.4592],
],
],
]
]
),
"targets": torch.tensor(
[
[
[[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]],
[[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]],
[[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]],
[[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]],
]
]
),
},
1.15035,
],
]


class TestNACLLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
loss = NACLLoss(**input_param)
result = loss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
unittest.main()
Loading