-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Lda training visualization in visdom #1399
Changes from 14 commits
bb65439
9d2e78d
33818ec
281222c
c507bbb
6f75ccc
d9db4e2
cd5f822
f4728e0
40cf092
d4f69f5
fde7d4d
3f18076
546908e
651a61a
13dfddc
1376d90
44c8e58
92949a3
5b22e4d
c369fc5
a32960d
48526d9
adf2a60
a272090
d3389bb
96949f7
7d0f0ec
dcc64a1
47434f9
30c9b64
e55af47
df5e01f
b334c50
c54e6bf
5f3d902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup Visdom\n", | ||
"\n", | ||
"Install it with:\n", | ||
"\n", | ||
"`pip install visdom`\n", | ||
"\n", | ||
"Start the server:\n", | ||
"\n", | ||
"`python -m visdom.server`\n", | ||
"\n", | ||
"Visdom now can be accessed at http://localhost:8097 in the browser.\n", | ||
"\n", | ||
"To monitor the training progress of LDA live in Visdom browser, you can set `viz=True` in the LDA function call.\n", | ||
"\n", | ||
"<span style=\"color:dark brown\">model = ldamodel.LdaModel(corpus=corpus, id2word=dictionary, passes=50, **viz=True**)</span>\n", | ||
"\n", | ||
"When the model is set for training, you can open http://localhost:8097 to see the training progress.\n", | ||
"\n", | ||
"\n", | ||
"## LDA Training Visualization\n", | ||
"\n", | ||
"There are four types of graphs which are plotted for LDA:\n", | ||
"\n", | ||
"1. **Coherence**\n", | ||
"\n", | ||
" Coherence is a measure used to evaluate topic models. A good model will generate coherent topics, i.e., topics with high topic coherence scores. Good topics are topics that can be described by a short label based on the topic terms they spit out. \n", | ||
"\n", | ||
" <img src=\"Coherence.gif\">\n", | ||
"\n", | ||
" Now, this graph along with the others explained below, can be used to decide if it's time to stop the training. We can see if the value stops changing after some epochs and that we are able to get the highest possible coherence of our model. \n", | ||
"\n", | ||
"\n", | ||
"2. **Perplexity**\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The numbers will not be rendered correctly, please remove them |
||
"\n", | ||
" Perplexity is a measurement of how well a probability distribution or probability model predicts a sample. In LDA, topics are described by a probability distribution over vocabulary words. So, perplexity can be used to compare probabilistic models like LDA.\n", | ||
" \n", | ||
" <img src=\"Perplexity.gif\">\n", | ||
"\n", | ||
" For a good model, perplexity should be as low as possible.\n", | ||
" \n", | ||
" \n", | ||
"3. **Topic Difference**\n", | ||
"\n", | ||
" Topic Diff calculates the distance between two LDA models. This distance is calculated based on the topics, by either using their probability distribution over vocabulary words (kullback_leibler, hellinger) or by simply using the common vocabulary words between the topics from both model.\n", | ||
"\n", | ||
" <img src=\"Diff.gif\">\n", | ||
" \n", | ||
" In the heatmap, X-axis define the Epoch no. and Y-axis define the distance between the identical topic from consecutive epochs. For ex. a particular cell in the heatmap with values (x=3, y=5, z=0.4) represent the distance(=0.4) between the topic 5 from 3rd epoch and topic 5 from 2nd epoch. With increasing epochs, the distance between the identical topics should decrease.\n", | ||
" \n", | ||
" \n", | ||
"4. **Convergence**\n", | ||
"\n", | ||
" Convergence is the sum of the difference between all the identical topics from two consecutive epochs. It is basically the sum of column values in the heatmap above.\n", | ||
"\n", | ||
" <img src=\"Convergence.gif\">\n", | ||
"\n", | ||
" The model is said to be converged when the convergence value stops descending with increasing epochs." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.4.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,11 +35,14 @@ | |
import numbers | ||
from random import sample | ||
import os | ||
import gensim | ||
import copy | ||
|
||
from gensim import interfaces, utils, matutils | ||
from gensim.matutils import dirichlet_expectation | ||
from gensim.models import basemodel | ||
from gensim.matutils import kullback_leibler, hellinger, jaccard_distance | ||
from visdom import Visdom | ||
|
||
from itertools import chain | ||
from scipy.special import gammaln, psi # gamma function utils | ||
|
@@ -192,10 +195,11 @@ class LdaModel(interfaces.TransformationABC, basemodel.BaseTopicModel): | |
""" | ||
def __init__(self, corpus=None, num_topics=100, id2word=None, | ||
distributed=False, chunksize=2000, passes=1, update_every=1, | ||
alpha='symmetric', eta=None, decay=0.5, offset=1.0, | ||
eval_every=10, iterations=50, gamma_threshold=0.001, | ||
minimum_probability=0.01, random_state=None, ns_conf={}, | ||
minimum_phi_value=0.01, per_word_topics=False): | ||
alpha='symmetric', eta=None, decay=0.5, offset=1.0, eval_every=10, | ||
iterations=50, gamma_threshold=0.001, minimum_probability=0.01, | ||
random_state=None, ns_conf={}, minimum_phi_value=0.01, | ||
per_word_topics=False, viz=False, env=None, distance="kulback_leibler", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move all parameters to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and don't forget about args validation (write a tests) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validation about every arg's object type? Or for the valid values of 'diff_distance' and 'coherence' (though, for which there is a test already in their respective functions of Diff and CoherenceModel)? |
||
coherence="u_mass", texts=None, window_size=None, topn=10): | ||
""" | ||
If given, start training from the iterable `corpus` straight away. If not given, | ||
the model is left untrained (presumably because you want to call `update()` manually). | ||
|
@@ -239,6 +243,31 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, | |
|
||
`random_state` can be a np.random.RandomState object or the seed for one | ||
|
||
`viz` set True for visualizing LDA training stats in Visdom | ||
|
||
`env` defines the environment to use in visdom browser | ||
|
||
`distance` measure to be used for Diff plot visualization | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit docstring in accordance with comment about |
||
|
||
`coherence` measure to be used for Coherence plot visualization | ||
|
||
`texts` : Tokenized texts. Needed if sliding_window_based coherence measures (c_v, c_uci, c_npmi) are chosen for visualization. eg:: | ||
texts = [['system', 'human', 'system', 'eps'], | ||
['user', 'response', 'time'], | ||
['trees'], | ||
['graph', 'trees'], | ||
['graph', 'minors', 'trees'], | ||
['graph', 'minors', 'survey']] | ||
|
||
`window_size` : Is the size of the window to be used for coherence measures using boolean sliding window as their | ||
probability estimator. For 'u_mass' this doesn't matter. | ||
If left 'None' the default window sizes are used which are: | ||
'c_v' : 110 | ||
'c_uci' : 10 | ||
'c_npmi' : 10 | ||
|
||
`topn` Integer corresponding to the number of top words to be extracted from each topic for coherence logging. | ||
|
||
Example: | ||
|
||
>>> lda = LdaModel(corpus, num_topics=100) # train model | ||
|
@@ -280,6 +309,14 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, | |
self.eval_every = eval_every | ||
self.minimum_phi_value = minimum_phi_value | ||
self.per_word_topics = per_word_topics | ||
self.viz = viz | ||
if self.viz: | ||
self.env = env | ||
self.distance = distance | ||
self.texts = texts | ||
self.coherence = coherence | ||
self.window_size = window_size | ||
self.topn = topn | ||
|
||
self.alpha, self.optimize_alpha = self.init_dir_prior(alpha, 'alpha') | ||
|
||
|
@@ -529,9 +566,9 @@ def log_perplexity(self, chunk, total_docs=None): | |
(perwordbound, np.exp2(-perwordbound), len(chunk), corpus_words)) | ||
return perwordbound | ||
|
||
def update(self, corpus, chunksize=None, decay=None, offset=None, | ||
passes=None, update_every=None, eval_every=None, iterations=None, | ||
gamma_threshold=None, chunks_as_numpy=False): | ||
def update(self, corpus, chunksize=None, decay=None, offset=None, passes=None, update_every=None, | ||
eval_every=None, iterations=None, gamma_threshold=None, chunks_as_numpy=False, | ||
viz=None, env=None, distance=None, coherence=None, texts=None, window_size=None, topn=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
""" | ||
Train the model with new documents, by EM-iterating over `corpus` until | ||
the topics converge (or until the maximum number of allowed iterations | ||
|
@@ -593,6 +630,22 @@ def update(self, corpus, chunksize=None, decay=None, offset=None, | |
|
||
self.state.numdocs += lencorpus | ||
|
||
if viz is None: | ||
viz = self.viz | ||
if viz: | ||
if env is None: | ||
env = self.env | ||
if distance is None: | ||
distance = self.distance | ||
if coherence is None: | ||
coherence = self.coherence | ||
if texts is None: | ||
texts = self.texts | ||
if window_size is None: | ||
window_size = self.window_size | ||
if topn is None: | ||
topn = self.topn | ||
|
||
if update_every: | ||
updatetype = "online" | ||
if passes == 1: | ||
|
@@ -626,6 +679,11 @@ def update(self, corpus, chunksize=None, decay=None, offset=None, | |
def rho(): | ||
return pow(offset + pass_ + (self.num_updates / chunksize), -decay) | ||
|
||
if self.viz: | ||
viz_window = Visdom() | ||
# save initial random state of model for Diff calculation with first epoch | ||
previous = copy.deepcopy(self) | ||
|
||
for pass_ in xrange(passes): | ||
if self.dispatcher: | ||
logger.info('initializing %s workers' % self.numworkers) | ||
|
@@ -674,9 +732,42 @@ def rho(): | |
other = LdaState(self.eta, self.state.sstats.shape) | ||
dirty = False | ||
# endfor single corpus iteration | ||
|
||
if reallen != lencorpus: | ||
raise RuntimeError("input corpus size changed during training (don't use generators as input)") | ||
|
||
if self.viz: | ||
# calculate coherence | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this block in separate function (with stat calc) |
||
cm = gensim.models.CoherenceModel(model=self, corpus=corpus, texts=texts, coherence=coherence, window_size=window_size, topn=topn) | ||
Coherence = np.array([cm.get_coherence()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the variable name should be in lowercase (here and anywhere) |
||
|
||
# calculate perplexity | ||
corpus_words = sum(cnt for document in corpus for _, cnt in document) | ||
perwordbound = self.bound(corpus) / corpus_words | ||
Perplexity = np.array([np.exp2(-perwordbound)]) | ||
|
||
# calculate diff | ||
diff_matrix = self.diff(previous, distance=distance)[0] | ||
diff_diagonal = np.diagonal(diff_matrix) | ||
previous = copy.deepcopy(self) | ||
Convergence = np.array([np.sum(diff_diagonal)]) | ||
|
||
if pass_ == 0: | ||
# initial plot windows | ||
Diff_mat = np.array([diff_diagonal]) | ||
viz_coherence = viz_window.line(Y=Coherence, X=np.array([pass_]), env=env, opts=dict(xlabel='Epochs', ylabel='Coherence', title='Coherence (%s)' % coherence)) | ||
viz_perplexity = viz_window.line(Y=Perplexity, X=np.array([pass_]), env=env, opts=dict(xlabel='Epochs', ylabel='Perplexity', title='Perplexity')) | ||
viz_convergence = viz_window.line(Y=Convergence, X=np.array([pass_]), env=env, opts=dict(xlabel='Epochs', ylabel='Convergence', title='Convergence (%s)' % distance)) | ||
viz_diff = viz_window.heatmap(X=np.array(Diff_mat).T, env=env, opts=dict(xlabel='Epochs', ylabel='Topic', title='Diff (%s)' % distance)) | ||
|
||
else: | ||
# update the plot with each epoch | ||
Diff_mat = np.concatenate((Diff_mat, np.array([diff_diagonal]))) | ||
viz_window.updateTrace(Y=Coherence, X=np.array([pass_]), env=env, win=viz_coherence) | ||
viz_window.updateTrace(Y=Perplexity, X=np.array([pass_]), env=env, win=viz_perplexity) | ||
viz_window.updateTrace(Y=Convergence, X=np.array([pass_]), env=env, win=viz_convergence) | ||
viz_window.heatmap(X=np.array(Diff_mat).T, env=env, win=viz_diff, opts=dict(xlabel='Epochs', ylabel='Topic', title='Diff (%s)' % distance)) | ||
|
||
if dirty: | ||
# finish any remaining updates | ||
if self.dispatcher: | ||
|
@@ -686,6 +777,7 @@ def rho(): | |
self.do_mstep(rho(), other, pass_ > 0) | ||
del other | ||
dirty = False | ||
|
||
# endfor entire corpus update | ||
|
||
def do_mstep(self, rho, other, extra_pass=False): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an example with logger="shell" in notebook (and show logging output in notebook)