diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_d2vmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_d2vmodel.py new file mode 100644 index 0000000000..ec7e8c4a39 --- /dev/null +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_d2vmodel.py @@ -0,0 +1,93 @@ +from gensim.models.doc2vec import TaggedDocument, Doc2Vec +from gensim.models.word2vec import MAX_WORDS_IN_BATCH +from numpy import vstack +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils.validation import check_is_fitted + + +class SklD2VModel(BaseEstimator, TransformerMixin): + + _PARAMS = [ + 'size', 'alpha', 'window', 'min_count', + 'max_vocab_size', 'sample', 'seed', 'workers', + 'min_alpha', 'hs', 'negative', 'hashfxn', 'iter', + 'trim_rule', 'sorted_vocab', 'batch_words', 'compute_loss', + 'dm_mean', 'dm', 'dbow_words', 'dm_concat', + 'dm_tag_count' + ] + + def __init__(self, size=100, alpha=0.025, window=5, min_count=5, + max_vocab_size=None, sample=1e-3, seed=1, workers=3, + min_alpha=0.0001, hs=0, negative=5, hashfxn=hash, iter=5, + trim_rule=None, sorted_vocab=1, + batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, + dm_mean=None, dm=1, dbow_words=0, dm_concat=0, + dm_tag_count=1): + for param in self._PARAMS: + setattr(self, param, locals()[param]) + + def fit(self, X, y=None): + """ + Train the model while manually decreasing the learning rate + see https://rare-technologies.com/doc2vec-tutorial/#training + for details + + :param X: an iterable of gensim.models.doc2vec.TaggedDocument + :return: a trained SklD2VModel instance + """ + # initialize model with doc2vec parameters + params = {param: getattr(self, param, None) for param in self._PARAMS} + self.gensim_model_ = Doc2Vec(**params) + # learn the vocabulary from X + self.gensim_model_.build_vocab(X) + corpus_count = self.gensim_model_.corpus_count + # train the model while manually controlling alpha + alpha_step = (self.alpha - self.min_alpha) / self.iter + for i in range(self.iter): + self.gensim_model_.train(X, total_examples=corpus_count, epochs=1) + self.gensim_model_.alpha -= alpha_step + self.gensim_model_.min_alpha = self.gensim_model_.alpha + return self + + def transform(self, X): + """ + Transform TaggedDocument to their doc2vec representation + + :param X: an iterable of gensim.models.doc2vec.TaggedDocument + :return: an array of shape (n_samples, n_features) + """ + check_is_fitted(self, 'gensim_model_') + return vstack([ + self.gensim_model_.infer_vector( + x.words, self.alpha, self.min_alpha, self.iter) + for x in X + ]) + + def fit_transform(self, X, y=None, **fit_params): + """ + Train the model on TaggedDocument and + then convert the latter to their doc2vec representation + + :param X: an iterable of gensim.models.doc2vec.TaggedDocument + :return: an array of shape (n_samples, n_features) + """ + self.fit(X) + return self.gensim_model_.docvecs[range(len(X))] + + +class TaggedDocumentTransformer(BaseEstimator, TransformerMixin): + + def fit(self, X=None, y=None): + return self + + def transform(self, X): + """ + Take a list of tokenized documents, and tag them with their index + + :param X: an iterable of tokenized documents (i.e. list of strings) + :return: an list of gensim.models.doc2vec.TaggedDocumentof, shape (n_samples) + """ + return [ + TaggedDocument(words=list(x), tags=[i]) + for i, x in enumerate(X) + ] diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 6cbfe902a3..07b5a1c689 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -20,6 +20,7 @@ from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel from gensim.sklearn_integration.sklearn_wrapper_gensim_w2vmodel import SklW2VModel +from gensim.sklearn_integration.sklearn_wrapper_gensim_d2vmodel import SklD2VModel, TaggedDocumentTransformer from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel from gensim.corpora import mmcorpus, Dictionary from gensim import matutils @@ -482,6 +483,74 @@ def testModelNotFitted(self): self.assertRaises(NotFittedError, w2vmodel_wrapper.transform, word) +class TestSklD2VModelWrapper(unittest.TestCase): + def setUp(self): + numpy.random.seed(0) + self.model = SklD2VModel(size=10, min_count=0, seed=42) + self.tagged_texts = TaggedDocumentTransformer().transform(texts) + self.model.fit(self.tagged_texts) + + def testTransform(self): + # transform one sentence + sentence = self.tagged_texts[0] + matrix = self.model.transform([sentence]) + self.assertEqual(matrix.shape[0], 1) + self.assertEqual(matrix.shape[1], self.model.size) + # transform multiple sentences + matrix = self.model.transform(self.tagged_texts) + self.assertEqual(matrix.shape[0], len(texts)) + self.assertEqual(matrix.shape[1], self.model.size) + + def testPipeline(self): + # create the pipeline + tagger = TaggedDocumentTransformer() + model = SklD2VModel(size=10, min_count=1) + clf = linear_model.LogisticRegression(penalty='l2', C=0.1) + text_d2v = Pipeline([('tags', tagger), ('features', model), ('classifier', clf)]) + + # give the test w2v_texts target labels and train + # 0 = physics, 1 = mathematics + w2v_targets = [1, 1, 1, 1, 1, 0, 0, 0, 0] + text_d2v.fit(w2v_texts, w2v_targets) + + # make sure the pipeline learned something useful + test_texts = [ + 'natural', 'nuclear', 'science', 'electromagnetism', + 'calculus', 'mathematical', 'geometry', 'operations', 'curves' + ] + test_targets = [ + 0, 0, 0, 0, + 1, 1, 1, 1, 1 + ] + score = text_d2v.score(test_texts, test_targets) + self.assertGreater(score, 0.5) + + def testPersistence(self): + model_dump = pickle.dumps(self.model) + model_load = pickle.loads(model_dump) + + sentence = self.tagged_texts[0] + loaded_transformed_vecs = model_load.transform([sentence]) + + # sanity check for transformation operation + self.assertEqual(loaded_transformed_vecs.shape[0], 1) + self.assertEqual(loaded_transformed_vecs.shape[1], model_load.size) + + # comparing the original and loaded models + original_transformed_vecs = self.model.transform([sentence]) + passed = numpy.allclose( + sorted(loaded_transformed_vecs), + sorted(original_transformed_vecs), + atol=1e-1) + self.assertTrue(passed) + + def testModelNotFitted(self): + d2vmodel_wrapper = SklD2VModel(size=10, min_count=0, seed=42) + sentence = self.tagged_texts[0] + with self.assertRaises(NotFittedError): + d2vmodel_wrapper.transform([sentence]) + + class TestSklATModelWrapper(unittest.TestCase): def setUp(self): self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100)