-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
27 lines (22 loc) · 924 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn.functional as F
class SCELoss(torch.nn.Module):
def __init__(self, alpha, beta, num_classes=10):
super(SCELoss, self).__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.alpha = alpha
self.beta = beta
self.num_classes = num_classes
self.cross_entropy = torch.nn.CrossEntropyLoss()
def forward(self, pred, labels):
# CCE
ce = self.cross_entropy(pred, labels)
# RCE
pred = F.softmax(pred, dim=1)
pred = torch.clamp(pred, min=1e-7, max=1.0)
label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float().to(self.device)
label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
# Loss
loss = self.alpha * ce + self.beta * rce.mean()
return loss