Skip to content

Commit

Permalink
Add sklearn wrapper for AuthorTopic model (#1403)
Browse files Browse the repository at this point in the history
* added sklearn wrapper for author topic model

* added 'fit', 'partial_fit', 'transform' functions

* added unit-tests for ATModel

* refactored code acc. to composite design pattern

* refactored wrapper and tests

* removed 'self.corpus' attribute

* updates 'self.model' to 'self.gensim_model'

* updated 'fit' and 'transform' functions

* updated 'testTransform' test

* updated 'tranform' function slightly

* updated 'testTransform' test

* PEP8 change

* updated 'testTransform' test

* included 'NotFittedError' error

* added pipeline unittest for ATModel
  • Loading branch information
chinmayapancholi13 authored and menshikh-iv committed Jun 28, 2017
1 parent 39820cf commit 7ccaabc
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 1 deletion.
1 change: 1 addition & 0 deletions gensim/sklearn_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .sklearn_wrapper_gensim_lsimodel import SklLsiModel # noqa: F401
from .sklearn_wrapper_gensim_rpmodel import SklRpModel # noqa: F401
from .sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel # noqa: F401
from .sklearn_wrapper_gensim_atmodel import SklATModel # noqa: F401
118 changes: 118 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""
import numpy as np
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper


class SklATModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base AuthorTopic module
"""

def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=None,
chunksize=2000, passes=1, iterations=50, decay=0.5, offset=1.0,
alpha='symmetric', eta='symmetric', update_every=1, eval_every=10,
gamma_threshold=0.001, serialized=False, serialization_path=None,
minimum_probability=0.01, random_state=None):
"""
Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel
"""
self.gensim_model = None
self.num_topics = num_topics
self.id2word = id2word
self.author2doc = author2doc
self.doc2author = doc2author
self.chunksize = chunksize
self.passes = passes
self.iterations = iterations
self.decay = decay
self.offset = offset
self.alpha = alpha
self.eta = eta
self.update_every = update_every
self.eval_every = eval_every
self.gamma_threshold = gamma_threshold
self.serialized = serialized
self.serialization_path = serialization_path
self.minimum_probability = minimum_probability
self.random_state = random_state

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"num_topics": self.num_topics, "id2word": self.id2word,
"author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize,
"passes": self.passes, "iterations": self.iterations, "decay": self.decay,
"offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every,
"eval_every": self.eval_every, "gamma_threshold": self.gamma_threshold,
"serialized": self.serialized, "serialization_path": self.serialization_path,
"minimum_probability": self.minimum_probability, "random_state": self.random_state}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklATModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.AuthorTopicModel
"""
self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word,
author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes,
iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta,
update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized,
serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state)
return self

def transform(self, author_names):
"""
Return topic distribution for input authors as a list of
(topic_id, topic_probabiity) 2-tuples.
"""
# The input as array of array
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

check = lambda x: [x] if not isinstance(x, list) else x
author_names = check(author_names)
X = [[] for _ in range(0, len(author_names))]

for k, v in enumerate(author_names):
transformed_author = self.gensim_model[v]
probs_author = list(map(lambda x: x[1], transformed_author))
# Everything should be equal in length
if len(probs_author) != self.num_topics:
probs_author.extend([1e-12] * (self.num_topics - len(probs_author)))
X[k] = probs_author

return np.reshape(np.array(X), (len(author_names), self.num_topics))

def partial_fit(self, X, author2doc=None, doc2author=None):
"""
Train model over X.
"""
if self.gensim_model is None:
self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word,
author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes,
iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta,
update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized,
serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state)

self.gensim_model.update(corpus=X, author2doc=author2doc, doc2author=doc2author)
return self
64 changes: 63 additions & 1 deletion gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import load_files
from sklearn import linear_model
from sklearn import linear_model, cluster
from sklearn.exceptions import NotFittedError
except ImportError:
raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available")
Expand All @@ -19,6 +19,7 @@
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel
from gensim.corpora import mmcorpus, Dictionary
from gensim import matutils

Expand All @@ -39,6 +40,12 @@
]
dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
author2doc = {'john': [0, 1, 2, 3, 4, 5, 6], 'jane': [2, 3, 4, 5, 6, 7, 8], 'jack': [0, 2, 4, 6, 8], 'jill': [1, 3, 5, 7]}

texts_new = texts[0:3]
author2doc_new = {'jill': [0], 'bob': [0, 1], 'sally': [1, 2]}
dictionary_new = Dictionary(texts_new)
corpus_new = [dictionary_new.doc2bow(text) for text in texts_new]

texts_ldaseq = [
[u'senior', u'studios', u'studios', u'studios', u'creators', u'award', u'mobile', u'currently', u'challenges', u'senior', u'summary', u'senior', u'motivated', u'creative', u'senior'],
Expand Down Expand Up @@ -396,5 +403,60 @@ def testModelNotFitted(self):
self.assertRaises(NotFittedError, rpmodel_wrapper.transform, doc)


class TestSklATModelWrapper(unittest.TestCase):
def setUp(self):
self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100)
self.model.fit(corpus)

def testTransform(self):
# transforming multiple authors
author_list = ['jill', 'jack']
author_topics = self.model.transform(author_list)
self.assertEqual(author_topics.shape[0], 2)
self.assertEqual(author_topics.shape[1], self.model.num_topics)

# transforming one author
jill_topics = self.model.transform('jill')
self.assertEqual(jill_topics.shape[0], 1)
self.assertEqual(jill_topics.shape[1], self.model.num_topics)

def testPartialFit(self):
self.model.partial_fit(corpus_new, author2doc=author2doc_new)

# Did we learn something about Sally?
output_topics = self.model.transform('sally')
sally_topics = output_topics[0] # getting the topics corresponding to 'sally' (from the list of lists)
self.assertTrue(all(sally_topics > 0))

def testSetGetParams(self):
# updating only one param
self.model.set_params(num_topics=3)
model_params = self.model.get_params()
self.assertEqual(model_params["num_topics"], 3)

# updating multiple params
param_dict = {"passes": 5, "iterations": 10}
self.model.set_params(**param_dict)
model_params = self.model.get_params()
for key in param_dict.keys():
self.assertEqual(model_params[key], param_dict[key])

def testPipeline(self):
# train the AuthorTopic model first
model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=10, passes=100)
model.fit(corpus)

# create and train clustering model
clstr = cluster.MiniBatchKMeans(n_clusters=2)
authors_full = ['john', 'jane', 'jack', 'jill']
clstr.fit(model.transform(authors_full))

# stack together the two models in a pipeline
text_atm = Pipeline((('features', model,), ('cluster', clstr)))
author_list = ['jane', 'jack', 'jill']
ret_val = text_atm.predict(author_list)
self.assertEqual(len(ret_val), len(author_list))


if __name__ == '__main__':
unittest.main()

0 comments on commit 7ccaabc

Please sign in to comment.