Skip to content

Commit

Permalink
Initial commit -- Adding calibration loss specific to segmentation (#…
Browse files Browse the repository at this point in the history
…7819)

### Description

Model calibration has helped in developing reliable deep learning
models. In this pull request, I have added a new loss function NACL
(https://arxiv.org/abs/2303.06268, https://arxiv.org/abs/2401.14487)
which has shown promising results for both discriminative and
calibration in segmentation.

**Future Plans:** Currently, MONAI has some of the alternative loss
functions (Label Smoothing, and Focal Loss), but it doesn't have the
calibration specific loss functions (https://arxiv.org/abs/2111.15430,
https://arxiv.org/abs/2209.09641). Besides, these methods are better
evaluated with calibration metrics, Expected Calibration Error
(https://lightning.ai/docs/torchmetrics/stable/classification/calibration_error.html).

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Balamurali <[email protected]>
Signed-off-by: bala93 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Aug 8, 2024
1 parent 49a1e34 commit 660891f
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ Segmentation Losses
.. autoclass:: SoftDiceclDiceLoss
:members:

`NACLLoss`
~~~~~~~~~~
.. autoclass:: NACLLoss
:members:

Registration Losses
-------------------

Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .nacl_loss import NACLLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
Expand Down
139 changes: 139 additions & 0 deletions monai/losses/nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.networks.layers import GaussianFilter, MeanFilter


class NACLLoss(_Loss):
"""
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
to match a soft class proportion of surrounding pixel.
Murugesan, Balamurali, et al.
"Trust your neighbours: Penalty-based constraints for model calibration."
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
https://arxiv.org/abs/2303.06268
Murugesan, Balamurali, et al.
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
https://arxiv.org/abs/2401.14487
"""

def __init__(
self,
classes: int,
dim: int,
kernel_size: int = 3,
kernel_ops: str = "mean",
distance_type: str = "l1",
alpha: float = 0.1,
sigma: float = 1.0,
) -> None:
"""
Args:
classes: number of classes
dim: dimension of data (supports 2d and 3d)
kernel_size: size of the spatial kernel
distance_type: l1/l2 distance between spatial kernel and predicted logits
alpha: weightage between cross entropy and logit constraint
sigma: sigma of gaussian
"""

super().__init__()

if kernel_ops not in ["mean", "gaussian"]:
raise ValueError("Kernel ops must be either mean or gaussian")

if dim not in [2, 3]:
raise ValueError(f"Support 2d and 3d, got dim={dim}.")

if distance_type not in ["l1", "l2"]:
raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}")

self.nc = classes
self.dim = dim
self.cross_entropy = nn.CrossEntropyLoss()
self.distance_type = distance_type
self.alpha = alpha
self.ks = kernel_size
self.svls_layer: Any

if kernel_ops == "mean":
self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)
self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)
if kernel_ops == "gaussian":
self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)

def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
"""
Converts the mask to one hot represenation and is smoothened with the selected spatial filter.
Args:
mask: the shape should be BH[WD].
Returns:
torch.Tensor: the shape would be BNH[WD], N being number of classes.
"""
rmask: torch.Tensor

if self.dim == 2:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
rmask = self.svls_layer(oh_labels)

if self.dim == 3:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
rmask = self.svls_layer(oh_labels)

return rmask

