Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

big doc-vector refactor/enhancements #356

Merged
merged 51 commits into from
Jun 28, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a93f8e6
initial inference support
gojomo Mar 2, 2015
627604e
support for doc2vec dm_concat (concatenative PV-DM) model: null_word …
gojomo Mar 8, 2015
f1ad6b6
pure-python doc2vec dm_concat (concatenative PV-DM) model: train_sent…
gojomo Mar 8, 2015
305ae2b
infer_vector() on Doc2Vec
gojomo Mar 11, 2015
5aa0458
missed rename in sg path
gojomo Mar 11, 2015
9f3d28b
only swap dot/saxpy for detected blas – reducing code duplication
gojomo Mar 14, 2015
ff4cb98
rename for clarity
gojomo May 6, 2015
d851078
merge pre-trained vectors; optionally lock syn0 indexes
gojomo May 8, 2015
0a8bff5
dbow_words option; parameter name/doc cleanup; expect cython dm_concat
gojomo May 8, 2015
eb04a73
cythonized dm_concat; dbow word cotraining; syn0locks support; parame…
gojomo May 8, 2015
bc6287b
parameters to support doc2vec inference modes
gojomo May 9, 2015
de9eafb
train_sentence_* refactoring, parameterization to support inference v…
gojomo May 9, 2015
6e85df5
compact_name
gojomo May 16, 2015
f33bb27
rename merge_ to intersect_
gojomo May 16, 2015
0e587c3
for dm-sum, divide error over all conributing vectors
gojomo May 16, 2015
d28b0b1
rm unnecessary pretrain()
gojomo May 16, 2015
8d3d0f3
doclbls/docvecs separate from vocab/syn0; rename syn0locks syn0_lockf…
gojomo May 16, 2015
13b7ee2
delegate docvecs to (memmappable, replaceable) DocvecsInArray
gojomo May 18, 2015
a75553d
fix thread perf crash from randint()-per-word
gojomo Jun 6, 2015
89ad0ff
thread count in compact_name
gojomo Jun 6, 2015
156ea06
rename [lbl,label,LabeledSentence] -> [tag,tag,TaggedDocument]
gojomo Jun 8, 2015
5229139
reset_from, borrow_from to share vocab/etc between models in testing
gojomo Jun 8, 2015
012823f
most_similar, etc on docvecs
gojomo Jun 9, 2015
93a6272
initial doc2vec unit tests
gojomo Jun 9, 2015
f171b52
fix off-by-1 risking segfaults
gojomo Jun 9, 2015
9305980
corrections, tolerance tuning
gojomo Jun 9, 2015
15241a8
looser float matching
gojomo Jun 9, 2015
c44cf58
clarify shrunken sentence_len
gojomo Jun 9, 2015
0878db8
expand deterministic tests
gojomo Jun 10, 2015
45c8151
comment cleanup; doc_locks in job batches
gojomo Jun 10, 2015
a902dd8
don't clobber weights (ruining inference, among other things)
gojomo Jun 10, 2015
d142640
rm stray printing
gojomo Jun 10, 2015
882cddd
inference in sanity checks; cleanup
gojomo Jun 10, 2015
2f085b6
IMDB sentiment experiments
gojomo Jun 10, 2015
cb13723
np.allclose for float checks
gojomo Jun 10, 2015
ec47ec6
notebook tweaks
gojomo Jun 10, 2015
6c2c4e9
nest with:s for py2.6
gojomo Jun 11, 2015
125e8ef
minimize imports, simplify logging
gojomo Jun 11, 2015
e445e3e
touch w/ comment
gojomo Jun 11, 2015
f5b4e30
rm stray import breaking py2.6 build
gojomo Jun 11, 2015
53d8645
wget --quiet (two dashes)
gojomo Jun 11, 2015
09a30b3
comments; sentence->document; ipynb tweaks
gojomo Jun 16, 2015
46c81a3
Merge remote-tracking branch 'upstream/develop' into bigdocvec_pr
gojomo Jun 24, 2015
f88beab
recursive SaveLoad for DocvecsArray numpys
gojomo Jun 24, 2015
19faaab
don't (try to) share __doc__
gojomo Jun 24, 2015
a1ed490
reorder to respect ignores; move mmap_error (fixes unit tests)
gojomo Jun 24, 2015
d02b574
only swap dot/saxpy – reduce redundancy
gojomo Jun 24, 2015
1ed5e49
for cbow-sum, divide error over all contributing vectors
gojomo Jun 24, 2015
739fe31
_lockf support in cython; test
gojomo Jun 24, 2015
356c53a
pep8 & python2 fixes to doc2vec notebook
piskvorky Jun 28, 2015
b558262
Merge pull request #6 from piskvorky/bigdocvec_pr
gojomo Jun 28, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,843 changes: 1,843 additions & 0 deletions docs/notebooks/doc2vec-IMDB.ipynb

