-
Notifications
You must be signed in to change notification settings - Fork 1
/
emission_probs.py
204 lines (153 loc) · 6.39 KB
/
emission_probs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""Class that stores the number of counts for each word, and for each
unigram, bigram and trigram tag pattern. It also calculates and stores
the emission probability of a word given a tag."""
class EmissionProbEmitter(object):
def __init__(self):
self.srcname = ""
self.counted = False
self.prob_computed = False
self.word_emm_probs = {}
self.word_counts = {}
self.unigram_counts = {}
self.bigram_counts = {}
self.trigram_counts = {}
def get_sourcename(self):
""" FUNCTION: get_sourcename
ARGUMETNS: self
Gets user input for the filename of the source file (should
be of the same form as gene.counts). Checks for valid file and
if invalid prompts again"""
print "Please supply a valid filename for the source file."
self.srcname = raw_input('> ')
try:
file(self.srcname)
except:
self.get_sourcename()
def get_counts_from_file(self):
""""FUNCTION: get_coutns_from_file
ARGUMETNS: self
Generates the dictionaries word_counts, unigram_counts bigram_counts
and trigram_counts, if not already generated. Calls get_sourcename if
necessary"""
if self.counted:
print "Already counted"
else:
#Mark self as counted
self.counted = True
#Attempd to oepn file
try:
src = open(self.srcname)
except: #Get sourcename if necessary
self.get_sourcename()
src = open(self.srcname)
#Step through file and identify record type
for line in src:
parts = line.split(' ')
#Count for a single word-tag combo
if parts[1] == 'WORDTAG':
count = parts[0]
tagtype = parts[2]
name = parts[3].strip() #Get rid of trailing '\n'
#Check to see if word has already been recorded
if name in self.word_counts:
(self.word_counts[name])[tagtype] = count
#If not create a new dict, otherwise add to existant dict
else:
self.word_counts[name] = {tagtype:count}
#Unigram, bigram, or trigram count
else:
count = parts[0]
seqtype = parts[1]
parts[-1] = parts[-1].strip() #Git rid of trailing '\n'
args = tuple(parts[2:]) #Make list into tuple
#Add to relevent dict. The key is a tuple with all tag types
#in sequence
if seqtype == '1-GRAM':
self.unigram_counts[args] = count
elif seqtype == '2-GRAM':
self.bigram_counts[args] = count
else:
self.trigram_counts[args] = count
src.close()
def calculate_word_probs(self):
""" FUCNTION: calculate_word_prob
ARGUMETNGS: self
Generates the dictionary of signle word probabilities. """
#Check that file has been analyzed
if not self.counted:
self.get_counts_from_file()
#Check for previous execution
if self.prob_computed:
print "Probabilities already computed"
else:
for word in self.word_counts:
for tag in self.word_counts[word]:
count = (self.word_counts[word])[tag]
totalcount = self.unigram_counts[(tag,)]
prob = float(count)/float(totalcount)
if word in self.word_emm_probs:
(self.word_emm_probs[word])[tag] = prob
else:
(self.word_emm_probs[word]) = {tag:prob}
def e(self, word, tag):
""" FUNCTION: e
ARGUMETNS: self
word - word ot look up emission probability of
tagtype - tag to be analyzes"""
try:
return (self.word_emm_probs[word])[tag]
except KeyError:
return 0
def best_tag(self, word):
tagdict = self.word_emm_probs[word]
vals = tagdict.values()
keys = tagdict.keys()
maxprob = max(vals)
for key in keys:
if tagdict[key] == maxprob:
return key
return None #Or some kind of error. What if word isn't in the dict? I should handle this error at some point
def basic_tagger(self, devfile, destfile):
#best_tag
dev = open(devfile)
dest = open(destfile, 'w')
for line in dev:
word = line.strip()
if word in self.word_emm_probs:
dest.write(word + ' ' + self.best_tag(word) + '\n')
elif word == '':
dest.write('\n')
else:
dest.write(word + ' ' + self.best_tag('_RARE_') + '\n')
dev.close()
dest.close()
def q(self, tag1, tag2, tag3):
if not self.counted:
self.get_counts_from_file()
bi_count = self.bigram_counts[(tag1, tag2)]
tri_count = self.trigram_counts[(tag1, tag2, tag3)]
return float(tri_count)/float(bi_count)
def viterbi_tagger(self, devfile, destfile):
possible_tags = self.unigram_counts.keys()
dev = open(devfile)
dest = open(destfile, 'w')
mem = ('*', '*')
for line in dev:
word = line.strip()
if word == '':
mem = ('*', '*')
dest.write('\n')
else:
maxtag = ''
prob = 0;
if word in self.word_counts:
word_eff = word
else:
word_eff = '_RARE_'
for tag in possible_tags:
tag = tag[0]
if prob < self.q(mem[0], mem[1], tag) * self.e(word_eff, tag):
prob = self.q(mem[0], mem[1], tag) * self.e(word_eff, tag)
maxtag = tag
dest.write(word + ' ' + maxtag + '\n')
mem = (mem[1], tag)