-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] PyTorch Command Processor #1770
Changes from 13 commits
31b644f
f5b7a2e
1ed29a3
85b5ebc
ec8b883
afbb3ba
6adf034
441b2fc
2e48245
4e77ca3
83e6172
fc00887
3a268ff
89e4424
2df2cb9
b819c19
2910f87
a83d667
ed8dbff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
#include "CCommandParser.h" | ||
|
||
#include <core/CLogger.h> | ||
|
||
#include <rapidjson/error/en.h> | ||
#include <rapidjson/istreamwrapper.h> | ||
#include <rapidjson/stringbuffer.h> | ||
#include <rapidjson/writer.h> | ||
|
||
#include <istream> | ||
|
||
namespace rapidjson { | ||
|
||
std::ostream& operator<<(std::ostream& os, const rapidjson::Document& doc) { | ||
rapidjson::StringBuffer buffer; | ||
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); | ||
doc.Accept(writer); | ||
return os << buffer.GetString(); | ||
} | ||
} | ||
|
||
namespace ml { | ||
namespace torch { | ||
|
||
const std::string CCommandParser::REQUEST_ID{"request_id"}; | ||
const std::string CCommandParser::TOKENS{"tokens"}; | ||
const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; | ||
|
||
CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { | ||
} | ||
|
||
bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) { | ||
|
||
rapidjson::IStreamWrapper isw{m_StrmIn}; | ||
|
||
while (true) { | ||
rapidjson::Document doc; | ||
rapidjson::ParseResult parseResult = | ||
doc.ParseStream<rapidjson::kParseStopWhenDoneFlag>(isw); | ||
|
||
if (static_cast<bool>(parseResult) == false) { | ||
if (m_StrmIn.eof()) { | ||
break; | ||
} | ||
|
||
LOG_ERROR(<< "Error parsing command from JSON: " | ||
<< rapidjson::GetParseError_En(parseResult.Code()) | ||
<< ". At offset: " << parseResult.Offset()); | ||
|
||
return false; | ||
} | ||
|
||
if (validateJson(doc) == false) { | ||
continue; | ||
} | ||
|
||
LOG_TRACE(<< "Inference command: " << doc); | ||
jsonToRequest(doc); | ||
if (requestHandler(m_Request) == false) { | ||
LOG_ERROR(<< "Request handler forced exit"); | ||
return false; | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
bool CCommandParser::validateJson(const rapidjson::Document& doc) const { | ||
if (doc.HasMember(REQUEST_ID) == false) { | ||
LOG_ERROR(<< "Invalid command: missing field [" << REQUEST_ID << "]"); | ||
return false; | ||
} | ||
|
||
if (doc.HasMember(TOKENS) == false) { | ||
LOG_ERROR(<< "Invalid command: missing field [" << TOKENS << "]"); | ||
return false; | ||
} | ||
|
||
const rapidjson::Value& tokens = doc[TOKENS]; | ||
if (tokens.IsArray() == false) { | ||
LOG_ERROR(<< "Invalid command: expected an array [" << TOKENS << "]"); | ||
return false; | ||
} | ||
|
||
// check optional args | ||
std::uint64_t varCount{1}; | ||
std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); | ||
while (doc.HasMember(varArgName)) { | ||
const rapidjson::Value& value = doc[varArgName]; | ||
if (value.IsArray() == false) { | ||
LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array"); | ||
return false; | ||
} | ||
|
||
++varCount; | ||
varArgName = VAR_ARG_PREFIX + std::to_string(varCount); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { | ||
|
||
m_Request.s_RequestId = doc[REQUEST_ID].GetString(); | ||
const rapidjson::Value& arr = doc[TOKENS]; | ||
// wipe any previous | ||
m_Request.s_Tokens.clear(); | ||
m_Request.s_Tokens.reserve(arr.Size()); | ||
|
||
for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { | ||
droberts195 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m_Request.s_Tokens.push_back(itr->GetUint()); | ||
tveasey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
std::uint64_t varCount{1}; | ||
std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); | ||
|
||
// wipe any previous | ||
m_Request.s_SecondaryArguments.clear(); | ||
TUint32Vec arg; | ||
while (doc.HasMember(varArgName)) { | ||
const rapidjson::Value& v = doc[varArgName]; | ||
for (auto itr = v.Begin(); itr != v.End(); ++itr) { | ||
arg.push_back(itr->GetUint()); | ||
tveasey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
m_Request.s_SecondaryArguments.push_back(arg); | ||
arg.clear(); | ||
++varCount; | ||
varArgName = VAR_ARG_PREFIX + std::to_string(varCount); | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
#ifndef INCLUDED_ml_torch_CCommandParser_h | ||
#define INCLUDED_ml_torch_CCommandParser_h | ||
|
||
#include <rapidjson/document.h> | ||
|
||
#include <functional> | ||
#include <iosfwd> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace ml { | ||
namespace torch { | ||
|
||
//! \brief | ||
//! Reads JSON documents from a stream calling the request handler | ||
//! for each parsed document. | ||
//! | ||
//! DESCRIPTION:\n | ||
//! Validation on the input documents is light. It is expected the input | ||
//! comes from another process which tightly controls what is sent. | ||
//! Input from an outside source that has not been sanitized should never | ||
//! be sent. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is saying to a hacker, "If you can manage to send dodgy input to this process we'll give you a shell prompt on the system." I think the input is actually validated to the extent of preventing array bounds overwrites. So instead the comment could be more along the lines of, "Validation exists to prevent memory violations from malicious input, but no more. The caller is responsible for sending input that will not result in errors from libTorch and will produce meaningful results." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
//! | ||
//! IMPLEMENTATION DECISIONS:\n | ||
//! RapidJSON will natively parse a stream of rootless JSON documents | ||
//! given the correct parse flags. The documents may be separated by | ||
//! whitespace but no other delineator is allowed. | ||
//! | ||
//! The parsed request is a member of this class and will be modified when | ||
//! a new command is parsed. The function handler passed to ioLoop must | ||
//! not keep a reference to the request object beyond the scope of the | ||
//! handle function as the request will change. | ||
//! | ||
//! The input stream is held by reference. They must outlive objects of | ||
//! this class, which, in practice, means that the CIoManager object managing | ||
//! them must outlive this object. | ||
//! | ||
class CCommandParser { | ||
public: | ||
static const std::string REQUEST_ID; | ||
static const std::string TOKENS; | ||
static const std::string VAR_ARG_PREFIX; | ||
|
||
using TUint32Vec = std::vector<std::uint32_t>; | ||
using TUint32VecVec = std::vector<TUint32Vec>; | ||
|
||
struct SRequest { | ||
std::string s_RequestId; | ||
TUint32Vec s_Tokens; | ||
TUint32VecVec s_SecondaryArguments; | ||
|
||
void clear(); | ||
}; | ||
|
||
using TRequestHandlerFunc = std::function<bool(SRequest&)>; | ||
|
||
public: | ||
CCommandParser(std::istream& strmIn); | ||
|
||
//! Pass input to the processor until it's consumed as much as it can. | ||
bool ioLoop(const TRequestHandlerFunc& requestHandler); | ||
|
||
CCommandParser(const CCommandParser&) = delete; | ||
CCommandParser& operator=(const CCommandParser&) = delete; | ||
|
||
private: | ||
bool validateJson(const rapidjson::Document& doc) const; | ||
void jsonToRequest(const rapidjson::Document& doc); | ||
|
||
private: | ||
std::istream& m_StrmIn; | ||
SRequest m_Request; | ||
}; | ||
} | ||
} | ||
|
||
#endif // INCLUDED_ml_torch_CCommandParser_h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually there is one security hole in the validation, which is that we need to confirm
doc[REQUEST_ID].IsString()
. Without this additional check, sending an integer for this field instead would be a way to get a pointer of choice dereferenced.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this and also the checks that the token arrays contain unsigned ints