Skip to content

Commit

Permalink
added test for sync_state (#2959)
Browse files Browse the repository at this point in the history
* added test for sync_state

* fix for ldamulticore spec
  • Loading branch information
sezanzeb authored Feb 12, 2021
1 parent cec8974 commit 0502284
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def init_dir_prior(self, prior, name):
name : {'alpha', 'eta'}
Whether the `prior` is parameterized by the alpha vector (1 parameter per topic)
or by the eta (1 parameter per unique term in the vocabulary).
"""
if prior is None:
prior = 'symmetric'
Expand Down Expand Up @@ -609,8 +610,8 @@ def sync_state(self, current_Elogbeta=None):
current_Elogbeta: numpy.ndarray
Posterior probabilities for each topic, optional.
If omitted, it will get Elogbeta from state.
"""
"""
if current_Elogbeta is None:
current_Elogbeta = self.state.get_Elogbeta()
self.expElogbeta = np.exp(current_Elogbeta)
Expand Down Expand Up @@ -1201,7 +1202,6 @@ def show_topic(self, topicid, topn=10):
def get_topics(self):
"""Get the term-topic matrix learned during inference.
Returns
-------
numpy.ndarray
Expand Down
20 changes: 20 additions & 0 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numbers
import os
import unittest
import copy

import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -41,6 +42,25 @@ def setUp(self):
self.class_ = ldamodel.LdaModel
self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100)

def test_sync_state(self):
model2 = self.class_(corpus=self.corpus, id2word=dictionary, num_topics=2, passes=1)
model2.state = copy.deepcopy(self.model.state)
model2.sync_state()

assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5)
assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5)

# properly continues training on the new state
self.model.random_state = np.random.RandomState(0)
model2.random_state = np.random.RandomState(0)
self.model.passes = 1
model2.passes = 1
self.model.update(self.corpus)
model2.update(self.corpus)

assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5)
assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5)

def test_transform(self):
passed = False
# sometimes, LDA training gets stuck at a local minimum
Expand Down

0 comments on commit 0502284

Please sign in to comment.