-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] PyTorch Command Processor #1770
[ML] PyTorch Command Processor #1770
Conversation
c4a2378
to
6adf034
Compare
The macOS aarch64 build failed because PyTorch has not been setup on the machine
|
bin/pytorch_inference/Main.cc
Outdated
torch::Tensor tensor = | ||
torch::from_blob(static_cast<void*>(args.data()), | ||
{1, static_cast<std::int64_t>(args.size())}, | ||
at::dtype(torch::kInt32)) | ||
.to(torch::kInt64); | ||
|
||
inputs.push_back(tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copying the tensors at the moment. It may be possible to emplace:
torch::Tensor tensor = | |
torch::from_blob(static_cast<void*>(args.data()), | |
{1, static_cast<std::int64_t>(args.size())}, | |
at::dtype(torch::kInt32)) | |
.to(torch::kInt64); | |
inputs.push_back(tensor); | |
inputs.emplace_back( | |
torch::from_blob(static_cast<void*>(args.data()), | |
{1, static_cast<std::int64_t>(args.size())}, | |
at::dtype(torch::kInt32)) | |
.to(torch::kInt64)); |
Or if that doesn't work for some reason, you could at least move when adding, i.e. inputs.push_back(std::move(tensor));
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch::tensor
is a wrapper around at::tensor
and internally has a smart pointer to its data. Passing by value/copying is the way to use it by design
Some discussion here https://discuss.pytorch.org/t/tensor-move-semantics-in-c-frontend/77901/5
bin/pytorch_inference/Main.cc
Outdated
torch::Tensor tokensTensor = | ||
torch::from_blob(data.data(), {1, static_cast<std::int64_t>(data.size())}) | ||
torch::from_blob(static_cast<void*>(request.s_Tokens.data()), | ||
{1, static_cast<std::int64_t>(request.s_Tokens.size())}, | ||
at::dtype(torch::kInt32)) | ||
.to(torch::kInt64); | ||
|
||
std::vector<torch::jit::IValue> inputs; | ||
inputs.push_back(tokensTensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As below, it looks like this is copying the tensor. If it compiles this should be more efficient:
std::vector<torch::jit::IValue> inputs;
inputs.emplace_back(
torch::from_blob(static_cast<void*>(request.s_Tokens.data()),
{1, static_cast<std::int64_t>(request.s_Tokens.size())},
at::dtype(torch::kInt32))
.to(torch::kInt64));
Hopefully this will be possible in PyTorch 1.8, which is not too far off now - see pytorch/pytorch#51886 (comment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a pass through. Overall looks nice and clean. I made some minor suggestions. My main observation is since this is long running and you are going to be streaming stuff to this executable I'd try and avoid all the temporary large heap objects. I don't feel this would complicate matters and as is feels like premature pessimization.
void debug(const rapidjson::Document& doc) { | ||
rapidjson::StringBuffer buffer; | ||
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); | ||
doc.Accept(writer); | ||
LOG_TRACE(<< buffer.GetString()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should move this to core/CRapidJsonUtils.h
. This pattern comes up a certain amount and it would be good to have a single definition for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added operator<<
to this file. It is a trivial function that has to live in the rapidjson
namespace so I don't think it belongs in core/CRapidJsonUtils.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's this stuff I meant:
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
doc.Accept(writer);
i.e. conversion to string which I'm sure I've put somewhere local myself in the past. I find I have to remind myself every time how to convert a rapidjson::Document
to a string and thought perhaps it was time to make this a utility somewhere. That said you don't have to make this change and perhaps we should hunt for other cases address them all in one go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an issue got introduced in the refactor, plus I do think input validation warrants a code comment.
bin/pytorch_inference/Main.cc
Outdated
torch::NoGradGuard noGrad; | ||
auto tuple = module.forward(inputs).toTuple(); | ||
auto predictions = tuple->elements()[0].toTensor(); | ||
for (auto args : request.s_SecondaryArguments) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing...
for (auto args : request.s_SecondaryArguments) { | |
for (const auto& args : request.s_SecondaryArguments) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 reference yes but it can't be const because later there is non-const access to .data()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interestingly, this should have caused undefined behaviour because the tensor is then referencing memory from the loop copy of request vector. I wonder if this points to a missing test of non-empty secondary arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, calling to()
is probably what saves you since it'll create a copy. This makes me wonder, should we be doing this here? I'd have thought it would be better to write the values into std::vector<std::uint64_t>
and keep with just having a reference to this memory for the Tensor, assuming they need to be 64 bit. It may need some alignment shenanigans to make the most of library optimisations but that should be manageable. Something to investigate in a follow up anyway.
69b9e6c
to
3a268ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating. LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if you could just change a couple of things related to securing the input.
//! Validation on the input documents is light. It is expected the input | ||
//! comes from another process which tightly controls what is sent. | ||
//! Input from an outside source that has not been sanitized should never | ||
//! be sent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is saying to a hacker, "If you can manage to send dodgy input to this process we'll give you a shell prompt on the system."
I think the input is actually validated to the extent of preventing array bounds overwrites. So instead the comment could be more along the lines of, "Validation exists to prevent memory violations from malicious input, but no more. The caller is responsible for sending input that will not result in errors from libTorch and will produce meaningful results."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
} | ||
|
||
bool CCommandParser::validateJson(const rapidjson::Document& doc) const { | ||
if (doc.HasMember(REQUEST_ID) == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually there is one security hole in the validation, which is that we need to confirm doc[REQUEST_ID].IsString()
. Without this additional check, sending an integer for this field instead would be a way to get a pointer of choice dereferenced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this and also the checks that the token arrays contain unsigned ints
Windows is showing up a problem:
It seems that:
is still used in So I was wrong when I said:
could be removed from the |
I merged this before CI was green because the macOS ARM build will not pass until PyTorch has been added to the build machine, Windows was failing for unrelated reasons. Given this is a feature branch PR it is fine to merge |
Defines the input and output documents for the PyTorch 3rd party model app and adds a command processor which parses JSON documents from a stream then calls a handler function for each request. This all happens in a single thread, the output will be written before the next request document is parsed.
Input
Models accept a variable number of arguments depending the purpose. In Python PyTorch these are named arguments, in LibTorch an array of input tensors is used. All BERT models take a list of tokens, the other parameters are passed in the fields
arg_1
,arg_2
etc. This program knows nothing about the expected number of arguments it simply consumes all fields starting witharg_
and forwards them to the model.RapidJSON supports reading multiple documents from a stream if the
kParseStopWhenDoneFlag
flag is used. The docs don't have to have a common root (e.g. in an array or nested inside a wrapper object). Docs can optionally be separated by whitespace but any other separator is invalid.Input Validation
The input will come from Elasticsearch never from a client, we control the comms with Elasticsearch so minimal validation is required. If the request is not correctly formed then something catastrophic has happened (broken pipe).
Typically the model throws a
std::runtime_error
if the input is not right, this is caught and returned to the caller.Output
All BERT models output a tuple the first element of which is the output tensor. The remaining elements are model dependent (might be logits or labels) we have not found a use case requiring the full tuple yet so the output response will only contain the tensor (for now). The tensor must have 2 dimensions or be reducible to 2 dimensions.
The output is a JSON document containing the tensor as an array of arrays.
In the case of an error the output doc has an
error
field. It is envisaged any errors will be returned directly to the client.Tests
The inputs to
evaluate.py
have been reworked so that a single JSON file contains both the input and expected output. Invoking the test and adding new examples is now much easier:Closes #1700
Closes #1701