From a07a945802d3fe29579313e6840fd621ab8abdc3 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 18 Jul 2022 01:13:29 +0900 Subject: [PATCH] fixed reproducibility issue of DMR & GDMR --- src/TopicModel/DMRModel.hpp | 2 +- test/unit_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/TopicModel/DMRModel.hpp b/src/TopicModel/DMRModel.hpp index a5696b6..cd0bcc7 100644 --- a/src/TopicModel/DMRModel.hpp +++ b/src/TopicModel/DMRModel.hpp @@ -86,7 +86,7 @@ namespace tomoto const size_t chStride = pool.getNumWorkers() * 8; for (size_t ch = 0; ch < chStride; ++ch) { - res.emplace_back(pool.enqueue([&](size_t threadId) + res.emplace_back(pool.enqueue([&, ch](size_t threadId) { auto& tmpK = localData[threadId].tmpK; if (!tmpK.size()) tmpK.resize(this->K); diff --git a/test/unit_test.py b/test/unit_test.py index 4ff1dc9..71aa672 100644 --- a/test/unit_test.py +++ b/test/unit_test.py @@ -135,6 +135,28 @@ def train0_without_optim(cls, inputFile, mdFields, f, kargs, ps): mdl.train(2000, parallel=ps) mdl.summary(file=sys.stderr) +def reproducibility(cls, inputFile, mdFields, f, kargs, ps): + print('Test reproducibility') + tw = 0 + results = [] + for _ in range(3): + print('Initialize model %s with TW=%s ...' % (str(cls), ['one', 'idf', 'pmi'][tw])) + mdl = cls(tw=tw, min_df=2, rm_top=2, seed=42, **kargs) + print('Adding docs...') + for n, line in enumerate(open(inputFile, encoding='utf-8')): + ch = line.strip().split() + if len(ch) < mdFields + 1: continue + if mdFields: mdl.add_doc(ch[mdFields:], f(ch[:mdFields])) + else: mdl.add_doc(ch) + mdl.train(1000, workers=1) + if isinstance(mdl, tp.DTModel): + results.append([mdl.get_topic_words(k, timepoint=0) for k in range(mdl.k)]) + else: + results.append([mdl.get_topic_words(k) for k in range(mdl.k)]) + + assert results[0] == results[1] + assert results[0] == results[2] + def save_and_load(cls, inputFile, mdFields, f, kargs, ps): print('Test save & load') tw = 0 @@ -603,6 +625,10 @@ def test_purge_dead_topics(): copy_train, uninit_doc, ]: locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps) + + func = reproducibility + ps = tp.ParallelScheme.NONE + locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps) for model_case in model_asym_cases: pss = model_case[5]