-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
126 lines (119 loc) · 5.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import pickle
import datetime
import threading
from flask import jsonify
from TransferTransfo.train_ironman import SPECIAL_TOKENS, build_input_from_segments, add_special_tokens_
EOS = "<|endoftext|>"
WELCOME = "Welcome to MPTI-Sherlock Chatbot, {}🙌 Enjoy your chatting with the great Sherlock Holmes. \n\n❗️Please Do not enter your personal information as recent conversation records are saved in the server.❗️"
REGISTER = "Please register your name first. 🧐\nYou can regiser your name with command🙏: !register [YOUR NAME HERE]"
CAPITAL = {'i': 'I', "i'll": "I'll", "i'm": "I'm", "i'd": "I'd", "i've":"I've",
'tony': 'Tony',
'morgan': 'Morgan',
'pepper': 'Pepper',
'potts': 'Potts',
'ironman': 'Ironman', 'iron': 'Iron',
'stark': 'Stark',
'avengers': 'Avengers',
'thanos':'Thanos',
'jarvis': 'JARVIS',
'manhattan':'Manhattan',
'new':'New','york':'York',
'california':'California',
'usa':'USA',
'howard':'Howard',
'may':'May',
}
PERSONALTXT= [
'my name is tony stark .',
'i am iron man .',
'i am a billionaire industrialist .',
'i am a superhero .',
'i have a daughter named morgan .',
'i had saved the world countless times .',
'i like American cheeseburger .',
'i am married with pepper potts .',
'i killed thanos .',
'i am a genius .',
'i own the stark industries .',
'i programmed JARVIS .',
'i put on my armored suit to protect the world as Iron Man .',
'i was born on may 29, 1970 .' ,
'i am a genius inventor .',
'i was born in manhattan, new york .',
'my father is howard stark .',
'i am the founding member of the Avengers .'
'i love to be the center of attention .'
]
def load(filename):
if not os.path.exists(filename):
return {}
with open(filename, 'rb') as handle:
return pickle.load(handle)
def nameFromDB(db, uid):
if uid not in db:
return (None, None)
return db[uid]
def save(db, filename):
now_time = '({})'.format(datetime.datetime.now().strftime('%y/%m/%d %H:%M:%S'))
output = now_time + " !!!!!!Something is Wrong!!!!!!"
with open(filename, 'wb') as handle:
pickle.dump(db, handle, protocol=pickle.HIGHEST_PROTOCOL)
output = now_time + " Successfully Saved "+filename
print(output)
threading.Timer(600, save, args = [filename]).start()
def toResponse(output):
res = {
"version": "2.0",
"template": {
"outputs": [
{
"simpleText": {"text": output}
}
]
}
}
return jsonify(res)
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
top_k = min(top_k, logits.size(-1))
# Compute cumulative probabilities of sorted tokens
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probabilities > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Back to unsorted indices and set them to -infinity
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
indices_to_remove = logits < threshold
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(personality, history, tokenizer, model, current_output=None):
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
if current_output is None:
current_output = []
for i in range(20):
instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False)
input_ids = torch.tensor(instance["input_ids"], device="cuda").unsqueeze(0)
token_type_ids = torch.tensor(instance["token_type_ids"], device="cuda").unsqueeze(0)
logits = model(input_ids, token_type_ids=token_type_ids).logits # modified
if isinstance(logits, tuple): # for gpt2 and maybe others
logits = logits[0]
logits = logits[0, -1, :] / 0.7
logits = top_filtering(logits, top_k=0, top_p=0.9)
probs = F.softmax(logits, dim=-1)
## no greedy decoding, do sampling
prev = torch.multinomial(probs, 1)
if i < 1 and prev.item() in special_tokens_ids:
while prev.item() in special_tokens_ids:
if probs.max().item() == 1:
warnings.warn("Warning: model generating special token with probability 1.")
break # avoid infinitely looping over special token
prev = torch.multinomial(probs, num_samples=1)
if prev.item() in special_tokens_ids:
break
current_output.append(prev.item())
return current_output