Skip to content

Commit

Permalink
Add mean loss and resolve formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Bala93 committed Aug 3, 2024
1 parent 0067953 commit dccde47
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 119 deletions.
76 changes: 65 additions & 11 deletions monai/losses/segcalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import math
import warnings

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,16 @@
from monai.utils import pytorch_after


def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor:
mean_kernel = torch.ones([ksize, ksize]) / (ksize**2)
return mean_kernel


def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor:
mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3)
return mean_kernel


def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor:
x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize)
y_grid = x_grid.t()
Expand Down Expand Up @@ -101,6 +112,50 @@ def forward(self, x):
return self.svls_layer(x) / self.svls_kernel.sum()


class MeanFilter(torch.nn.Module):
def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> torch.Tensor:
super(MeanFilter, self).__init__()

if dim == 2:
self.svls_kernel = get_mean_kernel_2d(ksize=ksize)
svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize)
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_2d
self.svls_layer.weight.requires_grad = False

if dim == 3:
self.svls_kernel = get_mean_kernel_3d(ksize=ksize)
svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize)
svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv3d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_3d
self.svls_layer.weight.requires_grad = False

def forward(self, x):
return self.svls_layer(x) / self.svls_kernel.sum()


class NACLLoss(_Loss):
"""
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
Expand All @@ -118,6 +173,7 @@ def __init__(
classes: int,
dim: int,
kernel_size: int = 3,
kernel_ops: str = "mean",
distance_type: str = "l1",
alpha: float = 0.1,
sigma: float = 1.0,
Expand All @@ -133,6 +189,9 @@ def __init__(

super().__init__()

if kernel_ops not in ["mean", "gaussian"]:
raise ValueError("Kernel ops must be either mean or gaussian")

if dim not in [2, 3]:
raise ValueError("Supoorts 2d and 3d")

Expand All @@ -146,7 +205,10 @@ def __init__(
self.alpha = alpha
self.ks = kernel_size

self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes)
if kernel_ops == "mean":
self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes)
if kernel_ops == "gaussian":
self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes)

self.old_pt_ver = not pytorch_after(1, 10)

Expand All @@ -173,24 +235,16 @@ def __init__(
# return self.cross_entropy(input, target) # type: ignore[no-any-return]

def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:

if self.dim == 2:

oh_labels = (
F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
)
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
rmask = self.svls_layer(oh_labels)

if self.dim == 3:

oh_labels = (
F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
)
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
rmask = self.svls_layer(oh_labels)

return rmask


def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
loss_ce = self.cross_entropy(inputs, targets)

Expand Down
Loading

0 comments on commit dccde47

Please sign in to comment.