forked from rogerrojur/tianchi-multi-task-nlp
-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
executable file
·121 lines (105 loc) · 5.61 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 5 17:47:24 2020
@author: luokai
"""
from net import Net
import json
import torch
import numpy as np
from transformers import BertModel, BertTokenizer
from utils import get_task_chinese
def test_csv_to_json():
for e in ['TNEWS', 'OCNLI', 'OCEMOTION']:
with open('./tianchi_datasets/' + e + '/test.csv') as fr:
with open('./tianchi_datasets/' + e + '/test.json', 'w') as fw:
json_dict = dict()
for line in fr:
tmp_list = line.strip().split('\t')
json_dict[tmp_list[0]] = dict()
json_dict[tmp_list[0]]['s1'] = tmp_list[1]
if e == 'OCNLI':
json_dict[tmp_list[0]]['s2'] = tmp_list[2]
fw.write(json.dumps(json_dict))
def inference_warpper(tokenizer_model):
ocnli_test = dict()
with open('./tianchi_datasets/OCNLI/test.json') as f:
for line in f:
ocnli_test = json.loads(line)
break
ocemotion_test = dict()
with open('./tianchi_datasets/OCEMOTION/test.json') as f:
for line in f:
ocemotion_test = json.loads(line)
break
tnews_test = dict()
with open('./tianchi_datasets/TNEWS/test.json') as f:
for line in f:
tnews_test = json.loads(line)
break
label_dict = dict()
with open('./tianchi_datasets/label.json') as f:
for line in f:
label_dict = json.loads(line)
break
model = torch.load('./saved_best.pt')
tokenizer = BertTokenizer.from_pretrained(tokenizer_model)
inference('./submission/5928/ocnli_predict.json', ocnli_test, model, tokenizer, label_dict['OCNLI'], 'ocnli', 'cuda:3', 64, True)
inference('./submission/5928/ocemotion_predict.json', ocemotion_test, model, tokenizer, label_dict['OCEMOTION'], 'ocemotion', 'cuda:3', 64, True)
inference('./submission/5928/tnews_predict.json', tnews_test, model, tokenizer, label_dict['TNEWS'], 'tnews', 'cuda:3', 64, True)
def inference(path, data_dict, model, tokenizer, idx2label, task_type, device='cuda:3', batchSize=64, print_result=True):
if task_type != 'ocnli' and task_type != 'ocemotion' and task_type != 'tnews':
print('task_type is incorrect!')
return
model.to(device, non_blocking=True)
model.eval()
ids_list = [k for k, _ in data_dict.items()]
next_start_ids = 0
with torch.no_grad():
with open(path, 'w') as f:
while next_start_ids < len(ids_list):
cur_ids_list = ids_list[next_start_ids: next_start_ids + batchSize]
next_start_ids += batchSize
if task_type == 'ocnli':
flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], [data_dict[idx]['s2'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt')
else:
flower = tokenizer([data_dict[idx]['s1'] for idx in cur_ids_list], add_special_tokens=True, padding=True, return_tensors='pt')
input_ids = flower['input_ids'].to(device, non_blocking=True)
token_type_ids = flower['token_type_ids'].to(device, non_blocking=True)
attention_mask = flower['attention_mask'].to(device, non_blocking=True)
ocnli_ids = torch.tensor([]).to(device, non_blocking=True)
ocemotion_ids = torch.tensor([]).to(device, non_blocking=True)
tnews_ids = torch.tensor([]).to(device, non_blocking=True)
if task_type == 'ocnli':
ocnli_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
elif task_type == 'ocemotion':
ocemotion_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
else:
tnews_ids = torch.tensor([i for i in range(len(cur_ids_list))]).to(device, non_blocking=True)
ocnli_out, ocemotion_out, tnews_out = model(input_ids, ocnli_ids, ocemotion_ids, tnews_ids, token_type_ids, attention_mask)
if task_type == 'ocnli':
pred = torch.argmax(ocnli_out, axis=1)
elif task_type == 'ocemotion':
pred = torch.argmax(ocemotion_out, axis=1)
else:
pred = torch.argmax(tnews_out, axis=1)
pred_final = [idx2label[e] for e in np.array(pred.cpu()).tolist()]
torch.cuda.empty_cache()
for i, idx in enumerate(cur_ids_list):
if print_result:
print_str = '[ ' + task_type + ' : ' + 'sentence one: ' + data_dict[idx]['s1']
if task_type == 'ocnli':
print_str += '; sentence two: ' + data_dict[idx]['s2']
print_str += '; result: ' + pred_final[i] + ' ]'
print(print_str)
single_result_dict = dict()
single_result_dict['id'] = idx
single_result_dict['label'] = pred_final[i]
f.write(json.dumps(single_result_dict, ensure_ascii=False))
if not (next_start_ids >= len(ids_list) and i == len(cur_ids_list) - 1):
f.write('\n')
if __name__ == '__main__':
test_csv_to_json()
print('---------------------------------start inference-----------------------------')
inference_warpper(tokenizer_model='./robert_pretrain_model')