Large diffs are not rendered by default.

670 changes: 546 additions & 124 deletions gensim/models/doc2vec.py

Large diffs are not rendered by default.

10,114 changes: 4,558 additions & 5,556 deletions gensim/models/doc2vec_inner.c

Large diffs are not rendered by default.

1,090 changes: 455 additions & 635 deletions gensim/models/doc2vec_inner.pyx

Large diffs are not rendered by default.

198 changes: 143 additions & 55 deletions gensim/models/word2vec.py

Large diffs are not rendered by default.

4,545 changes: 1,061 additions & 3,484 deletions gensim/models/word2vec_inner.c

Large diffs are not rendered by default.

492 changes: 78 additions & 414 deletions gensim/models/word2vec_inner.pyx

Large diffs are not rendered by default.

366 changes: 366 additions & 0 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking transformation algorithms (the models package).
"""

from __future__ import with_statement

import logging
import unittest
import os
import tempfile
import itertools
import bz2
from six import iteritems, iterkeys
from six.moves import xrange, zip as izip
from collections import namedtuple

import numpy as np

from gensim import utils, matutils
from gensim.models import doc2vec

module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
datapath = lambda fname: os.path.join(module_path, 'test_data', fname)


class DocsLeeCorpus(object):
def __init__(self, string_tags=False):
self.string_tags = string_tags

def _tag(self, i):
return i if not self.string_tags else '_*%d' % i

def __iter__(self):
with open(datapath('lee_background.cor')) as f:
for i, line in enumerate(f):
yield doc2vec.TaggedDocument(utils.simple_preprocess(line),[self._tag(i)])

list_corpus = list(DocsLeeCorpus())

sentences = [
['human', 'interface', 'computer'],
['survey', 'user', 'computer', 'system', 'response', 'time'],
['eps', 'user', 'interface', 'system'],
['system', 'human', 'system', 'eps'],
['user', 'response', 'time'],
['trees'],
['graph', 'trees'],
['graph', 'minors', 'trees'],
['graph', 'minors', 'survey']
]

sentences = [doc2vec.TaggedDocument(words,[i]) for i, words in enumerate(sentences)]


def testfile():
# temporary data will be stored to this file
return os.path.join(tempfile.gettempdir(), 'gensim_doc2vec.tst')


class TestDoc2VecModel(unittest.TestCase):
def test_persistence(self):
"""Test storing/loading the entire model."""
model = doc2vec.Doc2Vec(DocsLeeCorpus(), min_count=1)
model.save(testfile())
self.models_equal(model, doc2vec.Doc2Vec.load(testfile()))

def test_load_mmap(self):
"""Test storing/loading the entire model."""
model = doc2vec.Doc2Vec(sentences, min_count=1)

# test storing the internal arrays into separate files
model.save(testfile(), sep_limit=0)
self.models_equal(model, doc2vec.Doc2Vec.load(testfile()))

# make sure mmaping the arrays back works, too
self.models_equal(model, doc2vec.Doc2Vec.load(testfile(), mmap='r'))

def test_int_doctags(self):
"""Test doc2vec doctag alternatives"""
corpus = DocsLeeCorpus()

model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0),300)
self.assertEqual(model.docvecs[0].shape,(300,))
self.assertRaises(KeyError,model.__getitem__,'_*0')

def test_string_doctags(self):
"""Test doc2vec doctag alternatives"""
corpus = DocsLeeCorpus(True)

model = doc2vec.Doc2Vec(min_count=1)
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0),300)
self.assertEqual(model.docvecs[0].shape,(300,))
self.assertEqual(model.docvecs['_*0'].shape,(300,))
self.assertTrue(all(model.docvecs['_*0']==model.docvecs[0]))

def test_empty_errors(self):
# no input => "RuntimeError: you must first build vocabulary before training the model"
self.assertRaises(RuntimeError, doc2vec.Doc2Vec, [])

# input not empty, but rather completely filtered out
self.assertRaises(RuntimeError, doc2vec.Doc2Vec, list_corpus, min_count=10000)

def model_sanity(self, model):
"""Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
fire1 = 0 # doc 0 sydney fires
fire2 = 8 # doc 8 sydney fires
tennis1 = 6 # doc 6 tennis

