diff --git a/src/TopicModel/CTModel.hpp b/src/TopicModel/CTModel.hpp index dd7b3c6..da2cf62 100644 --- a/src/TopicModel/CTModel.hpp +++ b/src/TopicModel/CTModel.hpp @@ -252,6 +252,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K); Eigen::Map> m{ ret.data(), this->K }; if (normalize) diff --git a/src/TopicModel/DMRModel.hpp b/src/TopicModel/DMRModel.hpp index fd5dd02..a5696b6 100644 --- a/src/TopicModel/DMRModel.hpp +++ b/src/TopicModel/DMRModel.hpp @@ -454,6 +454,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K); auto alphaDoc = getCachedAlpha(doc); Eigen::Map> m{ ret.data(), this->K }; diff --git a/src/TopicModel/HDPModel.hpp b/src/TopicModel/HDPModel.hpp index e3639b5..e7877e5 100644 --- a/src/TopicModel/HDPModel.hpp +++ b/src/TopicModel/HDPModel.hpp @@ -492,6 +492,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K); Eigen::Map> m{ ret.data(), this->K }; if (normalize) @@ -538,8 +539,11 @@ namespace tomoto auto d = lda->_makeFromRawDoc(doc); lda->_addDoc(d); } - - lda->prepare(true, this->minWordCf, this->minWordDf, this->removeTopN); + + lda->realV = this->realV; + lda->realN = this->realN; + lda->weightedN = this->weightedN; + lda->prepare(true, 0, 0, 0, false); auto selectFirst = [&](const std::pair& p) { return std::max(p.first / sum - topicThreshold, 0.f); }; std::discrete_distribution randomTopic{ diff --git a/src/TopicModel/HPAModel.hpp b/src/TopicModel/HPAModel.hpp index a8a8411..5683264 100644 --- a/src/TopicModel/HPAModel.hpp +++ b/src/TopicModel/HPAModel.hpp @@ -542,6 +542,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(1 + this->K + K2); Float sum = doc.getSumWordWeight() + this->alphas.sum(); if (!normalize) sum = 1; diff --git a/src/TopicModel/LDACVB0Model.hpp b/src/TopicModel/LDACVB0Model.hpp index f8ce26e..8cec1bc 100644 --- a/src/TopicModel/LDACVB0Model.hpp +++ b/src/TopicModel/LDACVB0Model.hpp @@ -366,9 +366,9 @@ namespace tomoto } } - void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override + void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) override { - if (initDocs) this->removeStopwords(minWordCnt, minWordDf, removeTopN); + if (initDocs) this->removeStopwords(minWordCnt, minWordDf, removeTopN, updateStopwords); static_cast(this)->updateWeakArray(); static_cast(this)->initGlobalState(initDocs); diff --git a/src/TopicModel/LDAModel.hpp b/src/TopicModel/LDAModel.hpp index ccc9f48..af30431 100644 --- a/src/TopicModel/LDAModel.hpp +++ b/src/TopicModel/LDAModel.hpp @@ -1057,9 +1057,9 @@ namespace tomoto } } - void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override + void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) override { - if (initDocs) this->removeStopwords(minWordCnt, minWordDf, removeTopN); + if (initDocs && updateStopwords) this->removeStopwords(minWordCnt, minWordDf, removeTopN); static_cast(this)->updateWeakArray(); static_cast(this)->initGlobalState(initDocs); static_cast(this)->prepareWordPriors(); @@ -1116,7 +1116,7 @@ namespace tomoto for (auto& doc : this->docs) doc.updateSumWordWeight(this->realV); } static_cast(this)->prepareShared(); - BaseClass::prepare(initDocs, minWordCnt, minWordDf, removeTopN); + BaseClass::prepare(initDocs, minWordCnt, minWordDf, removeTopN, updateStopwords); } std::vector getCountByTopic() const override @@ -1126,6 +1126,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(K); Eigen::Map> m{ ret.data(), K }; if (normalize) diff --git a/src/TopicModel/LLDAModel.hpp b/src/TopicModel/LLDAModel.hpp index 5b4e2b3..d6e5d0a 100644 --- a/src/TopicModel/LLDAModel.hpp +++ b/src/TopicModel/LLDAModel.hpp @@ -176,6 +176,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K); auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast().array(); Eigen::Map> m{ ret.data(), this->K }; diff --git a/src/TopicModel/MGLDAModel.hpp b/src/TopicModel/MGLDAModel.hpp index 27035da..72993ff 100644 --- a/src/TopicModel/MGLDAModel.hpp +++ b/src/TopicModel/MGLDAModel.hpp @@ -529,6 +529,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K + KL); Eigen::Map> m{ ret.data(), this->K + KL }; if (normalize) diff --git a/src/TopicModel/PLDAModel.hpp b/src/TopicModel/PLDAModel.hpp index 5480447..9b981f7 100644 --- a/src/TopicModel/PLDAModel.hpp +++ b/src/TopicModel/PLDAModel.hpp @@ -183,6 +183,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (!doc.numByTopic.size()) return {}; std::vector ret(this->K); auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast().array(); Eigen::Map> m{ ret.data(), this->K }; diff --git a/src/TopicModel/PT.h b/src/TopicModel/PT.h index aaa80d0..1ac32e4 100644 --- a/src/TopicModel/PT.h +++ b/src/TopicModel/PT.h @@ -18,7 +18,7 @@ namespace tomoto struct PTArgs : public LDAArgs { - size_t p = 100; + size_t p = 0; Float lambda = 0.01; }; @@ -30,5 +30,7 @@ namespace tomoto bool scalarRng = false); virtual size_t getP() const = 0; + virtual std::vector getTopicsFromPseudoDoc(const DocumentBase* doc, bool normalize = true) const = 0; + virtual std::vector> getTopicsFromPseudoDocSorted(const DocumentBase* doc, size_t topN) const = 0; }; } diff --git a/src/TopicModel/PTModel.hpp b/src/TopicModel/PTModel.hpp index bc113ef..df5c365 100644 --- a/src/TopicModel/PTModel.hpp +++ b/src/TopicModel/PTModel.hpp @@ -266,6 +266,7 @@ namespace tomoto std::vector _getTopicsByDoc(const _DocType& doc, bool normalize) const { + if (doc.Zs.empty()) return {}; std::vector ret(this->K); Eigen::Map> m{ ret.data(), this->K }; m = this->alphas.array(); @@ -280,6 +281,25 @@ namespace tomoto return ret; } + std::vector getTopicsFromPseudoDoc(const DocumentBase* _doc, bool normalize) const override + { + auto& doc = *static_cast(_doc); + if (!doc.numByTopic.size()) return {}; + std::vector ret(this->K); + Eigen::Map> m{ ret.data(), this->K }; + m = doc.numByTopic.array().template cast() + this->alphas.array(); + if (normalize) + { + m /= m.sum(); + } + return ret; + } + + std::vector> getTopicsFromPseudoDocSorted(const DocumentBase* doc, size_t topN) const override + { + return extractTopN(getTopicsFromPseudoDoc(doc, true), topN); + } + void updateDocs() { for (auto& doc : this->docs) diff --git a/src/TopicModel/TopicModel.hpp b/src/TopicModel/TopicModel.hpp index 9c98377..2d7816a 100644 --- a/src/TopicModel/TopicModel.hpp +++ b/src/TopicModel/TopicModel.hpp @@ -254,7 +254,7 @@ namespace tomoto virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0; virtual size_t getGlobalStep() const = 0; - virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) = 0; + virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) = 0; virtual size_t getK() const = 0; virtual std::vector getWidsByTopic(size_t tid, bool normalize = true) const = 0; @@ -605,7 +605,7 @@ namespace tomoto return empty; } - void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override + void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) override { auto p = countRealN(); realN = p.first; diff --git a/src/python/py_PT.cpp b/src/python/py_PT.cpp index c4c121c..97083dc 100644 --- a/src/python/py_PT.cpp +++ b/src/python/py_PT.cpp @@ -22,6 +22,8 @@ static int PT_init(TopicModelObject *self, PyObject *args, PyObject *kwargs) [=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; } ); + if (margs.p == 0) margs.p = margs.k * 10; + tomoto::ITopicModel* inst = tomoto::IPTModel::create((tomoto::TermWeight)tw, margs); if (!inst) throw py::ValueError{ "unknown `tw` value" }; self->inst = inst; @@ -99,3 +101,18 @@ TopicModelTypeObject PT_type = { { PyType_GenericAlloc, PyType_GenericNew, }}; + + +PyObject* Document_getTopicsFromPseudoDoc(DocumentObject* self, size_t topN) +{ + tomoto::IPTModel* mdl = dynamic_cast(self->corpus->tm->inst); + if (!mdl) throw py::ValueError{ "`from_pseudo_doc` is valid for only `tomotopy.PTModel`." }; + return py::buildPyValue(self->corpus->tm->inst->getTopicsByDocSorted(self->getBoundDoc(), topN)); +} + +PyObject* Document_getTopicDistFromPseudoDoc(DocumentObject* self, bool normalize) +{ + tomoto::IPTModel* mdl = dynamic_cast(self->corpus->tm->inst); + if (!mdl) throw py::ValueError{ "`from_pseudo_doc` is valid for only `tomotopy.PTModel`." }; + return py::buildPyValue(self->corpus->tm->inst->getTopicsByDoc(self->getBoundDoc(), !!normalize)); +} \ No newline at end of file diff --git a/src/python/py_utils.cpp b/src/python/py_utils.cpp index 4d9e7c1..8181c0b 100644 --- a/src/python/py_utils.cpp +++ b/src/python/py_utils.cpp @@ -1094,13 +1094,17 @@ PyObject* DocumentObject::repr(DocumentObject* self) static PyObject* Document_getTopics(DocumentObject* self, PyObject* args, PyObject* kwargs) { size_t topN = 10; - static const char* kwlist[] = { "top_n", nullptr }; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|n", (char**)kwlist, &topN)) return nullptr; + size_t fromPseudoDoc = 0; + static const char* kwlist[] = { "top_n", "from_pseudo_doc", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|np", (char**)kwlist, &topN, &fromPseudoDoc)) return nullptr; return py::handleExc([&]() { if (self->corpus->isIndependent()) throw py::RuntimeError{ "This method can only be called by documents bound to the topic model." }; if (!self->corpus->tm->inst) throw py::RuntimeError{ "inst is null" }; if (!self->corpus->tm->isPrepared) throw py::RuntimeError{ "train() should be called first for calculating the topic distribution" }; +#ifdef TM_PT + if (fromPseudoDoc) return Document_getTopicsFromPseudoDoc(self, topN); +#endif return py::buildPyValue(self->corpus->tm->inst->getTopicsByDocSorted(self->getBoundDoc(), topN)); }); } @@ -1108,13 +1112,17 @@ static PyObject* Document_getTopics(DocumentObject* self, PyObject* args, PyObje static PyObject* Document_getTopicDist(DocumentObject* self, PyObject* args, PyObject* kwargs) { size_t normalize = 1; - static const char* kwlist[] = { "normalize", nullptr }; + size_t fromPseudoDoc = 0; + static const char* kwlist[] = { "normalize", "from_pseudo_doc", nullptr }; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p", (char**)kwlist, &normalize)) return nullptr; return py::handleExc([&]() { if (self->corpus->isIndependent()) throw py::RuntimeError{ "This method can only be called by documents bound to the topic model." }; if (!self->corpus->tm->inst) throw py::RuntimeError{ "inst is null" }; if (!self->corpus->tm->isPrepared) throw py::RuntimeError{ "train() should be called first for calculating the topic distribution" }; +#ifdef TM_PT + if (fromPseudoDoc) return Document_getTopicDistFromPseudoDoc(self, !!normalize); +#endif return py::buildPyValue(self->corpus->tm->inst->getTopicsByDoc(self->getBoundDoc(), !!normalize)); }); } diff --git a/src/python/utils.h b/src/python/utils.h index aaa4af6..901e5fa 100644 --- a/src/python/utils.h +++ b/src/python/utils.h @@ -342,6 +342,10 @@ PyObject* Document_getSubTopicDist(DocumentObject* self, PyObject* args, PyObjec PyObject* Document_getCountVector(DocumentObject* self); +PyObject* Document_getTopicsFromPseudoDoc(DocumentObject* self, size_t topN); +PyObject* Document_getTopicDistFromPseudoDoc(DocumentObject* self, bool normalize); + + template PyObject* buildPyValueReorder(const _Target& target, const _Order& order) { diff --git a/test/unit_test.py b/test/unit_test.py index 8e7e78c..f9b865b 100644 --- a/test/unit_test.py +++ b/test/unit_test.py @@ -295,6 +295,32 @@ def train_multi_corpus(cls, inputFile, mdFields, f, kargs, ps): print('Corpus2') for d in tcorpus2[:10]: print(d.get_ll()) +def uninit_doc(cls, inputFile, mdFields, f, kargs, ps): + print('Test uninitialized doc') + tw = 0 + print('Initialize model %s with TW=%s ...' % (str(cls), ['one', 'idf', 'pmi'][tw])) + mdl = cls(tw=tw, min_df=2, rm_top=2, **kargs) + print('Adding docs...') + unseen_docs = [] + for n, line in enumerate(open(inputFile, encoding='utf-8')): + ch = line.strip().split() + if len(ch) < mdFields + 1: continue + if n < 20: unseen_docs.append(line) + else: + if mdFields: + mdl.add_doc(ch[mdFields:], f(ch[:mdFields])) + else: + mdl.add_doc(ch) + mdl.train(20, parallel=ps) + for n, line in enumerate(unseen_docs): + ch = line.strip().split() + if mdFields: + unseen_docs[n] = mdl.make_doc(ch[mdFields:], f(ch[:mdFields])) + else: + unseen_docs[n] = mdl.make_doc(ch) + unseen_docs[n].get_topics() + unseen_docs[n].get_topic_dist() + def test_empty_uid(): cps = tp.utils.Corpus() cps.add_doc("test text".split()) @@ -489,7 +515,7 @@ def test_corpus_save_load(): for ps in pss: for func in [null_doc, train1, train4, train0, save_and_load, infer, infer_together, - copy_train, + copy_train, uninit_doc, ]: locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps)