Skip to content

Commit

Permalink
refactor main.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongukjae committed Jun 29, 2020
1 parent d75139a commit 23e2157
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
cxxopts::ParseResult parseCLIArgs(int argc, char** argv);
std::string getServerStat(std::map<std::string, lgbm_serving::Model*> models);
std::pair<size_t, std::vector<float*>> parse2DFloatArray(const std::string& payload);
std::string serializeModelOutput(int nrows, int nClasses, double* outResult);

int main(int argc, char** argv) {
auto args = parseCLIArgs(argc, argv);
Expand Down Expand Up @@ -103,32 +104,10 @@ int main(int argc, char** argv) {
if (outputLength != nrows * nClasses) {
res.status = 400;
res.set_content("{\"error\": \"invalid shape\"}", "application/json");
for (const auto* feat : features.second)
delete[] feat;
return;
}

rapidjson::Document document;
document.SetArray();

for (size_t i = 0; i < nrows; i++) {
if (nClasses == 1) {
document.PushBack(outResult[i], document.GetAllocator());
} else {
rapidjson::Value value(rapidjson::kArrayType);
for (size_t j = 0; j < nClasses; j++) {
value.PushBack(outResult[i], document.GetAllocator());
}
document.PushBack(value, document.GetAllocator());
}
} else {
res.set_content(serializeModelOutput(nrows, nClasses, outResult.data()), "application/json");
}

rapidjson::StringBuffer buffer;
buffer.Clear();
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
document.Accept(writer);
res.set_content(std::string(buffer.GetString()), "application/json");

for (const auto* feat : features.second)
delete[] feat;
});
Expand Down Expand Up @@ -211,3 +190,26 @@ std::pair<size_t, std::vector<float*>> parse2DFloatArray(const std::string& payl

return std::make_pair(ncol, features);
}

std::string serializeModelOutput(int nrows, int nClasses, double* outResult) {
rapidjson::Document document;
document.SetArray();

for (size_t i = 0; i < nrows; i++) {
if (nClasses == 1) {
document.PushBack(outResult[i], document.GetAllocator());
} else {
rapidjson::Value value(rapidjson::kArrayType);
for (size_t j = 0; j < nClasses; j++) {
value.PushBack(outResult[i], document.GetAllocator());
}
document.PushBack(value, document.GetAllocator());
}
}

rapidjson::StringBuffer buffer;
buffer.Clear();
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
document.Accept(writer);
return std::string(buffer.GetString());
}

0 comments on commit 23e2157

Please sign in to comment.