# inferred vector should be top10 close to bulk-trained one
doc0_inferred = model.infer_vector(list(DocsLeeCorpus())[0].words)
sims_to_infer = model.docvecs.most_similar([doc0_inferred])
self.assertTrue(fire1 in [match[0] for match in sims_to_infer])

# fire8 should be top20 close to fire1
sims = model.docvecs.most_similar(fire1,topn=20)
self.assertTrue(fire2 in [match[0] for match in sims])

# same sims should appear in lookup by vec as by index
doc0_vec = model.docvecs[fire1]
sims2 = model.docvecs.most_similar(positive=[doc0_vec], topn=21)
sims2 = sims2[1:] # ignore first element of sims2, which is doc itself
self.assertEqual(list(zip(*sims))[0], list(zip(*sims2))[0]) # same doc ids
self.assertTrue(np.allclose(list(zip(*sims))[1], list(zip(*sims2))[1])) # close-enough dists

# tennis doc should be out-of-place among fire news
self.assertEqual(model.docvecs.doesnt_match([fire1, tennis1, fire2]), tennis1)

# fire docs should be closer than fire-tennis
self.assertTrue(model.docvecs.similarity(fire1,fire2) > model.docvecs.similarity(fire1,tennis1))

def test_training(self):
"""Test doc2vec training."""
corpus = DocsLeeCorpus()
model = doc2vec.Doc2Vec(size=100, min_count=2, iter=20)
model.build_vocab(corpus)
self.assertEqual(model.docvecs.doctag_syn0.shape, (300, 100))
model.train(corpus)

self.model_sanity(model)

# build vocab and train in one step; must be the same as above
model2 = doc2vec.Doc2Vec(corpus, size=100, min_count=2, iter=20)
self.models_equal(model, model2)

def test_dbow_hs(self):
"""Test DBOW doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=0, hs=1, negative=0, min_count=2, iter=20)
self.model_sanity(model)

def test_dmm_hs(self):
"""Test DM/mean doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=1, negative=0,
min_count=2, iter=20)
self.model_sanity(model)

def test_dms_hs(self):
"""Test DM/sum doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=0, size=24, window=4, hs=1, negative=0,
min_count=2, iter=20)
self.model_sanity(model)

def test_dmc_hs(self):
"""Test DM/concatenate doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_concat=1, size=24, window=4, hs=1, negative=0,
min_count=2, iter=20)
self.model_sanity(model)

def test_dbow_neg(self):
"""Test DBOW doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=0, hs=0, negative=10, min_count=2, iter=20)
self.model_sanity(model)

def test_dmm_neg(self):
"""Test DM/mean doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=0, negative=10,
min_count=2, iter=20)
self.model_sanity(model)

