diff --git a/gensim/sklearn_integration/__init__.py b/gensim/sklearn_integration/__init__.py index d351e625fc..908ce47380 100644 --- a/gensim/sklearn_integration/__init__.py +++ b/gensim/sklearn_integration/__init__.py @@ -15,3 +15,4 @@ from .sklearn_wrapper_gensim_lsimodel import SklLsiModel # noqa: F401 from .sklearn_wrapper_gensim_rpmodel import SklRpModel # noqa: F401 from .sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel # noqa: F401 +from .sklearn_wrapper_gensim_atmodel import SklATModel # noqa: F401 diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py new file mode 100644 index 0000000000..2ef6db3c7a --- /dev/null +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2011 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Scikit learn interface for gensim for easy use of gensim with scikit-learn +Follows scikit-learn API conventions +""" +import numpy as np +from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.exceptions import NotFittedError + +from gensim import models +from gensim.sklearn_integration import BaseSklearnWrapper + + +class SklATModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator): + """ + Base AuthorTopic module + """ + + def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=None, + chunksize=2000, passes=1, iterations=50, decay=0.5, offset=1.0, + alpha='symmetric', eta='symmetric', update_every=1, eval_every=10, + gamma_threshold=0.001, serialized=False, serialization_path=None, + minimum_probability=0.01, random_state=None): + """ + Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel + """ + self.gensim_model = None + self.num_topics = num_topics + self.id2word = id2word + self.author2doc = author2doc + self.doc2author = doc2author + self.chunksize = chunksize + self.passes = passes + self.iterations = iterations + self.decay = decay + self.offset = offset + self.alpha = alpha + self.eta = eta + self.update_every = update_every + self.eval_every = eval_every + self.gamma_threshold = gamma_threshold + self.serialized = serialized + self.serialization_path = serialization_path + self.minimum_probability = minimum_probability + self.random_state = random_state + + def get_params(self, deep=True): + """ + Returns all parameters as dictionary. + """ + return {"num_topics": self.num_topics, "id2word": self.id2word, + "author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize, + "passes": self.passes, "iterations": self.iterations, "decay": self.decay, + "offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every, + "eval_every": self.eval_every, "gamma_threshold": self.gamma_threshold, + "serialized": self.serialized, "serialization_path": self.serialization_path, + "minimum_probability": self.minimum_probability, "random_state": self.random_state} + + def set_params(self, **parameters): + """ + Set all parameters. + """ + super(SklATModel, self).set_params(**parameters) + return self + + def fit(self, X, y=None): + """ + Fit the model according to the given training data. + Calls gensim.models.AuthorTopicModel + """ + self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word, + author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, + iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, + update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, + serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) + return self + + def transform(self, author_names): + """ + Return topic distribution for input authors as a list of + (topic_id, topic_probabiity) 2-tuples. + """ + # The input as array of array + if self.gensim_model is None: + raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") + + check = lambda x: [x] if not isinstance(x, list) else x + author_names = check(author_names) + X = [[] for _ in range(0, len(author_names))] + + for k, v in enumerate(author_names): + transformed_author = self.gensim_model[v] + probs_author = list(map(lambda x: x[1], transformed_author)) + # Everything should be equal in length + if len(probs_author) != self.num_topics: + probs_author.extend([1e-12] * (self.num_topics - len(probs_author))) + X[k] = probs_author + + return np.reshape(np.array(X), (len(author_names), self.num_topics)) + + def partial_fit(self, X, author2doc=None, doc2author=None): + """ + Train model over X. + """ + if self.gensim_model is None: + self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word, + author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, + iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, + update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, + serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) + + self.gensim_model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) + return self diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index f18fae044f..09e65dcc8b 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -10,7 +10,7 @@ from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import CountVectorizer from sklearn.datasets import load_files - from sklearn import linear_model + from sklearn import linear_model, cluster from sklearn.exceptions import NotFittedError except ImportError: raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available") @@ -19,6 +19,7 @@ from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel +from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel from gensim.corpora import mmcorpus, Dictionary from gensim import matutils @@ -39,6 +40,12 @@ ] dictionary = Dictionary(texts) corpus = [dictionary.doc2bow(text) for text in texts] +author2doc = {'john': [0, 1, 2, 3, 4, 5, 6], 'jane': [2, 3, 4, 5, 6, 7, 8], 'jack': [0, 2, 4, 6, 8], 'jill': [1, 3, 5, 7]} + +texts_new = texts[0:3] +author2doc_new = {'jill': [0], 'bob': [0, 1], 'sally': [1, 2]} +dictionary_new = Dictionary(texts_new) +corpus_new = [dictionary_new.doc2bow(text) for text in texts_new] texts_ldaseq = [ [u'senior', u'studios', u'studios', u'studios', u'creators', u'award', u'mobile', u'currently', u'challenges', u'senior', u'summary', u'senior', u'motivated', u'creative', u'senior'], @@ -396,5 +403,60 @@ def testModelNotFitted(self): self.assertRaises(NotFittedError, rpmodel_wrapper.transform, doc) +class TestSklATModelWrapper(unittest.TestCase): + def setUp(self): + self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100) + self.model.fit(corpus) + + def testTransform(self): + # transforming multiple authors + author_list = ['jill', 'jack'] + author_topics = self.model.transform(author_list) + self.assertEqual(author_topics.shape[0], 2) + self.assertEqual(author_topics.shape[1], self.model.num_topics) + + # transforming one author + jill_topics = self.model.transform('jill') + self.assertEqual(jill_topics.shape[0], 1) + self.assertEqual(jill_topics.shape[1], self.model.num_topics) + + def testPartialFit(self): + self.model.partial_fit(corpus_new, author2doc=author2doc_new) + + # Did we learn something about Sally? + output_topics = self.model.transform('sally') + sally_topics = output_topics[0] # getting the topics corresponding to 'sally' (from the list of lists) + self.assertTrue(all(sally_topics > 0)) + + def testSetGetParams(self): + # updating only one param + self.model.set_params(num_topics=3) + model_params = self.model.get_params() + self.assertEqual(model_params["num_topics"], 3) + + # updating multiple params + param_dict = {"passes": 5, "iterations": 10} + self.model.set_params(**param_dict) + model_params = self.model.get_params() + for key in param_dict.keys(): + self.assertEqual(model_params[key], param_dict[key]) + + def testPipeline(self): + # train the AuthorTopic model first + model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=10, passes=100) + model.fit(corpus) + + # create and train clustering model + clstr = cluster.MiniBatchKMeans(n_clusters=2) + authors_full = ['john', 'jane', 'jack', 'jill'] + clstr.fit(model.transform(authors_full)) + + # stack together the two models in a pipeline + text_atm = Pipeline((('features', model,), ('cluster', clstr))) + author_list = ['jane', 'jack', 'jill'] + ret_val = text_atm.predict(author_list) + self.assertEqual(len(ret_val), len(author_list)) + + if __name__ == '__main__': unittest.main()