From b4bd541b767a7120630aeb9c80a2d028cb88ced0 Mon Sep 17 00:00:00 2001 From: Mack Date: Fri, 18 Aug 2017 03:37:30 -0400 Subject: [PATCH] Use `CoherenceModel` for `LdaModel.top_topics`. Fix #1128 (#1427) * Add a `get_topics` method to all topic models, add test coverage for this, and update the `CoherenceModel` to use this for getting topics from models. * Require topics returned from `get_topics` to be probability distributions for the probabilistic topic models. * Replace code in `LdaModel.top_topics` with use of `CoherenceModel`. * Fix docstrings to use Google style throughout PR changes and various LdaModel methods. --- gensim/models/coherencemodel.py | 28 +-- gensim/models/hdpmodel.py | 14 +- gensim/models/ldamodel.py | 258 +++++++++++----------- gensim/models/lsimodel.py | 26 ++- gensim/models/wrappers/ldamallet.py | 25 ++- gensim/models/wrappers/ldavowpalwabbit.py | 17 +- gensim/test/basetests.py | 12 +- gensim/test/test_lsimodel.py | 16 +- 8 files changed, 225 insertions(+), 171 deletions(-) diff --git a/gensim/models/coherencemodel.py b/gensim/models/coherencemodel.py index 8556db1c45..95a5117eee 100644 --- a/gensim/models/coherencemodel.py +++ b/gensim/models/coherencemodel.py @@ -26,8 +26,6 @@ from gensim import interfaces from gensim.matutils import argsort -from gensim.models.ldamodel import LdaModel -from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet from gensim.topic_coherence import (segmentation, probability_estimation, direct_confirmation_measure, indirect_confirmation_measure, aggregation) @@ -268,23 +266,15 @@ def _topics_differ(self, new_topics): def _get_topics(self): """Internal helper function to return topics from a trained topic model.""" - topics = [] - if isinstance(self.model, LdaModel): - for topic in self.model.state.get_lambda(): - bestn = argsort(topic, topn=self.topn, reverse=True) - topics.append(bestn) - elif isinstance(self.model, LdaVowpalWabbit): - for topic in self.model._get_topics(): - bestn = argsort(topic, topn=self.topn, reverse=True) - topics.append(bestn) - elif isinstance(self.model, LdaMallet): - for topic in self.model.word_topics: - bestn = argsort(topic, topn=self.topn, reverse=True) - topics.append(bestn) - else: - raise ValueError("This topic model is not currently supported. Supported topic models " - " are LdaModel, LdaVowpalWabbit and LdaMallet.") - return topics + try: + return [ + argsort(topic, topn=self.topn, reverse=True) for topic in + self.model.get_topics() + ] + except AttributeError: + raise ValueError( + "This topic model is not currently supported. Supported topic models" + " should implement the `get_topics` method.") def segment_topics(self): return self.measure.seg(self.topics) diff --git a/gensim/models/hdpmodel.py b/gensim/models/hdpmodel.py index 6937d928d4..46995549f8 100755 --- a/gensim/models/hdpmodel.py +++ b/gensim/models/hdpmodel.py @@ -36,13 +36,14 @@ import logging import time import warnings + import numpy as np from scipy.special import gammaln, psi # gamma function utils +from six.moves import xrange from gensim import interfaces, utils, matutils from gensim.matutils import dirichlet_expectation from gensim.models import basemodel, ldamodel -from six.moves import xrange logger = logging.getLogger(__name__) @@ -456,6 +457,15 @@ def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=No hdp_formatter = HdpTopicFormatter(self.id2word, betas) return hdp_formatter.show_topic(topic_id, topn, log, formatted) + def get_topics(self): + """ + Returns: + np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents + the term topic matrix learned during inference. + """ + topics = self.m_lambda + self.m_eta + return topics / topics.sum(axis=1)[:, None] + def show_topics(self, num_topics=20, num_words=20, log=False, formatted=True): """ Print the `num_words` most probable words for `num_topics` number of topics. @@ -642,7 +652,7 @@ def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words= N logger.info(topic) else: topic = (topic_id, topic_terms) - + # we only return the topic_terms return topic[1] diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 7b137284cd..200a6cc55e 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -31,21 +31,20 @@ import logging -import numpy as np import numbers -from random import sample import os +from random import sample -from gensim import interfaces, utils, matutils -from gensim.matutils import dirichlet_expectation -from gensim.models import basemodel -from gensim.matutils import kullback_leibler, hellinger, jaccard_distance - -from itertools import chain +import numpy as np +import six from scipy.special import gammaln, psi # gamma function utils from scipy.special import polygamma from six.moves import xrange -import six + +from gensim import interfaces, utils, matutils +from gensim.matutils import dirichlet_expectation +from gensim.matutils import kullback_leibler, hellinger, jaccard_distance +from gensim.models import basemodel, CoherenceModel # log(sum(exp(x))) that tries to avoid overflow try: @@ -718,10 +717,18 @@ def bound(self, corpus, gamma=None, subsample_ratio=1.0): Estimate the variational bound of documents from `corpus`: E_q[log p(corpus)] - E_q[log q(corpus)] - `gamma` are the variational parameters on topic weights for each `corpus` - document (=2d matrix=what comes out of `inference()`). - If not supplied, will be inferred from the model. - + Args: + corpus: documents to infer variational bounds from. + gamma: the variational parameters on topic weights for each `corpus` + document (=2d matrix=what comes out of `inference()`). + If not supplied, will be inferred from the model. + subsample_ratio (float): If `corpus` is a sample of the whole corpus, + pass this to inform on what proportion of the corpus it represents. + This is used as a multiplicative factor to scale the likelihood + appropriately. + + Returns: + The variational bound score calculated. """ score = 0.0 _lambda = self.state.get_lambda() @@ -763,18 +770,18 @@ def bound(self, corpus, gamma=None, subsample_ratio=1.0): def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): """ - For `num_topics` number of topics, return `num_words` most significant words - (10 words per topic, by default). - - The topics are returned as a list -- a list of strings if `formatted` is - True, or a list of `(word, probability)` 2-tuples if False. - - If `log` is True, also output this result to log. - - Unlike LSA, there is no natural ordering between the topics in LDA. - The returned `num_topics <= self.num_topics` subset of all topics is therefore - arbitrary and may change between two LDA training runs. - + Args: + num_topics (int): show results for first `num_topics` topics. + Unlike LSA, there is no natural ordering between the topics in LDA. + The returned `num_topics <= self.num_topics` subset of all topics is + therefore arbitrary and may change between two LDA training runs. + num_words (int): include top `num_words` with highest probabilities in topic. + log (bool): If True, log output in addition to returning it. + formatted (bool): If True, format topics as strings, otherwise return them as + `(word, probability) 2-tuples. + Returns: + list: `num_words` most significant words for `num_topics` number of topics + (10 words for top 10 topics, by default). """ if num_topics < 0 or num_topics >= self.num_topics: num_topics = self.num_topics @@ -807,99 +814,86 @@ def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): def show_topic(self, topicid, topn=10): """ - Return a list of `(word, probability)` 2-tuples for the most probable - words in topic `topicid`. - - Only return 2-tuples for the topn most probable words (ignore the rest). + Args: + topn (int): Only return 2-tuples for the topn most probable words + (ignore the rest). + Returns: + list: of `(word, probability)` 2-tuples for the most probable + words in topic `topicid`. """ return [(self.id2word[id], value) for id, value in self.get_topic_terms(topicid, topn)] - def get_topic_terms(self, topicid, topn=10): + def get_topics(self): """ - Return a list of `(word_id, probability)` 2-tuples for the most - probable words in topic `topicid`. + Returns: + np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents + the term topic matrix learned during inference. + """ + topics = self.state.get_lambda() + return topics / topics.sum(axis=1)[:, None] - Only return 2-tuples for the topn most probable words (ignore the rest). + def get_topic_terms(self, topicid, topn=10): + """ + Args: + topn (int): Only return 2-tuples for the topn most probable words + (ignore the rest). + Returns: + list: `(word_id, probability)` 2-tuples for the most probable words + in topic with id `topicid`. """ - topic = self.state.get_lambda()[topicid] + topic = self.get_topics()[topicid] topic = topic / topic.sum() # normalize to probability distribution bestn = matutils.argsort(topic, topn, reverse=True) return [(id, topic[id]) for id in bestn] - def top_topics(self, corpus, num_words=20): - """ - Calculate the Umass topic coherence for each topic. Algorithm from - **Mimno, Wallach, Talley, Leenders, McCallum: Optimizing Semantic Coherence in Topic Models, CEMNLP 2011.** + def top_topics(self, corpus=None, texts=None, dictionary=None, window_size=None, + coherence='u_mass', topn=20, processes=-1): """ - is_corpus, corpus = utils.is_corpus(corpus) - if not is_corpus: - logger.warning("LdaModel.top_topics() called with an empty corpus") - return + Calculate the coherence for each topic; default is Umass coherence. - topics = [] - str_topics = [] - for topic in self.state.get_lambda(): - topic = topic / topic.sum() # normalize to probability distribution - bestn = matutils.argsort(topic, topn=num_words, reverse=True) - topics.append(bestn) - beststr = [(topic[id], self.id2word[id]) for id in bestn] - str_topics.append(beststr) - - # top_ids are limited to every topics top words. should not exceed the - # vocabulary size. - top_ids = set(chain.from_iterable(topics)) - - # create a document occurence sparse matrix for each word - doc_word_list = {} - for id in top_ids: - id_list = set() - for n, document in enumerate(corpus): - if id in frozenset(x[0] for x in document): - id_list.add(n) - - doc_word_list[id] = id_list - - coherence_scores = [] - for t, top_words in enumerate(topics): - # Calculate each coherence score C(t, top_words) - coherence = 0.0 - # Sum of top words m=2..M - for m in top_words[1:]: - # m_docs is v_m^(t) - m_docs = doc_word_list[m] - m_index = np.where(top_words == m)[0][0] - - # Sum of top words l=1..m - # i.e., all words ranked higher than the current word m - for l in top_words[:m_index]: - # l_docs is v_l^(t) - l_docs = doc_word_list[l] - - # make sure this word appears in some documents. - if len(l_docs) > 0: - # co_doc_frequency is D(v_m^(t), v_l^(t)) - co_doc_frequency = len(m_docs.intersection(l_docs)) - - # add to the coherence sum for these two words m, l - coherence += np.log((co_doc_frequency + 1.0) / len(l_docs)) - - coherence_scores.append((str_topics[t], coherence)) - - top_topics = sorted(coherence_scores, key=lambda t: t[1], reverse=True) - return top_topics - - def get_document_topics(self, bow, minimum_probability=None, minimum_phi_value=None, per_word_topics=False): + See the :class:`gensim.models.CoherenceModel` constructor for more info on the + parameters and the different coherence metrics. + + Returns: + list: tuples with `(topic_repr, coherence_score)`, where `topic_repr` is a list + of representations of the `topn` terms for the topic. The terms are represented + as tuples of `(membership_in_topic, token)`. The `coherence_score` is a float. """ - Return topic distribution for the given document `bow`, as a list of - (topic_id, topic_probability) 2-tuples. + cm = CoherenceModel( + model=self, corpus=corpus, texts=texts, dictionary=dictionary, + window_size=window_size, coherence=coherence, topn=topn, + processes=processes) + coherence_scores = cm.get_coherence_per_topic() - Ignore topics with very low probability (below `minimum_probability`). + str_topics = [] + for topic in self.get_topics(): # topic = array of vocab_size floats, one per term + bestn = matutils.argsort(topic, topn=topn, reverse=True) # top terms for topic + beststr = [(topic[_id], self.id2word[_id]) for _id in bestn] # membership, token + str_topics.append(beststr) # list of topn (float membership, token) tuples - If per_word_topics is True, it also returns a list of topics, sorted in descending order of most likely topics for that word. - It also returns a list of word_ids and each words corresponding topics' phi_values, multiplied by feature length (i.e, word count) + scored_topics = zip(str_topics, coherence_scores) + return sorted(scored_topics, key=lambda tup: tup[1], reverse=True) + def get_document_topics(self, bow, minimum_probability=None, minimum_phi_value=None, + per_word_topics=False): + """ + Args: + bow (list): Bag-of-words representation of the document to get topics for. + minimum_probability (float): Ignore topics with probability below this value + (None by default). If set to None, a value of 1e-8 is used to prevent 0s. + per_word_topics (bool): If True, also returns a list of topics, sorted in + descending order of most likely topics for that word. It also returns a list + of word_ids and each words corresponding topics' phi_values, multiplied by + feature length (i.e, word count). + minimum_phi_value (float): if `per_word_topics` is True, this represents a lower + bound on the term probabilities that are included (None by default). If set + to None, a value of 1e-8 is used to prevent 0s. + + Returns: + topic distribution for the given document `bow`, as a list of + `(topic_id, topic_probability)` 2-tuples. """ if minimum_probability is None: minimum_probability = self.minimum_probability @@ -929,32 +923,38 @@ def get_document_topics(self, bow, minimum_probability=None, minimum_phi_value=N if not per_word_topics: return document_topics - else: - word_topic = [] # contains word and corresponding topic - word_phi = [] # contains word and phi values - for word_type, weight in bow: - phi_values = [] # contains (phi_value, topic) pairing to later be sorted - phi_topic = [] # contains topic and corresponding phi value to be returned 'raw' to user - for topic_id in range(0, self.num_topics): - if phis[topic_id][word_type] >= minimum_phi_value: - # appends phi values for each topic for that word - # these phi values are scaled by feature length - phi_values.append((phis[topic_id][word_type], topic_id)) - phi_topic.append((topic_id, phis[topic_id][word_type])) - - # list with ({word_id => [(topic_0, phi_value), (topic_1, phi_value) ...]). - word_phi.append((word_type, phi_topic)) - # sorts the topics based on most likely topic - # returns a list like ({word_id => [topic_id_most_probable, topic_id_second_most_probable, ...]). - sorted_phi_values = sorted(phi_values, reverse=True) - topics_sorted = [x[1] for x in sorted_phi_values] - word_topic.append((word_type, topics_sorted)) - return (document_topics, word_topic, word_phi) # returns 2-tuple + + word_topic = [] # contains word and corresponding topic + word_phi = [] # contains word and phi values + for word_type, weight in bow: + phi_values = [] # contains (phi_value, topic) pairing to later be sorted + phi_topic = [] # contains topic and corresponding phi value to be returned 'raw' to user + for topic_id in range(0, self.num_topics): + if phis[topic_id][word_type] >= minimum_phi_value: + # appends phi values for each topic for that word + # these phi values are scaled by feature length + phi_values.append((phis[topic_id][word_type], topic_id)) + phi_topic.append((topic_id, phis[topic_id][word_type])) + + # list with ({word_id => [(topic_0, phi_value), (topic_1, phi_value) ...]). + word_phi.append((word_type, phi_topic)) + # sorts the topics based on most likely topic + # returns a list like ({word_id => [topic_id_most_probable, topic_id_second_most_probable, ...]). + sorted_phi_values = sorted(phi_values, reverse=True) + topics_sorted = [x[1] for x in sorted_phi_values] + word_topic.append((word_type, topics_sorted)) + + return document_topics, word_topic, word_phi # returns 2-tuple def get_term_topics(self, word_id, minimum_probability=None): """ - Returns most likely topics for a particular word in vocab. - + Args: + word_id (int): ID of the word to get topic probabilities for. + minimum_probability (float): Only include topic probabilities above this + value (None by default). If set to None, use 1e-8 to prevent including 0s. + Returns: + list: The most likely topics for the given word. Each topic is represented + as a tuple of `(topic_id, term_probability)`. """ if minimum_probability is None: minimum_probability = self.minimum_probability @@ -1013,7 +1013,7 @@ def diff(self, other, distance="kullback_leibler", num_words=100, n_ann_terms=10 raise ValueError("The parameter `other` must be of type `{}`".format(self.__name__)) distance_func = distances[distance] - d1, d2 = self.state.get_lambda(), other.state.get_lambda() + d1, d2 = self.get_topics(), other.get_topics() t1_size, t2_size = d1.shape[0], d2.shape[0] annotation_terms = None @@ -1061,11 +1061,13 @@ def diff(self, other, distance="kullback_leibler", num_words=100, n_ann_terms=10 def __getitem__(self, bow, eps=None): """ - Return topic distribution for the given document `bow`, as a list of - (topic_id, topic_probability) 2-tuples. - - Ignore topics with very low probability (below `eps`). + Args: + bow (list): Bag-of-words representation of a document. + eps (float): Ignore topics with probability below `eps`. + Returns: + topic distribution for the given document `bow`, as a list of + `(topic_id, topic_probability)` 2-tuples. """ return self.get_document_topics(bow, eps, self.minimum_phi_value, self.per_word_topics) diff --git a/gensim/models/lsimodel.py b/gensim/models/lsimodel.py index bc0cc4fb0f..8e326fdc6c 100644 --- a/gensim/models/lsimodel.py +++ b/gensim/models/lsimodel.py @@ -57,13 +57,11 @@ import scipy.linalg import scipy.sparse from scipy.sparse import sparsetools - -from gensim import interfaces, matutils, utils -from gensim.models import basemodel - from six import iterkeys from six.moves import xrange +from gensim import interfaces, matutils, utils +from gensim.models import basemodel logger = logging.getLogger(__name__) @@ -470,6 +468,26 @@ def __getitem__(self, bow, scaled=False, chunksize=512): result = matutils.Dense2Corpus(topic_dist) return result + def get_topics(self): + """ + Returns: + np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents + the term topic matrix learned during inference. + + Note: + The number of topics can actually be smaller than `self.num_topics`, + if there were not enough factors (real rank of input matrix smaller than + `self.num_topics`). + """ + projections = self.projection.u.T + num_topics = len(projections) + topics = [] + for i in range(num_topics): + c = np.asarray(projections[i, :]).flatten() + norm = np.sqrt(np.sum(np.dot(c, c))) + topics.append(1.0 * c / norm) + return np.array(topics) + def show_topic(self, topicno, topn=10): """ Return a specified topic (=left singular vector), 0 <= `topicno` < `self.num_topics`, diff --git a/gensim/models/wrappers/ldamallet.py b/gensim/models/wrappers/ldamallet.py index 5276b035f1..ab4cdc4adf 100644 --- a/gensim/models/wrappers/ldamallet.py +++ b/gensim/models/wrappers/ldamallet.py @@ -30,22 +30,19 @@ import logging +import os import random import tempfile -import os - -import numpy - import xml.etree.ElementTree as et import zipfile -from six import iteritems +import numpy from smart_open import smart_open from gensim import utils, matutils -from gensim.utils import check_output, revdict -from gensim.models.ldamodel import LdaModel from gensim.models import basemodel +from gensim.models.ldamodel import LdaModel +from gensim.utils import check_output, revdict logger = logging.getLogger(__name__) @@ -208,11 +205,21 @@ def load_word_topics(self): def load_document_topics(self): """ - Return an iterator over the topic distribution of training corpus, by reading - the doctopics.txt generated during training. + Returns: + An iterator over the topic distribution of training corpus, by reading + the doctopics.txt generated during training. """ return self.read_doctopics(self.fdoctopics()) + def get_topics(self): + """ + Returns: + np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents + the term topic matrix learned during inference. + """ + topics = self.word_topics + return topics / topics.sum(axis=1)[:, None] + def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True): """ Print the `num_words` most probable words for `num_topics` number of topics. diff --git a/gensim/models/wrappers/ldavowpalwabbit.py b/gensim/models/wrappers/ldavowpalwabbit.py index 6d6ae9e275..8fd0582bae 100644 --- a/gensim/models/wrappers/ldavowpalwabbit.py +++ b/gensim/models/wrappers/ldavowpalwabbit.py @@ -53,15 +53,15 @@ .. [2] http://www.cs.princeton.edu/~mdhoffma/ """ -from __future__ import unicode_literals -from __future__ import print_function from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals -import os import logging -import tempfile +import os import shutil import subprocess +import tempfile import numpy @@ -235,6 +235,15 @@ def log_perplexity(self, chunk): corpus_words) return bound + def get_topics(self): + """ + Returns: + np.ndarray: `num_topics` x `vocabulary_size` array of floats which represents + the term topic matrix learned during inference. + """ + topics = self._get_topics() + return topics / topics.sum(axis=1)[:, None] + def print_topics(self, num_topics=10, num_words=10): return self.show_topics(num_topics, num_words, log=True) diff --git a/gensim/test/basetests.py b/gensim/test/basetests.py index 4032f6ff21..1d22d8b1a8 100644 --- a/gensim/test/basetests.py +++ b/gensim/test/basetests.py @@ -8,8 +8,9 @@ Automated tests for checking transformation algorithms (the models package). """ -import six import numpy as np +import six + class TestBaseTopicModel(object): def testPrintTopic(self): @@ -41,3 +42,12 @@ def testShowTopics(self): for k, v in topic: self.assertTrue(isinstance(k, six.string_types)) self.assertTrue(isinstance(v, (np.floating, float))) + + def testGetTopics(self): + topics = self.model.get_topics() + vocab_size = len(self.model.id2word) + for topic in topics: + self.assertTrue(isinstance(topic, np.ndarray)) + self.assertEqual(topic.dtype, np.float64) + self.assertEqual(vocab_size, topic.shape[0]) + self.assertAlmostEqual(np.sum(topic), 1.0, 5) diff --git a/gensim/test/test_lsimodel.py b/gensim/test/test_lsimodel.py index a703c3a923..012ac6b8f5 100644 --- a/gensim/test/test_lsimodel.py +++ b/gensim/test/test_lsimodel.py @@ -10,21 +10,19 @@ import logging -import unittest import os import os.path import tempfile +import unittest -import six import numpy as np import scipy.linalg +from gensim import matutils from gensim.corpora import mmcorpus, Dictionary from gensim.models import lsimodel -from gensim import matutils from gensim.test import basetests - module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder @@ -180,6 +178,16 @@ def testDocsProcessed(self): self.assertEqual(self.model.docs_processed, 9) self.assertEqual(self.model.docs_processed, self.corpus.num_docs) + def testGetTopics(self): + topics = self.model.get_topics() + vocab_size = len(self.model.id2word) + for topic in topics: + self.assertTrue(isinstance(topic, np.ndarray)) + self.assertEqual(topic.dtype, np.float64) + self.assertEqual(vocab_size, topic.shape[0]) + # LSI topics are not probability distributions + # self.assertAlmostEqual(np.sum(topic), 1.0, 5) + # endclass TestLsiModel