From 7b475029466a99af035f9e8227c4c24def68c423 Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Tue, 21 Jun 2016 21:39:05 -0400 Subject: [PATCH] Add support for using Annoy as an external similarity index w/ Doc2Vec --- gensim/similarities/index.py | 9 ++++++++ gensim/test/test_similarities.py | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/gensim/similarities/index.py b/gensim/similarities/index.py index 331d73d268..6eac18b856 100644 --- a/gensim/similarities/index.py +++ b/gensim/similarities/index.py @@ -16,6 +16,15 @@ def build_from_word2vec(cls, model, num_trees): model.init_sims() return cls._build_from_model(model.syn0norm, model.index2word, model.vector_size, num_trees) + @classmethod + def build_from_doc2vec(cls, model, num_trees): + """Build an Annoy index using document vectors from a Doc2Vec model""" + + docvecs = model.docvecs + docvecs.init_sims() + labels = [docvecs.index_to_doctag(i) for i in range(0, docvecs.count)] + return cls._build_from_model(docvecs.doctag_syn0norm, labels, model.vector_size, num_trees) + @classmethod def _build_from_model(cls, vectors, labels, num_features, num_trees): index = AnnoyIndex(num_features) diff --git a/gensim/test/test_similarities.py b/gensim/test/test_similarities.py index c4ec4d9f60..f9ac5d44b4 100644 --- a/gensim/test/test_similarities.py +++ b/gensim/test/test_similarities.py @@ -19,6 +19,7 @@ from gensim.corpora import mmcorpus, Dictionary from gensim.models import word2vec +from gensim.models import doc2vec from gensim import matutils, utils, similarities from gensim.models import Word2Vec @@ -45,6 +46,9 @@ dictionary = Dictionary(texts) corpus = [dictionary.doc2bow(text) for text in texts] +sentences = [doc2vec.TaggedDocument(words, [i]) + for i, words in enumerate(texts)] + def testfile(): # temporary data will be stored to this file @@ -455,6 +459,41 @@ def testApproxNeighborsMatchExact(self): self.assertEqual(approx_words, exact_words) +class TestDoc2VecSimilarityIndex(unittest.TestCase): + + def setUp(self): + try: + import annoy + except ImportError: + raise unittest.SkipTest("Annoy library is not available") + + from gensim.similarities.index import SimilarityIndex + + self.model = doc2vec.Doc2Vec(sentences, min_count=1) + self.model.init_sims() + self.index = SimilarityIndex.build_from_doc2vec(self.model, 10) + + def testDocumentIsSimilarToItself(self): + vector = self.model.docvecs.doctag_syn0norm[0] + + approx_neighbors = self.index.most_similar(vector, 1) + doc, similarity = approx_neighbors[0] + + self.assertEqual(doc, 0) + self.assertEqual(similarity, 1.0) + + def testApproxNeighborsMatchExact(self): + vector = self.model.docvecs.doctag_syn0norm[0] + approx_neighbors = self.index.most_similar(vector, 5) + exact_neighbors = self.model.docvecs.most_similar( + positive=[vector], topn=5) + + approx_words = [neighbor[0] for neighbor in approx_neighbors] + exact_words = [neighbor[0] for neighbor in exact_neighbors] + + self.assertEqual(approx_words, exact_words) + + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) unittest.main()