diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 0863bcc578..01a4029705 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -542,6 +542,7 @@ def init_dir_prior(self, prior, name): name : {'alpha', 'eta'} Whether the `prior` is parameterized by the alpha vector (1 parameter per topic) or by the eta (1 parameter per unique term in the vocabulary). + """ if prior is None: prior = 'symmetric' @@ -609,8 +610,8 @@ def sync_state(self, current_Elogbeta=None): current_Elogbeta: numpy.ndarray Posterior probabilities for each topic, optional. If omitted, it will get Elogbeta from state. - """ + """ if current_Elogbeta is None: current_Elogbeta = self.state.get_Elogbeta() self.expElogbeta = np.exp(current_Elogbeta) @@ -1201,7 +1202,6 @@ def show_topic(self, topicid, topn=10): def get_topics(self): """Get the term-topic matrix learned during inference. - Returns ------- numpy.ndarray diff --git a/gensim/test/test_ldamodel.py b/gensim/test/test_ldamodel.py index a292ed61d0..b809b39754 100644 --- a/gensim/test/test_ldamodel.py +++ b/gensim/test/test_ldamodel.py @@ -13,6 +13,7 @@ import numbers import os import unittest +import copy import numpy as np from numpy.testing import assert_allclose @@ -41,6 +42,25 @@ def setUp(self): self.class_ = ldamodel.LdaModel self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100) + def test_sync_state(self): + model2 = self.class_(corpus=self.corpus, id2word=dictionary, num_topics=2, passes=1) + model2.state = copy.deepcopy(self.model.state) + model2.sync_state() + + assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5) + assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5) + + # properly continues training on the new state + self.model.random_state = np.random.RandomState(0) + model2.random_state = np.random.RandomState(0) + self.model.passes = 1 + model2.passes = 1 + self.model.update(self.corpus) + model2.update(self.corpus) + + assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5) + assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5) + def test_transform(self): passed = False # sometimes, LDA training gets stuck at a local minimum