forked from sordonia/rnn-lm
-
Notifications
You must be signed in to change notification settings - Fork 13
/
convert-text2dict.py
105 lines (81 loc) · 3.01 KB
/
convert-text2dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import collections
import numpy
import operator
import os
import sys
import logging
import cPickle
from collections import Counter
def safe_pickle(obj, filename):
if os.path.isfile(filename):
logger.info("Overwriting %s." % filename)
else:
logger.info("Saving to %s." % filename)
with open(filename, 'wb') as f:
cPickle.dump(obj, f, protocol=cPickle.HIGHEST_PROTOCOL)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('text2dict')
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Tab separated session file")
parser.add_argument("--cutoff", type=int, default=-1, help="Vocabulary cutoff (optional)")
parser.add_argument("--dict", type=str, default="", help="External dictionary (optional)")
parser.add_argument("output", type=str, help="Pickle binarized session corpus")
args = parser.parse_args()
freqs = collections.defaultdict(lambda: 1)
###############################
# Part I: Create the dictionary
###############################
if args.dict != "":
# Load external dictionary
assert os.path.isfile(args.dict)
vocab = cPickle.load(open(args.dict, "r"))
vocab = dict([(x[0], x[1]) for x in vocab])
# Check consistency
assert '<unk>' in vocab
assert '<s>' in vocab
assert '</s>' in vocab
else:
word_counter = Counter()
for count, line in enumerate(open(args.input, 'r')):
s = [x for x in line.strip().split()]
word_counter.update(s)
total_freq = sum(word_counter.values())
logger.info("Read %d sentences " % (count + 1))
logger.info("Total word frequency in dictionary %d " % total_freq)
if args.cutoff != -1:
logger.info("Cutoff %d" % args.cutoff)
vocab_count = word_counter.most_common(args.cutoff)
else:
vocab_count = word_counter.most_common()
vocab = {'<unk>': 0, '<s>': 1, '</s>': 2}
for i, (word, count) in enumerate(vocab_count):
vocab[word] = len(vocab)
logger.info("Vocab size %d" % len(vocab))
# Some statistics
mean_sl = 0.
unknowns = 0.
num_terms = 0.
binarized_corpus = []
for line, document in enumerate(open(args.input, 'r')):
binarized_document = []
for word in document.strip().split():
word_id = vocab.get(word, 0)
if not word_id:
unknowns += 1
binarized_document.append(word_id)
freqs[word_id] += 1
binarized_document = [1] + binarized_document + [2]
freqs[1] += 1
freqs[2] += 1
document_len = len(binarized_document)
num_terms += document_len
binarized_corpus.append(binarized_document)
logger.info("Vocab size %d" % len(vocab))
logger.info("Number of unknowns %d" % unknowns)
logger.info("Number of terms %d" % num_terms)
logger.info("Writing training %d documents " % len(binarized_corpus))
safe_pickle(binarized_corpus, args.output + ".word.pkl")
# Store triples word, word_id, freq
if args.dict == "":
safe_pickle([(word, word_id, freqs[word_id]) for word, word_id in vocab.items()], args.output + ".dict.pkl")