Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster analogies #340

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 169 additions & 47 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

from numpy import exp, dot, zeros, outer, random, dtype, float32 as REAL,\
uint32, seterr, array, uint8, vstack, argsort, fromstring, sqrt, newaxis,\
ndarray, empty, sum as np_sum, prod
ndarray, empty, sum as np_sum, prod, inf, argmax, tanh

logger = logging.getLogger("gensim.models.word2vec")

Expand Down Expand Up @@ -829,15 +829,58 @@ def init_sims(self, replace=False):
else:
self.syn0norm = (self.syn0 / sqrt((self.syn0 ** 2).sum(-1))[..., newaxis]).astype(REAL)

@staticmethod
def log_accuracy(section):
correct, incorrect = len(section['correct']), len(section['incorrect'])
if correct + incorrect > 0:
logger.info("%s: %.1f%% (%i/%i)" %
(section['section'], 100.0 * correct / (correct + incorrect),
correct, correct + incorrect))
def index_by_count(self):
"""
Reorder the vocabulary so that frequent words
have low indices
"""

i2w = [pair[0] for pair in sorted(iteritems(self.vocab),
key=lambda item: -item[1].count)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be clearer if you put reverse=True: sorted(iteritems(self.vocab), key=lambda item: item[1].count, reverse=True) ?

It's more verbose but the intention is clear at least


old_w2i = {}
for i, word in enumerate(i2w):
old_w2i[word] = self.vocab[word].index
self.vocab[word].index = i

replaced = self.syn0 is self.syn0norm

syn0 = empty(self.syn0.shape, dtype=REAL)
for i, word in enumerate(i2w):
syn0[i] = self.syn0[old_w2i[word]]
self.syn0 = syn0

if replaced:
self.syn0norm = self.syn0
elif self.syn0norm is not None:
syn0norm = empty(self.syn0norm.shape, dtype=REAL)
for i, word in enumerate(i2w):
syn0norm[i] = self.syn0norm[old_w2i[word]]
self.syn0norm = syn0norm

def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar):
try:
self.syn1
syn1 = empty(self.syn1.shape, dtype=REAL)
for i, word in enumerate(i2w):
syn1[i] = self.syn1[old_w2i[word]]
self.syn1 = syn1
except AttributeError:
pass

try:
self.syn1neg
syn1neg = empty(self.syn1neg.shape, dtype=REAL)
for i, word in enumerate(i2w):
syn1neg[i] = self.syn1neg[old_w2i[word]]
self.syn1neg = syn1neg
except AttributeError:
pass

self.index2word = i2w

self.indexed_by_count = True

def accuracy(self, questions, restrict_vocab=30000, batchsize=2000, non_lin=None):
"""
Compute accuracy of the model. `questions` is a filename where lines are
4-tuples of words, split into sections by ": SECTION NAME" lines.
Expand All @@ -848,63 +891,142 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar):

Use `restrict_vocab` to ignore all questions containing a word whose frequency
is not in the top-N most frequent words (default top 30,000).

If you have memory issues, try to lower `batchsize`.

`non_lin` is a function applied after computing the cosine similarities.
Use `None` for the usual additive combination.
Use `'mul'` for the 3CosMul objective proposed in [4]_.
Use `'tanh'` for (possibly) marginally better results.
You may also define any Python function with one real input and output.

This method corresponds to the `compute-accuracy` script of the original C word2vec.

If there are words that are identical except for capitalization,
the reported results may differ from those obtained with the C tool.
"""
ok_vocab = dict(sorted(iteritems(self.vocab),
key=lambda item: -item[1].count)[:restrict_vocab])
ok_index = set(v.index for v in itervalues(ok_vocab))
self.init_sims()

if not hasattr(self, 'indexed_by_count'):
self.indexed_by_count = False
if not self.indexed_by_count:
self.index_by_count()

restrict_syn0norm = self.syn0norm[:restrict_vocab].T

sections, section = [], None
section_a, section_b, section_c, section_expected = [], [], [], []
all_a, all_b, all_c, all_expected = [], [], [], []
sections = []
for line_no, line in enumerate(utils.smart_open(questions)):
# TODO: use level3 BLAS (=evaluate multiple questions at once), for speed
line = utils.to_unicode(line)
if line.startswith(': '):
# a new section starts => store the old section
if section:
sections.append(section)
self.log_accuracy(section)
section = {'section': line.lstrip(': ').strip(), 'correct': [], 'incorrect': []}
if len(section_a):
all_a.append(section_a)
all_b.append(section_b)
all_c.append(section_c)
all_expected.append(section_expected)
sections.append({'section':category_name})
category_name = line.lstrip(': ').strip()
section_a, section_b, section_c, section_expected = [], [], [], []
else:
if not section:
raise ValueError("missing section header before line #%i in %s" % (line_no, questions))
try:
a, b, c, expected = [word.lower() for word in line.split()] # TODO assumes vocabulary preprocessing uses lowercase, too...
except:
logger.info("skipping invalid line #%i in %s" % (line_no, questions))
if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab:
a, b, c, expected = [word.lower() for word in line.split()]
if a not in self.vocab or b not in self.vocab \
or c not in self.vocab or expected not in self.vocab:
logger.debug("skipping line #%i with OOV words: %s" % (line_no, line.strip()))
continue

