Skip to content

Commit

Permalink
fixed coherence of DTModel (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Jul 17, 2022
1 parent 6e89172 commit f14a2be
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
13 changes: 13 additions & 0 deletions test/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,19 @@ def test_coherence():
coherence = tp.coherence.Coherence(corpus=mdl, coherence=coh)
print(coherence.get_score())

def test_coherence_dtm():
mdl = tp.DTModel(k=10, t=13)
for n, line in enumerate(open(curpath + '/sample_tp.txt', encoding='utf-8')):
ch = line.strip().split()
if len(ch) < 2: continue
mdl.add_doc(ch[1:], timepoint=int(ch[0]))
mdl.train(100)
coh = tp.coherence.Coherence(mdl)

print(coh.get_score(topic_id=0, timepoint=0))

print(coh.get_score())

def test_corpus_save_load():
corpus = tp.utils.Corpus()
# data_feeder yields a tuple of (raw string, user data) or a str (raw string)
Expand Down
46 changes: 35 additions & 11 deletions tomotopy/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def __init__(self, corpus, coherence='u_mass', window_size=0, targets=None, top_
import tomotopy as tp
import itertools
self._top_n = top_n
if isinstance(corpus, tp.LDAModel):
if isinstance(corpus, tp.DTModel):
self._topic_model = corpus
if targets is None:
targets = itertools.chain(*((w for w, _ in corpus.get_topic_words(k, t, top_n=top_n)) for k in range(corpus.k) for t in range(corpus.num_timepoints)))
corpus = corpus.docs
elif isinstance(corpus, tp.LDAModel):
self._topic_model = corpus
if targets is None:
targets = itertools.chain(*((w for w, _ in corpus.get_topic_words(k, top_n=top_n)) for k in range(corpus.k)))
Expand Down Expand Up @@ -131,7 +136,7 @@ def __init__(self, corpus, coherence='u_mass', window_size=0, targets=None, top_

super().__init__(corpus, pe=pe, seg=seg, cm=cm, im=im, window_size=window_size or w, targets=targets, eps=eps, gamma=gamma)

def get_score(self, words=None, topic_id=None):
def get_score(self, words=None, topic_id=None, timepoint=None):
'''Calculate the coherence score for given `words` or `topic_id`
Parameters
Expand All @@ -144,18 +149,37 @@ def get_score(self, words=None, topic_id=None):
An id of the topic from which words are extracted.
This parameter is valid when `tomotopy.coherence.Coherence` was initialized using `corpus` as `tomotopy.LDAModel` or its descendants.
If this is omitted, the average score of all topics is returned.
timepoint : int
A timepoint of the topic from which words are extracted. (Only for `DTModel`)
'''
import tomotopy as tp
if words is None and self._topic_model is None:
raise ValueError("`words` must be provided if `Coherence` is not bound to an instance of topic model.")
if words is None and topic_id is None:
c = []
for k in range(self._topic_model.k):
c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, top_n=self._top_n))))
return sum(c) / len(c)

if words is None:
words = (w for w, _ in self._topic_model.get_topic_words(topic_id, top_n=self._top_n))
return super().get_score(words)
if isinstance(self._topic_model, tp.DTModel):
if int(topic_id is None) + int(timepoint is None) == 1:
raise ValueError("Both `topic_id` and `timepoint` should be given.")
if words is None and topic_id is None:
c = []
for k in range(self._topic_model.k):
for t in range(self._topic_model.num_timepoints):
c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, timepoint=t, top_n=self._top_n))))
return sum(c) / len(c)

if words is None:
words = (w for w, _ in self._topic_model.get_topic_words(topic_id, timepoint=timepoint, top_n=self._top_n))
return super().get_score(words)
else:
if timepoint is not None:
raise ValueError("`timepoint` is valid for only `DTModel`.")
if words is None and topic_id is None:
c = []
for k in range(self._topic_model.k):
c.append(super().get_score((w for w, _ in self._topic_model.get_topic_words(k, top_n=self._top_n))))
return sum(c) / len(c)

if words is None:
words = (w for w, _ in self._topic_model.get_topic_words(topic_id, top_n=self._top_n))
return super().get_score(words)

import os
if os.environ.get('TOMOTOPY_LANG') == 'kr':
Expand Down

0 comments on commit f14a2be

Please sign in to comment.