Skip to content

Commit

Permalink
Update sklearn API for Gensim models (#1473)
Browse files Browse the repository at this point in the history
* renamed sklearn wrapper classes

* added newline for flake8 check

* renamed sklearn api files

* updated tests for sklearn api

* updated ipynb for sklearn api

* PEP8 changes

* updated docstrings for sklearn wrappers

* added 'testPersistence' and 'testModelNotFitted' tests for author topic model

* removed 'set_params' function from all wrappers

* removed 'get_params' function from base class

* removed 'get_params' function from all api classes

* removed 'partial_fit()' from base class

* updated error message

* updated error message for 'partial_fit' function in W2VTransformer

* removed 'BaseTransformer' class

* updated error message for 'partial_fit' in 'W2VTransformer'

* added checks for setting attributes after calling 'fit'

* flake8 fix

* using 'sparse2full' in 'transform' function

* added missing imports

* added comment about returning dense representation in 'transform' function

* added 'testConsistencyWithGensimModel' for ldamodel

* updated ipynb

* updated 'testPartialFit' for Lda and Lsi transformers

* added author info

* added 'testConsistencyWithGensimModel' for w2v transformer
  • Loading branch information
chinmayapancholi13 authored and menshikh-iv committed Aug 10, 2017
1 parent 9c43ef5 commit 718b1c6
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 366 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
"metadata": {},
"source": [
"The wrappers available (as of now) are :\n",
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklLdaModel```),which implements gensim's ```LDA Model``` in a scikit-learn interface\n",
"* LdaModel (```gensim.sklearn_api.ldamodel.LdaTransformer```),which implements gensim's ```LDA Model``` in a scikit-learn interface\n",
"\n",
"* LsiModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklLsiModel```),which implements gensim's ```LSI Model``` in a scikit-learn interface\n",
"* LsiModel (```gensim.sklearn_api.lsimodel.LsiTransformer```),which implements gensim's ```LSI Model``` in a scikit-learn interface\n",
"\n",
"* RpModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_rpmodel.SklRpModel```),which implements gensim's ```Random Projections Model``` in a scikit-learn interface\n",
"* RpModel (```gensim.sklearn_api.rpmodel.RpTransformer```),which implements gensim's ```Random Projections Model``` in a scikit-learn interface\n",
"\n",
"* LDASeq Model (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklLdaSeqModel```),which implements gensim's ```LdaSeqModel``` in a scikit-learn interface"
"* LDASeq Model (```gensim.sklearn_api.ldaseqmodel.LdaSeqTransformer```),which implements gensim's ```LdaSeqModel``` in a scikit-learn interface"
]
},
{
Expand Down Expand Up @@ -56,7 +56,7 @@
}
],
"source": [
"from gensim.sklearn_integration import SklLdaModel"
"from gensim.sklearn_api import LdaTransformer"
]
},
{
Expand Down Expand Up @@ -105,15 +105,15 @@
{
"data": {
"text/plain": [
"array([[ 0.85275314, 0.14724686],\n",
" [ 0.12390183, 0.87609817],\n",
" [ 0.4612995 , 0.5387005 ],\n",
" [ 0.84924177, 0.15075823],\n",
"array([[ 0.85275316, 0.14724687],\n",
" [ 0.12390183, 0.87609816],\n",
" [ 0.46129951, 0.53870052],\n",
" [ 0.84924179, 0.15075824],\n",
" [ 0.49180096, 0.50819904],\n",
" [ 0.40086923, 0.59913077],\n",
" [ 0.28454427, 0.71545573],\n",
" [ 0.88776198, 0.11223802],\n",
" [ 0.84210373, 0.15789627]])"
" [ 0.40086922, 0.59913075],\n",
" [ 0.28454426, 0.71545571],\n",
" [ 0.88776201, 0.11223802],\n",
" [ 0.84210372, 0.15789627]], dtype=float32)"
]
},
"execution_count": 3,
Expand All @@ -122,7 +122,7 @@
}
],
"source": [
"model = SklLdaModel(num_topics=2, id2word=dictionary, iterations=20, random_state=1)\n",
"model = LdaTransformer(num_topics=2, id2word=dictionary, iterations=20, random_state=1)\n",
"model.fit(corpus)\n",
"model.transform(corpus)"
]
Expand All @@ -145,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"collapsed": true
},
Expand All @@ -155,22 +155,19 @@
"from gensim import matutils\n",
"from gensim.models.ldamodel import LdaModel\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel"
"from gensim.sklearn_api.ldamodel import LdaTransformer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n",
"cats = ['alt.atheism',\n",
" 'comp.graphics',\n",
" 'rec.autos'\n",
" ]\n",
"cats = ['rec.sport.baseball', 'sci.crypt']\n",
"data = fetch_20newsgroups(subset='train', categories=cats, shuffle=True)"
]
},
Expand All @@ -183,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {
"collapsed": true
},
Expand All @@ -203,13 +200,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"obj = SklLdaModel(id2word=id2word, num_topics=5, iterations=20)\n",
"obj = LdaTransformer(id2word=id2word, num_topics=5, iterations=20)\n",
"lda = obj.fit(corpus)"
]
},
Expand All @@ -224,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {
"collapsed": true
},
Expand All @@ -242,7 +239,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -251,13 +248,13 @@
"{'iterations': 20, 'num_topics': 2}"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"obj = SklLdaModel(id2word=id2word, num_topics=2, iterations=5, scorer='u_mass') # here 'scorer' can be 'perplexity' or 'u_mass'\n",
"obj = LdaTransformer(id2word=id2word, num_topics=2, iterations=5, scorer='u_mass') # here 'scorer' can be 'perplexity' or 'u_mass'\n",
"parameters = {'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)}\n",
"\n",
"# set `scoring` as `None` to use the inbuilt score function of `SklLdaModel` class\n",
Expand All @@ -276,16 +273,16 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'iterations': 50, 'num_topics': 2}"
"{'iterations': 20, 'num_topics': 2}"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -298,7 +295,7 @@
" goodcm = CoherenceModel(model=estimator.gensim_model, texts=data_texts, dictionary=estimator.gensim_model.id2word, coherence='c_v')\n",
" return goodcm.get_coherence()\n",
"\n",
"obj = SklLdaModel(id2word=id2word, num_topics=5, iterations=5)\n",
"obj = LdaTransformer(id2word=id2word, num_topics=5, iterations=5)\n",
"parameters = {'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)}\n",
"\n",
"# set `scoring` as your custom scoring function\n",
Expand All @@ -317,7 +314,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {
"collapsed": true
},
Expand All @@ -336,24 +333,36 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"id2word = Dictionary([_.split() for _ in data.data])\n",
"corpus = [id2word.doc2bow(i.split()) for i in data.data]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.70383516 -1.02111207 -0.70109648 -0.4570351 0.02175549 -0.20545727\n",
" -0.01517654 0.0717219 0.51160112 0.49580676 0.33125423 0.31986992\n",
" 0.48162159 0.26829541 0.11292571]\n",
"Positive features: hanging:0.51 localized:0.50 course...:0.48 LAST:0.33 wax):0.32 Stoakley):0.27 Signature!:0.11 circuitry:0.07 technique),:0.02\n",
"Negative features: considered.:-1.02 al-Qanawi,:-0.70 alt.autos.karting:-0.70 considered,:-0.46 360.0;:-0.21 talk.origins:-0.02\n",
"0.437876960193\n"
"[ 0.3032212 0.53114732 -0.3556002 0.05528797 -0.23462074 0.10164825\n",
" -0.34895972 -0.07528751 -0.31437197 -0.24760965 -0.27430636 -0.05328458\n",
" 0.1792989 -0.11535102 0.98473296]\n",
"Positive features: >Pat:0.98 considered,:0.53 Fame.:0.30 internet...:0.18 comp.org.eff.talk.:0.10 Keach:0.06\n",
"Negative features: Fame,:-0.36 01101001B:-0.35 circuitry:-0.31 hanging:-0.27 [email protected]:-0.25 comp.org.eff.talk,:-0.23 dome.:-0.12 *best*:-0.08 trawling:-0.05\n",
"0.648489932886\n"
]
}
],
"source": [
"model = SklLdaModel(num_topics=15, id2word=id2word, iterations=10, random_state=37)\n",
"model = LdaTransformer(num_topics=15, id2word=id2word, iterations=10, random_state=37)\n",
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
"pipe.fit(corpus, data.target)\n",
Expand All @@ -377,13 +386,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from gensim.sklearn_integration import SklLsiModel"
"from gensim.sklearn_api import LsiTransformer"
]
},
{
Expand All @@ -395,24 +404,24 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.00929522 -0.68199243 0.00292464 -0.22145158 0.58612298 0.16492053\n",
" 0.00354379 0.37629786 -0.17202791 -0.03847397 -0.05041424 -0.08953721\n",
" 0.21241931 -0.04824542 0.2287388 ]\n",
"Positive features: technique),:0.59 circuitry:0.38 Signature!:0.23 course...:0.21 360.0;:0.16 al-Qanawi,:0.01 talk.origins:0.00 alt.autos.karting:0.00\n",
"Negative features: considered.:-0.68 considered,:-0.22 hanging:-0.17 wax):-0.09 LAST:-0.05 Stoakley):-0.05 localized:-0.04\n",
"0.683353437877\n"
"[ 0.13653967 -0.00378269 0.02652037 0.08496786 -0.02401959 -0.60089273\n",
" -1.0708177 -0.03932274 -0.43813039 -0.54848409 -0.20147759 0.21781259\n",
" 1.30378972 -0.08678691 -0.17529122]\n",
"Positive features: internet...:1.30 trawling:0.22 Fame.:0.14 Keach:0.08 Fame,:0.03\n",
"Negative features: 01101001B:-1.07 comp.org.eff.talk.:-0.60 [email protected]:-0.55 circuitry:-0.44 hanging:-0.20 >Pat:-0.18 dome.:-0.09 *best*:-0.04 comp.org.eff.talk,:-0.02 considered,:-0.00\n",
"0.865771812081\n"
]
}
],
"source": [
"model = SklLsiModel(num_topics=15, id2word=id2word)\n",
"model = LsiTransformer(num_topics=15, id2word=id2word)\n",
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
"pipe.fit(corpus, data.target)\n",
Expand All @@ -436,13 +445,13 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from gensim.sklearn_integration import SklRpModel"
"from gensim.sklearn_api import RpTransformer"
]
},
{
Expand All @@ -454,22 +463,22 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.00474168 0.01863391]\n",
"Positive features: considered.:0.02\n",
"Negative features: al-Qanawi,:-0.00\n",
"0.434861278649\n"
"[-0.03186506 -0.00872616]\n",
"Positive features: \n",
"Negative features: Fame.:-0.03 considered,:-0.01\n",
"0.621644295302\n"
]
}
],
"source": [
"model = SklRpModel(num_topics=2)\n",
"model = RpTransformer(num_topics=2)\n",
"np.random.mtrand.RandomState(1) # set seed for getting same result\n",
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
Expand All @@ -494,13 +503,13 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from gensim.sklearn_integration import SklLdaSeqModel"
"from gensim.sklearn_api import LdaSeqTransformer"
]
},
{
Expand All @@ -512,7 +521,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"metadata": {},
"outputs": [
{
Expand All @@ -527,9 +536,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.04877123 -0.04877123]\n",
"Positive features: In-Reply-To::0.05\n",
"Negative features: nicknames:-0.05\n",
"[-0.04877324 0.04877324]\n",
"Positive features: NLCS:0.05\n",
"Negative features: What:-0.05\n",
"1.0\n"
]
}
Expand All @@ -540,7 +549,7 @@
"id2word = Dictionary(map(lambda x: x.split(), test_data))\n",
"corpus = [id2word.doc2bow(i.split()) for i in test_data]\n",
"\n",
"model = SklLdaSeqModel(id2word=id2word, num_topics=2, time_slice=[1, 1, 1], initialize='gensim')\n",
"model = LdaSeqTransformer(id2word=id2word, num_topics=2, time_slice=[1, 1, 1], initialize='gensim')\n",
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n",
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n",
"pipe.fit(corpus, test_target)\n",
Expand Down
19 changes: 19 additions & 0 deletions gensim/sklearn_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Chinmaya Pancholi <[email protected]>
# Copyright (C) 2017 Radim Rehurek <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""Scikit learn wrapper for gensim.
Contains various gensim based implementations which match with scikit-learn standards.
See [1] for complete set of conventions.
[1] http://scikit-learn.org/stable/developers/
"""


from .ldamodel import LdaTransformer # noqa: F401
from .lsimodel import LsiTransformer # noqa: F401
from .rpmodel import RpTransformer # noqa: F401
from .ldaseqmodel import LdaSeqTransformer # noqa: F401
from .w2vmodel import W2VTransformer # noqa: F401
from .atmodel import AuthorTopicTransformer # noqa: F401
Loading

0 comments on commit 718b1c6

Please sign in to comment.