diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7ef..4a363cba5c79 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -67,7 +67,17 @@ namespace { namespace tvm { namespace runtime { RPCEnv::RPCEnv() { +#ifndef _WIN32 + char cwd[PATH_MAX]; + if (char *rc = getcwd(cwd, sizeof(cwd))) { + base_ = std::string(cwd) + "/rpc"; + } else { + base_ = "./rpc"; + } +#else base_ = "./rpc"; +#endif + mkdir(base_.c_str(), 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { static RPCEnv env; diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 57a68f452d3d..2c8bdfae0168 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -66,6 +66,22 @@ static pid_t waitPidEintr(int* status) { } #endif +#ifdef __ANDROID__ +static std::string getNextString(std::stringstream* iss) { + std::string str = iss->str(); + size_t start = iss->tellg(); + size_t len = str.size(); + // Skip leading spaces. + while (start < len && isspace(str[start])) start++; + + size_t end = start; + while (end < len && !isspace(str[end])) end++; + + iss->seekg(end); + return str.substr(start, end-start); +} +#endif + /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 @@ -164,9 +180,9 @@ class RPCServer { int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); + kill(worker_pid, SIGTERM); } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); + kill(timer_pid, SIGTERM); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } @@ -260,7 +276,12 @@ class RPCServer { std::stringstream ssin(remote_key); std::string arg0; +#ifndef __ANDROID__ ssin >> arg0; +#else + arg0 = getNextString(&ssin); +#endif + if (arg0 != expect_header) { code = kRPCMismatch; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); @@ -274,7 +295,11 @@ class RPCServer { CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); +#ifndef __ANDROID__ ssin >> *opts; +#else + *opts = getNextString(&ssin); +#endif *conn_sock = conn; return; } @@ -301,8 +326,9 @@ class RPCServer { int GetTimeOutFromOpts(const std::string& opts) const { const std::string option = "-timeout="; - if (opts.find(option) == 0) { - const std::string cmd = opts.substr(opts.find_last_of(option) + 1); + size_t pos = opts.rfind(option); + if (pos != std::string::npos) { + const std::string cmd = opts.substr(pos + option.size()); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); }