def test_dms_neg(self):
"""Test DM/sum doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=0, size=24, window=4, hs=0, negative=10,
min_count=2, iter=20)
self.model_sanity(model)

def test_dmc_neg(self):
"""Test DM/concatenate doc2vec training."""
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_concat=1, size=24, window=4, hs=0, negative=10,
min_count=2, iter=20)
self.model_sanity(model)

def test_parallel(self):
"""Test doc2vec parallel training."""
if doc2vec.FAST_VERSION < 0: # don't test the plain NumPy version for parallelism (too slow)
return

corpus = utils.RepeatCorpus(DocsLeeCorpus(), 10000)

for workers in [2, 4]:
model = doc2vec.Doc2Vec(corpus, workers=workers)
self.model_sanity(model)

def test_deterministic_hs(self):
"""Test doc2vec results identical with identical RNG seed."""
# hs
model = doc2vec.Doc2Vec(DocsLeeCorpus(), seed=42, workers=1)
model2 = doc2vec.Doc2Vec(DocsLeeCorpus(), seed=42, workers=1)
self.models_equal(model, model2)

def test_deterministic_neg(self):
"""Test doc2vec results identical with identical RNG seed."""
# neg
model = doc2vec.Doc2Vec(DocsLeeCorpus(), hs=0, negative=3, seed=42, workers=1)
model2 = doc2vec.Doc2Vec(DocsLeeCorpus(), hs=0, negative=3, seed=42, workers=1)
self.models_equal(model, model2)

def test_deterministic_dmc(self):
"""Test doc2vec results identical with identical RNG seed."""
# bigger, dmc
model = doc2vec.Doc2Vec(DocsLeeCorpus(), dm=1, dm_concat=1, size=24, window=4, hs=1, negative=3,
seed=42, workers=1)
model2 = doc2vec.Doc2Vec(DocsLeeCorpus(), dm=1, dm_concat=1, size=24, window=4, hs=1, negative=3,
seed=42, workers=1)
self.models_equal(model, model2)

def models_equal(self, model, model2):
# check words/hidden-weights
self.assertEqual(len(model.vocab), len(model2.vocab))
self.assertTrue(np.allclose(model.syn0, model2.syn0))
if model.hs:
self.assertTrue(np.allclose(model.syn1, model2.syn1))
if model.negative:
self.assertTrue(np.allclose(model.syn1neg, model2.syn1neg))
# check docvecs
self.assertEqual(len(model.docvecs.doctags), len(model2.docvecs.doctags))
self.assertEqual(len(model.docvecs.index2doctag), len(model2.docvecs.index2doctag))
self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0))

#endclass TestDoc2VecModel

# following code is useful for reproducing paragraph-vectors paper sentiment experiments

class ConcatenatedDoc2Vec(object):
"""
Concatenation of multiple models for reproducing the Paragraph Vectors paper.
Models must have exactly-matching vocabulary and document IDs. (Models should
be trained separately; this wrapper just returns concatenated results.)
"""
def __init__(self, models):
self.models = models
if hasattr(models[0],'docvecs'):
self.docvecs = ConcatenatedDocvecs([model.docvecs for model in models])

def __getitem__(self, token):
return np.concatenate([model[token] for model in self.models])

def infer_vector(self, document, alpha=0.1, min_alpha=0.0001, steps=5):
return np.concatenate([model.infer_vector(document,alpha,min_alpha,steps) for model in self.models])

def train(self, ignored):
pass # train subcomponents individually

class ConcatenatedDocvecs(object):
def __init__(self, models):
self.models = models

def __getitem__(self, token):
return np.concatenate([model[token] for model in self.models])


SentimentDocument = namedtuple('SentimentDocument','words tags split sentiment')


def read_su_sentiment_rotten_tomatoes(dirname, lowercase=True):
"""
Read and return documents from the Stanford Sentiment Treebank
corpus (Rotten Tomatoes reviews), from http://nlp.Stanford.edu/sentiment/

