Skip to content

Commit

Permalink
fixed reproducibility issue of DMR & GDMR
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Jul 17, 2022
1 parent 3d7c849 commit a07a945
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/TopicModel/DMRModel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace tomoto
const size_t chStride = pool.getNumWorkers() * 8;
for (size_t ch = 0; ch < chStride; ++ch)
{
res.emplace_back(pool.enqueue([&](size_t threadId)
res.emplace_back(pool.enqueue([&, ch](size_t threadId)
{
auto& tmpK = localData[threadId].tmpK;
if (!tmpK.size()) tmpK.resize(this->K);
Expand Down
26 changes: 26 additions & 0 deletions test/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ def train0_without_optim(cls, inputFile, mdFields, f, kargs, ps):
mdl.train(2000, parallel=ps)
mdl.summary(file=sys.stderr)

def reproducibility(cls, inputFile, mdFields, f, kargs, ps):
print('Test reproducibility')
tw = 0
results = []
for _ in range(3):
print('Initialize model %s with TW=%s ...' % (str(cls), ['one', 'idf', 'pmi'][tw]))
mdl = cls(tw=tw, min_df=2, rm_top=2, seed=42, **kargs)
print('Adding docs...')
for n, line in enumerate(open(inputFile, encoding='utf-8')):
ch = line.strip().split()
if len(ch) < mdFields + 1: continue
if mdFields: mdl.add_doc(ch[mdFields:], f(ch[:mdFields]))
else: mdl.add_doc(ch)
mdl.train(1000, workers=1)
if isinstance(mdl, tp.DTModel):
results.append([mdl.get_topic_words(k, timepoint=0) for k in range(mdl.k)])
else:
results.append([mdl.get_topic_words(k) for k in range(mdl.k)])

assert results[0] == results[1]
assert results[0] == results[2]

def save_and_load(cls, inputFile, mdFields, f, kargs, ps):
print('Test save & load')
tw = 0
Expand Down Expand Up @@ -603,6 +625,10 @@ def test_purge_dead_topics():
copy_train, uninit_doc,
]:
locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps)

func = reproducibility
ps = tp.ParallelScheme.NONE
locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps)

for model_case in model_asym_cases:
pss = model_case[5]
Expand Down

0 comments on commit a07a945

Please sign in to comment.