def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
Args:
inputs: the shape should be BNH[WD], where N is the number of classes.
targets: the shape should be BH[WD].
Returns:
torch.Tensor: value of the loss.
Example:
>>> import torch
>>> from monai.losses import NACLLoss
>>> B, N, H, W = 8, 3, 64, 64
>>> input = torch.rand(B, N, H, W)
>>> target = torch.randint(0, N, (B, H, W))
>>> criterion = NACLLoss(classes = N, dim = 2)
>>> loss = criterion(input, target)
"""

loss_ce = self.cross_entropy(inputs, targets)

utargets = self.get_constr_target(targets)

if self.distance_type == "l1":
loss_conf = utargets.sub(inputs).abs_().mean()
elif self.distance_type == "l2":
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()

loss: torch.Tensor = loss_ce + self.alpha * loss_conf

return loss
166 changes: 166 additions & 0 deletions tests/test_nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import NACLLoss

inputs = torch.tensor(
[
[
[
[0.1498, 0.1158, 0.3996, 0.3730],
[0.2155, 0.1585, 0.8541, 0.8579],
[0.6640, 0.2424, 0.0774, 0.0324],
[0.0580, 0.2180, 0.3447, 0.8722],
],
[
[0.3908, 0.9366, 0.1779, 0.1003],
[0.9630, 0.6118, 0.4405, 0.7916],
[0.5782, 0.9515, 0.4088, 0.3946],
[0.7860, 0.3910, 0.0324, 0.9568],
],
[
[0.0759, 0.0238, 0.5570, 0.1691],
[0.2703, 0.7722, 0.1611, 0.6431],
[0.8051, 0.6596, 0.4121, 0.1125],
[0.5283, 0.6746, 0.5528, 0.7913],
],
]
]
)
targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]])

TEST_CASES = [
[{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442],
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433],
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469],
[{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269],
[{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790],
[
{"classes": 3, "dim": 3, "kernel_ops": "gaussian"},
{
"inputs": torch.tensor(
[
[
[
[
[0.5977, 0.2767, 0.0591, 0.1675],
[0.4835, 0.3778, 0.8406, 0.3065],
[0.6047, 0.2860, 0.9742, 0.2013],
[0.9128, 0.8368, 0.6711, 0.4384],
],
[
[0.9797, 0.1863, 0.5584, 0.6652],
[0.2272, 0.2004, 0.7914, 0.4224],
[0.5097, 0.8818, 0.2581, 0.3495],
[0.1054, 0.5483, 0.3732, 0.3587],
],
[
[0.3060, 0.7066, 0.7922, 0.4689],
[0.1733, 0.8902, 0.6704, 0.2037],
[0.8656, 0.5561, 0.2701, 0.0092],
[0.1866, 0.7714, 0.6424, 0.9791],
],
[
[0.5067, 0.3829, 0.6156, 0.8985],
[0.5192, 0.8347, 0.2098, 0.2260],
[0.8887, 0.3944, 0.6400, 0.5345],
[0.1207, 0.3763, 0.5282, 0.7741],
],
],
[
[
[0.8499, 0.4759, 0.1964, 0.5701],
[0.3190, 0.1238, 0.2368, 0.9517],
[0.0797, 0.6185, 0.0135, 0.8672],
[0.4116, 0.1683, 0.1355, 0.0545],
],
[
[0.7533, 0.2658, 0.5955, 0.4498],
[0.9500, 0.2317, 0.2825, 0.9763],
[0.1493, 0.1558, 0.3743, 0.8723],
[0.1723, 0.7980, 0.8816, 0.0133],
],
[
[0.8426, 0.2666, 0.2077, 0.3161],
[0.1725, 0.8414, 0.1515, 0.2825],
[0.4882, 0.5159, 0.4120, 0.1585],
[0.2551, 0.9073, 0.7691, 0.9898],
],
[
[0.4633, 0.8717, 0.8537, 0.2899],
[0.3693, 0.7953, 0.1183, 0.4596],
[0.0087, 0.7925, 0.0989, 0.8385],
[0.8261, 0.6920, 0.7069, 0.4464],
],
],
[
[
[0.0110, 0.1608, 0.4814, 0.6317],
[0.0194, 0.9669, 0.3259, 0.0028],
[0.5674, 0.8286, 0.0306, 0.5309],
[0.3973, 0.8183, 0.0238, 0.1934],
],
[
[0.8947, 0.6629, 0.9439, 0.8905],
[0.0072, 0.1697, 0.4634, 0.0201],
[0.7184, 0.2424, 0.0820, 0.7504],
[0.3937, 0.1424, 0.4463, 0.5779],
],
[
[0.4123, 0.6227, 0.0523, 0.8826],
[0.0051, 0.0353, 0.3662, 0.7697],
[0.4867, 0.8986, 0.2510, 0.5316],
[0.1856, 0.2634, 0.9140, 0.9725],
],
[
[0.2041, 0.4248, 0.2371, 0.7256],
[0.2168, 0.5380, 0.4538, 0.7007],
[0.9013, 0.2623, 0.0739, 0.2998],
[0.1366, 0.5590, 0.2952, 0.4592],
],
],
]
]
),
"targets": torch.tensor(
[
[
[[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]],
[[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]],
[[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]],
[[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]],
]
]
),
},
1.15035,
],
]


class TestNACLLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
loss = NACLLoss(**input_param)
result = loss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 660891f

Please sign in to comment.