Skip to content

Commit

Permalink
added warning related to reproducibility (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Jul 17, 2022
1 parent a07a945 commit a9e9fa8
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/python/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ struct TopicModelObject
{
PyObject_HEAD;
tomoto::ITopicModel* inst;
bool isPrepared;
bool isPrepared, seedGiven;
size_t minWordCnt, minWordDf;
size_t removeTopWord;
PyObject* initParams;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_CT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ static int CT_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
tomoto::CTArgs margs;

PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k", "smoothing_alpha", "eta",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`smoothing_alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::ICTModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_DMR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ static int DMR_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::DMRArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k", "alpha", "eta", "sigma", "alpha_epsilon",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfffnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.sigma, &margs.alphaEps, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfffOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.sigma, &margs.alphaEps, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IDMRModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 6 additions & 2 deletions src/python/py_DT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ static int DT_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::DTArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k", "t",
"alpha_var", "eta_var", "phi_var", "lr_a", "lr_b", "lr_c",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnffffffnOO", (char**)kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnffffffOOO", (char**)kwlist,
&tw, &minCnt, &minDf, &rmTop, &margs.k, &margs.t,
&margs.alpha[0], &margs.eta, &margs.phi, &margs.shapeA, &margs.shapeB, &margs.shapeC,
&margs.seed, &objCorpus, &objTransform)) return -1;
&objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IDTModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::RuntimeError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_GDMR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ static int GDMR_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
tomoto::GDMRArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr,
*objDegrees = nullptr, *objRange = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k",
"degrees", "alpha", "eta", "sigma", "sigma0", "alpha_epsilon",
"decay", "metadata_range", "seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOOfffffOnOO", (char**)kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOOfffffOOOO", (char**)kwlist,
&tw, &minCnt, &minDf, &rmTop, &margs.k,
&objDegrees, &objAlpha, &margs.eta, &margs.sigma, &margs.sigma0, &margs.alphaEps,
&margs.orderDecay, &objRange, &margs.seed, &objCorpus, &objTransform)) return -1;
&margs.orderDecay, &objRange, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

