Skip to content

Commit

Permalink
fixed #119
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Sep 1, 2021
1 parent 78ab9f7 commit 5141849
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/TopicModel/PT.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace tomoto

struct PTArgs : public LDAArgs
{
size_t p = 100;
size_t p = 0;
Float lambda = 0.01;
};

Expand All @@ -30,5 +30,7 @@ namespace tomoto
bool scalarRng = false);

virtual size_t getP() const = 0;
virtual std::vector<Float> getTopicsFromPseudoDoc(const DocumentBase* doc, bool normalize = true) const = 0;
virtual std::vector<std::pair<Tid, Float>> getTopicsFromPseudoDocSorted(const DocumentBase* doc, size_t topN) const = 0;
};
}
19 changes: 19 additions & 0 deletions src/TopicModel/PTModel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,25 @@ namespace tomoto
return ret;
}

std::vector<Float> getTopicsFromPseudoDoc(const DocumentBase* _doc, bool normalize) const override
{
auto& doc = *static_cast<const _DocType*>(_doc);
if (!doc.numByTopic.size()) return {};
std::vector<Float> ret(this->K);
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
m = doc.numByTopic.array().template cast<Float>() + this->alphas.array();
if (normalize)
{
m /= m.sum();
}
return ret;
}

std::vector<std::pair<Tid, Float>> getTopicsFromPseudoDocSorted(const DocumentBase* doc, size_t topN) const override
{
return extractTopN<Tid>(getTopicsFromPseudoDoc(doc, true), topN);
}

void updateDocs()
{
for (auto& doc : this->docs)
Expand Down
15 changes: 15 additions & 0 deletions src/python/py_PT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,18 @@ TopicModelTypeObject PT_type = { {
PyType_GenericAlloc,
PyType_GenericNew,
}};


PyObject* Document_getTopicsFromPseudoDoc(DocumentObject* self, size_t topN)
{
tomoto::IPTModel* mdl = dynamic_cast<tomoto::IPTModel*>(self->corpus->tm->inst);
if (!mdl) throw py::ValueError{ "`from_pseudo_doc` is valid for only `tomotopy.PTModel`." };
return py::buildPyValue(self->corpus->tm->inst->getTopicsByDocSorted(self->getBoundDoc(), topN));
}

PyObject* Document_getTopicDistFromPseudoDoc(DocumentObject* self, bool normalize)
{
tomoto::IPTModel* mdl = dynamic_cast<tomoto::IPTModel*>(self->corpus->tm->inst);
if (!mdl) throw py::ValueError{ "`from_pseudo_doc` is valid for only `tomotopy.PTModel`." };
return py::buildPyValue(self->corpus->tm->inst->getTopicsByDoc(self->getBoundDoc(), !!normalize));
}
14 changes: 11 additions & 3 deletions src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,27 +1094,35 @@ PyObject* DocumentObject::repr(DocumentObject* self)
static PyObject* Document_getTopics(DocumentObject* self, PyObject* args, PyObject* kwargs)
{
size_t topN = 10;
static const char* kwlist[] = { "top_n", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|n", (char**)kwlist, &topN)) return nullptr;
size_t fromPseudoDoc = 0;
static const char* kwlist[] = { "top_n", "from_pseudo_doc", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|np", (char**)kwlist, &topN, &fromPseudoDoc)) return nullptr;
return py::handleExc([&]()
{
if (self->corpus->isIndependent()) throw py::RuntimeError{ "This method can only be called by documents bound to the topic model." };
if (!self->corpus->tm->inst) throw py::RuntimeError{ "inst is null" };
if (!self->corpus->tm->isPrepared) throw py::RuntimeError{ "train() should be called first for calculating the topic distribution" };
#ifdef TM_PT
if (fromPseudoDoc) return Document_getTopicsFromPseudoDoc(self, topN);
#endif
return py::buildPyValue(self->corpus->tm->inst->getTopicsByDocSorted(self->getBoundDoc(), topN));
});
}

static PyObject* Document_getTopicDist(DocumentObject* self, PyObject* args, PyObject* kwargs)
{
size_t normalize = 1;
static const char* kwlist[] = { "normalize", nullptr };
size_t fromPseudoDoc = 0;
static const char* kwlist[] = { "normalize", "from_pseudo_doc", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p", (char**)kwlist, &normalize)) return nullptr;
return py::handleExc([&]()
{
if (self->corpus->isIndependent()) throw py::RuntimeError{ "This method can only be called by documents bound to the topic model." };
if (!self->corpus->tm->inst) throw py::RuntimeError{ "inst is null" };
if (!self->corpus->tm->isPrepared) throw py::RuntimeError{ "train() should be called first for calculating the topic distribution" };
#ifdef TM_PT
if (fromPseudoDoc) return Document_getTopicDistFromPseudoDoc(self, !!normalize);
#endif
return py::buildPyValue(self->corpus->tm->inst->getTopicsByDoc(self->getBoundDoc(), !!normalize));
});
}
Expand Down
4 changes: 4 additions & 0 deletions src/python/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ PyObject* Document_getSubTopicDist(DocumentObject* self, PyObject* args, PyObjec

PyObject* Document_getCountVector(DocumentObject* self);

PyObject* Document_getTopicsFromPseudoDoc(DocumentObject* self, size_t topN);
PyObject* Document_getTopicDistFromPseudoDoc(DocumentObject* self, bool normalize);


template<typename _Target, typename _Order>
PyObject* buildPyValueReorder(const _Target& target, const _Order& order)
{
Expand Down

0 comments on commit 5141849

Please sign in to comment.