-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
62 lines (51 loc) · 2.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from utils import load_model, extend_maps, prepocess_data_for_lstmcrf
from data_process import build_corpus
from evaluating import Metrics
from evaluate import ensemble_evaluate
import os
checkpoint_path = "ckpts/param_128_128_1_pku"
data_tag = "pku"
CRF_MODEL_PATH = os.path.join(checkpoint_path, 'crf.pkl')
BiLSTM_MODEL_PATH = os.path.join(checkpoint_path, 'bilstm.pkl')
BiLSTMCRF_MODEL_PATH = os.path.join(checkpoint_path, 'bilstm_crf.pkl')
REMOVE_O = False # 在评估的时候是否去除O标记
def main():
print("读取数据...")
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus(data_tag, "train")
# dev_word_lists, dev_tag_lists = build_corpus(data_tag, "dev", make_vocab=False)
test_word_lists, test_tag_lists = build_corpus(data_tag, "test", make_vocab=False)
# 加载并评估CRF模型
print("加载并评估crf模型...")
crf_model = load_model(CRF_MODEL_PATH)
crf_pred = crf_model.test(test_word_lists)
metrics = Metrics(test_tag_lists, crf_pred, remove_O=REMOVE_O)
metrics.report_scores()
metrics.report_confusion_matrix()
# bilstm模型
print("加载并评估bilstm模型...")
bilstm_word2id, bilstm_tag2id = extend_maps(word2id, tag2id, for_crf=False)
bilstm_model = load_model(BiLSTM_MODEL_PATH)
bilstm_model.model.bilstm.flatten_parameters() # remove warning
lstm_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
bilstm_word2id, bilstm_tag2id)
metrics = Metrics(target_tag_list, lstm_pred, remove_O=REMOVE_O)
metrics.report_scores()
metrics.report_confusion_matrix()
print("加载并评估bilstm+crf模型...")
crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
bilstm_model.model.bilstm.bilstm.flatten_parameters() # remove warning
test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
test_word_lists, test_tag_lists, test=True
)
lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
crf_word2id, crf_tag2id)
metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
metrics.report_scores()
metrics.report_confusion_matrix()
ensemble_evaluate(
[crf_pred, lstm_pred, lstmcrf_pred],
test_tag_lists
)
if __name__ == "__main__":
main()