-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
128 lines (114 loc) · 5.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from operator import itemgetter
import torch.nn.utils
from surel_gacc import sjoin
from sklearn.metrics import roc_auc_score
from torch.nn import BCEWithLogitsLoss
from utils import *
def train(model, opti, data, dT):
model.train()
total_loss = 0
labels, preds = [], []
for wl, wr, label, x in data:
labels.append(label)
Tf = torch.stack([dT[wl], dT[wr]])
opti.zero_grad()
pred = model(Tf, [wl, wr])
preds.append(pred.detach().sigmoid())
target = label.to(pred.device)
loss = BCEWithLogitsLoss()(pred, target)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
loss.backward()
opti.step()
total_loss += loss.item() * len(label)
predictions = torch.cat(preds).cpu()
labels = torch.cat(labels)
return total_loss / len(labels), roc_auc_score(labels, predictions)
def eval_model(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False):
model.eval()
preds = []
with torch.no_grad():
x_embed, target = x_set['X'], x_set[mode]['E']
with tqdm(total=len(target)) as pbar:
for batch in gen_batch(target, args.batch_num, keep=True):
Bs = torch.unique(batch).numpy()
S, K, F = zip(*itemgetter(*Bs)(x_dict))
S = torch.from_numpy(np.asarray(S)).long()
F = np.concatenate(F)
F = np.concatenate([[[0] * F.shape[-1]], F])
mF = torch.from_numpy(F).to(device)
uvw, uvx = sjoin(S, K, batch, return_idx=True)
uvw = uvw.reshape(2, -1, 2)
x = torch.from_numpy(uvw)
gT = normalization(mF, args)
gT = torch.stack([gT[uvw[0]], gT[uvw[1]]])
pred = model(gT, x)
preds.append(pred.sigmoid())
pbar.update(len(pred))
predictions = torch.cat(preds, dim=0)
if not return_predictions:
labels = torch.zeros(len(predictions))
result_dict = {'metric': args.metric, 'mode': mode}
if args.metric == 'mrr':
num_pos = x_set[mode]['num_pos']
labels[:num_pos] = 1
pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:]
result_dict['mrr_list'] = \
evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list']
elif 'Hits' in args.metric:
num_neg = x_set[mode]['num_neg']
labels[num_neg:] = 1
pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:]
result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator)
result_dict['num_pos'] = len(pred_pos)
else:
raise NotImplementedError
result_dict['auc'] = roc_auc_score(labels, predictions.cpu())
return result_dict
else:
return predictions
def eval_model_horder(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False):
model.eval()
preds = []
with torch.no_grad():
x_embed, target = x_set['X'], x_set[mode]['E']
with tqdm(total=len(target)) as pbar:
for batch in gen_batch(target, args.batch_num, keep=True):
Bs = torch.unique(batch).numpy()
S, K, F = zip(*itemgetter(*Bs)(x_dict))
S = torch.from_numpy(np.asarray(S)).long()
F = np.concatenate(F)
F = np.concatenate([[[0] * F.shape[-1]], F])
mF = torch.from_numpy(F).to(device)
uw = sjoin(S, K, batch[:, [0, 2]], return_idx=False)
vw = sjoin(S, K, batch[:, [1, 2]], return_idx=False)
uvw = np.concatenate([uw, vw], axis=1).reshape(2, -1, 2)
x = torch.from_numpy(uvw)
gT = normalization(mF, args)
gT = torch.stack([gT[uvw[0]], gT[uvw[1]]])
pred = model(gT, x)
preds.append(pred.sigmoid())
pbar.update(len(pred))
predictions = torch.cat(preds, dim=0)
if not return_predictions:
labels = torch.zeros(len(predictions))
result_dict = {'metric': args.metric, 'mode': mode}
if args.metric == 'mrr':
num_pos = x_set[mode]['num_pos']
labels[:num_pos] = 1
pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:]
result_dict['mrr_list'] = \
evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list']
elif 'Hits' in args.metric:
num_neg = x_set[mode]['num_neg']
labels[num_neg:] = 1
pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:]
result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator)
result_dict['num_pos'] = len(pred_pos)
else:
raise NotImplementedError
result_dict['auc'] = roc_auc_score(labels, predictions.cpu())
return result_dict
else:
return predictions