From f14a2be2422779df8a9ed20e274392920e1bfb4b Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 17 Jul 2022 20:55:09 +0900 Subject: [PATCH] fixed coherence of DTModel (#164) --- test/unit_test.py | 13 ++++++++++++ tomotopy/coherence.py | 46 ++++++++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/test/unit_test.py b/test/unit_test.py index 3b52a9e..4ff1dc9 100644 --- a/test/unit_test.py +++ b/test/unit_test.py @@ -521,6 +521,19 @@ def test_coherence(): coherence = tp.coherence.Coherence(corpus=mdl, coherence=coh) print(coherence.get_score()) +def test_coherence_dtm(): + mdl = tp.DTModel(k=10, t=13) + for n, line in enumerate(open(curpath + '/sample_tp.txt', encoding='utf-8')): + ch = line.strip().split() + if len(ch) < 2: continue + mdl.add_doc(ch[1:], timepoint=int(ch[0])) + mdl.train(100) + coh = tp.coherence.Coherence(mdl) + + print(coh.get_score(topic_id=0, timepoint=0)) + + print(coh.get_score()) + def test_corpus_save_load(): corpus = tp.utils.Corpus() # data_feeder yields a tuple of (raw string, user data) or a str (raw string) diff --git a/tomotopy/coherence.py b/tomotopy/coherence.py index d0068b3..d060675 100644 --- a/tomotopy/coherence.py +++ b/tomotopy/coherence.py @@ -102,7 +102,12 @@ def __init__(self, corpus, coherence='u_mass', window_size=0, targets=None, top_ import tomotopy as tp import itertools self._top_n = top_n - if isinstance(corpus, tp.LDAModel): + if isinstance(corpus, tp.DTModel): + self._topic_model = corpus + if targets is None: + targets = itertools.chain(*((w for w, _ in corpus.get_topic_words(k, t, top_n=top_n)) for k in range(corpus.k) for t in range(corpus.num_timepoints))) + corpus = corpus.docs + elif isinstance(corpus, tp.LDAModel): self._topic_model = corpus if targets is None: targets = itertools.chain(*((w for w, _ in corpus.get_topic_words(k, top_n=top_n)) for k in range(corpus.k))) @@ -131,7 +136,7 @@ def __init__(self, corpus, coherence='u_mass', window_size=0, targets=None, top_ super().__init__(corpus, pe=pe, seg=seg, cm=cm, im=im, window_size=window_size or w, targets=targets, eps=eps, gamma=gamma) - def get_score(self, words=None, topic_id=None): + def get_score(self, words=None, topic_id=None, timepoint=None): '''Calculate the coherence score for given `words` or `topic_id` Parameters @@ -144,18 +149,37 @@ def get_score(self, words=None, topic_id=None): An id of the topic from which words are extracted. This parameter is valid when `tomotopy.coherence.Coherence` was initialized using `corpus` as `tomotopy.LDAModel` or its descendants. If this is omitted, the average score of all topics is returned. +timepoint : int + A timepoint of the topic from which words are extracted. (Only for `DTModel`) ''' + import tomotopy as tp if words is None and self._topic_model is None: raise ValueError("`words` must be provided if `Coherence` is not bound to an instance of topic model.") - if words is None and topic_id is None: - c = [] - for k in range(self._topic_model.k): - c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, top_n=self._top_n)))) - return sum(c) / len(c) - - if words is None: - words = (w for w, _ in self._topic_model.get_topic_words(topic_id, top_n=self._top_n)) - return super().get_score(words) + if isinstance(self._topic_model, tp.DTModel): + if int(topic_id is None) + int(timepoint is None) == 1: + raise ValueError("Both `topic_id` and `timepoint` should be given.") + if words is None and topic_id is None: + c = [] + for k in range(self._topic_model.k): + for t in range(self._topic_model.num_timepoints): + c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, timepoint=t, top_n=self._top_n)))) + return sum(c) / len(c) + + if words is None: + words = (w for w, _ in self._topic_model.get_topic_words(topic_id, timepoint=timepoint, top_n=self._top_n)) + return super().get_score(words) + else: + if timepoint is not None: + raise ValueError("`timepoint` is valid for only `DTModel`.") + if words is None and topic_id is None: + c = [] + for k in range(self._topic_model.k): + c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, top_n=self._top_n)))) + return sum(c) / len(c) + + if words is None: + words = (w for w, _ in self._topic_model.get_topic_words(topic_id, top_n=self._top_n)) + return super().get_score(words) import os if os.environ.get('TOMOTOPY_LANG') == 'kr':