if (objDegrees)
{
Expand All @@ -46,6 +47,7 @@ static int GDMR_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 6 additions & 2 deletions src/python/py_HDP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ static int HDP_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::HDPArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "initial_k", "alpha", "eta", "gamma",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnfffnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.alpha[0], &margs.eta, &margs.gamma, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnfffOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.alpha[0], &margs.eta, &margs.gamma, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IHDPModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_HLDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ static int HLDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::HLDAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "depth", "alpha", "eta", "gamma",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOffnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.gamma, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOffOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.gamma, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `depth` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IHLDAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
9 changes: 6 additions & 3 deletions src/python/py_HPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ static int HPA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::HPAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr, * objSubAlpha = nullptr;
PyObject* objAlpha = nullptr, * objSubAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k1", "k2", "alpha", "subalpha", "eta",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOOfnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.k2, &objAlpha, &objSubAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOOfOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.k2, &objAlpha, &objSubAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k + 1,
Expand All @@ -24,11 +24,14 @@ static int HPA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
if (objSubAlpha) margs.subalpha = broadcastObj<tomoto::Float>(objSubAlpha, margs.k2 + 1,
[=]() { return "`subalpha` must be an instance of `float` or `List[float]` with length `k2 + 1` (given " + py::repr(objSubAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IHPAModel::create((tomoto::TermWeight)tw,
false, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
9 changes: 6 additions & 3 deletions src/python/py_LDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@ static int LDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::LDAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k", "alpha", "eta", "seed",
"corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfnOO", (char**)kwlist,
&tw, &minCnt, &minDf, &rmTop, &margs.k, &objAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfOOO", (char**)kwlist,
&tw, &minCnt, &minDf, &rmTop, &margs.k, &objAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::ILDAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown tw value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down Expand Up @@ -159,6 +161,7 @@ static PyObject* LDA_train(TopicModelObject* self, PyObject* args, PyObject *kwa
size_t iteration = 10, workers = 0, ps = 0, fixed = 0;
static const char* kwlist[] = { "iter", "workers", "parallel", "freeze_topics", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnp", (char**)kwlist, &iteration, &workers, &ps, &fixed)) return nullptr;
if (self->seedGiven && workers != 1 && PyErr_WarnEx(PyExc_RuntimeWarning, "The training result may differ even with fixed seed if `workers` != 1.", 1)) return nullptr;
return py::handleExc([&]()
{
if (!self->inst) throw py::RuntimeError{ "inst is null" };
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_LLDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ static int LLDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::LDAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k", "alpha", "eta",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnOfOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &objAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;

if (PyErr_WarnEx(PyExc_DeprecationWarning, "`tomotopy.LLDAModel` is deprecated. Please use `tomotopy.PLDAModel` instead.", 1)) return -1;

Expand All @@ -30,11 +30,13 @@ static int LLDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::ILLDAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 6 additions & 2 deletions src/python/py_MGLDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ static int MGLDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::MGLDAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k_g", "k_l", "t", "alpha_g", "alpha_l", "alpha_mg", "alpha_ml",
"eta_g", "eta_l", "gamma", "seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnnfffffffnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnnfffffffOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.kL, &margs.t, &margs.alpha[0], &margs.alphaL[0], &margs.alphaMG, &margs.alphaML, &margs.eta, &margs.etaL, &margs.gamma,
&margs.seed, &objCorpus, &objTransform)) return -1;
&objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IMGLDAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_PA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ static int PA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
tomoto::PAArgs margs;
size_t K = 1, K2 = 1;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr, *objSubAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSubAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "k1", "k2", "alpha", "subalpha", "eta",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOOfnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.k2, &objAlpha, &objSubAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOOfOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.k, &margs.k2, &objAlpha, &objSubAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
Expand All @@ -25,11 +25,13 @@ static int PA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
if (objSubAlpha) margs.subalpha = broadcastObj<tomoto::Float>(objSubAlpha, margs.k2,
[=]() { return "`subalpha` must be an instance of `float` or `List[float]` with length `k2` (given " + py::repr(objSubAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IPAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
8 changes: 5 additions & 3 deletions src/python/py_PLDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@ static int PLDA_init(TopicModelObject *self, PyObject *args, PyObject *kwargs)
size_t tw = 0, minCnt = 0, minDf = 0, rmTop = 0;
tomoto::PLDAArgs margs;
PyObject* objCorpus = nullptr, *objTransform = nullptr;
PyObject* objAlpha = nullptr;
PyObject* objAlpha = nullptr, *objSeed = nullptr;
static const char* kwlist[] = { "tw", "min_cf", "min_df", "rm_top", "latent_topics", "topics_per_label", "alpha", "eta",
"seed", "corpus", "transform", nullptr };
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOfnOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.numLatentTopics, &margs.numTopicsPerLabel, &objAlpha, &margs.eta, &margs.seed, &objCorpus, &objTransform)) return -1;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|nnnnnnOfOOO", (char**)kwlist, &tw, &minCnt, &minDf, &rmTop,
&margs.numLatentTopics, &margs.numTopicsPerLabel, &objAlpha, &margs.eta, &objSeed, &objCorpus, &objTransform)) return -1;
return py::handleExc([&]()
{
if (objAlpha) margs.alpha = broadcastObj<tomoto::Float>(objAlpha, margs.k,
[=]() { return "`alpha` must be an instance of `float` or `List[float]` with length `k` (given " + py::repr(objAlpha) + ")"; }
);
if (objSeed) margs.seed = py::toCpp<size_t>(objSeed, "`seed` must be an integer or None.");

tomoto::ITopicModel* inst = tomoto::IPLDAModel::create((tomoto::TermWeight)tw, margs);
if (!inst) throw py::ValueError{ "unknown `tw` value" };
self->inst = inst;
self->isPrepared = false;
self->seedGiven = !!objSeed;
self->minWordCnt = minCnt;
self->minWordDf = minDf;
self->removeTopWord = rmTop;
Expand Down
Loading

0 comments on commit a9e9fa8

Please sign in to comment.