Skip to content

Commit

Permalink
Add support for using Annoy as an external similarity index w/ Doc2Vec
Browse files Browse the repository at this point in the history
  • Loading branch information
karlhigley committed Jun 22, 2016
1 parent a5dab21 commit 2816523
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
7 changes: 7 additions & 0 deletions gensim/similarities/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def build_from_word2vec(cls, model, num_trees):
model.init_sims()
return cls._build_from_model(model.syn0norm, model.index2word, model.vector_size, num_trees)

@classmethod
def build_from_doc2vec(cls, model, num_trees):
docvecs = model.docvecs
docvecs.init_sims()
labels = [docvecs.index_to_doctag(i) for i in range(0, docvecs.count)]
return cls._build_from_model(docvecs.doctag_syn0norm, labels, model.vector_size, num_trees)

@classmethod
def _build_from_model(cls, vectors, labels, num_features, num_trees):
index = AnnoyIndex(num_features)
Expand Down
39 changes: 39 additions & 0 deletions gensim/test/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from gensim.corpora import mmcorpus, Dictionary
from gensim.models import word2vec
from gensim.models import doc2vec
from gensim import matutils, utils, similarities
from gensim.models import Word2Vec

Expand All @@ -45,6 +46,9 @@
dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]

sentences = [doc2vec.TaggedDocument(words, [i])
for i, words in enumerate(texts)]


def testfile():
# temporary data will be stored to this file
Expand Down Expand Up @@ -455,6 +459,41 @@ def testApproxNeighborsMatchExact(self):
self.assertEqual(approx_words, exact_words)


class TestDoc2VecSimilarityIndex(unittest.TestCase):

def setUp(self):
try:
import annoy
except ImportError:
raise unittest.SkipTest("Annoy library is not available")

from gensim.similarities.index import SimilarityIndex

self.model = doc2vec.Doc2Vec(sentences, min_count=1)
self.model.init_sims()
self.index = SimilarityIndex.build_from_doc2vec(self.model, 10)

def testDocumentIsSimilarToItself(self):
vector = self.model.docvecs.doctag_syn0norm[0]

approx_neighbors = self.index.most_similar(vector, 1)
doc, similarity = approx_neighbors[0]

self.assertEqual(doc, 0)
self.assertEqual(similarity, 1.0)

def testApproxNeighborsMatchExact(self):
vector = self.model.docvecs.doctag_syn0norm[0]
approx_neighbors = self.index.most_similar(vector, 5)
exact_neighbors = self.model.docvecs.most_similar(
positive=[vector], topn=5)

approx_words = [neighbor[0] for neighbor in approx_neighbors]
exact_words = [neighbor[0] for neighbor in exact_neighbors]

self.assertEqual(approx_words, exact_words)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 2816523

Please sign in to comment.