diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 44a14d5c..b4f846b1 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -614,8 +614,7 @@ namespace kiwi std::vector corpora; size_t minMorphCnt = 10; size_t lmOrder = 4; - size_t lmMinCnt = 1; - size_t lmLastOrderMinCnt = 2; + std::vector lmMinCnts = { 1 }; size_t numWorkers = 1; size_t sbgSize = 1000000; bool useLmTagHistory = true; diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 10a798e3..e6bfa7a9 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -872,7 +872,8 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) pool.~ThreadPool(); new (&pool) utils::ThreadPool{ args.numWorkers }; } - auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); + size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); + auto cntNodes = utils::count(sents.begin(), sents.end(), lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); // discount for bos node cnt if (args.useLmTagHistory) { @@ -882,8 +883,16 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) { cntNodes.root().getNext(0)->val /= 2; } - std::vector minCnts(args.lmOrder, args.lmMinCnt); - minCnts.back() = args.lmLastOrderMinCnt; + std::vector minCnts; + if (args.lmMinCnts.size() == 1) + { + minCnts.clear(); + minCnts.resize(args.lmOrder, args.lmMinCnts[0]); + } + else if (args.lmMinCnts.size() == args.lmOrder) + { + minCnts = args.lmMinCnts; + } langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( cntNodes, args.lmOrder, minCnts, diff --git a/tools/model_builder.cpp b/tools/model_builder.cpp index 199e13ab..c9b37302 100644 --- a/tools/model_builder.cpp +++ b/tools/model_builder.cpp @@ -8,6 +8,27 @@ using namespace std; using namespace kiwi; +vector splitMultipleInts(const string& s, const char delim = ',') +{ + vector ret; + size_t p = 0, e = 0; + while (1) + { + size_t t = s.find(delim, p); + if (t == s.npos) + { + ret.emplace_back(atoi(&s[e])); + return ret; + } + else + { + ret.emplace_back(atoi(&s[e])); + p = t + 1; + e = t + 1; + } + } +} + int run(const KiwiBuilder::ModelBuildArgs& args, const string& output, bool skipBigram) { try @@ -49,7 +70,7 @@ int main(int argc, const char* argv[]) ValueArg workers{ "w", "workers", "number of workers", false, 1, "int" }; ValueArg morMinCnt{ "", "morpheme_min_cnt", "min count of morpheme", false, 10, "int" }; ValueArg lmOrder{ "", "order", "order of LM", false, 4, "int" }; - ValueArg lmMinCnt{ "", "min_cnt", "min count of LM", false, 1, "int" }; + ValueArg lmMinCnt{ "", "min_cnt", "min count of LM", false, "1", "multiple ints with comma"}; ValueArg lmLastOrderMinCnt{ "", "last_min_cnt", "min count of the last order of LM", false, 2, "int" }; ValueArg output{ "o", "output", "output model path", true, "", "string" }; ValueArg sbgSize{ "", "sbg_size", "sbg size", false, 1000000, "int" }; @@ -86,10 +107,24 @@ int main(int argc, const char* argv[]) args.useLmTagHistory = tagHistory; args.minMorphCnt = morMinCnt; args.lmOrder = lmOrder; - args.lmMinCnt = lmMinCnt; - args.lmLastOrderMinCnt = lmLastOrderMinCnt; args.numWorkers = workers; args.sbgSize = sbgSize; + + auto v = splitMultipleInts(lmMinCnt.getValue()); + + if (v.empty()) + { + args.lmMinCnts.resize(1, 1); + } + else if (v.size() == 1 || v.size() == lmOrder) + { + args.lmMinCnts = v; + } + else + { + cerr << "error: min_cnt size should be 1 or equal to order" << endl; + return -1; + } return run(args, output, skipBigram); }