forked from xingruiyu/coteaching_plus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
83 lines (64 loc) · 3.23 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from numpy.testing import assert_array_almost_equal
# Loss functions
def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not):
loss_1 = F.cross_entropy(y_1, t, reduction='none')
ind_1_sorted = np.argsort(loss_1.cpu().data).cuda()
loss_1_sorted = loss_1[ind_1_sorted]
loss_2 = F.cross_entropy(y_2, t, reduction='none')
ind_2_sorted = np.argsort(loss_2.cpu().data).cuda()
loss_2_sorted = loss_2[ind_2_sorted]
remember_rate = 1 - forget_rate
num_remember = int(remember_rate * len(loss_1_sorted))
ind_1_update=ind_1_sorted[:num_remember].cpu()
ind_2_update=ind_2_sorted[:num_remember].cpu()
if len(ind_1_update) == 0:
ind_1_update = ind_1_sorted.cpu().numpy()
ind_2_update = ind_2_sorted.cpu().numpy()
num_remember = ind_1_update.shape[0]
pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_update]])/float(num_remember)
pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_update]])/float(num_remember)
loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update])
loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update])
return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2
def loss_coteaching_plus(logits, logits2, labels, forget_rate, ind, noise_or_not, step):
outputs = F.softmax(logits, dim=1)
outputs2 = F.softmax(logits2, dim=1)
_, pred1 = torch.max(logits.data, 1)
_, pred2 = torch.max(logits2.data, 1)
pred1, pred2 = pred1.cpu().numpy(), pred2.cpu().numpy()
logical_disagree_id=np.zeros(labels.size(), dtype=bool)
disagree_id = []
for idx, p1 in enumerate(pred1):
if p1 != pred2[idx]:
disagree_id.append(idx)
logical_disagree_id[idx] = True
temp_disagree = ind*logical_disagree_id.astype(np.int64)
ind_disagree = np.asarray([i for i in temp_disagree if i != 0]).transpose()
try:
assert ind_disagree.shape[0]==len(disagree_id)
except:
disagree_id = disagree_id[:ind_disagree.shape[0]]
_update_step = np.logical_or(logical_disagree_id, step < 5000).astype(np.float32)
update_step = Variable(torch.from_numpy(_update_step)).cuda()
if len(disagree_id) > 0:
update_labels = labels[disagree_id]
update_outputs = outputs[disagree_id]
update_outputs2 = outputs2[disagree_id]
loss_1, loss_2, pure_ratio_1, pure_ratio_2 = loss_coteaching(update_outputs, update_outputs2, update_labels, forget_rate, ind_disagree, noise_or_not)
else:
update_labels = labels
update_outputs = outputs
update_outputs2 = outputs2
cross_entropy_1 = F.cross_entropy(update_outputs, update_labels)
cross_entropy_2 = F.cross_entropy(update_outputs2, update_labels)
loss_1 = torch.sum(update_step*cross_entropy_1)/labels.size()[0]
loss_2 = torch.sum(update_step*cross_entropy_2)/labels.size()[0]
pure_ratio_1 = np.sum(noise_or_not[ind])/ind.shape[0]
pure_ratio_2 = np.sum(noise_or_not[ind])/ind.shape[0]
return loss_1, loss_2, pure_ratio_1, pure_ratio_2