diff --git a/CHANGELOG.md b/CHANGELOG.md index fd73c4f5fa..14a3657972 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ Changes ======= +0.13.5, 2016-11-12 +* Add delete_temporary_training_data() function to word2vec and doc2vec models. (@deepmipt-VladZhukov, [#987](https://github.com/RaRe-Technologies/gensim/pull/987)) 0.13.4, 2016-10-25 * Passed all the params through the apply call in lda.get_document_topics(), test case to use the per_word_topics through the corpus in test_ldamodel (@parthoiiitm, [#978](https://github.com/RaRe-Technologies/gensim/pull/978)) diff --git a/gensim/models/doc2vec.py b/gensim/models/doc2vec.py index d6b0997744..48807a8813 100644 --- a/gensim/models/doc2vec.py +++ b/gensim/models/doc2vec.py @@ -778,6 +778,19 @@ def __str__(self): segments.append('t%d' % self.workers) return '%s(%s)' % (self.__class__.__name__, ','.join(segments)) + def delete_temporary_training_data(self, keep_doctags_vectors=True, keep_inference=True): + """ + Discard parameters that are used in training and score. Use if you're sure you're done training a model. + Set `keep_doctags_vectors` to False if you don't want to save doctags vectors, + in this case you can't to use docvecs's most_similar, similarity etc. methods. + Set `keep_inference` to False if you don't want to store parameters that is used for infer_vector method + """ + if not keep_inference: + self._minimize_model(False, False, False) + if self.docvecs and hasattr(self.docvecs, 'doctag_syn0') and not keep_doctags_vectors: + del self.docvecs.doctag_syn0 + if self.docvecs and hasattr(self.docvecs, 'doctag_syn0_lockf'): + del self.docvecs.doctag_syn0_lockf class TaggedBrownCorpus(object): """Iterate over documents from the Brown corpus (part of NLTK data), yielding diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 18bcf262f2..97f98c7614 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -462,7 +462,7 @@ def __init__( self.total_train_time = 0 self.sorted_vocab = sorted_vocab self.batch_words = batch_words - + self.model_trimmed_post_training = False if sentences is not None: if isinstance(sentences, GeneratorType): raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") @@ -754,6 +754,8 @@ def train(self, sentences, total_words=None, word_count=0, sentences are the same as those that were used to initially build the vocabulary. """ + if (self.model_trimmed_post_training): + raise RuntimeError("Parameters for training were discarded using model_trimmed_post_training method") if FAST_VERSION < 0: import warnings warnings.warn("C extension not loaded for Word2Vec, training will be slow. " @@ -1751,6 +1753,25 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, c def __str__(self): return "%s(vocab=%s, size=%s, alpha=%s)" % (self.__class__.__name__, len(self.index2word), self.vector_size, self.alpha) + def _minimize_model(self, save_syn1 = False, save_syn1neg = False, save_syn0_lockf = False): + if hasattr(self, 'syn1') and not save_syn1: + del self.syn1 + if hasattr(self, 'syn1neg') and not save_syn1neg: + del self.syn1neg + if hasattr(self, 'syn0_lockf') and not save_syn0_lockf: + del self.syn0_lockf + self.model_trimmed_post_training = True + + def delete_temporary_training_data(self, replace_word_vectors_with_normalized=False): + """ + Discard parameters that are used in training and score. Use if you're sure you're done training a model. + If `replace_word_vectors_with_normalized` is set, forget the original vectors and only keep the normalized + ones = saves lots of memory! + """ + if replace_word_vectors_with_normalized: + self.init_sims(replace=True) + self._minimize_model() + def save(self, *args, **kwargs): # don't bother storing the cached normalized vectors, recalculable table kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'table', 'cum_table']) diff --git a/gensim/test/test_doc2vec.py b/gensim/test/test_doc2vec.py index a695b1a724..51392889d9 100644 --- a/gensim/test/test_doc2vec.py +++ b/gensim/test/test_doc2vec.py @@ -287,6 +287,34 @@ def models_equal(self, model, model2): self.assertEqual(len(model.docvecs.offset2doctag), len(model2.docvecs.offset2doctag)) self.assertTrue(np.allclose(model.docvecs.doctag_syn0, model2.docvecs.doctag_syn0)) + def test_delete_temporary_training_data(self): + """Test doc2vec model after delete_temporary_training_data""" + for i in [0, 1]: + for j in [0, 1]: + model = doc2vec.Doc2Vec(sentences, size=5, min_count=1, window=4, hs=i, negative=j) + if i: + self.assertTrue(hasattr(model, 'syn1')) + if j: + self.assertTrue(hasattr(model, 'syn1neg')) + self.assertTrue(hasattr(model, 'syn0_lockf')) + model.delete_temporary_training_data(keep_doctags_vectors=False, keep_inference=False) + self.assertTrue(len(model['human']), 10) + self.assertTrue(model.vocab['graph'].count, 5) + self.assertTrue(not hasattr(model, 'syn1')) + self.assertTrue(not hasattr(model, 'syn1neg')) + self.assertTrue(not hasattr(model, 'syn0_lockf')) + self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0')) + self.assertTrue(model.docvecs and not hasattr(model.docvecs, 'doctag_syn0_lockf')) + model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=1, negative=0, alpha=0.05, min_count=2, iter=20) + model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True) + self.assertTrue(model.docvecs and hasattr(model.docvecs, 'doctag_syn0')) + self.assertTrue(hasattr(model, 'syn1')) + self.model_sanity(model) + model = doc2vec.Doc2Vec(list_corpus, dm=1, dm_mean=1, size=24, window=4, hs=0, negative=1, alpha=0.05, min_count=2, iter=20) + model.delete_temporary_training_data(keep_doctags_vectors=True, keep_inference=True) + self.model_sanity(model) + self.assertTrue(hasattr(model, 'syn1neg')) + @log_capture() def testBuildVocabWarning(self, l): """Test if logger warning is raised on non-ideal input to a doc2vec model""" diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index e7369835e8..57d47a98cf 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -434,7 +434,7 @@ def testSimilarities(self): model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2) model.build_vocab(sentences) model.train(sentences) - + self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph'])) self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees')) self.assertRaises(ZeroDivisionError, model.n_similarity, ['graph', 'trees'], []) @@ -482,6 +482,31 @@ def models_equal(self, model, model2): most_common_word = max(model.vocab.items(), key=lambda item: item[1].count)[0] self.assertTrue(np.allclose(model[most_common_word], model2[most_common_word])) + def testDeleteTemporaryTrainingData(self): + """Test word2vec model after delete_temporary_training_data""" + for i in [0, 1]: + for j in [0, 1]: + model = word2vec.Word2Vec(sentences, size=10, min_count=0, seed=42, hs=i, negative=j) + if i: + self.assertTrue(hasattr(model, 'syn1')) + if j: + self.assertTrue(hasattr(model, 'syn1neg')) + self.assertTrue(hasattr(model, 'syn0_lockf')) + model.delete_temporary_training_data(replace_word_vectors_with_normalized=True) + self.assertTrue(len(model['human']), 10) + self.assertTrue(len(model.vocab), 12) + self.assertTrue(model.vocab['graph'].count, 3) + self.assertTrue(not hasattr(model, 'syn1')) + self.assertTrue(not hasattr(model, 'syn1neg')) + self.assertTrue(not hasattr(model, 'syn0_lockf')) + + def testNormalizeAfterTrainingData(self): + model = word2vec.Word2Vec(sentences, min_count=1) + model.save_word2vec_format(testfile(), binary=True) + norm_only_model = word2vec.Word2Vec.load_word2vec_format(testfile(), binary=True) + norm_only_model.delete_temporary_training_data(replace_word_vectors_with_normalized=True) + self.assertFalse(np.allclose(model['human'], norm_only_model['human'])) + @log_capture() def testBuildVocabWarning(self, l): """Test if warning is raised on non-ideal input to a word2vec model"""