-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloss_functions.py
109 lines (95 loc) · 4.1 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn.functional as F
class DiceLoss(torch.nn.Module):
def __init__(self, smooth=1.):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, pred, target):
intersection = (pred * target).sum()
denominator = pred.sum() + target.sum()
dice_score = (2. * intersection + self.smooth) / (denominator + self.smooth)
dice_loss = 1. - dice_score
return dice_loss
class clDiceLoss(torch.nn.Module):
def __init__(self, smooth=1.):
super(clDiceLoss, self).__init__()
self.smooth = smooth
def soft_cldice_loss(self, pred, target, target_skeleton=None):
'''
inputs shape (batch, channel, height, width).
calculate clDice loss
Because pred and target at moment of loss calculation will be a torch tensors
it is preferable to calculate target_skeleton on the step of batch forming,
when it will be in numpy array format by means of opencv
'''
cl_pred = self.soft_skeletonize(pred)
if target_skeleton is None:
target_skeleton = self.soft_skeletonize(target)
iflat = self.norm_intersection(cl_pred, target)
tflat = self.norm_intersection(target_skeleton, pred)
intersection = (iflat * tflat).sum()
return 1. - (2. * intersection) / (iflat + tflat).sum()
def dice_loss(self, pred, target):
'''
inputs shape (batch, channel, height, width).
calculate dice loss per batch and channel of sample.
E.g. if batch shape is [64, 1, 128, 128] -> [64, 1]
'''
intersection = (pred * target).sum()
denominator = pred.sum() + target.sum()
dice_score = (2. * intersection + self.smooth) / (denominator + self.smooth)
dice_loss = 1. - dice_score
return dice_loss
def soft_skeletonize(self, x, thresh_width=10):
'''
Differenciable aproximation of morphological skelitonization operaton
thresh_width - maximal expected width of vessel
'''
for i in range(thresh_width):
min_pool_x = torch.nn.functional.max_pool2d(x * -1, (3, 3), 1, 1) * -1
contour = torch.nn.functional.relu(torch.nn.functional.max_pool2d(min_pool_x, (3, 3), 1, 1) - min_pool_x)
x = torch.nn.functional.relu(x - contour)
return x
def norm_intersection(self, center_line, vessel):
'''
inputs shape (batch, channel, height, width)
intersection formalized by first ares
x - suppose to be centerline of vessel (pred or gt) and y - is vessel (pred or gt)
'''
smooth = 1.
clf = center_line.view(*center_line.shape[:2], -1)
vf = vessel.view(*vessel.shape[:2], -1)
intersection = (clf * vf).sum(-1)
return (intersection + smooth) / (clf.sum(-1) + smooth)
def forward(self, pred, target):
return 0.8 * self.dice_loss(pred, target) + 0.2 * self.soft_cldice_loss(pred, target)
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class MCELoss(torch.nn.Module):
def __init__(self, reduction='mean'):
super(MCELoss, self).__init__()
self.reduction = reduction
def forward(self, inputs, targets):
predicted_classes = inputs.argmax(dim=1)
incorrect_predictions = predicted_classes != targets
mce_loss = incorrect_predictions.float().mean()
if self.reduction == 'mean':
return mce_loss
elif self.reduction == 'sum':
return mce_loss * inputs.size(0)
else:
return mce_loss