Skip to content

Commit

Permalink
Add an sklearn wrapper for the Doc2Vec model
Browse files Browse the repository at this point in the history
 * TaggedDocumentTransformer transform tokenized documents to TaggedDocuments
 * SklD2VModel transforms TaggedDocuments to doc2vec vectors
  • Loading branch information
oxymor0n committed Jul 12, 2017
1 parent 3e38e33 commit 9ddcf97
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
93 changes: 93 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_d2vmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from gensim.models.doc2vec import TaggedDocument, Doc2Vec
from gensim.models.word2vec import MAX_WORDS_IN_BATCH
from numpy import vstack
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted


class SklD2VModel(BaseEstimator, TransformerMixin):

_PARAMS = [
'size', 'alpha', 'window', 'min_count',
'max_vocab_size', 'sample', 'seed', 'workers',
'min_alpha', 'hs', 'negative', 'hashfxn', 'iter',
'trim_rule', 'sorted_vocab', 'batch_words', 'compute_loss',
'dm_mean', 'dm', 'dbow_words', 'dm_concat',
'dm_tag_count'
]

def __init__(self, size=100, alpha=0.025, window=5, min_count=5,
max_vocab_size=None, sample=1e-3, seed=1, workers=3,
min_alpha=0.0001, hs=0, negative=5, hashfxn=hash, iter=5,
trim_rule=None, sorted_vocab=1,
batch_words=MAX_WORDS_IN_BATCH, compute_loss=False,
dm_mean=None, dm=1, dbow_words=0, dm_concat=0,
dm_tag_count=1):
for param in self._PARAMS:
setattr(self, param, locals()[param])

def fit(self, X, y=None):
"""
Train the model while manually decreasing the learning rate
see https://rare-technologies.com/doc2vec-tutorial/#training
for details
:param X: an iterable of gensim.models.doc2vec.TaggedDocument
:return: a trained SklD2VModel instance
"""
# initialize model with doc2vec parameters
params = {param: getattr(self, param, None) for param in self._PARAMS}
self.gensim_model_ = Doc2Vec(**params)
# learn the vocabulary from X
self.gensim_model_.build_vocab(X)
corpus_count = self.gensim_model_.corpus_count
# train the model while manually controlling alpha
alpha_step = (self.alpha - self.min_alpha) / self.iter
for i in range(self.iter):
self.gensim_model_.train(X, total_examples=corpus_count, epochs=1)
self.gensim_model_.alpha -= alpha_step
self.gensim_model_.min_alpha = self.gensim_model_.alpha
return self

def transform(self, X):
"""
Transform TaggedDocument to their doc2vec representation
:param X: an iterable of gensim.models.doc2vec.TaggedDocument
:return: an array of shape (n_samples, n_features)
"""
check_is_fitted(self, 'gensim_model_')
return vstack([
self.gensim_model_.infer_vector(
x.words, self.alpha, self.min_alpha, self.iter)
for x in X
])

def fit_transform(self, X, y=None, **fit_params):
"""
Train the model on TaggedDocument and
then convert the latter to their doc2vec representation
:param X: an iterable of gensim.models.doc2vec.TaggedDocument
:return: an array of shape (n_samples, n_features)
"""
self.fit(X)
return self.gensim_model_.docvecs[range(len(X))]


class TaggedDocumentTransformer(BaseEstimator, TransformerMixin):

def fit(self, X=None, y=None):
return self

def transform(self, X):
"""
Take a list of tokenized documents, and tag them with their index
:param X: an iterable of tokenized documents (i.e. list of strings)
:return: an list of gensim.models.doc2vec.TaggedDocumentof, shape (n_samples)
"""
return [
TaggedDocument(words=list(x), tags=[i])
for i, x in enumerate(X)
]
69 changes: 69 additions & 0 deletions gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_w2vmodel import SklW2VModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_d2vmodel import SklD2VModel, TaggedDocumentTransformer
from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel
from gensim.corpora import mmcorpus, Dictionary
from gensim import matutils
Expand Down Expand Up @@ -482,6 +483,74 @@ def testModelNotFitted(self):
self.assertRaises(NotFittedError, w2vmodel_wrapper.transform, word)


class TestSklD2VModelWrapper(unittest.TestCase):
def setUp(self):
numpy.random.seed(0)
self.model = SklD2VModel(size=10, min_count=0, seed=42)
self.tagged_texts = TaggedDocumentTransformer().transform(texts)
self.model.fit(self.tagged_texts)

def testTransform(self):
# transform one sentence
sentence = self.tagged_texts[0]
matrix = self.model.transform([sentence])
self.assertEqual(matrix.shape[0], 1)
self.assertEqual(matrix.shape[1], self.model.size)
# transform multiple sentences
matrix = self.model.transform(self.tagged_texts)
self.assertEqual(matrix.shape[0], len(texts))
self.assertEqual(matrix.shape[1], self.model.size)

def testPipeline(self):
# create the pipeline
tagger = TaggedDocumentTransformer()
model = SklD2VModel(size=10, min_count=1)
clf = linear_model.LogisticRegression(penalty='l2', C=0.1)
text_d2v = Pipeline([('tags', tagger), ('features', model), ('classifier', clf)])

# give the test w2v_texts target labels and train
# 0 = physics, 1 = mathematics
w2v_targets = [1, 1, 1, 1, 1, 0, 0, 0, 0]
text_d2v.fit(w2v_texts, w2v_targets)

# make sure the pipeline learned something useful
test_texts = [
'natural', 'nuclear', 'science', 'electromagnetism',
'calculus', 'mathematical', 'geometry', 'operations', 'curves'
]
test_targets = [
0, 0, 0, 0,
1, 1, 1, 1, 1
]
score = text_d2v.score(test_texts, test_targets)
self.assertGreater(score, 0.5)

def testPersistence(self):
model_dump = pickle.dumps(self.model)
model_load = pickle.loads(model_dump)

sentence = self.tagged_texts[0]
loaded_transformed_vecs = model_load.transform([sentence])

# sanity check for transformation operation
self.assertEqual(loaded_transformed_vecs.shape[0], 1)
self.assertEqual(loaded_transformed_vecs.shape[1], model_load.size)

# comparing the original and loaded models
original_transformed_vecs = self.model.transform([sentence])
passed = numpy.allclose(
sorted(loaded_transformed_vecs),
sorted(original_transformed_vecs),
atol=1e-1)
self.assertTrue(passed)

def testModelNotFitted(self):
d2vmodel_wrapper = SklD2VModel(size=10, min_count=0, seed=42)
sentence = self.tagged_texts[0]
with self.assertRaises(NotFittedError):
d2vmodel_wrapper.transform([sentence])


class TestSklATModelWrapper(unittest.TestCase):
def setUp(self):
self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100)
Expand Down

0 comments on commit 9ddcf97

Please sign in to comment.