forked from aditya10/VLC-BERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_attn_annot_aokvqa.py
121 lines (87 loc) · 3.47 KB
/
build_attn_annot_aokvqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import random
import numpy as np
from external.pytorch_pretrained_bert import BertTokenizer
import string
from nltk.corpus import stopwords
#nltk.download('stopwords')
DATASET = 'aokvqa'
EXP_NAME = 'semqo'
MAX_COMMONSENSE_LEN = 5
RANDOM_SEED = 12345
random.seed(RANDOM_SEED)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
s = set(stopwords.words('english'))
def _load_json(path):
with open(path, 'r') as f:
return json.load(f)
def filename(exp_name):
return (exp_name[:-1]+ "." + exp_name[-1]).lower()
def build_automatic():
# Load expansions
# Load answers
# If answer is in expansion, give it a weight of 1
# If answer is not in expansion, give it a weight of 0
if DATASET == 'aokvqa':
annotations = _load_json('data/coco/aokvqa/aokvqa_v1p0_train.json')
expansions = _load_json('data/coco/aokvqa/commonsense/expansions/'+filename(EXP_NAME)+'_aokvqa_train.json')
annot_size = 4000
annotations_subset = random.sample(annotations, annot_size)
attn_annot = {}
good_counter = 0
total_counter = 0
bad_capacity = 500
for annot in annotations_subset:
question_id = annot['question_id']
image_id = str(annot['image_id'])
direct_answers = annot['direct_answers']
exp = expansions['{:012d}.jpg'.format(annot['image_id'])][str(annot['question_id'])][0]
exp = exp.split('.')
exp = [e.strip() for e in exp]
exp = [e for e in exp if e != '']
if len(exp) > MAX_COMMONSENSE_LEN:
exp = exp[:MAX_COMMONSENSE_LEN]
else:
exp = exp + ['']*(MAX_COMMONSENSE_LEN-len(exp))
weights, good = auto_annotator(exp, direct_answers)
if not good and bad_capacity <= 0:
continue
if not good:
bad_capacity -= 1
if image_id not in attn_annot:
attn_annot[image_id] = {}
attn_annot[image_id][question_id] = weights
total_counter += 1
good_counter += 1 if good else 0
with open('data/coco/aokvqa/'+EXP_NAME+'_aokvqa_train_attn_annot_'+str(MAX_COMMONSENSE_LEN)+'.json', 'w') as f:
json.dump(attn_annot, f)
print('Good: {}'.format(good_counter))
print('Total: {}'.format(total_counter))
def auto_annotator(expansion_list, ans_list):
ans_text = ' '.join(ans_list)
ans_text = ans_text.translate(str.maketrans('', '', string.punctuation))
ans_text = ans_text.lower()
ans_tokens = tokenizer.tokenize(ans_text)
ans_tokens = [t for t in ans_tokens if t not in s]
final_weights = [0.05]*len(expansion_list)
for i, expansion in enumerate(expansion_list):
exp_text = expansion.translate(str.maketrans('', '', string.punctuation))
exp_text = exp_text.lower()
exp_tokens = tokenizer.tokenize(exp_text)
exp_tokens = [t for t in exp_tokens if t not in s]
for token in ans_tokens:
if token in exp_tokens:
final_weights[i] = 0.8
break
good = False
if np.sum(final_weights) > (0.05*len(expansion_list)):
final_weights = np.array(final_weights + [0.05])
final_weights = final_weights / np.sum(final_weights)
good = True
else:
final_weights = np.array(final_weights + [0.25])
final_weights = final_weights / np.sum(final_weights)
assert len(final_weights) == MAX_COMMONSENSE_LEN+1
return final_weights.tolist(), good
if __name__ == '__main__':
build_automatic()