-
Notifications
You must be signed in to change notification settings - Fork 4
/
evaluate.py
73 lines (53 loc) · 2.38 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import time, os
from collections import Counter
from models.crf import CRFModel
from models.bilstm_crf import BILSTM_Model
from utils import save_model, flatten_lists
from evaluating import Metrics
def crf_train_eval(train_data, test_data, output_dir, remove_O=False):
# 训练CRF模型
train_word_lists, train_tag_lists = train_data
test_word_lists, test_tag_lists = test_data
crf_model = CRFModel()
crf_model.train(train_word_lists, train_tag_lists)
save_model(crf_model, os.path.join(output_dir, 'crf.pkl'))
pred_tag_lists = crf_model.test(test_word_lists)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists
def bilstm_train_and_eval(train_data, dev_data, test_data,
word2id, tag2id, output_dir, crf=True, remove_O=False):
train_word_lists, train_tag_lists = train_data
dev_word_lists, dev_tag_lists = dev_data
test_word_lists, test_tag_lists = test_data
start = time.time()
vocab_size = len(word2id)
out_size = len(tag2id)
bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
bilstm_model.train(train_word_lists, train_tag_lists,
dev_word_lists, dev_tag_lists, word2id, tag2id)
model_name = "bilstm_crf" if crf else "bilstm"
save_model(bilstm_model, os.path.join(output_dir, model_name + ".pkl"))
print("训练完毕,共用时{}秒.".format(int(time.time()-start)))
print("评估{}模型中...".format(model_name))
pred_tag_lists, test_tag_lists = bilstm_model.test(
test_word_lists, test_tag_lists, word2id, tag2id)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists
def ensemble_evaluate(results, targets, remove_O=False):
"""ensemble多个模型"""
for i in range(len(results)):
results[i] = flatten_lists(results[i])
pred_tags = []
for result in zip(*results):
ensemble_tag = Counter(result).most_common(1)[0][0]
pred_tags.append(ensemble_tag)
targets = flatten_lists(targets)
assert len(pred_tags) == len(targets)
print("Ensemble 三个模型的结果如下:")
metrics = Metrics(targets, pred_tags, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()