Initialize the corpus from a given directory, where
http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip
has been expanded. It's not too big, so compose entirely into memory.
"""
logging.info("loading corpus from %s" % dirname)

# many mangled chars in sentences (datasetSentences.txt)
chars_sst_mangled = ['à', 'á', 'â', 'ã', 'æ', 'ç', 'è', 'é', 'í',
'í', 'ï', 'ñ', 'ó', 'ô', 'ö', 'û', 'ü']
sentence_fixups = [(char.encode('utf-8').decode('latin1'), char) for char in chars_sst_mangled]
# more junk, and the replace necessary for sentence-phrase consistency
sentence_fixups.extend([
('Â', ''),
('\xa0', ' '),
('-LRB-', '('),
('-RRB-', ')'),
])
# only this junk in phrases (dictionary.txt)
phrase_fixups = [('\xa0', ' ')]

# sentence_id and split are only positive for the full sentences

# read sentences to temp {sentence -> (id,split) dict, to correlate with dictionary.txt
info_by_sentence = {}
with open(os.path.join(dirname, 'datasetSentences.txt'), 'r') as sentences:
with open(os.path.join(dirname, 'datasetSplit.txt'), 'r') as splits:
next(sentences) # legend
next(splits) # legend
for sentence_line, split_line in izip(sentences, splits):
(id, text) = sentence_line.split('\t')
id = int(id)
text = text.rstrip()
for junk, fix in sentence_fixups:
text = text.replace(junk, fix)
(id2, split_i) = split_line.split(',')
assert id == int(id2)
if text not in info_by_sentence: # discard duplicates
info_by_sentence[text] = (id, int(split_i))

# read all phrase text
phrases = [None] * 239232 # known size of phrases
with open(os.path.join(dirname, 'dictionary.txt'), 'r') as phrase_lines:
for line in phrase_lines:
(text, id) = line.split('|')
for junk, fix in phrase_fixups:
text = text.replace(junk, fix)
phrases[int(id)] = text.rstrip() # for 1st pass just string

SentimentPhrase = namedtuple('SentimentPhrase', SentimentDocument._fields + ('sentence_id',))
# add sentiment labels, correlate with sentences
with open(os.path.join(dirname, 'sentiment_labels.txt'), 'r') as sentiments:
next(sentiments) # legend
for line in sentiments:
(id, sentiment) = line.split('|')
id = int(id)
sentiment = float(sentiment)
text = phrases[id]
words = text.split()
if lowercase:
words = [word.lower() for word in words]
(sentence_id, split_i) = info_by_sentence.get(text, (None, 0))
split = [None,'train','test','dev'][split_i]
phrases[id] = SentimentPhrase(words, [id], split, sentiment, sentence_id)

assert len([phrase for phrase in phrases if phrase.sentence_id is not None]) == len(info_by_sentence) # all
# counts don't match 8544, 2210, 1101 because 13 TRAIN and 1 DEV sentences are duplicates
assert len([phrase for phrase in phrases if phrase.split == 'train']) == 8531 # 'train'
assert len([phrase for phrase in phrases if phrase.split == 'test']) == 2210 # 'test'
assert len([phrase for phrase in phrases if phrase.split == 'dev']) == 1100 # 'dev'

logging.info("loaded corpus with %i sentences and %i phrases from %s"
% (len(info_by_sentence), len(phrases), dirname))

return phrases


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
logging.info("using optimization %s" % doc2vec.FAST_VERSION)
unittest.main()
19 changes: 19 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ def testTraining(self):
model2 = word2vec.Word2Vec(sentences, size=2, min_count=1)
self.models_equal(model, model2)


def testLocking(self):
"""Test word2vec training doesn't change locked vectors."""
corpus = LeeCorpus()
# build vocabulary, don't train yet
for sg in range(2): # test both cbow and sg
model = word2vec.Word2Vec(size=4, hs=1, negative=5, min_count=1, sg=sg, window=5)
model.build_vocab(corpus)

# remember two vectors
locked0 = numpy.copy(model.syn0[0])
unlocked1 = numpy.copy(model.syn0[1])
# lock the vector in slot 0 against change
model.syn0_lockf[0] = 0.0

model.train(corpus)
self.assertFalse((unlocked1==model.syn0[1]).all()) # unlocked vector should vary
self.assertTrue((locked0==model.syn0[0]).all()) # locked vector should not vary

def testTrainingCbow(self):
"""Test CBOW word2vec training."""
# to test training, make the corpus larger by repeating its sentences over and over
Expand Down
Loading