-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update sklearn API for Gensim models (#1473)
* renamed sklearn wrapper classes * added newline for flake8 check * renamed sklearn api files * updated tests for sklearn api * updated ipynb for sklearn api * PEP8 changes * updated docstrings for sklearn wrappers * added 'testPersistence' and 'testModelNotFitted' tests for author topic model * removed 'set_params' function from all wrappers * removed 'get_params' function from base class * removed 'get_params' function from all api classes * removed 'partial_fit()' from base class * updated error message * updated error message for 'partial_fit' function in W2VTransformer * removed 'BaseTransformer' class * updated error message for 'partial_fit' in 'W2VTransformer' * added checks for setting attributes after calling 'fit' * flake8 fix * using 'sparse2full' in 'transform' function * added missing imports * added comment about returning dense representation in 'transform' function * added 'testConsistencyWithGensimModel' for ldamodel * updated ipynb * updated 'testPartialFit' for Lda and Lsi transformers * added author info * added 'testConsistencyWithGensimModel' for w2v transformer
- Loading branch information
1 parent
9c43ef5
commit 718b1c6
Showing
11 changed files
with
288 additions
and
366 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,13 @@ | |
"metadata": {}, | ||
"source": [ | ||
"The wrappers available (as of now) are :\n", | ||
"* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklLdaModel```),which implements gensim's ```LDA Model``` in a scikit-learn interface\n", | ||
"* LdaModel (```gensim.sklearn_api.ldamodel.LdaTransformer```),which implements gensim's ```LDA Model``` in a scikit-learn interface\n", | ||
"\n", | ||
"* LsiModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklLsiModel```),which implements gensim's ```LSI Model``` in a scikit-learn interface\n", | ||
"* LsiModel (```gensim.sklearn_api.lsimodel.LsiTransformer```),which implements gensim's ```LSI Model``` in a scikit-learn interface\n", | ||
"\n", | ||
"* RpModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_rpmodel.SklRpModel```),which implements gensim's ```Random Projections Model``` in a scikit-learn interface\n", | ||
"* RpModel (```gensim.sklearn_api.rpmodel.RpTransformer```),which implements gensim's ```Random Projections Model``` in a scikit-learn interface\n", | ||
"\n", | ||
"* LDASeq Model (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklLdaSeqModel```),which implements gensim's ```LdaSeqModel``` in a scikit-learn interface" | ||
"* LDASeq Model (```gensim.sklearn_api.ldaseqmodel.LdaSeqTransformer```),which implements gensim's ```LdaSeqModel``` in a scikit-learn interface" | ||
] | ||
}, | ||
{ | ||
|
@@ -56,7 +56,7 @@ | |
} | ||
], | ||
"source": [ | ||
"from gensim.sklearn_integration import SklLdaModel" | ||
"from gensim.sklearn_api import LdaTransformer" | ||
] | ||
}, | ||
{ | ||
|
@@ -105,15 +105,15 @@ | |
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([[ 0.85275314, 0.14724686],\n", | ||
" [ 0.12390183, 0.87609817],\n", | ||
" [ 0.4612995 , 0.5387005 ],\n", | ||
" [ 0.84924177, 0.15075823],\n", | ||
"array([[ 0.85275316, 0.14724687],\n", | ||
" [ 0.12390183, 0.87609816],\n", | ||
" [ 0.46129951, 0.53870052],\n", | ||
" [ 0.84924179, 0.15075824],\n", | ||
" [ 0.49180096, 0.50819904],\n", | ||
" [ 0.40086923, 0.59913077],\n", | ||
" [ 0.28454427, 0.71545573],\n", | ||
" [ 0.88776198, 0.11223802],\n", | ||
" [ 0.84210373, 0.15789627]])" | ||
" [ 0.40086922, 0.59913075],\n", | ||
" [ 0.28454426, 0.71545571],\n", | ||
" [ 0.88776201, 0.11223802],\n", | ||
" [ 0.84210372, 0.15789627]], dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 3, | ||
|
@@ -122,7 +122,7 @@ | |
} | ||
], | ||
"source": [ | ||
"model = SklLdaModel(num_topics=2, id2word=dictionary, iterations=20, random_state=1)\n", | ||
"model = LdaTransformer(num_topics=2, id2word=dictionary, iterations=20, random_state=1)\n", | ||
"model.fit(corpus)\n", | ||
"model.transform(corpus)" | ||
] | ||
|
@@ -145,7 +145,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
|
@@ -155,22 +155,19 @@ | |
"from gensim import matutils\n", | ||
"from gensim.models.ldamodel import LdaModel\n", | ||
"from sklearn.datasets import fetch_20newsgroups\n", | ||
"from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel" | ||
"from gensim.sklearn_api.ldamodel import LdaTransformer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"execution_count": 6, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n", | ||
"cats = ['alt.atheism',\n", | ||
" 'comp.graphics',\n", | ||
" 'rec.autos'\n", | ||
" ]\n", | ||
"cats = ['rec.sport.baseball', 'sci.crypt']\n", | ||
"data = fetch_20newsgroups(subset='train', categories=cats, shuffle=True)" | ||
] | ||
}, | ||
|
@@ -183,7 +180,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"execution_count": 7, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
|
@@ -203,13 +200,13 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"execution_count": 8, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"obj = SklLdaModel(id2word=id2word, num_topics=5, iterations=20)\n", | ||
"obj = LdaTransformer(id2word=id2word, num_topics=5, iterations=20)\n", | ||
"lda = obj.fit(corpus)" | ||
] | ||
}, | ||
|
@@ -224,7 +221,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"execution_count": 9, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
|
@@ -242,7 +239,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
|
@@ -251,13 +248,13 @@ | |
"{'iterations': 20, 'num_topics': 2}" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"obj = SklLdaModel(id2word=id2word, num_topics=2, iterations=5, scorer='u_mass') # here 'scorer' can be 'perplexity' or 'u_mass'\n", | ||
"obj = LdaTransformer(id2word=id2word, num_topics=2, iterations=5, scorer='u_mass') # here 'scorer' can be 'perplexity' or 'u_mass'\n", | ||
"parameters = {'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)}\n", | ||
"\n", | ||
"# set `scoring` as `None` to use the inbuilt score function of `SklLdaModel` class\n", | ||
|
@@ -276,16 +273,16 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'iterations': 50, 'num_topics': 2}" | ||
"{'iterations': 20, 'num_topics': 2}" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
|
@@ -298,7 +295,7 @@ | |
" goodcm = CoherenceModel(model=estimator.gensim_model, texts=data_texts, dictionary=estimator.gensim_model.id2word, coherence='c_v')\n", | ||
" return goodcm.get_coherence()\n", | ||
"\n", | ||
"obj = SklLdaModel(id2word=id2word, num_topics=5, iterations=5)\n", | ||
"obj = LdaTransformer(id2word=id2word, num_topics=5, iterations=5)\n", | ||
"parameters = {'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)}\n", | ||
"\n", | ||
"# set `scoring` as your custom scoring function\n", | ||
|
@@ -317,7 +314,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"execution_count": 12, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
|
@@ -336,24 +333,36 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"execution_count": 13, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"id2word = Dictionary([_.split() for _ in data.data])\n", | ||
"corpus = [id2word.doc2bow(i.split()) for i in data.data]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[-0.70383516 -1.02111207 -0.70109648 -0.4570351 0.02175549 -0.20545727\n", | ||
" -0.01517654 0.0717219 0.51160112 0.49580676 0.33125423 0.31986992\n", | ||
" 0.48162159 0.26829541 0.11292571]\n", | ||
"Positive features: hanging:0.51 localized:0.50 course...:0.48 LAST:0.33 wax):0.32 Stoakley):0.27 Signature!:0.11 circuitry:0.07 technique),:0.02\n", | ||
"Negative features: considered.:-1.02 al-Qanawi,:-0.70 alt.autos.karting:-0.70 considered,:-0.46 360.0;:-0.21 talk.origins:-0.02\n", | ||
"0.437876960193\n" | ||
"[ 0.3032212 0.53114732 -0.3556002 0.05528797 -0.23462074 0.10164825\n", | ||
" -0.34895972 -0.07528751 -0.31437197 -0.24760965 -0.27430636 -0.05328458\n", | ||
" 0.1792989 -0.11535102 0.98473296]\n", | ||
"Positive features: >Pat:0.98 considered,:0.53 Fame.:0.30 internet...:0.18 comp.org.eff.talk.:0.10 Keach:0.06\n", | ||
"Negative features: Fame,:-0.36 01101001B:-0.35 circuitry:-0.31 hanging:-0.27 [email protected]:-0.25 comp.org.eff.talk,:-0.23 dome.:-0.12 *best*:-0.08 trawling:-0.05\n", | ||
"0.648489932886\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = SklLdaModel(num_topics=15, id2word=id2word, iterations=10, random_state=37)\n", | ||
"model = LdaTransformer(num_topics=15, id2word=id2word, iterations=10, random_state=37)\n", | ||
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n", | ||
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n", | ||
"pipe.fit(corpus, data.target)\n", | ||
|
@@ -377,13 +386,13 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"execution_count": 15, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration import SklLsiModel" | ||
"from gensim.sklearn_api import LsiTransformer" | ||
] | ||
}, | ||
{ | ||
|
@@ -395,24 +404,24 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[ 0.00929522 -0.68199243 0.00292464 -0.22145158 0.58612298 0.16492053\n", | ||
" 0.00354379 0.37629786 -0.17202791 -0.03847397 -0.05041424 -0.08953721\n", | ||
" 0.21241931 -0.04824542 0.2287388 ]\n", | ||
"Positive features: technique),:0.59 circuitry:0.38 Signature!:0.23 course...:0.21 360.0;:0.16 al-Qanawi,:0.01 talk.origins:0.00 alt.autos.karting:0.00\n", | ||
"Negative features: considered.:-0.68 considered,:-0.22 hanging:-0.17 wax):-0.09 LAST:-0.05 Stoakley):-0.05 localized:-0.04\n", | ||
"0.683353437877\n" | ||
"[ 0.13653967 -0.00378269 0.02652037 0.08496786 -0.02401959 -0.60089273\n", | ||
" -1.0708177 -0.03932274 -0.43813039 -0.54848409 -0.20147759 0.21781259\n", | ||
" 1.30378972 -0.08678691 -0.17529122]\n", | ||
"Positive features: internet...:1.30 trawling:0.22 Fame.:0.14 Keach:0.08 Fame,:0.03\n", | ||
"Negative features: 01101001B:-1.07 comp.org.eff.talk.:-0.60 [email protected]:-0.55 circuitry:-0.44 hanging:-0.20 >Pat:-0.18 dome.:-0.09 *best*:-0.04 comp.org.eff.talk,:-0.02 considered,:-0.00\n", | ||
"0.865771812081\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = SklLsiModel(num_topics=15, id2word=id2word)\n", | ||
"model = LsiTransformer(num_topics=15, id2word=id2word)\n", | ||
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n", | ||
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n", | ||
"pipe.fit(corpus, data.target)\n", | ||
|
@@ -436,13 +445,13 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"execution_count": 17, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration import SklRpModel" | ||
"from gensim.sklearn_api import RpTransformer" | ||
] | ||
}, | ||
{ | ||
|
@@ -454,22 +463,22 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[-0.00474168 0.01863391]\n", | ||
"Positive features: considered.:0.02\n", | ||
"Negative features: al-Qanawi,:-0.00\n", | ||
"0.434861278649\n" | ||
"[-0.03186506 -0.00872616]\n", | ||
"Positive features: \n", | ||
"Negative features: Fame.:-0.03 considered,:-0.01\n", | ||
"0.621644295302\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = SklRpModel(num_topics=2)\n", | ||
"model = RpTransformer(num_topics=2)\n", | ||
"np.random.mtrand.RandomState(1) # set seed for getting same result\n", | ||
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n", | ||
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n", | ||
|
@@ -494,13 +503,13 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"execution_count": 19, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration import SklLdaSeqModel" | ||
"from gensim.sklearn_api import LdaSeqTransformer" | ||
] | ||
}, | ||
{ | ||
|
@@ -512,7 +521,7 @@ | |
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
|
@@ -527,9 +536,9 @@ | |
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[ 0.04877123 -0.04877123]\n", | ||
"Positive features: In-Reply-To::0.05\n", | ||
"Negative features: nicknames:-0.05\n", | ||
"[-0.04877324 0.04877324]\n", | ||
"Positive features: NLCS:0.05\n", | ||
"Negative features: What:-0.05\n", | ||
"1.0\n" | ||
] | ||
} | ||
|
@@ -540,7 +549,7 @@ | |
"id2word = Dictionary(map(lambda x: x.split(), test_data))\n", | ||
"corpus = [id2word.doc2bow(i.split()) for i in test_data]\n", | ||
"\n", | ||
"model = SklLdaSeqModel(id2word=id2word, num_topics=2, time_slice=[1, 1, 1], initialize='gensim')\n", | ||
"model = LdaSeqTransformer(id2word=id2word, num_topics=2, time_slice=[1, 1, 1], initialize='gensim')\n", | ||
"clf = linear_model.LogisticRegression(penalty='l2', C=0.1) # l2 penalty used\n", | ||
"pipe = Pipeline((('features', model,), ('classifier', clf)))\n", | ||
"pipe.fit(corpus, test_target)\n", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Author: Chinmaya Pancholi <[email protected]> | ||
# Copyright (C) 2017 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
"""Scikit learn wrapper for gensim. | ||
Contains various gensim based implementations which match with scikit-learn standards. | ||
See [1] for complete set of conventions. | ||
[1] http://scikit-learn.org/stable/developers/ | ||
""" | ||
|
||
|
||
from .ldamodel import LdaTransformer # noqa: F401 | ||
from .lsimodel import LsiTransformer # noqa: F401 | ||
from .rpmodel import RpTransformer # noqa: F401 | ||
from .ldaseqmodel import LdaSeqTransformer # noqa: F401 | ||
from .w2vmodel import W2VTransformer # noqa: F401 | ||
from .atmodel import AuthorTopicTransformer # noqa: F401 |
Oops, something went wrong.