diff --git a/monai/losses/segcalib.py b/monai/losses/segcalib.py index 5b4de7916f..38dc97f7d6 100644 --- a/monai/losses/segcalib.py +++ b/monai/losses/segcalib.py @@ -12,6 +12,7 @@ from __future__ import annotations import math +import warnings import torch import torch.nn as nn @@ -21,6 +22,16 @@ from monai.utils import pytorch_after +def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor: + mean_kernel = torch.ones([ksize, ksize]) / (ksize**2) + return mean_kernel + + +def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor: + mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3) + return mean_kernel + + def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor: x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize) y_grid = x_grid.t() @@ -101,6 +112,50 @@ def forward(self, x): return self.svls_layer(x) / self.svls_kernel.sum() +class MeanFilter(torch.nn.Module): + def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor: + super(MeanFilter, self).__init__() + + if dim == 2: + self.svls_kernel = get_mean_kernel_2d(ksize=ksize) + svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_2d + self.svls_layer.weight.requires_grad = False + + if dim == 3: + self.svls_kernel = get_mean_kernel_3d(ksize=ksize) + svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize) + svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1) + padding = int(ksize / 2) + + self.svls_layer = torch.nn.Conv3d( + in_channels=channels, + out_channels=channels, + kernel_size=ksize, + groups=channels, + bias=False, + padding=padding, + padding_mode="replicate", + ) + self.svls_layer.weight.data = svls_kernel_3d + self.svls_layer.weight.requires_grad = False + + def forward(self, x): + return self.svls_layer(x) / self.svls_kernel.sum() + + class NACLLoss(_Loss): """ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. @@ -118,6 +173,7 @@ def __init__( classes: int, dim: int, kernel_size: int = 3, + kernel_ops: str = "mean", distance_type: str = "l1", alpha: float = 0.1, sigma: float = 1.0, @@ -133,6 +189,9 @@ def __init__( super().__init__() + if kernel_ops not in ["mean", "gaussian"]: + raise ValueError("Kernel ops must be either mean or gaussian") + if dim not in [2, 3]: raise ValueError("Supoorts 2d and 3d") @@ -146,7 +205,10 @@ def __init__( self.alpha = alpha self.ks = kernel_size - self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) + if kernel_ops == "mean": + self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes) + if kernel_ops == "gaussian": + self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes) self.old_pt_ver = not pytorch_after(1, 10) @@ -173,24 +235,16 @@ def __init__( # return self.cross_entropy(input, target) # type: ignore[no-any-return] def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: - if self.dim == 2: - - oh_labels = ( - F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() - ) + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() rmask = self.svls_layer(oh_labels) if self.dim == 3: - - oh_labels = ( - F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() - ) + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() rmask = self.svls_layer(oh_labels) return rmask - def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss_ce = self.cross_entropy(inputs, targets) diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 19a2ef6336..1a4772dcb8 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -20,127 +20,203 @@ from monai.losses import NACLLoss TEST_CASES = [ - [ # shape: (2, 2, 3), (2, 2, 3) + [ {"classes": 3, "dim": 2}, { - "inputs": torch.tensor([[[[0.1498, 0.1158, 0.3996, 0.3730], - [0.2155, 0.1585, 0.8541, 0.8579], - [0.6640, 0.2424, 0.0774, 0.0324], - [0.0580, 0.2180, 0.3447, 0.8722]], - [[0.3908, 0.9366, 0.1779, 0.1003], - [0.9630, 0.6118, 0.4405, 0.7916], - [0.5782, 0.9515, 0.4088, 0.3946], - [0.7860, 0.3910, 0.0324, 0.9568]], - [[0.0759, 0.0238, 0.5570, 0.1691], - [0.2703, 0.7722, 0.1611, 0.6431], - [0.8051, 0.6596, 0.4121, 0.1125], - [0.5283, 0.6746, 0.5528, 0.7913]]]]), - "targets": torch.tensor([[[1, 1, 1, 1], - [1, 1, 1, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]]]), + "inputs": torch.tensor( + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), }, - 1.1850, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.1820, ], - [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 3, "dim": 2}, + [ + {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, { - "inputs": torch.tensor([[[[0.0411, 0.3386, 0.8352, 0.2741], - [0.4821, 0.0519, 0.2561, 0.9391], - [0.5954, 0.4184, 0.9160, 0.7977], - [0.0588, 0.9156, 0.1307, 0.9914]], - [[0.8481, 0.0892, 0.2044, 0.8470], - [0.3558, 0.3569, 0.0979, 0.4491], - [0.0876, 0.0929, 0.4040, 0.8384], - [0.5313, 0.3927, 0.4165, 0.1107]], - [[0.7993, 0.6938, 0.3151, 0.8728], - [0.7332, 0.4111, 0.3862, 0.9988], - [0.2622, 0.5002, 0.1905, 0.1644], - [0.6354, 0.0047, 0.1649, 0.7112]]]]), - - "targets": torch.tensor([[[1, 2, 0, 1], - [0, 2, 1, 2], - [0, 0, 2, 1], - [1, 1, 1, 2]]]), + "inputs": torch.tensor( + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]), }, - 1.0375, # the result equals to -1 + np.log(1 + np.exp(1)) + 1.1850, ], - [ # shape: (2, 2, 3), (2, 2, 3) - {"classes": 3, "dim": 3}, + [ + {"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, { - "inputs": torch.tensor([[[[[0.5977, 0.2767, 0.0591, 0.1675], - [0.4835, 0.3778, 0.8406, 0.3065], - [0.6047, 0.2860, 0.9742, 0.2013], - [0.9128, 0.8368, 0.6711, 0.4384]], - [[0.9797, 0.1863, 0.5584, 0.6652], - [0.2272, 0.2004, 0.7914, 0.4224], - [0.5097, 0.8818, 0.2581, 0.3495], - [0.1054, 0.5483, 0.3732, 0.3587]], - [[0.3060, 0.7066, 0.7922, 0.4689], - [0.1733, 0.8902, 0.6704, 0.2037], - [0.8656, 0.5561, 0.2701, 0.0092], - [0.1866, 0.7714, 0.6424, 0.9791]], - [[0.5067, 0.3829, 0.6156, 0.8985], - [0.5192, 0.8347, 0.2098, 0.2260], - [0.8887, 0.3944, 0.6400, 0.5345], - [0.1207, 0.3763, 0.5282, 0.7741]]], - [[[0.8499, 0.4759, 0.1964, 0.5701], - [0.3190, 0.1238, 0.2368, 0.9517], - [0.0797, 0.6185, 0.0135, 0.8672], - [0.4116, 0.1683, 0.1355, 0.0545]], - [[0.7533, 0.2658, 0.5955, 0.4498], - [0.9500, 0.2317, 0.2825, 0.9763], - [0.1493, 0.1558, 0.3743, 0.8723], - [0.1723, 0.7980, 0.8816, 0.0133]], - [[0.8426, 0.2666, 0.2077, 0.3161], - [0.1725, 0.8414, 0.1515, 0.2825], - [0.4882, 0.5159, 0.4120, 0.1585], - [0.2551, 0.9073, 0.7691, 0.9898]], - [[0.4633, 0.8717, 0.8537, 0.2899], - [0.3693, 0.7953, 0.1183, 0.4596], - [0.0087, 0.7925, 0.0989, 0.8385], - [0.8261, 0.6920, 0.7069, 0.4464]]], - [[[0.0110, 0.1608, 0.4814, 0.6317], - [0.0194, 0.9669, 0.3259, 0.0028], - [0.5674, 0.8286, 0.0306, 0.5309], - [0.3973, 0.8183, 0.0238, 0.1934]], - [[0.8947, 0.6629, 0.9439, 0.8905], - [0.0072, 0.1697, 0.4634, 0.0201], - [0.7184, 0.2424, 0.0820, 0.7504], - [0.3937, 0.1424, 0.4463, 0.5779]], - [[0.4123, 0.6227, 0.0523, 0.8826], - [0.0051, 0.0353, 0.3662, 0.7697], - [0.4867, 0.8986, 0.2510, 0.5316], - [0.1856, 0.2634, 0.9140, 0.9725]], - [[0.2041, 0.4248, 0.2371, 0.7256], - [0.2168, 0.5380, 0.4538, 0.7007], - [0.9013, 0.2623, 0.0739, 0.2998], - [0.1366, 0.5590, 0.2952, 0.4592]]]]]), - - "targets": torch.tensor([[[[0, 1, 0, 1], - [1, 2, 1, 0], - [2, 1, 1, 1], - [1, 1, 0, 1]], - [[2, 1, 0, 2], - [1, 2, 0, 2], - [1, 0, 1, 1], - [1, 1, 0, 0]], - [[1, 0, 2, 1], - [0, 2, 2, 1], - [1, 0, 1, 1], - [0, 0, 2, 1]], - [[2, 1, 1, 0], - [1, 0, 0, 2], - [1, 0, 2, 1], - [2, 1, 0, 1]]]]), - }, - 1.1504, # the result equals to -1 + np.log(1 + np.exp(1)) + "inputs": torch.tensor( + [ + [ + [ + [0.0411, 0.3386, 0.8352, 0.2741], + [0.4821, 0.0519, 0.2561, 0.9391], + [0.5954, 0.4184, 0.9160, 0.7977], + [0.0588, 0.9156, 0.1307, 0.9914], + ], + [ + [0.8481, 0.0892, 0.2044, 0.8470], + [0.3558, 0.3569, 0.0979, 0.4491], + [0.0876, 0.0929, 0.4040, 0.8384], + [0.5313, 0.3927, 0.4165, 0.1107], + ], + [ + [0.7993, 0.6938, 0.3151, 0.8728], + [0.7332, 0.4111, 0.3862, 0.9988], + [0.2622, 0.5002, 0.1905, 0.1644], + [0.6354, 0.0047, 0.1649, 0.7112], + ], + ] + ] + ), + "targets": torch.tensor([[[1, 2, 0, 1], [0, 2, 1, 2], [0, 0, 2, 1], [1, 1, 1, 2]]]), + }, + 1.0375, + ], + [ + {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, + { + "inputs": torch.tensor( + [ + [ + [ + [ + [0.5977, 0.2767, 0.0591, 0.1675], + [0.4835, 0.3778, 0.8406, 0.3065], + [0.6047, 0.2860, 0.9742, 0.2013], + [0.9128, 0.8368, 0.6711, 0.4384], + ], + [ + [0.9797, 0.1863, 0.5584, 0.6652], + [0.2272, 0.2004, 0.7914, 0.4224], + [0.5097, 0.8818, 0.2581, 0.3495], + [0.1054, 0.5483, 0.3732, 0.3587], + ], + [ + [0.3060, 0.7066, 0.7922, 0.4689], + [0.1733, 0.8902, 0.6704, 0.2037], + [0.8656, 0.5561, 0.2701, 0.0092], + [0.1866, 0.7714, 0.6424, 0.9791], + ], + [ + [0.5067, 0.3829, 0.6156, 0.8985], + [0.5192, 0.8347, 0.2098, 0.2260], + [0.8887, 0.3944, 0.6400, 0.5345], + [0.1207, 0.3763, 0.5282, 0.7741], + ], + ], + [ + [ + [0.8499, 0.4759, 0.1964, 0.5701], + [0.3190, 0.1238, 0.2368, 0.9517], + [0.0797, 0.6185, 0.0135, 0.8672], + [0.4116, 0.1683, 0.1355, 0.0545], + ], + [ + [0.7533, 0.2658, 0.5955, 0.4498], + [0.9500, 0.2317, 0.2825, 0.9763], + [0.1493, 0.1558, 0.3743, 0.8723], + [0.1723, 0.7980, 0.8816, 0.0133], + ], + [ + [0.8426, 0.2666, 0.2077, 0.3161], + [0.1725, 0.8414, 0.1515, 0.2825], + [0.4882, 0.5159, 0.4120, 0.1585], + [0.2551, 0.9073, 0.7691, 0.9898], + ], + [ + [0.4633, 0.8717, 0.8537, 0.2899], + [0.3693, 0.7953, 0.1183, 0.4596], + [0.0087, 0.7925, 0.0989, 0.8385], + [0.8261, 0.6920, 0.7069, 0.4464], + ], + ], + [ + [ + [0.0110, 0.1608, 0.4814, 0.6317], + [0.0194, 0.9669, 0.3259, 0.0028], + [0.5674, 0.8286, 0.0306, 0.5309], + [0.3973, 0.8183, 0.0238, 0.1934], + ], + [ + [0.8947, 0.6629, 0.9439, 0.8905], + [0.0072, 0.1697, 0.4634, 0.0201], + [0.7184, 0.2424, 0.0820, 0.7504], + [0.3937, 0.1424, 0.4463, 0.5779], + ], + [ + [0.4123, 0.6227, 0.0523, 0.8826], + [0.0051, 0.0353, 0.3662, 0.7697], + [0.4867, 0.8986, 0.2510, 0.5316], + [0.1856, 0.2634, 0.9140, 0.9725], + ], + [ + [0.2041, 0.4248, 0.2371, 0.7256], + [0.2168, 0.5380, 0.4538, 0.7007], + [0.9013, 0.2623, 0.0739, 0.2998], + [0.1366, 0.5590, 0.2952, 0.4592], + ], + ], + ] + ] + ), + "targets": torch.tensor( + [ + [ + [[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]], + [[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]], + [[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]], + [[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]], + ] + ] + ), + }, + 1.1504, ], ] class TestNACLLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): loss = NACLLoss(**input_param)