Skip to content

Commit

Permalink
Word2Vec/Doc2Vec offer model-minimization method. Fix issue #446 (#987)
Browse files Browse the repository at this point in the history
* issue #446

add finished_training method

* private _minimize_model, tests

We can't just call «the super method in word2vec explicitly» without
adding the flag to save syn0_lockf, which as is necessary to save in
d2v.

* fix_print

* flag finished_training fix

* fix_bug with docvecs, controllability

* rename flag, flag move, init_sims

* renaming the RuntimeError message

* fix, add more tests

* fix, i == j

* fix

* tests_fix

* delete useless code

* numpy fix

* hs,neg in tests; assert parameters existance

* changelog update

* rename replace, description fix
  • Loading branch information
pum-purum-pum-pum authored and tmylk committed Nov 13, 2016
1 parent 5ee30fc commit 284a9f7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Changes
=======
0.13.5, 2016-11-12
* Add delete_temporary_training_data() function to word2vec and doc2vec models. (@deepmipt-VladZhukov, [#987](https://github.com/RaRe-Technologies/gensim/pull/987))

0.13.4, 2016-10-25
* Passed all the params through the apply call in lda.get_document_topics(), test case to use the per_word_topics through the corpus in test_ldamodel (@parthoiiitm, [#978](https://github.com/RaRe-Technologies/gensim/pull/978))
Expand Down
13 changes: 13 additions & 0 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,19 @@ def __str__(self):
segments.append('t%d' % self.workers)
return '%s(%s)' % (self.__class__.__name__, ','.join(segments))

def delete_temporary_training_data(self, keep_doctags_vectors=True, keep_inference=True):
"""
Discard parameters that are used in training and score. Use if you're sure you're done training a model.
Set `keep_doctags_vectors` to False if you don't want to save doctags vectors,
in this case you can't to use docvecs's most_similar, similarity etc. methods.
Set `keep_inference` to False if you don't want to store parameters that is used for infer_vector method
"""
if not keep_inference:
self._minimize_model(False, False, False)
if self.docvecs and hasattr(self.docvecs, 'doctag_syn0') and not keep_doctags_vectors:
del self.docvecs.doctag_syn0
if self.docvecs and hasattr(self.docvecs, 'doctag_syn0_lockf'):
del self.docvecs.doctag_syn0_lockf

class TaggedBrownCorpus(object):
"""Iterate over documents from the Brown corpus (part of NLTK data), yielding
Expand Down
23 changes: 22 additions & 1 deletion gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def __init__(
self.total_train_time = 0
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

self.model_trimmed_post_training = False
if sentences is not None:
if isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.")
Expand Down Expand Up @@ -745,6 +745,8 @@ def train(self, sentences, total_words=None, word_count=0,
sentences are the same as those that were used to initially build the vocabulary.
"""
if (self.model_trimmed_post_training):
raise RuntimeError("Parameters for training were discarded using model_trimmed_post_training method")
if FAST_VERSION < 0:
import warnings
warnings.warn("C extension not loaded for Word2Vec, training will be slow. "
Expand Down Expand Up @@ -1393,6 +1395,25 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=None, case_inse
def __str__(self):
return "%s(vocab=%s, size=%s, alpha=%s)" % (self.__class__.__name__, len(self.wv.index2word), self.vector_size, self.alpha)

def _minimize_model(self, save_syn1 = False, save_syn1neg = False, save_syn0_lockf = False):
if hasattr(self, 'syn1') and not save_syn1:
del self.syn1
if hasattr(self, 'syn1neg') and not save_syn1neg:
del self.syn1neg
if hasattr(self, 'syn0_lockf') and not save_syn0_lockf:
del self.syn0_lockf
self.model_trimmed_post_training = True

def delete_temporary_training_data(self, replace_word_vectors_with_normalized=False):
"""
Discard parameters that are used in training and score. Use if you're sure you're done training a model.
If `replace_word_vectors_with_normalized` is set, forget the original vectors and only keep the normalized
ones = saves lots of memory!
"""
if replace_word_vectors_with_normalized:
self.init_sims(replace=True)
self._minimize_model()

def save(self, *args, **kwargs):
# don't bother storing the cached normalized vectors, recalculable table
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'table', 'cum_table'])
Expand Down
28 changes: 28 additions & 0 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,34 @@ def models_equal(self, model, model2):
self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag))
self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0))

def test_delete_temporary_training_data(self):
"""Test doc2vec model after delete_temporary_training_data"""
for i in [0, 1]:
for j in [0, 1]:
model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, window=4, hs=i, negative=j)
if i:
self.assertTrue(hasattr(model, 'syn1'))
if j:
self.assertTrue(hasattr(model, 'syn1neg'))
self.assertTrue(hasattr(model, 'syn0_lockf'))
model.delete_temporary_training_data(keep_doctags_vectors=False, keep_inference=False)
self.assertTrue(len(model['human']), 10)
self.assertTrue(model.vocab['graph'].count, 5)
self.assertTrue(not hasattr(model, 'syn1'))
self.assertTrue(not hasattr(model, 'syn1neg'))
self.assertTrue(not hasattr(model, 'syn0_lockf'))
self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0'))
self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0_lockf'))
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=1, negative=0, alpha=0.05, min_count=2, iter=20)
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)
self.assertTrue(model.docvecs and hasattr(model.docvecs, 'doctag_syn0'))
self.assertTrue(hasattr(model, 'syn1'))
self.model_sanity(model)
model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=0, negative=1, alpha=0.05, min_count=2, iter=20)
model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True)
self.model_sanity(model)
self.assertTrue(hasattr(model, 'syn1neg'))

@log_capture()
def testBuildVocabWarning(self, l):
"""Test if logger warning is raised on non-ideal input to a doc2vec model"""
Expand Down
27 changes: 26 additions & 1 deletion gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def testSimilarities(self):
model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2)
model.build_vocab(sentences)
model.train(sentences)

self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph']))
self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees'))
self.assertRaises(ZeroDivisionError, model.n_similarity, ['graph', 'trees'], [])
Expand Down Expand Up @@ -536,6 +536,31 @@ def models_equal(self, model, model2):
most_common_word = max(model.vocab.items(), key=lambda item: item[1].count)[0]
self.assertTrue(np.allclose(model[most_common_word], model2[most_common_word]))

def testDeleteTemporaryTrainingData(self):
"""Test word2vec model after delete_temporary_training_data"""
for i in [0, 1]:
for j in [0, 1]:
model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j)
if i:
self.assertTrue(hasattr(model, 'syn1'))
if j:
self.assertTrue(hasattr(model, 'syn1neg'))
self.assertTrue(hasattr(model, 'syn0_lockf'))
model.delete_temporary_training_data(replace_word_vectors_with_normalized=True)
self.assertTrue(len(model['human']), 10)
self.assertTrue(len(model.vocab), 12)
self.assertTrue(model.vocab['graph'].count, 3)
self.assertTrue(not hasattr(model, 'syn1'))
self.assertTrue(not hasattr(model, 'syn1neg'))
self.assertTrue(not hasattr(model, 'syn0_lockf'))

def testNormalizeAfterTrainingData(self):
model = word2vec.Word2Vec(sentences, min_count=1)
model.save_word2vec_format(testfile(), binary=True)
norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True)
norm_only_model.delete_temporary_training_data(replace_word_vectors_with_normalized=True)
self.assertFalse(np.allclose(model['human'], norm_only_model['human']))

@log_capture()
def testBuildVocabWarning(self, l):
"""Test if warning is raised on non-ideal input to a word2vec model"""
Expand Down

0 comments on commit 284a9f7

Please sign in to comment.