-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpriority_train_eval.py
146 lines (135 loc) · 6.29 KB
/
priority_train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from pytorch_pretrained.optimization import BertAdam
import os
class MultiCEFocalLoss(torch.nn.Module):
def __init__(self, class_num, gamma=2, alpha=None, reduction='mean'):
super(MultiCEFocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.class_num = class_num
def forward(self, predict, target):
pt = F.softmax(predict, dim=1) # softmmax获取预测概率
class_mask = F.one_hot(target, self.class_num) #获取target的one hot编码
ids = target.view(-1, 1)
alpha = self.alpha[ids.data.view(-1)] # 注意,这里的alpha是给定的一个list(tensor
#),里面的元素分别是每一个类的权重因子
probs = (pt * class_mask).sum(1).view(-1, 1) # 利用onehot作为mask,提取对应的pt
log_p = probs.log()
# 同样,原始ce上增加一个动态权重衰减因子
loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
def train(config, model, train_iter, dev_iter, test_iter):
start_time = time.time()
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
log_path = config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
# optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
optimizer = BertAdam(optimizer_grouped_parameters,
lr=config.learning_rate,
warmup=0.05,
t_total=len(train_iter) * config.num_epochs)
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
last_improve = 0 # 记录上次验证集loss下降的batch数
flag = False # 记录是否很久没有效果提升
model.train()
for epoch in range(config.num_epochs):
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
for i, (trains, labels) in enumerate(train_iter):
outputs = model(trains)
# print(i,outputs,labels)
model.zero_grad()
l1=MultiCEFocalLoss(class_num = 3,alpha=torch.tensor([0.2,0.3,0.5]))
# l2=MultiCEFocalLoss(class_num = 3,alpha=torch.tensor([7/82,20/82,55/82]))
loss = l1(outputs.cpu(), labels[0].cpu())
loss.backward()
optimizer.step()
if total_batch % 100 == 0:
# 每多少轮输出在训练集和验证集上的效果
true1 = labels[0].data.cpu()
predic1 = torch.max(outputs.data, 1)[1].cpu()
train_acc1 = metrics.accuracy_score(true1, predic1)
dev_acc1, dev_loss = evaluate(config, model, dev_iter)
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
torch.save(model.state_dict(), config.save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc_pri: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc_pri: {4:>6.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss.item(), train_acc1, dev_loss, dev_acc1, time_dif, improve))
model.train()
total_batch += 1
if total_batch - last_improve > config.require_improvement:
# 验证集loss超过1000batch没下降,结束训练
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
test(config, model, test_iter,log_path)
def test(config, model, test_iter,log_path):
# test
model.load_state_dict(torch.load(config.save_path))
model.eval()
start_time = time.time()
test_acc1, test_loss, test_report1, test_confusion1 = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc_pri: {1:>6.2%} '
print(msg.format(test_loss, test_acc1))
print("Precision, Recall and F1-Score...")
print(test_report1)
# print(test_report2)
print("Confusion Matrix...")
print(test_confusion1)
# print(test_confusion2)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
if not os.path.exists(log_path): # 如果路径不存在
os.makedirs(log_path)
file = open(log_path+'/report.txt','w')
file.write(msg.format(test_loss, test_acc1)+'\n\n')
file.write('Precision, Recall and F1-Score...\n')
file.write(str(test_report1)+'\n')
file.write('Confusion Matrix...\n')
file.write(str(test_confusion1))
def evaluate(config, model, data_iter, test=False):
model.eval()
loss_total = 0
predict_all1 = np.array([], dtype=int)
labels_all1 = np.array([], dtype=int)
with torch.no_grad():
for texts, labels in data_iter:
outputs = model(texts)
l1=MultiCEFocalLoss(class_num = 3,alpha=torch.tensor([0.2,0.3,0.5]))
loss = l1(outputs.cpu(), labels[0].cpu())
loss_total += loss
labels1 = labels[0].data.cpu().numpy()
predic1 = torch.max(outputs.data, 1)[1].cpu().numpy()
labels_all1 = np.append(labels_all1, labels1)
predict_all1 = np.append(predict_all1, predic1)
acc1 = metrics.accuracy_score(labels_all1, predict_all1)
if test:
report1 = metrics.classification_report(labels_all1, predict_all1, digits=4)
confusion1 = metrics.confusion_matrix(labels_all1, predict_all1)
return acc1, loss_total / len(data_iter), report1, confusion1
return acc1, loss_total / len(data_iter)