From 6b932d2b0d4855e72d9d25315cab6728ab160e02 Mon Sep 17 00:00:00 2001 From: Minchul Lee Date: Mon, 26 Apr 2021 01:49:06 +0900 Subject: [PATCH] fixed DTModel.copy() --- src/TopicModel/DTModel.hpp | 10 ++++++++++ src/python/py_LLDA.cpp | 10 +++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/TopicModel/DTModel.hpp b/src/TopicModel/DTModel.hpp index 1841017..1649622 100644 --- a/src/TopicModel/DTModel.hpp +++ b/src/TopicModel/DTModel.hpp @@ -477,6 +477,16 @@ namespace tomoto return cnt; } + void updateForCopy() + { + BaseClass::updateForCopy(); + size_t docId = 0; + for (auto& doc : this->docs) + { + doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1); + } + } + public: DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, T, shapeA, shapeB, shapeC, alphaVar, etaVar, phiVar, alphas, etaByDoc, phi); diff --git a/src/python/py_LLDA.cpp b/src/python/py_LLDA.cpp index e5a738b..f9189f8 100644 --- a/src/python/py_LLDA.cpp +++ b/src/python/py_LLDA.cpp @@ -129,22 +129,22 @@ PyObject* Document_labels(DocumentObject* self, void* closure) if (self->corpus->isIndependent()) throw py::AttributeError{ "doc doesn't has `labels` field!" }; if (!self->doc) throw py::RuntimeError{ "doc is null!" }; - if (auto* r = docVisit(self->getBoundDoc(), [&](auto* doc) + if (auto* ret = docVisit(self->getBoundDoc(), [&](auto* doc) { auto inst = dynamic_cast(self->corpus->tm->inst); auto dict = inst->getTopicLabelDict(); - vector>> ret; + vector>> r; auto topicDist = inst->getTopicsByDoc(doc); for (size_t i = 0; i < dict.size(); ++i) { if (doc->labelMask[i * inst->getNumTopicsPerLabel()]) { - ret.emplace_back(inst->getTopicLabelDict().toWord(i), + r.emplace_back(inst->getTopicLabelDict().toWord(i), vector{ &topicDist[i * inst->getNumTopicsPerLabel()], &topicDist[(i + 1) * inst->getNumTopicsPerLabel()] }); } } - return py::buildPyValue(ret); - })) return r; + return py::buildPyValue(r); + })) return ret; throw py::AttributeError{ "doc doesn't has `labels` field!" }; });