Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] PyTorch Command Processor #1770

Merged
merged 19 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions bin/pytorch_inference/CCommandParser.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

#include "CCommandParser.h"

#include <core/CLogger.h>

#include <rapidjson/error/en.h>
#include <rapidjson/istreamwrapper.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>

#include <istream>

namespace rapidjson {

std::ostream& operator<<(std::ostream& os, const rapidjson::Document& doc) {
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
return os << buffer.GetString();
}
}

namespace ml {
namespace torch {

const std::string CCommandParser::REQUEST_ID{"request_id"};
const std::string CCommandParser::TOKENS{"tokens"};
const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"};

CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) {
}

bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) {

rapidjson::IStreamWrapper isw{m_StrmIn};

while (true) {
rapidjson::Document doc;
rapidjson::ParseResult parseResult =
doc.ParseStream<rapidjson::kParseStopWhenDoneFlag>(isw);

if (static_cast<bool>(parseResult) == false) {
if (m_StrmIn.eof()) {
break;
}

LOG_ERROR(<< "Error parsing command from JSON: "
<< rapidjson::GetParseError_En(parseResult.Code())
<< ". At offset: " << parseResult.Offset());

return false;
}

if (validateJson(doc) == false) {
continue;
}

LOG_TRACE(<< "Inference command: " << doc);
jsonToRequest(doc);
if (requestHandler(m_Request) == false) {
LOG_ERROR(<< "Request handler forced exit");
return false;
}
}

return true;
}

bool CCommandParser::validateJson(const rapidjson::Document& doc) const {
if (doc.HasMember(REQUEST_ID) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there is one security hole in the validation, which is that we need to confirm doc[REQUEST_ID].IsString(). Without this additional check, sending an integer for this field instead would be a way to get a pointer of choice dereferenced.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this and also the checks that the token arrays contain unsigned ints

LOG_ERROR(<< "Invalid command: missing field [" << REQUEST_ID << "]");
return false;
}

if (doc.HasMember(TOKENS) == false) {
LOG_ERROR(<< "Invalid command: missing field [" << TOKENS << "]");
return false;
}

const rapidjson::Value& tokens = doc[TOKENS];
if (tokens.IsArray() == false) {
LOG_ERROR(<< "Invalid command: expected an array [" << TOKENS << "]");
return false;
}

// check optional args
std::uint64_t varCount{1};
std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount);
while (doc.HasMember(varArgName)) {
const rapidjson::Value& value = doc[varArgName];
if (value.IsArray() == false) {
LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array");
return false;
}

++varCount;
varArgName = VAR_ARG_PREFIX + std::to_string(varCount);
}

return true;
}

void CCommandParser::jsonToRequest(const rapidjson::Document& doc) {

m_Request.s_RequestId = doc[REQUEST_ID].GetString();
const rapidjson::Value& arr = doc[TOKENS];
// wipe any previous
m_Request.s_Tokens.clear();
m_Request.s_Tokens.reserve(arr.Size());

for (auto itr = arr.Begin(); itr != arr.End(); ++itr) {
droberts195 marked this conversation as resolved.
Show resolved Hide resolved
m_Request.s_Tokens.push_back(itr->GetUint());
tveasey marked this conversation as resolved.
Show resolved Hide resolved
}

std::uint64_t varCount{1};
std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount);

// wipe any previous
m_Request.s_SecondaryArguments.clear();
TUint32Vec arg;
while (doc.HasMember(varArgName)) {
const rapidjson::Value& v = doc[varArgName];
for (auto itr = v.Begin(); itr != v.End(); ++itr) {
arg.push_back(itr->GetUint());
tveasey marked this conversation as resolved.
Show resolved Hide resolved
}

m_Request.s_SecondaryArguments.push_back(arg);
arg.clear();
++varCount;
varArgName = VAR_ARG_PREFIX + std::to_string(varCount);
}
}
}
}
83 changes: 83 additions & 0 deletions bin/pytorch_inference/CCommandParser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

#ifndef INCLUDED_ml_torch_CCommandParser_h
#define INCLUDED_ml_torch_CCommandParser_h

#include <rapidjson/document.h>

#include <functional>
#include <iosfwd>
#include <string>
#include <vector>

namespace ml {
namespace torch {

//! \brief
//! Reads JSON documents from a stream calling the request handler
//! for each parsed document.
//!
//! DESCRIPTION:\n
//! Validation on the input documents is light. It is expected the input
//! comes from another process which tightly controls what is sent.
//! Input from an outside source that has not been sanitized should never
//! be sent.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is saying to a hacker, "If you can manage to send dodgy input to this process we'll give you a shell prompt on the system."

I think the input is actually validated to the extent of preventing array bounds overwrites. So instead the comment could be more along the lines of, "Validation exists to prevent memory violations from malicious input, but no more. The caller is responsible for sending input that will not result in errors from libTorch and will produce meaningful results."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

//!
//! IMPLEMENTATION DECISIONS:\n
//! RapidJSON will natively parse a stream of rootless JSON documents
//! given the correct parse flags. The documents may be separated by
//! whitespace but no other delineator is allowed.
//!
//! The parsed request is a member of this class and will be modified when
//! a new command is parsed. The function handler passed to ioLoop must
//! not keep a reference to the request object beyond the scope of the
//! handle function as the request will change.
//!
//! The input stream is held by reference. They must outlive objects of
//! this class, which, in practice, means that the CIoManager object managing
//! them must outlive this object.
//!
class CCommandParser {
public:
static const std::string REQUEST_ID;
static const std::string TOKENS;
static const std::string VAR_ARG_PREFIX;

using TUint32Vec = std::vector<std::uint32_t>;
using TUint32VecVec = std::vector<TUint32Vec>;

struct SRequest {
std::string s_RequestId;
TUint32Vec s_Tokens;
TUint32VecVec s_SecondaryArguments;

void clear();
};

using TRequestHandlerFunc = std::function<bool(SRequest&)>;

public:
CCommandParser(std::istream& strmIn);

//! Pass input to the processor until it's consumed as much as it can.
bool ioLoop(const TRequestHandlerFunc& requestHandler);

CCommandParser(const CCommandParser&) = delete;
CCommandParser& operator=(const CCommandParser&) = delete;

private:
bool validateJson(const rapidjson::Document& doc) const;
void jsonToRequest(const rapidjson::Document& doc);

private:
std::istream& m_StrmIn;
SRequest m_Request;
};
}
}

#endif // INCLUDED_ml_torch_CCommandParser_h
Loading