-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
214 lines (177 loc) · 8.02 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import numpy as np
import data_io as pio
import nltk
from bert_serving.client import BertClient
bc = BertClient(ip='ess76.wisers.com', port=16995, port_out=16996)
vocb_path = 'data/vocab.txt'
glove_path = 'data/glove.6B.300d.txt'
class Preprocessor:
def __init__(self, datasets_fp, max_length=384, stride=128):
self.datasets_fp = datasets_fp
self.max_length = max_length
self.max_clen = 25
self.max_qlen = 25
self.max_char_len = 5
self.stride = stride
self.charset = set()
self.word_list = []
self.embeddings_index = {}
self.embeddings_matrix = []
# self.load_glove(glove_path)
# self.build_charset()
# self.build_wordset()
def build_charset(self):
for fp in self.datasets_fp:
self.charset |= self.dataset_info(fp)
self.charset = sorted(list(self.charset))
self.charset = ['[PAD]', '[CLS]', '[SEP]'] + self.charset + ['[UNK]']
idx = list(range(len(self.charset)))
self.ch2id = dict(zip(self.charset, idx))
self.id2ch = dict(zip(idx, self.charset))
# print(self.ch2id, self.id2ch)
def build_wordset(self):
# all words in vocab
idx = list(range(len(self.word_list)))
self.w2id = dict(zip(self.word_list, idx))
self.id2w = dict(zip(idx, self.word_list))
def dataset_info(self, inn):
charset = set()
dataset = pio.load(inn)
for _, context, question, answer, _ in self.iter_cqa(dataset):
charset |= set(context) | set(question) | set(answer)
# self.max_clen = max(self.max_clen, len(context))
# self.max_qlen = max(self.max_clen, len(question))
return charset
def iter_cqa(self, dataset):
for data in dataset['data']:
for paragraph in data['paragraphs']:
context = paragraph['context']
for qa in paragraph['qas']:
qid = qa['id']
question = qa['question']
for answer in qa['answers']:
text = answer['text']
answer_start = answer['answer_start']
yield qid, context, question, text, answer_start
def char_encode(self, context, question):
q_seg_list = self.tokenize(question)
c_seg_list = self.tokenize(context)
question_encode = self.convert2id_char(q_seg_list, max_char_len=self.max_char_len, maxlen=100, begin=True, end=True)
print(np.array(question_encode))
left_length = self.max_length - len(question_encode)
context_encode = self.convert2id_char(c_seg_list, max_char_len=self.max_char_len, maxlen=left_length, end=True)
cq_encode = question_encode + context_encode
assert len(cq_encode) == self.max_length
return cq_encode
def word_encode(self, context, question):
q_seg_list = self.tokenize(question)
c_seg_list = self.tokenize(context)
question_encode = self.convert2id_word(q_seg_list, begin=True, end=True)
left_length = self.max_length - len(question_encode)
context_encode = self.convert2id_word(c_seg_list, maxlen=left_length, end=True)
cq_encode = question_encode + context_encode
assert len(cq_encode) == self.max_length
return cq_encode
def convert2id_char(self, seg_list, max_char_len=None, maxlen=None, begin=False, end=False):
char_list = []
char_list = [[self.get_id_char('[CLS]')] + [self.get_id_char('[PAD]')] * (max_char_len-1)] * begin + char_list
for word in seg_list:
ch = [ch for ch in word]
# ch = ['[CLS]'] * begin + ch
if max_char_len is not None:
ch = ch[:max_char_len]
ids = list(map(self.get_id_char, ch))
while len(ids) < max_char_len:
ids.append(self.get_id_char('[PAD]'))
char_list.append(np.array(ids))
if maxlen is not None:
char_list = char_list[:maxlen - 1 * end]
char_list += [[self.get_id_char('[PAD]')] * max_char_len] * (maxlen - len(char_list))
return char_list
def convert2id_word(self, seg_list, maxlen=None, begin=False, end=False):
word = [word for word in seg_list]
word = ['cls'] * begin + word
if maxlen is not None:
word = word[:maxlen - 1 * end]
word += ['sep'] * end
word += ['pad'] * (maxlen - len(word))
else:
word += ['sep'] * end
ids = list(map(self.get_id_word, word))
return ids
def get_id_char(self, ch):
return self.ch2id.get(ch, self.ch2id['[UNK]'])
def get_id_word(self, word):
return self.w2id.get(word, self.w2id['unk'])
def get_dataset(self, ds_fp):
ccs, qcs, cws, qws, be = [], [], [], [], []
for _, cc, qc, cw, qw, b, e in self.get_data(ds_fp):
ccs.append(cc)
qcs.append(qc)
cws.append(cw)
qws.append(qw)
be.append((b, e))
return map(np.array, (ccs, qcs, cws, qws, be))
def get_data(self, ds_fp):
dataset = pio.load(ds_fp)
for qid, context, question, text, answer_start in self.iter_cqa(dataset):
c_seg_list = self.tokenize(context)
q_seg_list = self.tokenize(question)
c_char_ids = self.get_sent_ids_char(c_seg_list, self.max_clen)
q_char_ids = self.get_sent_ids_char(q_seg_list, self.max_qlen)
c_word_ids = self.get_sent_ids_word(c_seg_list, self.max_clen)
q_word_ids = self.get_sent_ids_word(q_seg_list, self.max_qlen)
b, e = answer_start, answer_start + len(text)
nb = -1
ne = -1
len_all_char = 0
for i, w in enumerate(c_seg_list):
if i == 0:
continue
if b > len_all_char - 1 and b <= len_all_char + len(w) - 1:
b = i + 1
if e > len_all_char - 1 and e <= len_all_char + len(w) - 1:
e = i + 1
len_all_char += len(w)
if ne == -1:
b = e = 0
yield qid, c_char_ids, q_char_ids, c_word_ids, q_word_ids, b, e
def get_sent_ids_char(self, sent, maxlen):
return self.convert2id_char(sent, max_char_len=self.max_char_len, maxlen=maxlen, end=True)
def get_sent_ids_word(self, sent, maxlen):
return self.convert2id_word(sent, maxlen=maxlen)
def tokenize(self, sequence, do_lowercase=True):
if do_lowercase:
tokens = [token.replace("``", '"').replace("''", '"').lower()
for token in nltk.word_tokenize(sequence)]
else:
tokens = [token.replace("``", '"').replace("''", '"')
for token in nltk.word_tokenize(sequence)]
return tokens
def load_glove(self, glove_path):
with open(glove_path, encoding='utf-8') as fr:
for line in fr:
word, coefs = line.split(maxsplit=1)
coefs = np.fromstring(coefs, sep=' ')
self.embeddings_index[word] = coefs
self.word_list.append(word)
self.embeddings_matrix.append(coefs)
def bert_encode(self, ds_fp):
cs, qs, be = [], [], []
dataset = pio.load(ds_fp)
for qid, context, question, text, answer_start in self.iter_cqa(dataset):
cc = bc.encode([context[:self.max_clen]])[0] # max_seq, emb_size
qc = bc.encode([question[:self.max_qlen]])[0] # max_seq, emb_size
b, e = answer_start, answer_start + len(text)
cs.append(cc)
qs.append(qc)
be.append((b, e))
return map(np.array, (cs, qs, be))
if __name__ == '__main__':
p = Preprocessor([
# './data/squad/train-v1.1.json',
# './data/squad/dev-v1.1.json',
'./data/squad/dev-v1.1.json'
])
print(p.char_encode('modern stone statue of Mary', 'To whom did the Virgin Mary '))
print(p.word_encode('modern stone statue of Mary', 'To whom did the Virgin Mary '))