-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathBatchAverage.py
52 lines (43 loc) · 1.56 KB
/
BatchAverage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch.autograd import Function
from torch import nn
import math
import numpy as np
class BatchCriterion(nn.Module):
''' Compute the loss within each batch
'''
def __init__(self, negM, T, batchSize):
super(BatchCriterion, self).__init__()
self.negM = negM
self.T = T
self.diag_mat = 1 - torch.eye(batchSize*2).cuda()
def forward(self, x, targets):
batchSize = x.size(0)
#get positive innerproduct
reordered_x = torch.cat((x.narrow(0,batchSize//2,batchSize//2),\
x.narrow(0,0,batchSize//2)), 0)
#reordered_x = reordered_x.data
pos = (x*reordered_x.data).sum(1).div_(self.T).exp_()
#get all innerproduct, remove diag
all_prob = torch.mm(x,x.t().data).div_(self.T).exp_()*self.diag_mat
if self.negM==1:
all_div = all_prob.sum(1)
else:
#remove pos for neg
all_div = (all_prob.sum(1) - pos)*self.negM + pos
lnPmt = torch.div(pos, all_div)
# negative probability
Pon_div = all_div.repeat(batchSize,1)
lnPon = torch.div(all_prob, Pon_div.t())
lnPon = -lnPon.add(-1)
# equation 7 in ref. A (NCE paper)
lnPon.log_()
# also remove the pos term
lnPon = lnPon.sum(1) - (-lnPmt.add(-1)).log_()
lnPmt.log_()
lnPmtsum = lnPmt.sum(0)
lnPonsum = lnPon.sum(0)
# negative multiply m
lnPonsum = lnPonsum * self.negM
loss = - (lnPmtsum + lnPonsum)/batchSize
return loss