forked from Cyanogenoid/pytorch-vqa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess-vocab.py
51 lines (40 loc) · 1.42 KB
/
preprocess-vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
from collections import Counter
import itertools
import config
import data
import utils
def extract_vocab(iterable, top_k=None, start=0):
""" Turns an iterable of list of tokens into a vocabulary.
These tokens could be single answers or word tokens in questions.
"""
all_tokens = itertools.chain.from_iterable(iterable)
counter = Counter(all_tokens)
if top_k:
most_common = counter.most_common(top_k)
most_common = (t for t, c in most_common)
else:
most_common = counter.keys()
# descending in count, then lexicographical order
tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True)
vocab = {t: i for i, t in enumerate(tokens, start=start)}
return vocab
def main():
questions = utils.path_for(train=True, question=True)
answers = utils.path_for(train=True, answer=True)
with open(questions, 'r') as fd:
questions = json.load(fd)
with open(answers, 'r') as fd:
answers = json.load(fd)
questions = data.prepare_questions(questions)
answers = data.prepare_answers(answers)
question_vocab = extract_vocab(questions, start=1)
answer_vocab = extract_vocab(answers, top_k=config.max_answers)
vocabs = {
'question': question_vocab,
'answer': answer_vocab,
}
with open(config.vocabulary_path, 'w') as fd:
json.dump(vocabs, fd)
if __name__ == '__main__':
main()