-
Notifications
You must be signed in to change notification settings - Fork 185
/
roberta_connector.py
executable file
·162 lines (135 loc) · 5.54 KB
/
roberta_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from fairseq.models.roberta import RobertaModel
from fairseq import utils
import torch
from lama.modules.base_connector import *
class RobertaVocab(object):
def __init__(self, roberta):
self.roberta = roberta
def __getitem__(self, arg):
value = ""
try:
predicted_token_bpe = self.roberta.task.source_dictionary.string([arg])
if (
predicted_token_bpe.strip() == ROBERTA_MASK
or predicted_token_bpe.strip() == ROBERTA_START_SENTENCE
):
value = predicted_token_bpe.strip()
else:
value = self.roberta.bpe.decode(str(predicted_token_bpe)).strip()
except Exception as e:
print(arg)
print(predicted_token_bpe)
print(value)
print("Exception {} for input {}".format(e, arg))
return value
class Roberta(Base_Connector):
def __init__(self, args):
super().__init__()
roberta_model_dir = args.roberta_model_dir
roberta_model_name = args.roberta_model_name
roberta_vocab_name = args.roberta_vocab_name
self.dict_file = "{}/{}".format(roberta_model_dir, roberta_vocab_name)
self.model = RobertaModel.from_pretrained(
roberta_model_dir, checkpoint_file=roberta_model_name
)
self.bpe = self.model.bpe
self.task = self.model.task
self._build_vocab()
self._init_inverse_vocab()
self.max_sentence_length = args.max_sentence_length
def _cuda(self):
self.model.cuda()
def _build_vocab(self):
self.vocab = []
for key in range(ROBERTA_VOCAB_SIZE):
predicted_token_bpe = self.task.source_dictionary.string([key])
try:
value = self.bpe.decode(predicted_token_bpe)
if value[0] == " ": # if the token starts with a whitespace
value = value.strip()
else:
# this is subword information
value = "_{}_".format(value)
if value in self.vocab:
# print("WARNING: token '{}' is already in the vocab".format(value))
value = "{}_{}".format(value, key)
self.vocab.append(value)
except Exception as e:
self.vocab.append(predicted_token_bpe.strip())
def get_id(self, input_string):
# Roberta predicts ' London' and not 'London'
string = " " + str(input_string).strip()
text_spans_bpe = self.bpe.encode(string.rstrip())
tokens = self.task.source_dictionary.encode_line(
text_spans_bpe, append_eos=False
)
return [element.item() for element in tokens.long().flatten()]
def get_batch_generation(self, sentences_list, logger=None, try_cuda=True):
if not sentences_list:
return None
if try_cuda:
self.try_cuda()
tensor_list = []
masked_indices_list = []
max_len = 0
output_tokens_list = []
for masked_inputs_list in sentences_list:
tokens_list = []
for idx, masked_input in enumerate(masked_inputs_list):
# 2. sobstitute [MASK] with <mask>
masked_input = masked_input.replace(MASK, ROBERTA_MASK)
text_spans = masked_input.split(ROBERTA_MASK)
text_spans_bpe = (
(" {0} ".format(ROBERTA_MASK))
.join(
[
self.bpe.encode(text_span.rstrip())
for text_span in text_spans
]
)
.strip()
)
prefix = ""
if idx == 0:
prefix = ROBERTA_START_SENTENCE
tokens_list.append(
self.task.source_dictionary.encode_line(
str(prefix + " " + text_spans_bpe).strip(), append_eos=True
)
)
tokens = torch.cat(tokens_list)[: self.max_sentence_length]
output_tokens_list.append(tokens.long().cpu().numpy())
if len(tokens) > max_len:
max_len = len(tokens)
tensor_list.append(tokens)
masked_index = (tokens == self.task.mask_idx).nonzero().numpy()
for x in masked_index:
masked_indices_list.append([x[0]])
pad_id = self.task.source_dictionary.pad()
tokens_list = []
for tokens in tensor_list:
pad_lenght = max_len - len(tokens)
if pad_lenght > 0:
pad_tensor = torch.full([pad_lenght], pad_id, dtype=torch.int)
tokens = torch.cat((tokens, pad_tensor))
tokens_list.append(tokens)
batch_tokens = torch.stack(tokens_list)
with torch.no_grad():
# with utils.eval(self.model.model):
self.model.eval()
self.model.model.eval()
log_probs, extra = self.model.model(
batch_tokens.long().to(device=self._model_device),
features_only=False,
return_all_hiddens=False,
)
return log_probs.cpu(), output_tokens_list, masked_indices_list
def get_contextual_embeddings(self, sentences_list, try_cuda=True):
# TBA
return None