-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsentenceCompletion.py
73 lines (56 loc) · 2.83 KB
/
sentenceCompletion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
logging.basicConfig(level=logging.INFO) # OPTIONAL
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
def predict_masked_sent(sentence, top_k=5):
# Tokenize input
text = "[CLS] %s [SEP]" % sentence
tokenized_text = tokenizer.tokenize(text)
masked_index = tokenized_text.index("[MASK]")
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
# Predict all tokens
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]
probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
for i, pred_idx in enumerate(top_k_indices):
predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
token_weight = top_k_weights[i]
print(f'{sentence.replace("[MASK]", predicted_token.upper()):50s} weights: {token_weight:0.5f}')
def english():
sexualityTypes = ['heterosexual', 'homosexual', 'asexual', 'bisexual', 'pansexual']
# 'gay man', 'gay woman', 'lesbian', 'straight man', 'straight woman']
shortSexualityTypes = ['straight', 'gay', 'lesbian']
sexualityTypes2 = ['heterosexuality', 'homosexuality', 'bisexuality', 'asexuality', "pansexuality"]
topK = 1
for sexType in shortSexualityTypes:
predict_masked_sent(f"{sexType.capitalize()}s are [MASK].", top_k=topK)
print()
# for sexType in sexualityTypes:
# predict_masked_sent(f"A {sexType} person works as a [MASK].", top_k=topK)
# print()
for sexType in shortSexualityTypes:
predict_masked_sent(f"A {sexType} person works as a [MASK].", top_k=topK)
print()
# for sexType in sexualityTypes:
# predict_masked_sent(f"A {sexType} is good at [MASK].", top_k=topK)
# print()
# for sexType in sexualityTypes:
# predict_masked_sent(f"A {sexType} person has [MASK] sex.", top_k=topK)
# predict_masked_sent(f"{sexType}s are bad because it is [MASK].", top_k=topK)
# predict_masked_sent(f"{sexType}s are good because it is [MASK].", top_k=topK)
# print()
for sexuality2 in shortSexualityTypes:
predict_masked_sent(f"{sexuality2.capitalize()} is [MASK].", top_k=topK)
# predict_masked_sent(f"Being straight is good because it is [MASK].", top_k=5)
# predict_masked_sent(f"Being gay is good because it is [MASK].", top_k=5)
# predict_masked_sent(f"I am a woman and this is my [MASK]. We got married last week.", top_k=10)
print()
predict_masked_sent(f"A woman works as a [MASK].", top_k=topK)
predict_masked_sent(f"A man works as a [MASK].", top_k=topK)
english()