diff --git a/docs/notebooks/Coherence.gif b/docs/notebooks/Coherence.gif
new file mode 100644
index 0000000000..eb1dd57f98
Binary files /dev/null and b/docs/notebooks/Coherence.gif differ
diff --git a/docs/notebooks/Convergence.gif b/docs/notebooks/Convergence.gif
new file mode 100644
index 0000000000..6b48328981
Binary files /dev/null and b/docs/notebooks/Convergence.gif differ
diff --git a/docs/notebooks/Diff.gif b/docs/notebooks/Diff.gif
new file mode 100644
index 0000000000..bd2d56db26
Binary files /dev/null and b/docs/notebooks/Diff.gif differ
diff --git a/docs/notebooks/Perplexity.gif b/docs/notebooks/Perplexity.gif
new file mode 100644
index 0000000000..b8a96a8d07
Binary files /dev/null and b/docs/notebooks/Perplexity.gif differ
diff --git a/docs/notebooks/Training_visualizations.ipynb b/docs/notebooks/Training_visualizations.ipynb
new file mode 100644
index 0000000000..dfd0d8bb81
--- /dev/null
+++ b/docs/notebooks/Training_visualizations.ipynb
@@ -0,0 +1,378 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setup Visdom\n",
+ "\n",
+ "Install it with:\n",
+ "\n",
+ "`pip install visdom`\n",
+ "\n",
+ "Start the server:\n",
+ "\n",
+ "`python -m visdom.server`\n",
+ "\n",
+ "Visdom now can be accessed at http://localhost:8097 in the browser.\n",
+ "\n",
+ "\n",
+ "# LDA Training Visualization\n",
+ "\n",
+ "Knowing about the progress and performance of a model, as we train them, could be very helpful in understanding it’s learning process and makes it easier to debug and optimize them. In this notebook, we will learn how to visualize training statistics for LDA topic model in gensim. To monitor the training, a list of Metrics is passed to the LDA function call for plotting their values live as the training progresses. \n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Let's plot the training stats for an LDA model being trained on kaggle's [fake news dataset](https://www.kaggle.com/mrisdal/fake-news). We will use the four evaluation metrics available for topic models in gensim: Coherence, Perplexity, Topic diff and Convergence. (using separate hold_out and test corpus for evaluating the perplexity)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using TensorFlow backend.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from gensim.models import ldamodel\n",
+ "from gensim.corpora import Dictionary\n",
+ "import pandas as pd\n",
+ "import re\n",
+ "from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "df_fake = pd.read_csv('fake.csv')\n",
+ "df_fake[['title', 'text', 'language']].head()\n",
+ "df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')]\n",
+ "\n",
+ "# remove stopwords and punctuations\n",
+ "def preprocess(row):\n",
+ " return strip_punctuation(remove_stopwords(row.lower()))\n",
+ " \n",
+ "df_fake['text'] = df_fake['text'].apply(preprocess)\n",
+ "\n",
+ "# Convert data to required input format by LDA\n",
+ "texts = []\n",
+ "for line in df_fake.text:\n",
+ " lowered = line.lower()\n",
+ " words = re.findall(r'\\w+', lowered, flags = re.UNICODE | re.LOCALE)\n",
+ " texts.append(words)\n",
+ "\n",
+ "dictionary = Dictionary(texts)\n",
+ "\n",
+ "training_texts = texts[:5000]\n",
+ "holdout_texts = texts[5000:7500]\n",
+ "test_texts = texts[7500:10000]\n",
+ "\n",
+ "training_corpus = [dictionary.doc2bow(text) for text in training_texts]\n",
+ "holdout_corpus = [dictionary.doc2bow(text) for text in holdout_texts]\n",
+ "test_corpus = [dictionary.doc2bow(text) for text in test_texts]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "collapsed": true,
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from gensim.models.callbacks import CoherenceMetric, DiffMetric, PerplexityMetric, ConvergenceMetric\n",
+ "\n",
+ "# define perplexity callback for hold_out and test corpus\n",
+ "pl_holdout = PerplexityMetric(corpus=holdout_corpus, logger=\"visdom\", title=\"Perplexity (hold_out)\")\n",
+ "pl_test = PerplexityMetric(corpus=test_corpus, logger=\"visdom\", title=\"Perplexity (test)\")\n",
+ "\n",
+ "# define other remaining metrics available\n",
+ "ch_umass = CoherenceMetric(corpus=training_corpus, coherence=\"u_mass\", logger=\"visdom\", title=\"Coherence (u_mass)\")\n",
+ "ch_cv = CoherenceMetric(corpus=training_corpus, texts=training_texts, coherence=\"c_v\", logger=\"visdom\", title=\"Coherence (c_v)\")\n",
+ "diff_kl = DiffMetric(distance=\"kullback_leibler\", logger=\"visdom\", title=\"Diff (kullback_leibler)\")\n",
+ "convergence_kl = ConvergenceMetric(distance=\"jaccard\", logger=\"visdom\", title=\"Convergence (jaccard)\")\n",
+ "\n",
+ "callbacks = [pl_holdout, pl_test, ch_umass, ch_cv, diff_kl, convergence_kl]\n",
+ "\n",
+ "# training LDA model\n",
+ "model = ldamodel.LdaModel(corpus=training_corpus, id2word=dictionary, num_topics=35, passes=50, chunksize=1500, iterations=200, alpha='auto', callbacks=callbacks)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When the model is set for training, you can open http://localhost:8097 to see the training progress."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-0.259766196856\n"
+ ]
+ }
+ ],
+ "source": [
+ "# to get a metric value on a trained model\n",
+ "print(CoherenceMetric(corpus=training_corpus, coherence=\"u_mass\").get_value(model=model))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The four types of graphs which are plotted for LDA:\n",
+ "\n",
+ "**Coherence**\n",
+ "\n",
+ "Coherence measures are generally based on the idea of computing the sum of pairwise scores of top *n* words w1, ...,wn used to describe the topic. There are four coherence measures available in gensim: `u_mass, c_v, c_uci, c_npmi`. A good model will generate coherent topics, i.e., topics with high topic coherence scores. Good topics can be described by a short label based on the topic terms they spit out. \n",
+ "\n",
+ "\n",
+ "\n",
+ "Now, this graph along with the others explained below, can be used to decide if it's time to stop the training. We can see if the value stops changing after some epochs and that we are able to get the highest possible coherence of our model. \n",
+ "\n",
+ "\n",
+ "**Perplexity**\n",
+ "\n",
+ "Perplexity is a measurement of how well a probability distribution or probability model predicts a sample. In LDA, topics are described by a probability distribution over vocabulary words. So, perplexity can be used to evaluate the topic-term distribution output by LDA.\n",
+ "\n",
+ "\n",
+ "\n",
+ "For a good model, perplexity should be low.\n",
+ "\n",
+ "\n",
+ "**Topic Difference**\n",
+ "\n",
+ "Topic Diff calculates the distance between two LDA models. This distance is calculated based on the topics, by either using their probability distribution over vocabulary words (kullback_leibler, hellinger) or by simply using the common vocabulary words between the topics from both model.\n",
+ "\n",
+ "\n",
+ "\n",
+ "In the heatmap, X-axis define the Epoch no. and Y-axis define the distance between identical topics from consecutive epochs. For ex. a particular cell in the heatmap with values (x=3, y=5, z=0.4) represent the distance(=0.4) between the topic 5 from 3rd epoch and topic 5 from 2nd epoch. With increasing epochs, the distance between the identical topics should decrease.\n",
+ " \n",
+ " \n",
+ "**Convergence**\n",
+ "\n",
+ "Convergence is the sum of the difference between all the identical topics from two consecutive epochs. It is basically the sum of column values in the heatmap above.\n",
+ "\n",
+ "\n",
+ "\n",
+ "The model is said to be converged when the convergence value stops descending with increasing epochs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Training Logs\n",
+ "\n",
+ "We can also log the metric values after every epoch to the shell apart from visualizing them in Visdom. The only change is to define `logger=\"shell\"` instead of `\"visdom\"` in the input callbacks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "collapsed": false,
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:gensim.models.ldamodel:using symmetric alpha at 0.02857142857142857\n",
+ "INFO:gensim.models.ldamodel:using symmetric eta at 1.0996744963490807e-05\n",
+ "INFO:gensim.models.ldamodel:using serial LDA version on this node\n",
+ "INFO:gensim.models.ldamodel:running online (multi-pass) LDA training, 35 topics, 2 passes over the supplied corpus of 5000 documents, updating model once every 2000 documents, evaluating perplexity every 0 documents, iterating 50x with a convergence threshold of 0.001000\n",
+ "WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 0, at document #2000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 2000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #5 (0.029): 0.019*\"s\" + 0.007*\"clinton\" + 0.006*\"hillary\" + 0.005*\"t\" + 0.005*\"it\" + 0.005*\"people\" + 0.004*\"trump\" + 0.003*\"2016\" + 0.003*\"state\" + 0.003*\"new\"\n",
+ "INFO:gensim.models.ldamodel:topic #24 (0.029): 0.024*\"s\" + 0.006*\"it\" + 0.005*\"trump\" + 0.005*\"t\" + 0.004*\"u\" + 0.004*\"clinton\" + 0.003*\"people\" + 0.003*\"world\" + 0.003*\"obama\" + 0.003*\"new\"\n",
+ "INFO:gensim.models.ldamodel:topic #4 (0.029): 0.012*\"s\" + 0.005*\"people\" + 0.004*\"time\" + 0.003*\"trump\" + 0.003*\"t\" + 0.003*\"email\" + 0.003*\"new\" + 0.003*\"it\" + 0.003*\"war\" + 0.002*\"state\"\n",
+ "INFO:gensim.models.ldamodel:topic #34 (0.029): 0.011*\"s\" + 0.007*\"trump\" + 0.005*\"people\" + 0.004*\"t\" + 0.004*\"clinton\" + 0.003*\"new\" + 0.003*\"government\" + 0.003*\"said\" + 0.003*\"it\" + 0.003*\"time\"\n",
+ "INFO:gensim.models.ldamodel:topic #25 (0.029): 0.020*\"s\" + 0.004*\"u\" + 0.004*\"trump\" + 0.004*\"people\" + 0.004*\"government\" + 0.004*\"it\" + 0.003*\"american\" + 0.003*\"t\" + 0.003*\"state\" + 0.003*\"saudi\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=24.655166, rho=1.000000\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 0, at document #4000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 2000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #25 (0.029): 0.016*\"s\" + 0.004*\"people\" + 0.003*\"u\" + 0.003*\"saudi\" + 0.003*\"government\" + 0.003*\"daily\" + 0.003*\"trump\" + 0.003*\"stormer\" + 0.003*\"it\" + 0.003*\"2016\"\n",
+ "INFO:gensim.models.ldamodel:topic #32 (0.029): 0.010*\"s\" + 0.005*\"world\" + 0.005*\"it\" + 0.004*\"horowitz\" + 0.004*\"ice\" + 0.003*\"people\" + 0.003*\"t\" + 0.003*\"like\" + 0.002*\"white\" + 0.002*\"fukushima\"\n",
+ "INFO:gensim.models.ldamodel:topic #6 (0.029): 0.018*\"s\" + 0.009*\"clinton\" + 0.009*\"trump\" + 0.009*\"t\" + 0.007*\"people\" + 0.007*\"hillary\" + 0.006*\"election\" + 0.005*\"like\" + 0.004*\"new\" + 0.004*\"it\"\n",
+ "INFO:gensim.models.ldamodel:topic #1 (0.029): 0.009*\"s\" + 0.005*\"comment\" + 0.004*\"facebook\" + 0.004*\"article\" + 0.004*\"account\" + 0.004*\"people\" + 0.003*\"war\" + 0.003*\"goat\" + 0.003*\"t\" + 0.003*\"disqus\"\n",
+ "INFO:gensim.models.ldamodel:topic #33 (0.029): 0.009*\"s\" + 0.008*\"people\" + 0.005*\"said\" + 0.004*\"t\" + 0.004*\"it\" + 0.004*\"trump\" + 0.003*\"government\" + 0.003*\"clinton\" + 0.003*\"fbi\" + 0.003*\"police\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=3.855081, rho=0.707107\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 0, at document #5000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 1000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #7 (0.029): 0.013*\"s\" + 0.007*\"it\" + 0.006*\"nuclear\" + 0.006*\"t\" + 0.006*\"activation\" + 0.005*\"placement\" + 0.004*\"clinton\" + 0.003*\"people\" + 0.003*\"hillary\" + 0.003*\"time\"\n",
+ "INFO:gensim.models.ldamodel:topic #23 (0.029): 0.008*\"s\" + 0.004*\"people\" + 0.004*\"water\" + 0.003*\"halloween\" + 0.003*\"t\" + 0.003*\"it\" + 0.003*\"silver\" + 0.003*\"jesus\" + 0.003*\"world\" + 0.003*\"flickr\"\n",
+ "INFO:gensim.models.ldamodel:topic #31 (0.029): 0.014*\"s\" + 0.006*\"t\" + 0.004*\"it\" + 0.004*\"people\" + 0.004*\"new\" + 0.003*\"girl\" + 0.003*\"child\" + 0.003*\"ellison\" + 0.003*\"know\" + 0.003*\"trump\"\n",
+ "INFO:gensim.models.ldamodel:topic #13 (0.029): 0.014*\"s\" + 0.010*\"t\" + 0.005*\"government\" + 0.004*\"school\" + 0.004*\"i\" + 0.004*\"president\" + 0.003*\"people\" + 0.003*\"trump\" + 0.003*\"like\" + 0.003*\"time\"\n",
+ "INFO:gensim.models.ldamodel:topic #25 (0.029): 0.111*\"utm\" + 0.015*\"force\" + 0.012*\"s\" + 0.007*\"tzrwu\" + 0.007*\"ims\" + 0.007*\"infowarsstore\" + 0.006*\"dr\" + 0.006*\"25\" + 0.006*\"and\" + 0.005*\"vitamin\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=2.676513, rho=0.577350\n",
+ "INFO:gensim.models.ldamodel:Epoch 0: Perplexity (hold_out) estimate: 2891.88630978\n",
+ "INFO:gensim.models.ldamodel:Epoch 0: Perplexity (test) estimate: 1874.8109115\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 1000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 2000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 3000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 4000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 5000 documents\n",
+ "INFO:gensim.models.ldamodel:Epoch 0: Coherence (u_mass) estimate: -2.03953260715\n",
+ "INFO:gensim.models.ldamodel:Epoch 0: Diff (kullback_leibler) estimate: [ 0.93640518 0.95952878 0.64519803 0.78364027 0.50865099 0.90607399\n",
+ " 0.87627853 0.77941292 0.52529268 0.5848095 0.78475031 0.86041719\n",
+ " 0.98346917 0.79422948 0.82464972 0.53735652 0.79245303 0.64592471\n",
+ " 0.89265179 0.78584557 0.66037967 0.613174 0.8009025 0.68748952\n",
+ " 0.77344932 0.96950451 0.80075146 0.83959967 0.83689681 0.70191726\n",
+ " 1. 0.69856365 0.60616776 0.81112339 0.54781754]\n",
+ "INFO:gensim.models.ldamodel:Epoch 0: Convergence (jaccard) estimate: 34.9295467235\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 1, at document #2000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 2000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #13 (0.029): 0.013*\"s\" + 0.009*\"t\" + 0.006*\"black\" + 0.005*\"school\" + 0.004*\"government\" + 0.004*\"i\" + 0.003*\"people\" + 0.003*\"president\" + 0.003*\"police\" + 0.003*\"like\"\n",
+ "INFO:gensim.models.ldamodel:topic #12 (0.029): 0.030*\"infowars\" + 0.023*\"brain\" + 0.013*\"force\" + 0.010*\"s\" + 0.009*\"com\" + 0.007*\"life\" + 0.006*\"www\" + 0.005*\"http\" + 0.005*\"content\" + 0.005*\"source\"\n",
+ "INFO:gensim.models.ldamodel:topic #17 (0.029): 0.012*\"obamacare\" + 0.010*\"people\" + 0.009*\"s\" + 0.007*\"care\" + 0.006*\"health\" + 0.006*\"insurance\" + 0.005*\"t\" + 0.005*\"premiums\" + 0.005*\"brock\" + 0.004*\"it\"\n",
+ "INFO:gensim.models.ldamodel:topic #10 (0.029): 0.025*\"s\" + 0.009*\"israel\" + 0.005*\"said\" + 0.005*\"jewish\" + 0.004*\"american\" + 0.004*\"muslim\" + 0.004*\"t\" + 0.003*\"it\" + 0.003*\"anti\" + 0.003*\"israeli\"\n",
+ "INFO:gensim.models.ldamodel:topic #33 (0.029): 0.009*\"s\" + 0.008*\"vaccine\" + 0.007*\"people\" + 0.006*\"flu\" + 0.005*\"vaccines\" + 0.005*\"virus\" + 0.005*\"marijuana\" + 0.005*\"court\" + 0.005*\"medical\" + 0.004*\"zika\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=1.496938, rho=0.471405\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 1, at document #4000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 2000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #17 (0.029): 0.017*\"obamacare\" + 0.009*\"people\" + 0.008*\"s\" + 0.008*\"insurance\" + 0.007*\"care\" + 0.007*\"premiums\" + 0.006*\"health\" + 0.005*\"t\" + 0.005*\"godlike\" + 0.004*\"brock\"\n",
+ "INFO:gensim.models.ldamodel:topic #28 (0.029): 0.018*\"retired\" + 0.016*\"syria\" + 0.016*\"s\" + 0.013*\"general\" + 0.011*\"syrian\" + 0.010*\"russian\" + 0.010*\"air\" + 0.009*\"russia\" + 0.009*\"force\" + 0.009*\"army\"\n",
+ "INFO:gensim.models.ldamodel:topic #24 (0.029): 0.021*\"s\" + 0.008*\"t\" + 0.006*\"it\" + 0.004*\"people\" + 0.004*\"world\" + 0.004*\"u\" + 0.003*\"that\" + 0.003*\"don\" + 0.003*\"like\" + 0.003*\"i\"\n",
+ "INFO:gensim.models.ldamodel:topic #4 (0.029): 0.019*\"email\" + 0.015*\"posts\" + 0.011*\"new\" + 0.010*\"subscribe\" + 0.009*\"notify\" + 0.009*\"follow\" + 0.007*\"donate\" + 0.006*\"radioactive\" + 0.005*\"up\" + 0.005*\"receive\"\n",
+ "INFO:gensim.models.ldamodel:topic #19 (0.029): 0.013*\"s\" + 0.007*\"said\" + 0.006*\"water\" + 0.005*\"pipeline\" + 0.005*\"police\" + 0.004*\"dakota\" + 0.004*\"year\" + 0.004*\"t\" + 0.003*\"2016\" + 0.003*\"north\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=1.335664, rho=0.471405\n",
+ "INFO:gensim.models.ldamodel:PROGRESS: pass 1, at document #5000/5000\n",
+ "INFO:gensim.models.ldamodel:merging changes from 1000 documents into a model of 5000 documents\n",
+ "INFO:gensim.models.ldamodel:topic #31 (0.029): 0.011*\"s\" + 0.008*\"child\" + 0.008*\"girl\" + 0.007*\"parents\" + 0.006*\"family\" + 0.005*\"old\" + 0.005*\"t\" + 0.005*\"baby\" + 0.005*\"ellison\" + 0.005*\"hospital\"\n",
+ "INFO:gensim.models.ldamodel:topic #16 (0.029): 0.036*\"trump\" + 0.013*\"white\" + 0.011*\"obama\" + 0.009*\"people\" + 0.008*\"s\" + 0.008*\"president\" + 0.007*\"vote\" + 0.006*\"america\" + 0.006*\"black\" + 0.006*\"right\"\n",
+ "INFO:gensim.models.ldamodel:topic #12 (0.029): 0.050*\"infowars\" + 0.042*\"brain\" + 0.026*\"force\" + 0.023*\"com\" + 0.015*\"life\" + 0.013*\"wellness\" + 0.011*\"content\" + 0.011*\"www\" + 0.010*\"source\" + 0.010*\"medium\"\n",
+ "INFO:gensim.models.ldamodel:topic #18 (0.029): 0.015*\"vaccine\" + 0.013*\"s\" + 0.011*\"medical\" + 0.010*\"doctors\" + 0.009*\"dr\" + 0.008*\"vaccines\" + 0.008*\"cdc\" + 0.007*\"children\" + 0.006*\"cancer\" + 0.006*\"autism\"\n",
+ "INFO:gensim.models.ldamodel:topic #13 (0.029): 0.013*\"s\" + 0.010*\"t\" + 0.007*\"government\" + 0.007*\"school\" + 0.004*\"police\" + 0.004*\"black\" + 0.003*\"law\" + 0.003*\"data\" + 0.003*\"at\" + 0.003*\"youtube\"\n",
+ "INFO:gensim.models.ldamodel:topic diff=1.217662, rho=0.471405\n",
+ "INFO:gensim.models.ldamodel:Epoch 1: Perplexity (hold_out) estimate: 1875.59367127\n",
+ "INFO:gensim.models.ldamodel:Epoch 1: Perplexity (test) estimate: 1369.87404449\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 1000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 2000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 3000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 4000 documents\n",
+ "INFO:gensim.topic_coherence.text_analysis:CorpusAccumulator accumulated stats from 5000 documents\n",
+ "INFO:gensim.models.ldamodel:Epoch 1: Coherence (u_mass) estimate: -3.33917142783\n",
+ "INFO:gensim.models.ldamodel:Epoch 1: Diff (kullback_leibler) estimate: [ 0.09301981 0.71566666 0.67676685 0.35730119 1. 0.54864275\n",
+ " 0.24120049 0.45830206 0.96709113 0.78362169 0.33318101 0.38479912\n",
+ " 0.62038738 0.43474664 0.31502652 0.87644947 0.5223633 0.76557247\n",
+ " 0.41316022 0.40029654 0.84127697 0.76384605 0.13714947 0.64728141\n",
+ " 0.3992811 0.70301974 0.17202488 0.33375843 0.55556689 0.68845743\n",
+ " 0.23326017 0.6709471 0.75061125 0.48405455 0.81424992]\n",
+ "INFO:gensim.models.ldamodel:Epoch 1: Convergence (jaccard) estimate: 22.293262955\n"
+ ]
+ }
+ ],
+ "source": [
+ "import logging\n",
+ "from gensim.models.callbacks import CoherenceMetric, DiffMetric, PerplexityMetric, ConvergenceMetric\n",
+ "\n",
+ "logging.basicConfig(level=logging.INFO)\n",
+ "logger = logging.getLogger(__name__)\n",
+ "logger.setLevel(logging.DEBUG)\n",
+ "\n",
+ "# define perplexity callback for hold_out and test corpus\n",
+ "pl_holdout = PerplexityMetric(corpus=holdout_corpus, logger=\"shell\", title=\"Perplexity (hold_out)\")\n",
+ "pl_test = PerplexityMetric(corpus=test_corpus, logger=\"shell\", title=\"Perplexity (test)\")\n",
+ "\n",
+ "# define other remaining metrics available\n",
+ "ch_umass = CoherenceMetric(corpus=training_corpus, coherence=\"u_mass\", logger=\"shell\", title=\"Coherence (u_mass)\")\n",
+ "diff_kl = DiffMetric(distance=\"kullback_leibler\", logger=\"shell\", title=\"Diff (kullback_leibler)\")\n",
+ "convergence_jc = ConvergenceMetric(distance=\"jaccard\", logger=\"shell\", title=\"Convergence (jaccard)\")\n",
+ "\n",
+ "callbacks = [pl_holdout, pl_test, ch_umass, diff_kl, convergence_jc]\n",
+ "\n",
+ "# training LDA model\n",
+ "model = ldamodel.LdaModel(corpus=training_corpus, id2word=dictionary, num_topics=35, passes=2, eval_every=None, callbacks=callbacks)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The metric values can also be accessed from the model instance for custom uses."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "defaultdict(list,\n",
+ " {'Coherence (u_mass)': [-2.0395326071549267, -3.339171427828973],\n",
+ " 'Convergence (jaccard)': [34.929546723516573, 22.293262954969034],\n",
+ " 'Diff (kullback_leibler)': [array([ 0.93640518, 0.95952878, 0.64519803, 0.78364027, 0.50865099,\n",
+ " 0.90607399, 0.87627853, 0.77941292, 0.52529268, 0.5848095 ,\n",
+ " 0.78475031, 0.86041719, 0.98346917, 0.79422948, 0.82464972,\n",
+ " 0.53735652, 0.79245303, 0.64592471, 0.89265179, 0.78584557,\n",
+ " 0.66037967, 0.613174 , 0.8009025 , 0.68748952, 0.77344932,\n",
+ " 0.96950451, 0.80075146, 0.83959967, 0.83689681, 0.70191726,\n",
+ " 1. , 0.69856365, 0.60616776, 0.81112339, 0.54781754]),\n",
+ " array([ 0.09301981, 0.71566666, 0.67676685, 0.35730119, 1. ,\n",
+ " 0.54864275, 0.24120049, 0.45830206, 0.96709113, 0.78362169,\n",
+ " 0.33318101, 0.38479912, 0.62038738, 0.43474664, 0.31502652,\n",
+ " 0.87644947, 0.5223633 , 0.76557247, 0.41316022, 0.40029654,\n",
+ " 0.84127697, 0.76384605, 0.13714947, 0.64728141, 0.3992811 ,\n",
+ " 0.70301974, 0.17202488, 0.33375843, 0.55556689, 0.68845743,\n",
+ " 0.23326017, 0.6709471 , 0.75061125, 0.48405455, 0.81424992])],\n",
+ " 'Perplexity (hold_out)': [2891.8863097795142, 1875.5936712709661],\n",
+ " 'Perplexity (test)': [1874.8109114962135, 1369.8740444934508]})"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.metrics"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/visdom_graph.png b/docs/notebooks/visdom_graph.png
new file mode 100644
index 0000000000..d4a38c720e
Binary files /dev/null and b/docs/notebooks/visdom_graph.png differ
diff --git a/gensim/models/callbacks.py b/gensim/models/callbacks.py
new file mode 100644
index 0000000000..d4598cca1d
--- /dev/null
+++ b/gensim/models/callbacks.py
@@ -0,0 +1,316 @@
+import gensim
+import logging
+import copy
+import sys
+import numpy as np
+
+if sys.version_info[0] >= 3:
+ from queue import Queue
+else:
+ from Queue import Queue
+
+# Visdom is used for training stats visualization
+try:
+ from visdom import Visdom
+ VISDOM_INSTALLED = True
+except ImportError:
+ VISDOM_INSTALLED = False
+
+
+class Metric(object):
+ """
+ Base Metric class for topic model evaluation metrics
+ """
+ def __str__(self):
+ """
+ Return a string representation of Metric class
+ """
+ if self.title is not None:
+ return self.title
+ else:
+ return type(self).__name__[:-6]
+
+ def set_parameters(self, **parameters):
+ """
+ Set the parameters
+ """
+ for parameter, value in parameters.items():
+ setattr(self, parameter, value)
+
+ def get_value(self):
+ pass
+
+
+class CoherenceMetric(Metric):
+ """
+ Metric class for coherence evaluation
+ """
+ def __init__(self, corpus=None, texts=None, dictionary=None, coherence=None, window_size=None, topn=10, logger=None, viz_env=None, title=None):
+ """
+ Args:
+ corpus : Gensim document corpus.
+ texts : Tokenized texts. Needed for coherence models that use sliding window based probability estimator,
+ eg::
+ texts = [['system', 'human', 'system', 'eps'],
+ ['user', 'response', 'time'],
+ ['trees'],
+ ['graph', 'trees'],
+ ['graph', 'minors', 'trees'],
+ ['graph', 'minors', 'survey']]
+
+ dictionary : Gensim dictionary mapping of id word to create corpus. If model.id2word is present,
+ this is not needed. If both are provided, dictionary will be used.
+ window_size : Is the size of the window to be used for coherence measures using boolean sliding window as their
+ probability estimator. For 'u_mass' this doesn't matter.
+ If left 'None' the default window sizes are used which are:
+
+ 'c_v' : 110
+ 'c_uci' : 10
+ 'c_npmi' : 10
+
+ coherence : Coherence measure to be used. Supported values are:
+ 'u_mass'
+ 'c_v'
+ 'c_uci' also popularly known as c_pmi
+ 'c_npmi'
+ For 'u_mass' corpus should be provided. If texts is provided, it will be converted
+ to corpus using the dictionary. For 'c_v', 'c_uci' and 'c_npmi' texts should be provided.
+ Corpus is not needed.
+ topn : Integer corresponding to the number of top words to be extracted from each topic.
+ logger : Monitor training process using:
+ "shell" : print coherence value in shell
+ "visdom" : visualize coherence value with increasing epochs in Visdom visualization framework
+ viz_env : Visdom environment to use for plotting the graph
+ title : title of the graph plot
+ """
+ self.corpus = corpus
+ self.dictionary = dictionary
+ self.coherence = coherence
+ self.texts = texts
+ self.window_size = window_size
+ self.topn = topn
+ self.logger = logger
+ self.viz_env = viz_env
+ self.title = title
+
+ def get_value(self, **kwargs):
+ """
+ Args:
+ model : Pre-trained topic model. Should be provided if topics is not provided.
+ Currently supports LdaModel, LdaMallet wrapper and LdaVowpalWabbit wrapper. Use 'topics'
+ parameter to plug in an as yet unsupported model.
+ topics : List of tokenized topics.
+ eg::
+ topics = [['human', 'machine', 'computer', 'interface'],
+ ['graph', 'trees', 'binary', 'widths']]
+ """
+ # only one of the model or topic would be defined
+ self.model = None
+ self.topics = None
+ super(CoherenceMetric, self).set_parameters(**kwargs)
+ cm = gensim.models.CoherenceModel(self.model, self.topics, self.texts, self.corpus, self.dictionary, self.window_size, self.coherence, self.topn)
+ return cm.get_coherence()
+
+
+class PerplexityMetric(Metric):
+ """
+ Metric class for perplexity evaluation
+ """
+ def __init__(self, corpus=None, logger=None, viz_env=None, title=None):
+ """
+ Args:
+ corpus : Gensim document corpus
+ logger : Monitor training process using:
+ "shell" : print coherence value in shell
+ "visdom" : visualize coherence value with increasing epochs in Visdom visualization framework
+ viz_env : Visdom environment to use for plotting the graph
+ title : title of the graph plot
+ """
+ self.corpus = corpus
+ self.logger = logger
+ self.viz_env = viz_env
+ self.title = title
+
+ def get_value(self, **kwargs):
+ """
+ Args:
+ model : Trained topic model
+ """
+ super(PerplexityMetric, self).set_parameters(**kwargs)
+ corpus_words = sum(cnt for document in self.corpus for _, cnt in document)
+ perwordbound = self.model.bound(self.corpus) / corpus_words
+ return np.exp2(-perwordbound)
+
+
+class DiffMetric(Metric):
+ """
+ Metric class for topic difference evaluation
+ """
+ def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, diagonal=True, annotation=False, normed=True, logger=None, viz_env=None, title=None):
+ """
+ Args:
+ distance : measure used to calculate difference between any topic pair. Available values:
+ `kullback_leibler`
+ `hellinger`
+ `jaccard`
+ num_words : is quantity of most relevant words that used if distance == `jaccard` (also used for annotation)
+ n_ann_terms : max quantity of words in intersection/symmetric difference between topics (used for annotation)
+ diagonal : difference between identical topic no.s
+ annotation : intersection or difference of words between topics
+ normed (bool) : If `true`, matrix/array Z will be normalized
+ logger : Monitor training process using:
+ "shell" : print coherence value in shell
+ "visdom" : visualize coherence value with increasing epochs in Visdom visualization framework
+ viz_env : Visdom environment to use for plotting the graph
+ title : title of the graph plot
+ """
+ self.distance = distance
+ self.num_words = num_words
+ self.n_ann_terms = n_ann_terms
+ self.diagonal = diagonal
+ self.annotation = annotation
+ self.normed = normed
+ self.logger = logger
+ self.viz_env = viz_env
+ self.title = title
+
+ def get_value(self, **kwargs):
+ """
+ Args:
+ model : Trained topic model
+ other_model : second topic model instance to calculate the difference from
+ """
+ super(DiffMetric, self).set_parameters(**kwargs)
+ diff_diagonal, _ = self.model.diff(self.other_model, self.distance, self.num_words, self.n_ann_terms, self.diagonal, self.annotation, self.normed)
+ return diff_diagonal
+
+
+class ConvergenceMetric(Metric):
+ """
+ Metric class for convergence evaluation
+ """
+ def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, diagonal=True, annotation=False, normed=True, logger=None, viz_env=None, title=None):
+ """
+ Args:
+ distance : measure used to calculate difference between any topic pair. Available values:
+ `kullback_leibler`
+ `hellinger`
+ `jaccard`
+ num_words : is quantity of most relevant words that used if distance == `jaccard` (also used for annotation)
+ n_ann_terms : max quantity of words in intersection/symmetric difference between topics (used for annotation)
+ diagonal : difference between identical topic no.s
+ annotation : intersection or difference of words between topics
+ normed (bool) : If `true`, matrix/array Z will be normalized
+ logger : Monitor training process using:
+ "shell" : print coherence value in shell
+ "visdom" : visualize coherence value with increasing epochs in Visdom visualization framework
+ viz_env : Visdom environment to use for plotting the graph
+ title : title of the graph plot
+ """
+ self.distance = distance
+ self.num_words = num_words
+ self.n_ann_terms = n_ann_terms
+ self.diagonal = diagonal
+ self.annotation = annotation
+ self.normed = normed
+ self.logger = logger
+ self.viz_env = viz_env
+ self.title = title
+
+ def get_value(self, **kwargs):
+ """
+ Args:
+ model : Trained topic model
+ other_model : second topic model instance to calculate the difference from
+ """
+ super(ConvergenceMetric, self).set_parameters(**kwargs)
+ diff_diagonal, _ = self.model.diff(self.other_model, self.distance, self.num_words, self.n_ann_terms, self.diagonal, self.annotation, self.normed)
+ return np.sum(diff_diagonal)
+
+
+class Callback(object):
+ """
+ Used to log/visualize the evaluation metrics during training. The values are stored at the end of each epoch.
+ """
+ def __init__(self, metrics):
+ """
+ Args:
+ metrics : a list of callbacks. Possible values:
+ "CoherenceMetric"
+ "PerplexityMetric"
+ "DiffMetric"
+ "ConvergenceMetric"
+ """
+ # list of metrics to be plot
+ self.metrics = metrics
+
+ def set_model(self, model):
+ """
+ Save the model instance and initialize any required variables which would be updated throughout training
+ """
+ self.model = model
+ self.previous = None
+ # check for any metric which need model state from previous epoch
+ if any(isinstance(metric, (DiffMetric, ConvergenceMetric)) for metric in self.metrics):
+ self.previous = copy.deepcopy(model)
+ # store diff diagonals of previous epochs
+ self.diff_mat = Queue()
+ if any(metric.logger == "visdom" for metric in self.metrics):
+ if not VISDOM_INSTALLED:
+ raise ImportError("Please install Visdom for visualization")
+ self.viz = Visdom()
+ # store initial plot windows of every metric (same window will be updated with increasing epochs)
+ self.windows = []
+ if any(metric.logger == "shell" for metric in self.metrics):
+ # set logger for current topic model
+ self.log_type = logging.getLogger('gensim.models.ldamodel')
+
+ def on_epoch_end(self, epoch, topics=None):
+ """
+ Log or visualize current epoch's metric value
+
+ Args:
+ epoch : current epoch no.
+ topics : topic distribution from current epoch (required for coherence of unsupported topic models)
+ """
+ # stores current epoch's metric values
+ current_metrics = {}
+
+ # plot all metrics in current epoch
+ for i, metric in enumerate(self.metrics):
+ label = str(metric)
+ value = metric.get_value(topics=topics, model=self.model, other_model=self.previous)
+
+ current_metrics[label] = value
+
+ if metric.logger == "visdom":
+ if epoch == 0:
+ if value.ndim > 0:
+ diff_mat = np.array([value])
+ viz_metric = self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=label, title=label))
+ # store current epoch's diff diagonal
+ self.diff_mat.put(diff_mat)
+ # saving initial plot window
+ self.windows.append(copy.deepcopy(viz_metric))
+ else:
+ viz_metric = self.viz.line(Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=label, title=label))
+ # saving initial plot window
+ self.windows.append(copy.deepcopy(viz_metric))
+ else:
+ if value.ndim > 0:
+ # concatenate with previous epoch's diff diagonals
+ diff_mat = np.concatenate((self.diff_mat.get(), np.array([value])))
+ self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, win=self.windows[i], opts=dict(xlabel='Epochs', ylabel=label, title=label))
+ self.diff_mat.put(diff_mat)
+ else:
+ self.viz.updateTrace(Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, win=self.windows[i])
+
+ if metric.logger == "shell":
+ statement = "".join(("Epoch ", str(epoch), ": ", label, " estimate: ", str(value)))
+ self.log_type.info(statement)
+
+ # check for any metric which need model state from previous epoch
+ if isinstance(metric, (DiffMetric, ConvergenceMetric)):
+ self.previous = copy.deepcopy(self.model)
+
+ return current_metrics
diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py
index 55ce334460..f8c3dfc63f 100755
--- a/gensim/models/ldamodel.py
+++ b/gensim/models/ldamodel.py
@@ -40,11 +40,13 @@
from scipy.special import gammaln, psi # gamma function utils
from scipy.special import polygamma
from six.moves import xrange
+from collections import defaultdict
from gensim import interfaces, utils, matutils
from gensim.matutils import dirichlet_expectation
from gensim.matutils import kullback_leibler, hellinger, jaccard_distance, jensen_shannon
from gensim.models import basemodel, CoherenceModel
+from gensim.models.callbacks import Callback
# log(sum(exp(x))) that tries to avoid overflow
try:
@@ -191,10 +193,10 @@ class LdaModel(interfaces.TransformationABC, basemodel.BaseTopicModel):
"""
def __init__(self, corpus=None, num_topics=100, id2word=None,
distributed=False, chunksize=2000, passes=1, update_every=1,
- alpha='symmetric', eta=None, decay=0.5, offset=1.0,
- eval_every=10, iterations=50, gamma_threshold=0.001,
- minimum_probability=0.01, random_state=None, ns_conf={},
- minimum_phi_value=0.01, per_word_topics=False):
+ alpha='symmetric', eta=None, decay=0.5, offset=1.0, eval_every=10,
+ iterations=50, gamma_threshold=0.001, minimum_probability=0.01,
+ random_state=None, ns_conf={}, minimum_phi_value=0.01,
+ per_word_topics=False, callbacks=None):
"""
If given, start training from the iterable `corpus` straight away. If not given,
the model is left untrained (presumably because you want to call `update()` manually).
@@ -238,6 +240,8 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
`random_state` can be a np.random.RandomState object or the seed for one
+ `callbacks` a list of metric callbacks to log/visualize evaluation metrics of topic model during training
+
Example:
>>> lda = LdaModel(corpus, num_topics=100) # train model
@@ -279,6 +283,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
self.eval_every = eval_every
self.minimum_phi_value = minimum_phi_value
self.per_word_topics = per_word_topics
+ self.callbacks = callbacks
self.alpha, self.optimize_alpha = self.init_dir_prior(alpha, 'alpha')
@@ -625,6 +630,13 @@ def update(self, corpus, chunksize=None, decay=None, offset=None,
def rho():
return pow(offset + pass_ + (self.num_updates / chunksize), -decay)
+ if self.callbacks:
+ # pass the list of input callbacks to Callback class
+ callback = Callback(self.callbacks)
+ callback.set_model(self)
+ # initialize metrics list to store metric values after every epoch
+ self.metrics = defaultdict(list)
+
for pass_ in xrange(passes):
if self.dispatcher:
logger.info('initializing %s workers' % self.numworkers)
@@ -673,9 +685,16 @@ def rho():
other = LdaState(self.eta, self.state.sstats.shape)
dirty = False
# endfor single corpus iteration
+
if reallen != lencorpus:
raise RuntimeError("input corpus size changed during training (don't use generators as input)")
+ # append current epoch's metric values
+ if self.callbacks:
+ current_metrics = callback.on_epoch_end(pass_)
+ for metric, value in current_metrics.items():
+ self.metrics[metric].append(value)
+
if dirty:
# finish any remaining updates
if self.dispatcher:
@@ -685,6 +704,7 @@ def rho():
self.do_mstep(rho(), other, pass_ > 0)
del other
dirty = False
+
# endfor entire corpus update
def do_mstep(self, rho, other, extra_pass=False):