ignore = set(self.vocab[v].index for v in [a, b, c]) # indexes of words to ignore
predicted = None
# find the most likely prediction, ignoring OOV words and input words
for index in argsort(most_similar(self, positive=[b, c], negative=[a], topn=False))[::-1]:
if index in ok_index and index not in ignore:
predicted = self.index2word[index]
if predicted != expected:
logger.debug("%s: expected %s, predicted %s" % (line.strip(), expected, predicted))
break
if predicted == expected:
section['correct'].append((a, b, c, expected))
if self.vocab[a].index >= restrict_vocab or self.vocab[b].index >= restrict_vocab \
or self.vocab[c].index >= restrict_vocab or self.vocab[expected].index >= restrict_vocab:
logger.debug("skipping line #%i with OOV words: %s" % (line_no, line.strip()))
continue
section_a.append(self.vocab[a].index)
section_b.append(self.vocab[b].index)
section_c.append(self.vocab[c].index)
section_expected.append(self.vocab[expected].index)

if len(section_a):
all_a.append(section_a)
all_b.append(section_b)
all_c.append(section_c)
all_expected.append(section_expected)
sections.append({'section':category_name})

for i in xrange(len(all_a)):
correct = 0
num_questions = len(all_a[i])
num_batches = num_questions//batchsize + bool(num_questions%batchsize)

for j in xrange(num_batches):
batch_a = all_a[i][j*batchsize:(j+1)*batchsize]
batch_b = all_b[i][j*batchsize:(j+1)*batchsize]
batch_c = all_c[i][j*batchsize:(j+1)*batchsize]

expected = all_expected[i][j*batchsize:(j+1)*batchsize]

num_examples = len(batch_a)

if non_lin is None:
vecs = -self.syn0norm[batch_a] + self.syn0norm[batch_b] \
+ self.syn0norm[batch_c]

sims = dot(vecs, restrict_syn0norm)

elif non_lin == 'mul':
"""
3CosMul (no need to divide by 2)
"""
sim_a = dot(self.syn0norm[batch_a], restrict_syn0norm) + 1.
sim_b = dot(self.syn0norm[batch_b], restrict_syn0norm) + 1.
sim_c = dot(self.syn0norm[batch_c], restrict_syn0norm) + 1.002

sims = sim_b * sim_c / sim_a

elif non_lin == 'tanh':
sim_a = tanh(dot(self.syn0norm[batch_a], restrict_syn0norm))
sim_b = tanh(dot(self.syn0norm[batch_b], restrict_syn0norm))
sim_c = tanh(dot(self.syn0norm[batch_c], restrict_syn0norm))

sims = sim_b + sim_c - sim_a

else:
section['incorrect'].append((a, b, c, expected))
if section:
# store the last section, too
sections.append(section)
self.log_accuracy(section)
# Assumes `non_lin` is a python function
# with one real input and one real output
sim_a = non_lin(dot(self.syn0norm[batch_a], restrict_syn0norm))
sim_b = non_lin(dot(self.syn0norm[batch_b], restrict_syn0norm))
sim_c = non_lin(dot(self.syn0norm[batch_c], restrict_syn0norm))

sims = sim_b + sim_c - sim_a

sims[range(num_examples), batch_a] = -inf
sims[range(num_examples), batch_b] = -inf
sims[range(num_examples), batch_c] = -inf

preds = argmax(sims, axis=1)

correct += sum((preds - expected)==0)

logger.info("%s: %.1f%% (%i/%i)" %
(sections[i]['section'], 100 * float(correct) / num_questions,
correct, num_questions))

sections[i]['correct'] = correct
sections[i]['incorrect'] = num_questions - correct

total = {
'section': 'total',
'correct': sum((s['correct'] for s in sections), []),
'incorrect': sum((s['incorrect'] for s in sections), []),
'correct': sum(s['correct'] for s in sections),
'incorrect': sum(s['incorrect'] for s in sections),
}
self.log_accuracy(total)

if total['correct'] + total['incorrect']:
logger.info("%s: %.1f%% (%i/%i)" %
('total', 100 * float(total['correct']) / (total['correct'] + total['incorrect']),
total['correct'], (total['correct'] + total['incorrect'])))
else:
logger.info("No valid questions")

sections.append(total)
return sections


def __str__(self):
return "Word2Vec(vocab=%s, size=%s, alpha=%s)" % (len(self.index2word), self.layer1_size, self.alpha)

Expand Down
Loading