From 31b644fe7fd1d5a88edd8c26c60f29b22f6a0a6b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 11 Feb 2021 14:06:03 +0000 Subject: [PATCH 01/19] Add command parser --- .../CBufferedIStreamAdapter.cc | 3 + .../CBufferedIStreamAdapter.h | 1 + bin/pytorch_inference/CCommandParser.cc | 100 ++++++++++++++++++ bin/pytorch_inference/CCommandParser.h | 72 +++++++++++++ bin/pytorch_inference/Main.cc | 65 ++++-------- bin/pytorch_inference/Makefile | 1 + bin/pytorch_inference/unittest/.gitignore | 2 + .../unittest/.objs/.gitignore | 0 .../unittest/CCommandParserTest.cc | 86 +++++++++++++++ bin/pytorch_inference/unittest/Main.cc | 23 ++++ bin/pytorch_inference/unittest/Makefile | 25 +++++ 11 files changed, 336 insertions(+), 42 deletions(-) create mode 100644 bin/pytorch_inference/CCommandParser.cc create mode 100644 bin/pytorch_inference/CCommandParser.h create mode 100644 bin/pytorch_inference/unittest/.gitignore create mode 100644 bin/pytorch_inference/unittest/.objs/.gitignore create mode 100644 bin/pytorch_inference/unittest/CCommandParserTest.cc create mode 100644 bin/pytorch_inference/unittest/Main.cc create mode 100644 bin/pytorch_inference/unittest/Makefile diff --git a/bin/pytorch_inference/CBufferedIStreamAdapter.cc b/bin/pytorch_inference/CBufferedIStreamAdapter.cc index be97d04d93..d225351a5a 100644 --- a/bin/pytorch_inference/CBufferedIStreamAdapter.cc +++ b/bin/pytorch_inference/CBufferedIStreamAdapter.cc @@ -22,6 +22,9 @@ CBufferedIStreamAdapter::CBufferedIStreamAdapter(std::istream& inputStream) : m_InputStream(inputStream) { } +CBufferedIStreamAdapter::~CBufferedIStreamAdapter() { +} + bool CBufferedIStreamAdapter::init() { if (parseSizeFromStream(m_Size) == false) { LOG_ERROR(<< "Failed to read model size"); diff --git a/bin/pytorch_inference/CBufferedIStreamAdapter.h b/bin/pytorch_inference/CBufferedIStreamAdapter.h index 23b1c40231..efc47e35d7 100644 --- a/bin/pytorch_inference/CBufferedIStreamAdapter.h +++ b/bin/pytorch_inference/CBufferedIStreamAdapter.h @@ -38,6 +38,7 @@ namespace torch { class CBufferedIStreamAdapter : public caffe2::serialize::ReadAdapterInterface { public: CBufferedIStreamAdapter(std::istream& inputStream); + ~CBufferedIStreamAdapter() override; //! True if the model is successfully read. //! Must be called before read or size diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc new file mode 100644 index 0000000000..7c1a5106d3 --- /dev/null +++ b/bin/pytorch_inference/CCommandParser.cc @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#include "CCommandParser.h" + +#include + +#include + +#include +#include +#include +#include + +namespace ml { +namespace torch { + +namespace { +void debug(const rapidjson::Document& doc) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + LOG_INFO(<< buffer.GetString()); +} +} + +const std::string CCommandParser::REQUEST_ID{"request_id"}; +const std::string CCommandParser::TOKENS{"tokens"}; + +CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { + +} + + +bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { + + rapidjson::IStreamWrapper isw(m_StrmIn); + + while (true) { + rapidjson::Document doc; + rapidjson::ParseResult parseResult = doc.ParseStream(isw); + + if (static_cast(parseResult) == false) { + if (m_StrmIn.eof()) { + break; + } + + LOG_ERROR(<< "Error parsing command from JSON: " << rapidjson::GetParseError_En(parseResult.Code()) + << ". At offset: " << parseResult.Offset()); + + return false; + } + + + if (validateJson(doc) == false) { + continue; + } + + + debug(doc); + requestHandler(jsonToRequest(doc)); + } + + return true; +} + +bool CCommandParser::validateJson(const rapidjson::Document& doc) const { + if (doc.HasMember(REQUEST_ID) == false) { + LOG_ERROR(<< "Malformed command request: missing field [" << REQUEST_ID << "]"); + return false; + } + + if (doc.HasMember(TOKENS) == false) { + LOG_ERROR(<< "Malformed command request: missing field [" << TOKENS << "]"); + return false; + } + + const rapidjson::Value& tokens = doc[TOKENS]; + if (tokens.IsArray() == false) { + LOG_ERROR(<< "Malformed command request: expected an array [" << TOKENS << "]"); + return false; + } + + return true; +} + +CCommandParser::SRequest CCommandParser::jsonToRequest(const rapidjson::Document& doc) const { + std::vector tokens; + const rapidjson::Value& arr = doc[TOKENS]; + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { + tokens.push_back(itr->GetUint()); + } + return {doc[REQUEST_ID].GetString(), tokens}; +} + +} +} diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h new file mode 100644 index 0000000000..ad563fc111 --- /dev/null +++ b/bin/pytorch_inference/CCommandParser.h @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#ifndef INCLUDED_ml_torch_CCommandParser_h +#define INCLUDED_ml_torch_CCommandParser_h + +#include +#include +#include +#include + +#include + + +namespace ml { +namespace torch { + + +//! \brief +//! Reads JSON documents from a stream emitting a request for +//! each parsed document. +//! +//! DESCRIPTION:\n +//! +//! +//! IMPLEMENTATION DECISIONS:\n +//! RapidJSON will natively parse a stream of rootless JSON documents +//! given the correct parse flags. The documents may be separated by +//! whitespace but no other delineator is allowed. +//! +//! The input stream is held by reference. They must outlive objects of +//! this class, which, in practice, means that the CIoManager object managing +//! them must outlive this object. +//! +class CCommandParser { +public: + + static const std::string REQUEST_ID; + static const std::string TOKENS; + + struct SRequest { + std::string s_RequestId; + std::vector s_Tokens; + }; + + using TRequestHandlerFunc = std::function; + + +public: + CCommandParser(std::istream& strmIn); + + //! Pass input to the processor until it's consumed as much as it can. + bool ioLoop(const TRequestHandlerFunc& requestHandler) const; + + CCommandParser(const CCommandParser&) = delete; + CCommandParser& operator=(const CCommandParser&) = delete; + +private: + bool validateJson(const rapidjson::Document& doc) const; + SRequest jsonToRequest(const rapidjson::Document& doc) const; +private: + //! + std::istream& m_StrmIn; +}; + +} +} + +#endif // INCLUDED_ml_torch_CCommandParser_h diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 55764490ca..a97f1de19d 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -18,6 +18,7 @@ #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" +#include "CCommandParser.h" #include #include @@ -54,44 +55,12 @@ torch::Tensor infer(torch::jit::script::Module& module, TFloatVec& data) { return torch::argmax(predictions, 2); } -bool readUInt32(std::istream& stream, std::uint32_t& num) { - std::uint32_t netNum{0}; - stream.read(reinterpret_cast(&netNum), sizeof(std::uint32_t)); - num = ntohl(netNum); - return stream.good(); -} - -boost::optional readTokens(std::istream& inputStream) { - if (inputStream.eof()) { - LOG_ERROR(<< "Unexpected end of stream reading tokens"); - return boost::none; - } - - // return a float vector rather than integers because - // float is needed to create the tensor - TFloatVec tokens; - std::uint32_t numTokens; - if (readUInt32(inputStream, numTokens) == false) { - LOG_ERROR(<< "Error reading the number of tokens"); - return boost::none; - } - - for (uint32_t i = 0; i < numTokens; ++i) { - std::uint32_t token; - if (readUInt32(inputStream, token) == false) { - LOG_ERROR(<< "Error reading token"); - return boost::none; - } - tokens.push_back(token); - } - - return tokens; -} - -void writePrediction(torch::Tensor& prediction, std::ostream& outputStream) { +void writePrediction(const torch::Tensor& prediction, const std::string& requestId, std::ostream& outputStream) { rapidjson::OStreamWrapper writeStream(outputStream); ml::core::CRapidJsonLineWriter jsonWriter(writeStream); jsonWriter.StartObject(); + jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); + jsonWriter.String(requestId); jsonWriter.Key("inference"); jsonWriter.StartArray(); auto arr = prediction.accessor(); @@ -102,6 +71,19 @@ void writePrediction(torch::Tensor& prediction, std::ostream& outputStream) { jsonWriter.EndObject(); } + +bool handleRequest(const ml::torch::CCommandParser::SRequest& request, + torch::jit::script::Module& module, + std::ostream& outputStream) { + + // A float vector is needed to create the tensor + TFloatVec tokens{request.s_Tokens.begin(), request.s_Tokens.end()}; + torch::Tensor results = infer(module, tokens); + writePrediction(results, request.s_RequestId, outputStream); + + return true; +} + int main(int argc, char** argv) { // command line options std::string modelId; @@ -186,14 +168,13 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } - boost::optional tokens = readTokens(ioMgr.inputStream()); - if (!tokens) { - LOG_ERROR(<< "Cannot infer, failed to read input tokens"); - return EXIT_FAILURE; - } - torch::Tensor results = infer(module, *tokens); - writePrediction(results, ioMgr.outputStream()); + ml::torch::CCommandParser commandParser{ioMgr.inputStream()}; + + commandParser.ioLoop([&module, &ioMgr](const ml::torch::CCommandParser::SRequest& request){ + return handleRequest(request, module, ioMgr.outputStream()); + }); + LOG_DEBUG(<< "ML Torch model prototype exiting"); diff --git a/bin/pytorch_inference/Makefile b/bin/pytorch_inference/Makefile index 34e7ea5700..7e2fba77b7 100644 --- a/bin/pytorch_inference/Makefile +++ b/bin/pytorch_inference/Makefile @@ -25,6 +25,7 @@ SRCS= \ Main.cc \ CBufferedIStreamAdapter.cc \ CCmdLineParser.cc \ + CCommandParser.cc \ NO_TEST_CASES=1 diff --git a/bin/pytorch_inference/unittest/.gitignore b/bin/pytorch_inference/unittest/.gitignore new file mode 100644 index 0000000000..2eb0aff1ca --- /dev/null +++ b/bin/pytorch_inference/unittest/.gitignore @@ -0,0 +1,2 @@ +# Ignore temporary test files +test.log diff --git a/bin/pytorch_inference/unittest/.objs/.gitignore b/bin/pytorch_inference/unittest/.objs/.gitignore new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc new file mode 100644 index 0000000000..a65d89b9e0 --- /dev/null +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#include "../CCommandParser.h" + +#include + +BOOST_AUTO_TEST_SUITE(CCommandParserTest) + +BOOST_AUTO_TEST_CASE(testParsingStream) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}" + "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE( + processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ + parsed.push_back(request); + return true; + } )); + + BOOST_REQUIRE_EQUAL(2, parsed.size()); + { + BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); + std::vector expected{1, 2, 3}; + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), parsed[0].s_Tokens.end(), + expected.begin(), expected.end()); + } + { + BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); + std::vector expected{4, 5}; + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), parsed[1].s_Tokens.end(), + expected.begin(), expected.end()); + } +} + +BOOST_AUTO_TEST_CASE(testParsingInvalidDoc) { + + std::vector parsed; + + std::string command{"{\"foo\": 1, }"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE( + processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ + parsed.push_back(request); + return true; + }) == false); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}\t" + "{\"request_id\": \"bar\", \"tokens\": [1, 2, 3]}\n" + "{\"request_id\": \"foo2\", \"tokens\": [1, 2, 3]} " + "{\"request_id\": \"bar2\", \"tokens\": [1, 2, 3]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE( + processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ + parsed.push_back(request); + return true; + } )); + + BOOST_REQUIRE_EQUAL(4, parsed.size()); + BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); + BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); + BOOST_REQUIRE_EQUAL("foo2", parsed[2].s_RequestId); + BOOST_REQUIRE_EQUAL("bar2", parsed[3].s_RequestId); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/Main.cc b/bin/pytorch_inference/unittest/Main.cc new file mode 100644 index 0000000000..ed5004a50e --- /dev/null +++ b/bin/pytorch_inference/unittest/Main.cc @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#define BOOST_TEST_MODULE bin.pytorch +// Defining BOOST_TEST_MODULE usually auto-generates main(), but we don't want +// this as we need custom initialisation to allow for output in both console and +// Boost.Test XML formats +#define BOOST_TEST_NO_MAIN + +#include +#include + +#include + +int main(int argc, char** argv) { + ml::test::CTestObserver observer; + boost::unit_test::framework::register_observer(observer); + return boost::unit_test::unit_test_main(&ml::test::CBoostTestXmlOutput::init, argc, argv); + boost::unit_test::framework::deregister_observer(observer); +} diff --git a/bin/pytorch_inference/unittest/Makefile b/bin/pytorch_inference/unittest/Makefile new file mode 100644 index 0000000000..4916872652 --- /dev/null +++ b/bin/pytorch_inference/unittest/Makefile @@ -0,0 +1,25 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. +# +include $(CPP_SRC_HOME)/mk/defines.mk + +TARGET=ml_test$(EXE_EXT) + +USE_BOOST=1 +USE_BOOST_PROGRAMOPTIONS_LIBS=1 +USE_BOOST_TEST_LIBS=1 +USE_RAPIDJSON=1 +USE_TORCH=1 + +all: build + +LIBS=\ + ../.objs/C*$(OBJECT_FILE_EXT) \ + +SRCS=\ + Main.cc \ + CCommandParserTest.cc \ + +include $(CPP_SRC_HOME)/mk/stdboosttest.mk \ No newline at end of file From f5b7a2e7b3658e6467b2af2e956739a796efcdb9 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 19 Feb 2021 15:27:14 +0000 Subject: [PATCH 02/19] Run multiple requests in evaluate.py --- bin/pytorch_inference/CCommandParser.cc | 2 +- bin/pytorch_inference/evaluate.py | 74 +++++++++++-------- .../examples/ner/expected_response.json | 3 - bin/pytorch_inference/examples/ner/input.json | 3 - .../examples/ner/test_run.json | 12 +++ 5 files changed, 58 insertions(+), 36 deletions(-) delete mode 100644 bin/pytorch_inference/examples/ner/expected_response.json delete mode 100644 bin/pytorch_inference/examples/ner/input.json create mode 100644 bin/pytorch_inference/examples/ner/test_run.json diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 7c1a5106d3..ee685b3002 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -23,7 +23,7 @@ void debug(const rapidjson::Document& doc) { rapidjson::StringBuffer buffer; rapidjson::Writer writer(buffer); doc.Accept(writer); - LOG_INFO(<< buffer.GetString()); + LOG_DEBUG(<< buffer.GetString()); } } diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index 97f03845e8..8edaa74d84 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -3,13 +3,25 @@ app, together with the encoded tokens from the input_tokens file. Then it checks the model's response matches the expected. -This script first prepares the input files, then launches the C++ -pytorch_inference program which handles them in batch. +This script reads the input files and expected outputs, then +launches the C++ pytorch_inference program which handles and +sends the request. The reponse is checked against the expected +defined in the test file Run this script with input from one of the example directories, for example: -python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/input.json examples/ner/expected_response.json +python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json + + +The test file must have the format: +[ + { + "input": {"request_id": "foo", "tokens": [1, 2, 3]}, + "expected_output": {"request_id": "foo", "inference": [1, 2, 3]} + }, + ... +] ''' import argparse @@ -22,8 +34,8 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('model', help='A TorchScript model with .pt extension') - parser.add_argument('input_tokens', help='JSON file with an array field "tokens"') - parser.add_argument('expected_output', help='Expected output. Another JSON file with an array field "tokens"') + parser.add_argument('test_file', help='JSON file with an array of objects each ' + 'containing "input" and "expected_output" subobjects') parser.add_argument('--restore_file', default='restore_file') parser.add_argument('--input_file', default='input_file') parser.add_argument('--output_file', default='output_file') @@ -64,18 +76,14 @@ def stream_file(source, destination) : destination.write(piece) -def write_tokens(destination, tokens): - - num_tokens = len(tokens) - destination.write(num_tokens.to_bytes(4, 'big')) - for token in tokens: - destination.write(token.to_bytes(4, 'big')) +def write_request(request, destination): + json.dump(request, destination) def main(): args = parse_arguments() - try: + try: # create the restore file with open(args.restore_file, 'wb') as restore_file: file_stats = os.stat(args.model) @@ -90,28 +98,35 @@ def main(): with open(args.model, 'rb') as source_file: stream_file(source_file, restore_file) - with open(args.input_file, 'wb') as input_file: - with open(args.input_tokens) as token_file: - input_tokens = json.load(token_file) + with open(args.input_file, 'w') as input_file: + with open(args.test_file) as test_file: + test_evaluation = json.load(test_file) print("writing query", flush=True) - write_tokens(input_file, input_tokens['tokens']) - - # one shot inference + for doc in test_evaluation: + write_request(doc['input'], input_file) + launch_pytorch_app(args) - print("reading results", flush=True) - with open(args.expected_output) as expected_output_file: - expected = json.load(expected_output_file) - + print("reading results", flush=True) with open(args.output_file) as output_file: - results = json.load(output_file) + + doc_count = 0 + results_match = True + # output is NDJSON + for jsonline in output_file: + result = json.loads(jsonline) + expected = test_evaluation[doc_count]['expected_output'] - # compare to expected - if results['inference'] == expected['tokens']: - print('inference results match expected results') - else: - print('ERROR: inference results do not match expected results') - print(results) + # compare to expected + if result != expected: + print('ERROR: inference result [{}] does not match expected results'.format(doc_count)) + print(result, expected) + results_match = False + + doc_count = doc_count +1 + + if results_match: + print('SUCCESS: inference results match expected') finally: if os.path.isfile(args.restore_file): @@ -122,6 +137,7 @@ def main(): os.remove(args.output_file) + if __name__ == "__main__": main() diff --git a/bin/pytorch_inference/examples/ner/expected_response.json b/bin/pytorch_inference/examples/ner/expected_response.json deleted file mode 100644 index d908590edc..0000000000 --- a/bin/pytorch_inference/examples/ner/expected_response.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "tokens": [0, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 0] -} diff --git a/bin/pytorch_inference/examples/ner/input.json b/bin/pytorch_inference/examples/ner/input.json deleted file mode 100644 index a35aad7d4f..0000000000 --- a/bin/pytorch_inference/examples/ner/input.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102] -} diff --git a/bin/pytorch_inference/examples/ner/test_run.json b/bin/pytorch_inference/examples/ner/test_run.json new file mode 100644 index 0000000000..b385bfad66 --- /dev/null +++ b/bin/pytorch_inference/examples/ner/test_run.json @@ -0,0 +1,12 @@ +[ + { + "source_text": "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge.", + "input": {"request_id": "one", "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102]}, + "expected_output": {"request_id": "one", "inference": [0, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 0]} + }, + { + "source_text": "Jim bought 300 shares of Acme Corp. in 2006", + "input": {"request_id": "two", "tokens": [101, 3104, 3306, 3127, 6117, 1104, 138, 1665, 3263, 13619, 119, 1107, 1386, 102]}, + "expected_output": {"request_id": "two", "inference": [0, 4, 0, 0, 0, 0, 6, 6, 6, 6, 6, 0, 0, 0]} + } +] \ No newline at end of file From 1ed29a37d8b40b9271939e77d64cc7d08a190a9a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 22 Feb 2021 12:55:57 +0000 Subject: [PATCH 03/19] Pass all attention mask etc as command arguments --- bin/pytorch_inference/CCommandParser.cc | 38 +++++++++++- bin/pytorch_inference/CCommandParser.h | 9 ++- bin/pytorch_inference/Main.cc | 39 ++++++++---- bin/pytorch_inference/evaluate.py | 16 ++--- .../examples/ner/test_run.json | 18 +++++- .../unittest/CCommandParserTest.cc | 62 ++++++++++++++++++- 6 files changed, 152 insertions(+), 30 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index ee685b3002..242a9c9c49 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -29,6 +29,7 @@ void debug(const rapidjson::Document& doc) { const std::string CCommandParser::REQUEST_ID{"request_id"}; const std::string CCommandParser::TOKENS{"tokens"}; +const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { @@ -61,7 +62,8 @@ bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { debug(doc); - requestHandler(jsonToRequest(doc)); + CCommandParser::SRequest request = jsonToRequest(doc); + requestHandler(request); } return true; @@ -84,16 +86,46 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return false; } + // check optional args + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + while (doc.HasMember(varArgName)) { + const rapidjson::Value& value = doc[varArgName]; + if (value.IsArray() == false) { + LOG_ERROR(<< "Malformed command request: argument [" << varArgName << "] is not an array"); + return false; + } + + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } + return true; } CCommandParser::SRequest CCommandParser::jsonToRequest(const rapidjson::Document& doc) const { - std::vector tokens; + TUint32Vec tokens; const rapidjson::Value& arr = doc[TOKENS]; for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { tokens.push_back(itr->GetUint()); } - return {doc[REQUEST_ID].GetString(), tokens}; + + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + TUint32VecVec args; + + while (doc.HasMember(varArgName)) { + TUint32Vec arg; + const rapidjson::Value& v = doc[varArgName]; + for (auto itr = v.Begin(); itr != v.End(); ++itr) { + arg.push_back(itr->GetUint()); + } + + args.push_back(arg); + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } + return {doc[REQUEST_ID].GetString(), tokens, args}; } } diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index ad563fc111..a2b2a50d99 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -40,13 +40,18 @@ class CCommandParser { static const std::string REQUEST_ID; static const std::string TOKENS; + static const std::string VAR_ARG_PREFIX; + + using TUint32Vec = std::vector; + using TUint32VecVec = std::vector; struct SRequest { std::string s_RequestId; - std::vector s_Tokens; + TUint32Vec s_Tokens; + TUint32VecVec s_SecondaryArguments; }; - using TRequestHandlerFunc = std::function; + using TRequestHandlerFunc = std::function; public: diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index a97f1de19d..d447ea3d27 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -35,18 +35,32 @@ #include #endif -using TFloatVec = std::vector; +static const std::string INFERENCE{"inference"}; + +torch::Tensor infer(torch::jit::script::Module& module, + ml::torch::CCommandParser::SRequest& request) { + + + ml::torch::CCommandParser::TUint32Vec data{request.s_Tokens.begin(), request.s_Tokens.end()}; -torch::Tensor infer(torch::jit::script::Module& module, TFloatVec& data) { torch::Tensor tokensTensor = - torch::from_blob(data.data(), {1, static_cast(data.size())}) + torch::from_blob(static_cast(request.s_Tokens.data()), + {1, static_cast(request.s_Tokens.size())}, + at::dtype(torch::kInt32)) .to(torch::kInt64); + std::vector inputs; inputs.push_back(tokensTensor); - inputs.push_back(torch::ones({1, static_cast(data.size())})); // attention mask - inputs.push_back( - torch::zeros({1, static_cast(data.size())}).to(torch::kInt64)); // token type ids - inputs.push_back(torch::arange(static_cast(data.size())).to(torch::kInt64)); // position ids + + for (auto args : request.s_SecondaryArguments) { + torch::Tensor tensor = + torch::from_blob(static_cast(args.data()), + {1, static_cast(args.size())}, + at::dtype(torch::kInt32)) + .to(torch::kInt64); + + inputs.push_back(tensor); + } torch::NoGradGuard noGrad; auto tuple = module.forward(inputs).toTuple(); @@ -61,7 +75,7 @@ void writePrediction(const torch::Tensor& prediction, const std::string& request jsonWriter.StartObject(); jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); jsonWriter.String(requestId); - jsonWriter.Key("inference"); + jsonWriter.Key(INFERENCE); jsonWriter.StartArray(); auto arr = prediction.accessor(); for (int i = 0; i < arr.size(1); i++) { @@ -72,13 +86,12 @@ void writePrediction(const torch::Tensor& prediction, const std::string& request } -bool handleRequest(const ml::torch::CCommandParser::SRequest& request, +bool handleRequest(ml::torch::CCommandParser::SRequest& request, torch::jit::script::Module& module, std::ostream& outputStream) { - // A float vector is needed to create the tensor - TFloatVec tokens{request.s_Tokens.begin(), request.s_Tokens.end()}; - torch::Tensor results = infer(module, tokens); + + torch::Tensor results = infer(module, request); writePrediction(results, request.s_RequestId, outputStream); return true; @@ -171,7 +184,7 @@ int main(int argc, char** argv) { ml::torch::CCommandParser commandParser{ioMgr.inputStream()}; - commandParser.ioLoop([&module, &ioMgr](const ml::torch::CCommandParser::SRequest& request){ + commandParser.ioLoop([&module, &ioMgr](ml::torch::CCommandParser::SRequest& request){ return handleRequest(request, module, ioMgr.outputStream()); }); diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index 8edaa74d84..ceb295da07 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -5,15 +5,9 @@ This script reads the input files and expected outputs, then launches the C++ pytorch_inference program which handles and -sends the request. The reponse is checked against the expected +sends the request. The response is checked against the expected defined in the test file -Run this script with input from one of the example directories, -for example: - -python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json - - The test file must have the format: [ { @@ -22,6 +16,14 @@ }, ... ] + + +EXAMPLES +-------- +Run this script with input from one of the example directories, +for example: + +python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json ''' import argparse diff --git a/bin/pytorch_inference/examples/ner/test_run.json b/bin/pytorch_inference/examples/ner/test_run.json index b385bfad66..1339669b0b 100644 --- a/bin/pytorch_inference/examples/ner/test_run.json +++ b/bin/pytorch_inference/examples/ner/test_run.json @@ -1,12 +1,24 @@ [ { "source_text": "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge.", - "input": {"request_id": "one", "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102]}, + "input": { + "request_id": "one", + "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] + }, "expected_output": {"request_id": "one", "inference": [0, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 0]} }, { "source_text": "Jim bought 300 shares of Acme Corp. in 2006", - "input": {"request_id": "two", "tokens": [101, 3104, 3306, 3127, 6117, 1104, 138, 1665, 3263, 13619, 119, 1107, 1386, 102]}, + "input": { + "request_id": "two", + "tokens": [101, 3104, 3306, 3127, 6117, 1104, 138, 1665, 3263, 13619, 119, 1107, 1386, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + }, "expected_output": {"request_id": "two", "inference": [0, 4, 0, 0, 0, 0, 6, 6, 6, 6, 6, 0, 0, 0]} - } + } ] \ No newline at end of file diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index a65d89b9e0..bc6b9ca7db 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -28,13 +28,13 @@ BOOST_AUTO_TEST_CASE(testParsingStream) { BOOST_REQUIRE_EQUAL(2, parsed.size()); { BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); - std::vector expected{1, 2, 3}; + ml::torch::CCommandParser::TUint32Vec expected{1, 2, 3}; BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), parsed[0].s_Tokens.end(), expected.begin(), expected.end()); } { BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); - std::vector expected{4, 5}; + ml::torch::CCommandParser::TUint32Vec expected{4, 5}; BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), parsed[1].s_Tokens.end(), expected.begin(), expected.end()); } @@ -82,5 +82,63 @@ BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { BOOST_REQUIRE_EQUAL("bar2", parsed[3].s_RequestId); } +BOOST_AUTO_TEST_CASE(testParsingVariableArguments) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": [0, 0], \"arg_2\": [0, 1]}" + "{\"request_id\": \"bar\", \"tokens\": [3, 4], \"arg_1\": [1, 0], \"arg_2\": [1, 1]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE( + processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ + parsed.push_back(request); + return true; + } )); + + BOOST_REQUIRE_EQUAL(2, parsed.size()); + { + ml::torch::CCommandParser::TUint32Vec expectedArg1{0, 0}; + ml::torch::CCommandParser::TUint32Vec expectedArg2{0, 1}; + + ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[0].s_SecondaryArguments; + BOOST_REQUIRE_EQUAL(2, extraArgs.size()); + + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } + { + ml::torch::CCommandParser::TUint32Vec expectedArg1{1, 0}; + ml::torch::CCommandParser::TUint32Vec expectedArg2{1, 1}; + + ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[1].s_SecondaryArguments; + BOOST_REQUIRE_EQUAL(2, extraArgs.size()); + + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } +} + +BOOST_AUTO_TEST_CASE(testParsingInvalidVarArg) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": \"not_an_array\"}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE( + processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ + parsed.push_back(request); + return true; + } )); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} BOOST_AUTO_TEST_SUITE_END() From 85b5ebc4ea36bcec2e9f3fc8d7ddbaa51a326a8b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 22 Feb 2021 15:29:20 +0000 Subject: [PATCH 04/19] Catch and output inference errors --- bin/pytorch_inference/Main.cc | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index d447ea3d27..10c90825a6 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -36,13 +36,11 @@ #endif static const std::string INFERENCE{"inference"}; +static const std::string ERROR{"error"}; torch::Tensor infer(torch::jit::script::Module& module, ml::torch::CCommandParser::SRequest& request) { - - ml::torch::CCommandParser::TUint32Vec data{request.s_Tokens.begin(), request.s_Tokens.end()}; - torch::Tensor tokensTensor = torch::from_blob(static_cast(request.s_Tokens.data()), {1, static_cast(request.s_Tokens.size())}, @@ -64,8 +62,7 @@ torch::Tensor infer(torch::jit::script::Module& module, torch::NoGradGuard noGrad; auto tuple = module.forward(inputs).toTuple(); - auto predictions = tuple->elements()[0].toTensor(); - + auto predictions = tuple->elements()[0].toTensor(); return torch::argmax(predictions, 2); } @@ -85,14 +82,29 @@ void writePrediction(const torch::Tensor& prediction, const std::string& request jsonWriter.EndObject(); } +void writeError(const std::string& requestId, const std::string& message, + std::ostream& outputStream) { + rapidjson::OStreamWrapper writeStream(outputStream); + ml::core::CRapidJsonLineWriter jsonWriter(writeStream); + jsonWriter.StartObject(); + jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); + jsonWriter.String(requestId); + jsonWriter.Key(ERROR); + jsonWriter.String(message); + jsonWriter.EndObject(); +} + bool handleRequest(ml::torch::CCommandParser::SRequest& request, torch::jit::script::Module& module, std::ostream& outputStream) { - - torch::Tensor results = infer(module, request); - writePrediction(results, request.s_RequestId, outputStream); + try { + torch::Tensor results = infer(module, request); + writePrediction(results, request.s_RequestId, outputStream); + } catch (std::runtime_error& e) { + writeError(request.s_RequestId, e.what(), outputStream); + } return true; } From ec8b8835c0e38fde2f1391aa71987ac1e18e0bca Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 23 Feb 2021 14:44:13 +0000 Subject: [PATCH 05/19] Add the sentiment analysis example --- bin/pytorch_inference/CCommandParser.cc | 169 +++++++++--------- bin/pytorch_inference/CCommandParser.h | 47 +++-- bin/pytorch_inference/Main.cc | 84 ++++----- bin/pytorch_inference/evaluate.py | 69 +++++-- .../examples/ner/test_run.json | 4 +- .../examples/sentiment_analysis/test_run.json | 20 +++ .../unittest/CCommandParserTest.cc | 128 +++++++------ 7 files changed, 298 insertions(+), 223 deletions(-) create mode 100644 bin/pytorch_inference/examples/sentiment_analysis/test_run.json diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 242a9c9c49..f080f27c15 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -20,113 +20,114 @@ namespace torch { namespace { void debug(const rapidjson::Document& doc) { - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - doc.Accept(writer); - LOG_DEBUG(<< buffer.GetString()); + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + LOG_TRACE(<< buffer.GetString()); +} } -} const std::string CCommandParser::REQUEST_ID{"request_id"}; const std::string CCommandParser::TOKENS{"tokens"}; const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { - } - bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { - rapidjson::IStreamWrapper isw(m_StrmIn); + rapidjson::IStreamWrapper isw(m_StrmIn); - while (true) { - rapidjson::Document doc; - rapidjson::ParseResult parseResult = doc.ParseStream(isw); - - if (static_cast(parseResult) == false) { - if (m_StrmIn.eof()) { - break; - } + while (true) { + rapidjson::Document doc; + rapidjson::ParseResult parseResult = + doc.ParseStream(isw); - LOG_ERROR(<< "Error parsing command from JSON: " << rapidjson::GetParseError_En(parseResult.Code()) - << ". At offset: " << parseResult.Offset()); + if (static_cast(parseResult) == false) { + if (m_StrmIn.eof()) { + break; + } - return false; - } + LOG_ERROR(<< "Error parsing command from JSON: " + << rapidjson::GetParseError_En(parseResult.Code()) + << ". At offset: " << parseResult.Offset()); + return false; + } - if (validateJson(doc) == false) { - continue; - } + if (validateJson(doc) == false) { + continue; + } - - debug(doc); - CCommandParser::SRequest request = jsonToRequest(doc); - requestHandler(request); - } - - return true; + // TODO if logger.trace_enabled then + debug(doc); + CCommandParser::SRequest request = jsonToRequest(doc); + if (requestHandler(request) == false) { + LOG_ERROR(<< "Request handler forced exit"); + return false; + } + } + + return true; } bool CCommandParser::validateJson(const rapidjson::Document& doc) const { - if (doc.HasMember(REQUEST_ID) == false) { - LOG_ERROR(<< "Malformed command request: missing field [" << REQUEST_ID << "]"); - return false; - } - - if (doc.HasMember(TOKENS) == false) { - LOG_ERROR(<< "Malformed command request: missing field [" << TOKENS << "]"); - return false; - } - - const rapidjson::Value& tokens = doc[TOKENS]; - if (tokens.IsArray() == false) { - LOG_ERROR(<< "Malformed command request: expected an array [" << TOKENS << "]"); - return false; - } - - // check optional args - std::uint64_t varCount{1}; - std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); - while (doc.HasMember(varArgName)) { - const rapidjson::Value& value = doc[varArgName]; - if (value.IsArray() == false) { - LOG_ERROR(<< "Malformed command request: argument [" << varArgName << "] is not an array"); - return false; - } - - ++varCount; - varArgName = VAR_ARG_PREFIX + std::to_string(varCount); - } - - return true; + if (doc.HasMember(REQUEST_ID) == false) { + LOG_ERROR(<< "Invalid command: missing field [" << REQUEST_ID << "]"); + return false; + } + + if (doc.HasMember(TOKENS) == false) { + LOG_ERROR(<< "Invalid command: missing field [" << TOKENS << "]"); + return false; + } + + const rapidjson::Value& tokens = doc[TOKENS]; + if (tokens.IsArray() == false) { + LOG_ERROR(<< "Invalid command: expected an array [" << TOKENS << "]"); + return false; + } + + // check optional args + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + while (doc.HasMember(varArgName)) { + const rapidjson::Value& value = doc[varArgName]; + if (value.IsArray() == false) { + LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array"); + return false; + } + + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } + + return true; } CCommandParser::SRequest CCommandParser::jsonToRequest(const rapidjson::Document& doc) const { - TUint32Vec tokens; - const rapidjson::Value& arr = doc[TOKENS]; - for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { - tokens.push_back(itr->GetUint()); - } - - std::uint64_t varCount{1}; - std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); - TUint32VecVec args; - - while (doc.HasMember(varArgName)) { - TUint32Vec arg; - const rapidjson::Value& v = doc[varArgName]; - for (auto itr = v.Begin(); itr != v.End(); ++itr) { - arg.push_back(itr->GetUint()); - } - - args.push_back(arg); - ++varCount; - varArgName = VAR_ARG_PREFIX + std::to_string(varCount); - } - return {doc[REQUEST_ID].GetString(), tokens, args}; + TUint32Vec tokens; + const rapidjson::Value& arr = doc[TOKENS]; + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { + tokens.push_back(itr->GetUint()); + } + + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + TUint32VecVec args; + + while (doc.HasMember(varArgName)) { + TUint32Vec arg; + const rapidjson::Value& v = doc[varArgName]; + for (auto itr = v.Begin(); itr != v.End(); ++itr) { + arg.push_back(itr->GetUint()); + } + + args.push_back(arg); + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } + return {doc[REQUEST_ID].GetString(), tokens, args}; } - } } diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index a2b2a50d99..0947315b70 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -7,52 +7,48 @@ #ifndef INCLUDED_ml_torch_CCommandParser_h #define INCLUDED_ml_torch_CCommandParser_h -#include #include +#include #include #include #include - namespace ml { namespace torch { - //! \brief -//! Reads JSON documents from a stream emitting a request for -//! each parsed document. +//! Reads JSON documents from a stream calling the request handler +//! for each parsed document. //! //! DESCRIPTION:\n -//! +//! //! //! IMPLEMENTATION DECISIONS:\n //! RapidJSON will natively parse a stream of rootless JSON documents -//! given the correct parse flags. The documents may be separated by +//! given the correct parse flags. The documents may be separated by //! whitespace but no other delineator is allowed. -//! +//! //! The input stream is held by reference. They must outlive objects of //! this class, which, in practice, means that the CIoManager object managing //! them must outlive this object. //! -class CCommandParser { +class CCommandParser { public: + static const std::string REQUEST_ID; + static const std::string TOKENS; + static const std::string VAR_ARG_PREFIX; - static const std::string REQUEST_ID; - static const std::string TOKENS; - static const std::string VAR_ARG_PREFIX; - - using TUint32Vec = std::vector; - using TUint32VecVec = std::vector; - - struct SRequest { - std::string s_RequestId; - TUint32Vec s_Tokens; - TUint32VecVec s_SecondaryArguments; - }; + using TUint32Vec = std::vector; + using TUint32VecVec = std::vector; - using TRequestHandlerFunc = std::function; + struct SRequest { + std::string s_RequestId; + TUint32Vec s_Tokens; + TUint32VecVec s_SecondaryArguments; + }; + using TRequestHandlerFunc = std::function; public: CCommandParser(std::istream& strmIn); @@ -64,13 +60,12 @@ class CCommandParser { CCommandParser& operator=(const CCommandParser&) = delete; private: - bool validateJson(const rapidjson::Document& doc) const; - SRequest jsonToRequest(const rapidjson::Document& doc) const; + bool validateJson(const rapidjson::Document& doc) const; + SRequest jsonToRequest(const rapidjson::Document& doc) const; + private: - //! std::istream& m_StrmIn; }; - } } diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 10c90825a6..2467f35f6b 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -4,17 +4,13 @@ * you may not use this file except in compliance with the Elastic License. */ +#include #include #include #include #include -#include - -#include - -#include - #include +#include #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" @@ -23,28 +19,19 @@ #include #include -#include - #include #include -// For ntohl -#ifdef Windows -#include -#else -#include -#endif - static const std::string INFERENCE{"inference"}; static const std::string ERROR{"error"}; -torch::Tensor infer(torch::jit::script::Module& module, - ml::torch::CCommandParser::SRequest& request) { +torch::Tensor infer(torch::jit::script::Module& module, + ml::torch::CCommandParser::SRequest& request) { torch::Tensor tokensTensor = - torch::from_blob(static_cast(request.s_Tokens.data()), - {1, static_cast(request.s_Tokens.size())}, - at::dtype(torch::kInt32)) + torch::from_blob(static_cast(request.s_Tokens.data()), + {1, static_cast(request.s_Tokens.size())}, + at::dtype(torch::kInt32)) .to(torch::kInt64); std::vector inputs; @@ -52,9 +39,9 @@ torch::Tensor infer(torch::jit::script::Module& module, for (auto args : request.s_SecondaryArguments) { torch::Tensor tensor = - torch::from_blob(static_cast(args.data()), - {1, static_cast(args.size())}, - at::dtype(torch::kInt32)) + torch::from_blob(static_cast(args.data()), + {1, static_cast(args.size())}, + at::dtype(torch::kInt32)) .to(torch::kInt64); inputs.push_back(tensor); @@ -62,11 +49,27 @@ torch::Tensor infer(torch::jit::script::Module& module, torch::NoGradGuard noGrad; auto tuple = module.forward(inputs).toTuple(); - auto predictions = tuple->elements()[0].toTensor(); - return torch::argmax(predictions, 2); + return tuple->elements()[0].toTensor(); } -void writePrediction(const torch::Tensor& prediction, const std::string& requestId, std::ostream& outputStream) { +void writePrediction(const torch::Tensor& prediction, + const std::string& requestId, + std::ostream& outputStream) { + + torch::Tensor view; + auto sizes = prediction.sizes(); + if (sizes.size() == 3 && sizes[0] == 1) { + view = prediction[0]; + } else { + view = prediction; + } + + // creating the accessor will throw if view does not + // have exactly 2 dimensions. Do this before writing + // any output so the error message isn't mingled with + // a partial result + auto accessor = view.accessor(); + rapidjson::OStreamWrapper writeStream(outputStream); ml::core::CRapidJsonLineWriter jsonWriter(writeStream); jsonWriter.StartObject(); @@ -74,35 +77,38 @@ void writePrediction(const torch::Tensor& prediction, const std::string& request jsonWriter.String(requestId); jsonWriter.Key(INFERENCE); jsonWriter.StartArray(); - auto arr = prediction.accessor(); - for (int i = 0; i < arr.size(1); i++) { - jsonWriter.Int64(arr[0][i]); + + for (int i = 0; i < accessor.size(0); ++i) { + jsonWriter.StartArray(); + for (int j = 0; j < accessor.size(1); ++j) { + jsonWriter.Double(static_cast(accessor[i][j])); + } + jsonWriter.EndArray(); } + jsonWriter.EndArray(); jsonWriter.EndObject(); } -void writeError(const std::string& requestId, const std::string& message, - std::ostream& outputStream) { +void writeError(const std::string& requestId, const std::string& message, std::ostream& outputStream) { rapidjson::OStreamWrapper writeStream(outputStream); ml::core::CRapidJsonLineWriter jsonWriter(writeStream); jsonWriter.StartObject(); jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); jsonWriter.String(requestId); - jsonWriter.Key(ERROR); + jsonWriter.Key(ERROR); jsonWriter.String(message); jsonWriter.EndObject(); } - bool handleRequest(ml::torch::CCommandParser::SRequest& request, - torch::jit::script::Module& module, - std::ostream& outputStream) { + torch::jit::script::Module& module, + std::ostream& outputStream) { - try { + try { torch::Tensor results = infer(module, request); writePrediction(results, request.s_RequestId, outputStream); - } catch (std::runtime_error& e) { + } catch (std::runtime_error& e) { writeError(request.s_RequestId, e.what(), outputStream); } @@ -193,14 +199,12 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } - ml::torch::CCommandParser commandParser{ioMgr.inputStream()}; - commandParser.ioLoop([&module, &ioMgr](ml::torch::CCommandParser::SRequest& request){ + commandParser.ioLoop([&module, &ioMgr](ml::torch::CCommandParser::SRequest& request) { return handleRequest(request, module, ioMgr.outputStream()); }); - LOG_DEBUG(<< "ML Torch model prototype exiting"); return EXIT_SUCCESS; diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index ceb295da07..6feda24aa8 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -3,9 +3,9 @@ app, together with the encoded tokens from the input_tokens file. Then it checks the model's response matches the expected. -This script reads the input files and expected outputs, then -launches the C++ pytorch_inference program which handles and -sends the request. The response is checked against the expected +This script reads the input files and expected outputs, then +launches the C++ pytorch_inference program which handles and +sends the request. The response is checked against the expected defined in the test file The test file must have the format: @@ -23,11 +23,12 @@ Run this script with input from one of the example directories, for example: -python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json +python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json ''' import argparse import json +import math import os import platform import stat @@ -78,14 +79,41 @@ def stream_file(source, destination) : destination.write(piece) -def write_request(request, destination): +def write_request(request, destination): json.dump(request, destination) + +def compare_results(expected, actual): + if expected['request_id'] != actual['request_id']: + print("request_ids do not match [{}], [{}]".format(expected['request_id'], actual['request_id']), flush=True) + return False + + if len(expected['inference']) != len(actual['inference']): + print("len(inference) does not match [{}], [{}]".format(len(expected['inference']), len(actual['inference'])), flush=True) + return False + + for i in range(len(expected['inference'])): + expected_row = expected['inference'][i] + actual_row = actual['inference'][i] + + if len(expected_row) != len(actual_row): + print("row [{}] lengths are not equal [{}], [{}]".format(i, len(expected_row), len(actual_row)), flush=True) + return False + + are_close = True + for j in range(len(expected_row)): + are_close = are_close and math.isclose(expected_row[j], actual_row[j], rel_tol=1e-04) + + if are_close == False: + print("row [{}] values are not close {}, {}".format(i, expected_row, actual_row), flush=True) + return False + + def main(): args = parse_arguments() - try: + try: # create the restore file with open(args.restore_file, 'wb') as restore_file: file_stats = os.stat(args.model) @@ -106,37 +134,46 @@ def main(): print("writing query", flush=True) for doc in test_evaluation: write_request(doc['input'], input_file) - + launch_pytorch_app(args) - print("reading results", flush=True) + print() + print("reading results...", flush=True) with open(args.output_file) as output_file: - + doc_count = 0 - results_match = True + results_match = True # output is NDJSON for jsonline in output_file: - result = json.loads(jsonline) + try: + result = json.loads(jsonline) + except: + print("Error parsing json: ", jsonline) + return + expected = test_evaluation[doc_count]['expected_output'] # compare to expected - if result != expected: + if compare_results(expected, result) == False: + print() print('ERROR: inference result [{}] does not match expected results'.format(doc_count)) - print(result, expected) + print() results_match = False doc_count = doc_count +1 if results_match: - print('SUCCESS: inference results match expected') + print() + print('SUCCESS: inference results match expected', flush=True) + print() - finally: + finally: if os.path.isfile(args.restore_file): os.remove(args.restore_file) if os.path.isfile(args.input_file): os.remove(args.input_file) if os.path.isfile(args.output_file): - os.remove(args.output_file) + os.remove(args.output_file) diff --git a/bin/pytorch_inference/examples/ner/test_run.json b/bin/pytorch_inference/examples/ner/test_run.json index 1339669b0b..07cb92ad86 100644 --- a/bin/pytorch_inference/examples/ner/test_run.json +++ b/bin/pytorch_inference/examples/ner/test_run.json @@ -8,7 +8,7 @@ "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] }, - "expected_output": {"request_id": "one", "inference": [0, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 0]} + "expected_output": {"request_id": "one", "inference": [[9.445713996887207,-2.5186009407043459,-1.6298730373382569,-2.0444109439849855,-2.2338523864746095,-1.7153222560882569,-0.43147391080856326,-2.0465660095214845,1.1865826845169068],[0.792238175868988,-2.8607232570648195,-0.7416678071022034,-3.2710094451904299,-0.8123045563697815,-1.7240391969680787,9.034605026245118,-2.5430126190185549,-0.4717179536819458],[1.9150646924972535,-2.288753032684326,0.10271981358528137,-3.2836503982543947,-0.05480371415615082,-1.2911514043807984,6.839619159698486,-2.345627546310425,-0.7598909139633179],[0.6548398733139038,-2.8548803329467775,-0.22843939065933228,-3.415384531021118,0.2715289294719696,-1.3837164640426636,7.829202175140381,-2.879647970199585,-0.38423842191696169],[1.1488219499588013,-2.967827796936035,-0.8474732041358948,-3.521238327026367,-1.603212594985962,-1.894327998161316,9.006779670715332,-2.138070821762085,-0.03918422386050224],[6.845504283905029,-2.78581166267395,-1.3969510793685914,-3.771115779876709,-0.6792833209037781,-1.7580525875091553,4.772181510925293,-2.548997640609741,-0.09716915339231491],[10.898977279663086,-2.2232580184936525,-0.8633253574371338,-2.6009225845336916,-1.4352948665618897,-1.5839565992355347,0.8288266062736511,-2.17624831199646,-0.6236083507537842],[11.136425018310547,-2.27126145362854,-0.8518160581588745,-2.6662473678588869,-1.4811440706253052,-1.636898398399353,0.7194949388504028,-2.180851697921753,-0.6165107488632202],[10.732818603515625,-2.5183522701263429,-0.7929745316505432,-2.885248899459839,-1.4805352687835694,-1.9814093112945557,1.8594862222671509,-2.201752185821533,-0.5392075777053833],[10.873189926147461,-2.306696891784668,-0.80250483751297,-2.8811705112457277,-1.4863512516021729,-1.8671237230300904,0.9172942638397217,-2.1708977222442629,-0.20861388742923737],[10.7890625,-2.5134470462799074,-0.7714695334434509,-2.6941256523132326,-1.6425325870513917,-1.7182999849319459,0.4641777276992798,-2.154723882675171,0.11145751923322678],[0.2922537624835968,-2.2648208141326906,-1.1112738847732545,-2.363692283630371,-1.1536873579025269,-2.16975474357605,0.1696488857269287,-1.7420969009399415,8.619940757751465],[-0.05837444216012955,-2.2322306632995607,-1.2456691265106202,-2.393845796585083,-1.0231869220733643,-2.1384024620056154,0.36509764194488528,-1.6257678270339966,8.386515617370606],[-0.40553534030914309,-2.0154948234558107,-0.9165335297584534,-2.3343887329101564,-1.008164644241333,-2.404677152633667,-0.25008249282836916,-0.8966042399406433,8.402207374572754],[9.445691108703614,-2.5186126232147219,-1.6298660039901734,-2.044417381286621,-2.2338428497314455,-1.7153276205062867,-0.43144431710243227,-2.0465352535247804,1.1865744590759278],[10.766592025756836,-2.483231782913208,-0.7467246651649475,-2.816481590270996,-1.5419433116912842,-1.7198463678359986,0.8594813942909241,-2.4609286785125734,-0.38469821214675906],[11.087119102478028,-2.369004726409912,-1.0071995258331299,-2.8232693672180177,-1.5680956840515137,-1.817236304283142,0.4988541603088379,-2.2202258110046388,-0.38811907172203066],[11.105840682983399,-2.1901695728302,-0.8257009387016296,-2.5521652698516847,-1.4372351169586182,-1.8002771139144898,0.3943113684654236,-2.086343288421631,-0.3007338047027588],[10.404936790466309,-2.4211761951446535,-0.6679232120513916,-2.7030136585235597,-1.702601671218872,-1.7996376752853394,0.34499865770339968,-1.9718276262283326,0.5988494753837586],[0.04125037044286728,-2.1516361236572267,-0.6540370583534241,-2.9895687103271486,-0.02026049792766571,-2.607717990875244,0.7579661011695862,-2.63039493560791,5.886351585388184],[0.41723012924194338,-1.6220901012420655,-0.6683351397514343,-3.1531808376312258,-0.7741800546646118,-2.0307929515838625,1.8664846420288087,-2.6268656253814699,4.985666275024414],[1.1279197931289673,-2.7916486263275148,-1.4165221452713013,-3.2945871353149416,-0.24917204678058625,-2.006779432296753,2.0364503860473635,-1.8233060836791993,4.838916301727295],[10.858016967773438,-2.210500478744507,-0.7385379672050476,-2.6571877002716066,-1.7322543859481812,-1.5595816373825074,0.38503119349479678,-2.160937547683716,-0.2505747079849243],[11.213296890258789,-2.2317054271698,-0.7556230425834656,-2.6072943210601808,-1.5168147087097169,-1.735519289970398,-0.2140665054321289,-1.9741536378860474,-0.3461124897003174],[11.128496170043946,-2.286959409713745,-0.6008748412132263,-2.680781841278076,-1.637223482131958,-1.8050802946090699,0.04603640362620354,-2.0937740802764894,-0.39641517400741579],[5.765664100646973,-2.3779571056365969,0.5095824599266052,-3.020923137664795,-0.835684597492218,-2.7159736156463625,1.3586626052856446,-2.0834007263183595,1.0498584508895875],[10.664918899536133,-2.2987592220306398,-0.45505377650260928,-2.8677072525024416,-1.7201433181762696,-1.8049241304397584,0.18862102925777436,-2.447751760482788,-0.1371004730463028],[10.836210250854493,-1.9615545272827149,-0.39490482211112978,-2.8520052433013918,-1.5771089792251588,-1.796066403388977,-0.2166888266801834,-2.1943840980529787,-0.14454032480716706],[11.126060485839844,-2.092292308807373,-0.7535718679428101,-2.7322869300842287,-1.737945318222046,-1.915206789970398,-0.05707954242825508,-2.2171850204467775,0.04100846126675606],[-0.9580754637718201,-1.4011685848236085,0.3523239195346832,-2.691606044769287,-0.3759573698043823,-1.531064510345459,0.32374972105026247,-2.8274593353271486,5.610057830810547],[-0.3176293671131134,-1.7536091804504395,0.2829216420650482,-2.8517212867736818,-0.298392117023468,-2.5902841091156008,0.05927262827754021,-2.667252779006958,6.318171977996826],[9.445697784423829,-2.5185813903808595,-1.629878282546997,-2.0444087982177736,-2.233858108520508,-1.7153252363204957,-0.4314804971218109,-2.0465519428253176,1.1865817308425904],[9.445714950561524,-2.5186026096343996,-1.6298726797103882,-2.0444114208221437,-2.2338526248931886,-1.7153226137161255,-0.431472510099411,-2.0465657711029054,1.1865825653076172]]} }, { "source_text": "Jim bought 300 shares of Acme Corp. in 2006", @@ -19,6 +19,6 @@ "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] }, - "expected_output": {"request_id": "two", "inference": [0, 4, 0, 0, 0, 0, 6, 6, 6, 6, 6, 0, 0, 0]} + "expected_output": {"request_id": "two", "inference": [[9.265243530273438,-2.515533685684204,-1.3969738483428956,-2.351750135421753,-1.803475260734558,-1.8310259580612183,-0.09857147932052612,-1.8631094694137574,0.39207878708839419],[1.2384196519851685,-1.9405957460403443,-1.775153398513794,-3.3273582458496095,6.286568641662598,-2.179285764694214,2.0314316749572756,-2.7113730907440187,-0.45493167638778689],[11.42744255065918,-2.162276268005371,-0.843875527381897,-2.559823513031006,-1.0762289762496949,-1.8819851875305176,-0.05885419622063637,-2.0086324214935304,-1.1396600008010865],[11.531952857971192,-2.2821600437164308,-0.5459890961647034,-2.5652623176574709,-1.4509466886520386,-2.0525619983673097,-0.25311681628227236,-2.0138297080993654,-1.1384060382843018],[11.342474937438965,-2.217491388320923,-0.8007739186286926,-2.817258834838867,-1.4931764602661133,-2.072322368621826,0.4753860533237457,-2.071307897567749,-1.1917411088943482],[11.184587478637696,-2.2941224575042726,-1.1137279272079468,-2.842801570892334,-1.3545138835906983,-2.021228075027466,0.747635006904602,-2.145432472229004,-0.8665369749069214],[0.21825885772705079,-2.547029733657837,-0.33246660232543948,-2.996917724609375,-1.280576229095459,-1.7315617799758912,9.094033241271973,-2.1498544216156008,-0.9704827666282654],[1.5527149438858033,-2.109332323074341,0.9170875549316406,-2.693459987640381,0.07074408233165741,-1.8306814432144166,4.762491226196289,-1.9375808238983155,-0.8089945316314697],[0.49461689591407778,-2.785339593887329,-0.35393503308296206,-3.1697468757629396,-1.1621716022491456,-1.5924493074417115,8.098134994506836,-2.1945927143096926,-0.2739206850528717],[0.8808980584144592,-2.962925910949707,-0.7409814596176148,-3.4761435985565187,-1.8204796314239503,-1.6338881254196168,8.950613975524903,-1.90260910987854,0.25867289304733279],[3.3241844177246095,-2.2221860885620119,0.010276625864207745,-3.562873125076294,-0.6271736025810242,-1.7994204759597779,3.548079252243042,-2.3455047607421877,0.8453909754753113],[11.597189903259278,-2.2594125270843508,-0.805946409702301,-2.462068796157837,-1.2340333461761475,-1.948832392692566,-0.4201577305793762,-1.6598224639892579,-1.0664936304092408],[11.329129219055176,-1.987001657485962,-0.5631923079490662,-2.7865068912506105,-1.0649499893188477,-2.278918981552124,-0.09693071246147156,-2.0265250205993654,-0.8837844729423523],[9.265239715576172,-2.5155346393585207,-1.3969736099243165,-2.351749897003174,-1.8034745454788209,-1.8310260772705079,-0.09857088327407837,-1.863108515739441,0.39208000898361208]]} } ] \ No newline at end of file diff --git a/bin/pytorch_inference/examples/sentiment_analysis/test_run.json b/bin/pytorch_inference/examples/sentiment_analysis/test_run.json new file mode 100644 index 0000000000..76aaa0ac51 --- /dev/null +++ b/bin/pytorch_inference/examples/sentiment_analysis/test_run.json @@ -0,0 +1,20 @@ +[ + { + "source_text": "The cat was sick on the bed", + "input": { + "request_id": "one", + "tokens": [101, 1996, 4937, 2001, 5305, 2006, 1996, 2793, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1] + }, + "expected_output": {"request_id": "one", "inference": [[3.9489, -3.2416]]} + }, + { + "source_text": "The movie was awesome!!", + "input": { + "request_id": "two", + "tokens": [101, 1996, 3185, 2001, 12476, 999, 999, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1] + }, + "expected_output": {"request_id": "two", "inference": [[-4.2720, 4.6515]]} + } +] \ No newline at end of file diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index bc6b9ca7db..4e83fc19a6 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -15,28 +15,29 @@ BOOST_AUTO_TEST_CASE(testParsingStream) { std::vector parsed; std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}" - "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; + "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; std::istringstream commandStream{command}; - ml::torch::CCommandParser processor{commandStream}; - BOOST_TEST_REQUIRE( - processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ - parsed.push_back(request); - return true; - } )); + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); - BOOST_REQUIRE_EQUAL(2, parsed.size()); + BOOST_REQUIRE_EQUAL(2, parsed.size()); { BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); ml::torch::CCommandParser::TUint32Vec expected{1, 2, 3}; - BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), parsed[0].s_Tokens.end(), - expected.begin(), expected.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), + parsed[0].s_Tokens.end(), + expected.begin(), expected.end()); } { - BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); + BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); ml::torch::CCommandParser::TUint32Vec expected{4, 5}; - BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), parsed[1].s_Tokens.end(), - expected.begin(), expected.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), + parsed[1].s_Tokens.end(), + expected.begin(), expected.end()); } } @@ -45,14 +46,13 @@ BOOST_AUTO_TEST_CASE(testParsingInvalidDoc) { std::vector parsed; std::string command{"{\"foo\": 1, }"}; - + std::istringstream commandStream{command}; - ml::torch::CCommandParser processor{commandStream}; - BOOST_TEST_REQUIRE( - processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ - parsed.push_back(request); - return true; + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; }) == false); BOOST_REQUIRE_EQUAL(0, parsed.size()); @@ -63,65 +63,64 @@ BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { std::vector parsed; std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}\t" - "{\"request_id\": \"bar\", \"tokens\": [1, 2, 3]}\n" - "{\"request_id\": \"foo2\", \"tokens\": [1, 2, 3]} " - "{\"request_id\": \"bar2\", \"tokens\": [1, 2, 3]}"}; + "{\"request_id\": \"bar\", \"tokens\": [1, 2, 3]}\n" + "{\"request_id\": \"foo2\", \"tokens\": [1, 2, 3]} " + "{\"request_id\": \"bar2\", \"tokens\": [1, 2, 3]}"}; std::istringstream commandStream{command}; - ml::torch::CCommandParser processor{commandStream}; - BOOST_TEST_REQUIRE( - processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ - parsed.push_back(request); - return true; - } )); + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); BOOST_REQUIRE_EQUAL(4, parsed.size()); BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); BOOST_REQUIRE_EQUAL("foo2", parsed[2].s_RequestId); - BOOST_REQUIRE_EQUAL("bar2", parsed[3].s_RequestId); + BOOST_REQUIRE_EQUAL("bar2", parsed[3].s_RequestId); } BOOST_AUTO_TEST_CASE(testParsingVariableArguments) { std::vector parsed; - std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": [0, 0], \"arg_2\": [0, 1]}" - "{\"request_id\": \"bar\", \"tokens\": [3, 4], \"arg_1\": [1, 0], \"arg_2\": [1, 1]}"}; + std::string command{ + "{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": [0, 0], \"arg_2\": [0, 1]}" + "{\"request_id\": \"bar\", \"tokens\": [3, 4], \"arg_1\": [1, 0], \"arg_2\": [1, 1]}"}; std::istringstream commandStream{command}; - ml::torch::CCommandParser processor{commandStream}; - BOOST_TEST_REQUIRE( - processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ - parsed.push_back(request); - return true; - } )); + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); BOOST_REQUIRE_EQUAL(2, parsed.size()); - { + { ml::torch::CCommandParser::TUint32Vec expectedArg1{0, 0}; ml::torch::CCommandParser::TUint32Vec expectedArg2{0, 1}; ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[0].s_SecondaryArguments; BOOST_REQUIRE_EQUAL(2, extraArgs.size()); - BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), - expectedArg1.begin(), expectedArg1.end()); - BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), - expectedArg2.begin(), expectedArg2.end()); - } - { + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } + { ml::torch::CCommandParser::TUint32Vec expectedArg1{1, 0}; ml::torch::CCommandParser::TUint32Vec expectedArg2{1, 1}; ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[1].s_SecondaryArguments; BOOST_REQUIRE_EQUAL(2, extraArgs.size()); - BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), - expectedArg1.begin(), expectedArg1.end()); - BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), - expectedArg2.begin(), expectedArg2.end()); - } + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } } BOOST_AUTO_TEST_CASE(testParsingInvalidVarArg) { @@ -131,14 +130,33 @@ BOOST_AUTO_TEST_CASE(testParsingInvalidVarArg) { std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": \"not_an_array\"}"}; std::istringstream commandStream{command}; - ml::torch::CCommandParser processor{commandStream}; - BOOST_TEST_REQUIRE( - processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request){ - parsed.push_back(request); - return true; - } )); + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); BOOST_REQUIRE_EQUAL(0, parsed.size()); } +BOOST_AUTO_TEST_CASE(testRequestHandlerExitsLoop) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}" + "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + // handler returns false + BOOST_TEST_REQUIRE( + false == processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return false; + })); + + // ioloop should exit after the first call to the handler + BOOST_REQUIRE_EQUAL(1, parsed.size()); +} + BOOST_AUTO_TEST_SUITE_END() From afbb3bab4a74fb63269a8ab0060c69b9f004e460 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 23 Feb 2021 17:04:58 +0000 Subject: [PATCH 06/19] tidy up --- bin/pytorch_inference/CBufferedIStreamAdapter.cc | 3 --- bin/pytorch_inference/CBufferedIStreamAdapter.h | 1 - bin/pytorch_inference/CCommandParser.cc | 7 ++++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/bin/pytorch_inference/CBufferedIStreamAdapter.cc b/bin/pytorch_inference/CBufferedIStreamAdapter.cc index d225351a5a..be97d04d93 100644 --- a/bin/pytorch_inference/CBufferedIStreamAdapter.cc +++ b/bin/pytorch_inference/CBufferedIStreamAdapter.cc @@ -22,9 +22,6 @@ CBufferedIStreamAdapter::CBufferedIStreamAdapter(std::istream& inputStream) : m_InputStream(inputStream) { } -CBufferedIStreamAdapter::~CBufferedIStreamAdapter() { -} - bool CBufferedIStreamAdapter::init() { if (parseSizeFromStream(m_Size) == false) { LOG_ERROR(<< "Failed to read model size"); diff --git a/bin/pytorch_inference/CBufferedIStreamAdapter.h b/bin/pytorch_inference/CBufferedIStreamAdapter.h index efc47e35d7..23b1c40231 100644 --- a/bin/pytorch_inference/CBufferedIStreamAdapter.h +++ b/bin/pytorch_inference/CBufferedIStreamAdapter.h @@ -38,7 +38,6 @@ namespace torch { class CBufferedIStreamAdapter : public caffe2::serialize::ReadAdapterInterface { public: CBufferedIStreamAdapter(std::istream& inputStream); - ~CBufferedIStreamAdapter() override; //! True if the model is successfully read. //! Must be called before read or size diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index f080f27c15..32e8f4728c 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -63,8 +63,8 @@ bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { debug(doc); CCommandParser::SRequest request = jsonToRequest(doc); if (requestHandler(request) == false) { - LOG_ERROR(<< "Request handler forced exit"); - return false; + LOG_ERROR(<< "Request handler forced exit"); + return false; } } @@ -94,7 +94,8 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { while (doc.HasMember(varArgName)) { const rapidjson::Value& value = doc[varArgName]; if (value.IsArray() == false) { - LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array"); + LOG_ERROR(<< "Invalid command: argument [" << varArgName + << "] is not an array"); return false; } From 6adf03498e3903069e817dd5fe6082ab71499263 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 24 Feb 2021 09:58:12 +0000 Subject: [PATCH 07/19] clang format --- bin/pytorch_inference/CCommandParser.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 32e8f4728c..f080f27c15 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -63,8 +63,8 @@ bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { debug(doc); CCommandParser::SRequest request = jsonToRequest(doc); if (requestHandler(request) == false) { - LOG_ERROR(<< "Request handler forced exit"); - return false; + LOG_ERROR(<< "Request handler forced exit"); + return false; } } @@ -94,8 +94,7 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { while (doc.HasMember(varArgName)) { const rapidjson::Value& value = doc[varArgName]; if (value.IsArray() == false) { - LOG_ERROR(<< "Invalid command: argument [" << varArgName - << "] is not an array"); + LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array"); return false; } From 441b2fcee1925b35b5c7bfecaac8ff6c32ff732d Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 24 Feb 2021 17:36:14 +0000 Subject: [PATCH 08/19] Address review comments --- bin/pytorch_inference/CCommandParser.cc | 45 ++++++++++++++----------- bin/pytorch_inference/CCommandParser.h | 11 +++--- bin/pytorch_inference/Main.cc | 14 ++++++-- bin/pytorch_inference/Makefile | 3 -- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index f080f27c15..69d186fe17 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -8,25 +8,26 @@ #include -#include - #include #include #include #include -namespace ml { -namespace torch { +#include -namespace { -void debug(const rapidjson::Document& doc) { +namespace rapidjson { + +std::ostream& operator<<(std::ostream& os, const rapidjson::Document& doc) { rapidjson::StringBuffer buffer; rapidjson::Writer writer(buffer); doc.Accept(writer); - LOG_TRACE(<< buffer.GetString()); + return os << buffer.GetString(); } } +namespace ml { +namespace torch { + const std::string CCommandParser::REQUEST_ID{"request_id"}; const std::string CCommandParser::TOKENS{"tokens"}; const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; @@ -34,9 +35,9 @@ const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { } -bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { +bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) { - rapidjson::IStreamWrapper isw(m_StrmIn); + rapidjson::IStreamWrapper isw{m_StrmIn}; while (true) { rapidjson::Document doc; @@ -59,10 +60,9 @@ bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) const { continue; } - // TODO if logger.trace_enabled then - debug(doc); - CCommandParser::SRequest request = jsonToRequest(doc); - if (requestHandler(request) == false) { + LOG_TRACE(<< "Inference command: " << doc); + jsonToRequest(doc); + if (requestHandler(m_Request) == false) { LOG_ERROR(<< "Request handler forced exit"); return false; } @@ -105,29 +105,34 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return true; } -CCommandParser::SRequest CCommandParser::jsonToRequest(const rapidjson::Document& doc) const { - TUint32Vec tokens; +void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { + + m_Request.s_RequestId = doc[REQUEST_ID].GetString(); const rapidjson::Value& arr = doc[TOKENS]; + // overwrite any previous + m_Request.s_Tokens.resize(arr.Size()); + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { - tokens.push_back(itr->GetUint()); + m_Request.s_Tokens.push_back(itr->GetUint()); } std::uint64_t varCount{1}; std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); - TUint32VecVec args; + // wipe any previous + m_Request.s_SecondaryArguments.clear(); + TUint32Vec arg; while (doc.HasMember(varArgName)) { - TUint32Vec arg; const rapidjson::Value& v = doc[varArgName]; for (auto itr = v.Begin(); itr != v.End(); ++itr) { arg.push_back(itr->GetUint()); } - args.push_back(arg); + m_Request.s_SecondaryArguments.push_back(arg); + arg.clear(); ++varCount; varArgName = VAR_ARG_PREFIX + std::to_string(varCount); } - return {doc[REQUEST_ID].GetString(), tokens, args}; } } } diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 0947315b70..66782e5589 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -7,13 +7,13 @@ #ifndef INCLUDED_ml_torch_CCommandParser_h #define INCLUDED_ml_torch_CCommandParser_h +#include + #include #include #include #include -#include - namespace ml { namespace torch { @@ -46,6 +46,8 @@ class CCommandParser { std::string s_RequestId; TUint32Vec s_Tokens; TUint32VecVec s_SecondaryArguments; + + void clear(); }; using TRequestHandlerFunc = std::function; @@ -54,17 +56,18 @@ class CCommandParser { CCommandParser(std::istream& strmIn); //! Pass input to the processor until it's consumed as much as it can. - bool ioLoop(const TRequestHandlerFunc& requestHandler) const; + bool ioLoop(const TRequestHandlerFunc& requestHandler); CCommandParser(const CCommandParser&) = delete; CCommandParser& operator=(const CCommandParser&) = delete; private: bool validateJson(const rapidjson::Document& doc) const; - SRequest jsonToRequest(const rapidjson::Document& doc) const; + void jsonToRequest(const rapidjson::Document& doc); private: std::istream& m_StrmIn; + SRequest m_Request; }; } } diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 2467f35f6b..628031c2ea 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -4,14 +4,17 @@ * you may not use this file except in compliance with the Elastic License. */ -#include #include #include #include #include + #include + #include +#include + #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" #include "CCommandParser.h" @@ -22,8 +25,10 @@ #include #include -static const std::string INFERENCE{"inference"}; -static const std::string ERROR{"error"}; +namespace { +const std::string INFERENCE{"inference"}; +const std::string ERROR{"error"}; +} torch::Tensor infer(torch::jit::script::Module& module, ml::torch::CCommandParser::SRequest& request) { @@ -35,6 +40,7 @@ torch::Tensor infer(torch::jit::script::Module& module, .to(torch::kInt64); std::vector inputs; + inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); for (auto args : request.s_SecondaryArguments) { @@ -58,6 +64,8 @@ void writePrediction(const torch::Tensor& prediction, torch::Tensor view; auto sizes = prediction.sizes(); + // Some models return a 3D tensor in which case + // the first dimension must have size == 1 if (sizes.size() == 3 && sizes[0] == 1) { view = prediction[0]; } else { diff --git a/bin/pytorch_inference/Makefile b/bin/pytorch_inference/Makefile index 7e2fba77b7..f41e540960 100644 --- a/bin/pytorch_inference/Makefile +++ b/bin/pytorch_inference/Makefile @@ -13,7 +13,6 @@ ML_LIBS=$(LIB_ML_CORE) $(LIB_ML_API) USE_BOOST=1 USE_BOOST_PROGRAMOPTIONS_LIBS=1 -USE_NET=1 USE_RAPIDJSON=1 USE_TORCH=1 @@ -27,7 +26,5 @@ SRCS= \ CCmdLineParser.cc \ CCommandParser.cc \ -NO_TEST_CASES=1 - include $(CPP_SRC_HOME)/mk/stdapp.mk From 2e48245b7941e05281df470228e3c1bb6ed2f918 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 24 Feb 2021 22:37:08 +0000 Subject: [PATCH 09/19] Review comments round 2 --- bin/pytorch_inference/CCommandParser.cc | 4 ++-- bin/pytorch_inference/CCommandParser.h | 10 +++++++++- bin/pytorch_inference/Main.cc | 13 +++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 69d186fe17..fe5eed7ca7 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -109,8 +109,8 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { m_Request.s_RequestId = doc[REQUEST_ID].GetString(); const rapidjson::Value& arr = doc[TOKENS]; - // overwrite any previous - m_Request.s_Tokens.resize(arr.Size()); + // wipe any previous + m_Request.s_Tokens.clear(); for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { m_Request.s_Tokens.push_back(itr->GetUint()); diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 66782e5589..3c12c1dbb6 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -22,13 +22,21 @@ namespace torch { //! for each parsed document. //! //! DESCRIPTION:\n -//! +//! Validation on the input documents is light. It is expected the input +//! comes from another process which tightly controls what is sent. +//! Input from an outside source that has not been sanitized should never +//! be sent. //! //! IMPLEMENTATION DECISIONS:\n //! RapidJSON will natively parse a stream of rootless JSON documents //! given the correct parse flags. The documents may be separated by //! whitespace but no other delineator is allowed. //! +//! The parsed request is a member of this class and will be modified when +//! a new command is parsed. The function handler passed to ioLoop must +//! not keep a reference to the request object beyond the scope of the +//! handle function as the request will change. +//! //! The input stream is held by reference. They must outlive objects of //! this class, which, in practice, means that the CIoManager object managing //! them must outlive this object. diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 628031c2ea..d2d4ef6694 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -43,14 +43,11 @@ torch::Tensor infer(torch::jit::script::Module& module, inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); - for (auto args : request.s_SecondaryArguments) { - torch::Tensor tensor = - torch::from_blob(static_cast(args.data()), - {1, static_cast(args.size())}, - at::dtype(torch::kInt32)) - .to(torch::kInt64); - - inputs.push_back(tensor); + for (auto& args : request.s_SecondaryArguments) { + inputs.emplace_back(torch::from_blob(static_cast(args.data()), + {1, static_cast(args.size())}, + at::dtype(torch::kInt32)) + .to(torch::kInt64)); } torch::NoGradGuard noGrad; From 4e77ca38036fe38e6a2577a4ee99479b9898badb Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 24 Feb 2021 22:37:54 +0000 Subject: [PATCH 10/19] Rebuild unit tests when the object files change --- bin/pytorch_inference/unittest/Makefile | 2 +- mk/stdboosttest.mk | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/pytorch_inference/unittest/Makefile b/bin/pytorch_inference/unittest/Makefile index 4916872652..deefef9bb5 100644 --- a/bin/pytorch_inference/unittest/Makefile +++ b/bin/pytorch_inference/unittest/Makefile @@ -15,7 +15,7 @@ USE_TORCH=1 all: build -LIBS=\ +LIB_OBJS=\ ../.objs/C*$(OBJECT_FILE_EXT) \ SRCS=\ diff --git a/mk/stdboosttest.mk b/mk/stdboosttest.mk index 83bb497e93..c1e2239c5e 100644 --- a/mk/stdboosttest.mk +++ b/mk/stdboosttest.mk @@ -18,8 +18,8 @@ CPPFLAGS+=$(INCLUDE_PATH) LDFLAGS:=$(UTLDFLAGS) $(LDFLAGS) $(LIB_PATH) $(ML_VER_LDFLAGS) PICFLAGS=$(PLATPIEFLAGS) -$(TARGET): $(OBJS) $(RESOURCE_FILE) - $(CXX) $(LINK_OUT_FLAG)$@ $(PDB_FLAGS) $(OBJS) $(RESOURCE_FILE) $(LDFLAGS) $(LIBS) +$(TARGET): $(OBJS) $(LIB_OBJS) $(RESOURCE_FILE) + $(CXX) $(LINK_OUT_FLAG)$@ $(PDB_FLAGS) $(OBJS) $(RESOURCE_FILE) $(LDFLAGS) $(LIBS) $(LIB_OBJS) build: $(TARGET) From 83e617242f717630c27b1ed52767c10aa1cbdc86 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 08:55:46 +0000 Subject: [PATCH 11/19] Revert "Review comments round 2" This reverts commit 2e48245b7941e05281df470228e3c1bb6ed2f918. --- bin/pytorch_inference/CCommandParser.cc | 4 ++-- bin/pytorch_inference/CCommandParser.h | 10 +--------- bin/pytorch_inference/Main.cc | 13 ++++++++----- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index fe5eed7ca7..69d186fe17 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -109,8 +109,8 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { m_Request.s_RequestId = doc[REQUEST_ID].GetString(); const rapidjson::Value& arr = doc[TOKENS]; - // wipe any previous - m_Request.s_Tokens.clear(); + // overwrite any previous + m_Request.s_Tokens.resize(arr.Size()); for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { m_Request.s_Tokens.push_back(itr->GetUint()); diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 3c12c1dbb6..66782e5589 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -22,21 +22,13 @@ namespace torch { //! for each parsed document. //! //! DESCRIPTION:\n -//! Validation on the input documents is light. It is expected the input -//! comes from another process which tightly controls what is sent. -//! Input from an outside source that has not been sanitized should never -//! be sent. +//! //! //! IMPLEMENTATION DECISIONS:\n //! RapidJSON will natively parse a stream of rootless JSON documents //! given the correct parse flags. The documents may be separated by //! whitespace but no other delineator is allowed. //! -//! The parsed request is a member of this class and will be modified when -//! a new command is parsed. The function handler passed to ioLoop must -//! not keep a reference to the request object beyond the scope of the -//! handle function as the request will change. -//! //! The input stream is held by reference. They must outlive objects of //! this class, which, in practice, means that the CIoManager object managing //! them must outlive this object. diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index d2d4ef6694..628031c2ea 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -43,11 +43,14 @@ torch::Tensor infer(torch::jit::script::Module& module, inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); - for (auto& args : request.s_SecondaryArguments) { - inputs.emplace_back(torch::from_blob(static_cast(args.data()), - {1, static_cast(args.size())}, - at::dtype(torch::kInt32)) - .to(torch::kInt64)); + for (auto args : request.s_SecondaryArguments) { + torch::Tensor tensor = + torch::from_blob(static_cast(args.data()), + {1, static_cast(args.size())}, + at::dtype(torch::kInt32)) + .to(torch::kInt64); + + inputs.push_back(tensor); } torch::NoGradGuard noGrad; From fc008875dcce8e2679acef2dbe79305dfbcce340 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 08:56:24 +0000 Subject: [PATCH 12/19] Revert "Rebuild unit tests when the object files change" This reverts commit 4e77ca38036fe38e6a2577a4ee99479b9898badb. --- bin/pytorch_inference/unittest/Makefile | 2 +- mk/stdboosttest.mk | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/pytorch_inference/unittest/Makefile b/bin/pytorch_inference/unittest/Makefile index deefef9bb5..4916872652 100644 --- a/bin/pytorch_inference/unittest/Makefile +++ b/bin/pytorch_inference/unittest/Makefile @@ -15,7 +15,7 @@ USE_TORCH=1 all: build -LIB_OBJS=\ +LIBS=\ ../.objs/C*$(OBJECT_FILE_EXT) \ SRCS=\ diff --git a/mk/stdboosttest.mk b/mk/stdboosttest.mk index c1e2239c5e..83bb497e93 100644 --- a/mk/stdboosttest.mk +++ b/mk/stdboosttest.mk @@ -18,8 +18,8 @@ CPPFLAGS+=$(INCLUDE_PATH) LDFLAGS:=$(UTLDFLAGS) $(LDFLAGS) $(LIB_PATH) $(ML_VER_LDFLAGS) PICFLAGS=$(PLATPIEFLAGS) -$(TARGET): $(OBJS) $(LIB_OBJS) $(RESOURCE_FILE) - $(CXX) $(LINK_OUT_FLAG)$@ $(PDB_FLAGS) $(OBJS) $(RESOURCE_FILE) $(LDFLAGS) $(LIBS) $(LIB_OBJS) +$(TARGET): $(OBJS) $(RESOURCE_FILE) + $(CXX) $(LINK_OUT_FLAG)$@ $(PDB_FLAGS) $(OBJS) $(RESOURCE_FILE) $(LDFLAGS) $(LIBS) build: $(TARGET) From 3a268fff119c4f6dbbd6f0c0ffaf0b681b556913 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 08:56:39 +0000 Subject: [PATCH 13/19] Revert "Revert "Review comments round 2"" This reverts commit 83e617242f717630c27b1ed52767c10aa1cbdc86. --- bin/pytorch_inference/CCommandParser.cc | 5 +++-- bin/pytorch_inference/CCommandParser.h | 10 +++++++++- bin/pytorch_inference/Main.cc | 13 +++++-------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 69d186fe17..9b69f504b6 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -109,8 +109,9 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { m_Request.s_RequestId = doc[REQUEST_ID].GetString(); const rapidjson::Value& arr = doc[TOKENS]; - // overwrite any previous - m_Request.s_Tokens.resize(arr.Size()); + // wipe any previous + m_Request.s_Tokens.clear(); + m_Request.s_Tokens.reserve(arr.Size()); for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { m_Request.s_Tokens.push_back(itr->GetUint()); diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 66782e5589..3c12c1dbb6 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -22,13 +22,21 @@ namespace torch { //! for each parsed document. //! //! DESCRIPTION:\n -//! +//! Validation on the input documents is light. It is expected the input +//! comes from another process which tightly controls what is sent. +//! Input from an outside source that has not been sanitized should never +//! be sent. //! //! IMPLEMENTATION DECISIONS:\n //! RapidJSON will natively parse a stream of rootless JSON documents //! given the correct parse flags. The documents may be separated by //! whitespace but no other delineator is allowed. //! +//! The parsed request is a member of this class and will be modified when +//! a new command is parsed. The function handler passed to ioLoop must +//! not keep a reference to the request object beyond the scope of the +//! handle function as the request will change. +//! //! The input stream is held by reference. They must outlive objects of //! this class, which, in practice, means that the CIoManager object managing //! them must outlive this object. diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 628031c2ea..d2d4ef6694 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -43,14 +43,11 @@ torch::Tensor infer(torch::jit::script::Module& module, inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); - for (auto args : request.s_SecondaryArguments) { - torch::Tensor tensor = - torch::from_blob(static_cast(args.data()), - {1, static_cast(args.size())}, - at::dtype(torch::kInt32)) - .to(torch::kInt64); - - inputs.push_back(tensor); + for (auto& args : request.s_SecondaryArguments) { + inputs.emplace_back(torch::from_blob(static_cast(args.data()), + {1, static_cast(args.size())}, + at::dtype(torch::kInt32)) + .to(torch::kInt64)); } torch::NoGradGuard noGrad; From 89e442450460ca8b3d7ec2dbe3568eb0e0839b62 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 11:42:09 +0000 Subject: [PATCH 14/19] Print errors from process in python script --- bin/pytorch_inference/evaluate.py | 43 ++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index 6feda24aa8..256bc3861c 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -84,29 +84,36 @@ def write_request(request, destination): def compare_results(expected, actual): - if expected['request_id'] != actual['request_id']: - print("request_ids do not match [{}], [{}]".format(expected['request_id'], actual['request_id']), flush=True) - return False + try: + if expected['request_id'] != actual['request_id']: + print("request_ids do not match [{}], [{}]".format(expected['request_id'], actual['request_id']), flush=True) + return False - if len(expected['inference']) != len(actual['inference']): - print("len(inference) does not match [{}], [{}]".format(len(expected['inference']), len(actual['inference'])), flush=True) - return False + if len(expected['inference']) != len(actual['inference']): + print("len(inference) does not match [{}], [{}]".format(len(expected['inference']), len(actual['inference'])), flush=True) + return False - for i in range(len(expected['inference'])): - expected_row = expected['inference'][i] - actual_row = actual['inference'][i] + for i in range(len(expected['inference'])): + expected_row = expected['inference'][i] + actual_row = actual['inference'][i] - if len(expected_row) != len(actual_row): - print("row [{}] lengths are not equal [{}], [{}]".format(i, len(expected_row), len(actual_row)), flush=True) - return False + if len(expected_row) != len(actual_row): + print("row [{}] lengths are not equal [{}], [{}]".format(i, len(expected_row), len(actual_row)), flush=True) + return False - are_close = True - for j in range(len(expected_row)): - are_close = are_close and math.isclose(expected_row[j], actual_row[j], rel_tol=1e-04) + are_close = True + for j in range(len(expected_row)): + are_close = are_close and math.isclose(expected_row[j], actual_row[j], rel_tol=1e-04) + + if are_close == False: + print("row [{}] values are not close {}, {}".format(i, expected_row, actual_row), flush=True) + return False + except KeyError as e: + print("ERROR: comparing results {}. Actual = {}".format(e, actual)) + return False + + return True - if are_close == False: - print("row [{}] values are not close {}, {}".format(i, expected_row, actual_row), flush=True) - return False def main(): From 2df2cb9d78ba96e3574d417934cd9defb8323835 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 11:45:48 +0000 Subject: [PATCH 15/19] Parse tokens as Uint64 to avoid copying tensor --- bin/pytorch_inference/CCommandParser.cc | 6 +++--- bin/pytorch_inference/CCommandParser.h | 8 ++++---- bin/pytorch_inference/Main.cc | 10 ++++------ .../unittest/CCommandParserTest.cc | 16 ++++++++-------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 9b69f504b6..4b47cb0959 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -114,7 +114,7 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { m_Request.s_Tokens.reserve(arr.Size()); for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { - m_Request.s_Tokens.push_back(itr->GetUint()); + m_Request.s_Tokens.push_back(itr->GetUint64()); } std::uint64_t varCount{1}; @@ -122,11 +122,11 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { // wipe any previous m_Request.s_SecondaryArguments.clear(); - TUint32Vec arg; + TUint64Vec arg; while (doc.HasMember(varArgName)) { const rapidjson::Value& v = doc[varArgName]; for (auto itr = v.Begin(); itr != v.End(); ++itr) { - arg.push_back(itr->GetUint()); + arg.push_back(itr->GetUint64()); } m_Request.s_SecondaryArguments.push_back(arg); diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 3c12c1dbb6..264a05dfa0 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -47,13 +47,13 @@ class CCommandParser { static const std::string TOKENS; static const std::string VAR_ARG_PREFIX; - using TUint32Vec = std::vector; - using TUint32VecVec = std::vector; + using TUint64Vec = std::vector; + using TUint64VecVec = std::vector; struct SRequest { std::string s_RequestId; - TUint32Vec s_Tokens; - TUint32VecVec s_SecondaryArguments; + TUint64Vec s_Tokens; + TUint64VecVec s_SecondaryArguments; void clear(); }; diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index d2d4ef6694..2724a446f8 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -36,18 +36,16 @@ torch::Tensor infer(torch::jit::script::Module& module, torch::Tensor tokensTensor = torch::from_blob(static_cast(request.s_Tokens.data()), {1, static_cast(request.s_Tokens.size())}, - at::dtype(torch::kInt32)) - .to(torch::kInt64); + at::dtype(torch::kInt64)); std::vector inputs; inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); for (auto& args : request.s_SecondaryArguments) { - inputs.emplace_back(torch::from_blob(static_cast(args.data()), - {1, static_cast(args.size())}, - at::dtype(torch::kInt32)) - .to(torch::kInt64)); + inputs.emplace_back(torch::from_blob( + static_cast(args.data()), + {1, static_cast(args.size())}, at::dtype(torch::kInt64))); } torch::NoGradGuard noGrad; diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index 4e83fc19a6..b9145d8e98 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -27,14 +27,14 @@ BOOST_AUTO_TEST_CASE(testParsingStream) { BOOST_REQUIRE_EQUAL(2, parsed.size()); { BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); - ml::torch::CCommandParser::TUint32Vec expected{1, 2, 3}; + ml::torch::CCommandParser::TUint64Vec expected{1, 2, 3}; BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), parsed[0].s_Tokens.end(), expected.begin(), expected.end()); } { BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); - ml::torch::CCommandParser::TUint32Vec expected{4, 5}; + ml::torch::CCommandParser::TUint64Vec expected{4, 5}; BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), parsed[1].s_Tokens.end(), expected.begin(), expected.end()); @@ -98,10 +98,10 @@ BOOST_AUTO_TEST_CASE(testParsingVariableArguments) { BOOST_REQUIRE_EQUAL(2, parsed.size()); { - ml::torch::CCommandParser::TUint32Vec expectedArg1{0, 0}; - ml::torch::CCommandParser::TUint32Vec expectedArg2{0, 1}; + ml::torch::CCommandParser::TUint64Vec expectedArg1{0, 0}; + ml::torch::CCommandParser::TUint64Vec expectedArg2{0, 1}; - ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[0].s_SecondaryArguments; + ml::torch::CCommandParser::TUint64VecVec extraArgs = parsed[0].s_SecondaryArguments; BOOST_REQUIRE_EQUAL(2, extraArgs.size()); BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), @@ -110,10 +110,10 @@ BOOST_AUTO_TEST_CASE(testParsingVariableArguments) { expectedArg2.begin(), expectedArg2.end()); } { - ml::torch::CCommandParser::TUint32Vec expectedArg1{1, 0}; - ml::torch::CCommandParser::TUint32Vec expectedArg2{1, 1}; + ml::torch::CCommandParser::TUint64Vec expectedArg1{1, 0}; + ml::torch::CCommandParser::TUint64Vec expectedArg2{1, 1}; - ml::torch::CCommandParser::TUint32VecVec extraArgs = parsed[1].s_SecondaryArguments; + ml::torch::CCommandParser::TUint64VecVec extraArgs = parsed[1].s_SecondaryArguments; BOOST_REQUIRE_EQUAL(2, extraArgs.size()); BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), From b819c197035c0ad65e584e46eaaa9fb8e790b784 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 13:06:04 +0000 Subject: [PATCH 16/19] Validate token arrays contain unsigned ints --- bin/pytorch_inference/CCommandParser.cc | 25 +++++++++ bin/pytorch_inference/CCommandParser.h | 1 + .../unittest/CCommandParserTest.cc | 52 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index 4b47cb0959..a057ff9443 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -77,6 +77,11 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return false; } + if (doc[REQUEST_ID].IsString() == false) { + LOG_ERROR(<< "Invalid command: [" << REQUEST_ID << "] field is not a string"); + return false; + } + if (doc.HasMember(TOKENS) == false) { LOG_ERROR(<< "Invalid command: missing field [" << TOKENS << "]"); return false; @@ -88,6 +93,11 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return false; } + if (checkArrayContainsUInts(tokens) == false) { + LOG_ERROR(<< "Invalid command: array [" << TOKENS << "] contains values that are not unsigned integers"); + return false; + } + // check optional args std::uint64_t varCount{1}; std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); @@ -98,6 +108,11 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return false; } + if (checkArrayContainsUInts(value) == false) { + LOG_ERROR(<< "Invalid command: array [" << varArgName << "] contains values that are not unsigned integers"); + return false; + } + ++varCount; varArgName = VAR_ARG_PREFIX + std::to_string(varCount); } @@ -105,6 +120,16 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { return true; } +bool CCommandParser::checkArrayContainsUInts(const rapidjson::Value& arr) const { + bool allInts{true}; + + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { + allInts = allInts && itr->IsUint64(); + } + + return allInts; +} + void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { m_Request.s_RequestId = doc[REQUEST_ID].GetString(); diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 264a05dfa0..4888517c84 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -71,6 +71,7 @@ class CCommandParser { private: bool validateJson(const rapidjson::Document& doc) const; + bool checkArrayContainsUInts(const rapidjson::Value& arr) const; void jsonToRequest(const rapidjson::Document& doc); private: diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index b9145d8e98..de1e9de779 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -58,6 +58,58 @@ BOOST_AUTO_TEST_CASE(testParsingInvalidDoc) { BOOST_REQUIRE_EQUAL(0, parsed.size()); } +BOOST_AUTO_TEST_CASE(testParsingInvalidRequestId) { + + std::vector parsed; + + std::string command{"{\"request_id\": 1}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingTokenArrayNotInts) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"tokens_should_be_uints\", \"tokens\": [\"a\", \"b\", \"c\"]}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingTokenVarArgsNotInts) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"bad\", \"tokens\": [1, 2], \"arg_1\": [\"a\", \"b\"]}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + + BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { std::vector parsed; From 2910f87fff72e86c1a0d9f38ecc3603bb79f95d9 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 25 Feb 2021 13:12:03 +0000 Subject: [PATCH 17/19] update docs --- bin/pytorch_inference/CCommandParser.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h index 4888517c84..1730877e88 100644 --- a/bin/pytorch_inference/CCommandParser.h +++ b/bin/pytorch_inference/CCommandParser.h @@ -22,12 +22,12 @@ namespace torch { //! for each parsed document. //! //! DESCRIPTION:\n -//! Validation on the input documents is light. It is expected the input -//! comes from another process which tightly controls what is sent. -//! Input from an outside source that has not been sanitized should never -//! be sent. //! //! IMPLEMENTATION DECISIONS:\n +//! Validation exists to prevent memory violations from malicious input, +//! but no more. The caller is responsible for sending input that will +//! not result in errors from libTorch and will produce meaningful results. +//! //! RapidJSON will natively parse a stream of rootless JSON documents //! given the correct parse flags. The documents may be separated by //! whitespace but no other delineator is allowed. From a83d6677af89edde27902a386d3661841b926d17 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 1 Mar 2021 09:25:10 +0000 Subject: [PATCH 18/19] Network is required --- bin/pytorch_inference/Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/pytorch_inference/Makefile b/bin/pytorch_inference/Makefile index f41e540960..ea1913e756 100644 --- a/bin/pytorch_inference/Makefile +++ b/bin/pytorch_inference/Makefile @@ -13,6 +13,7 @@ ML_LIBS=$(LIB_ML_CORE) $(LIB_ML_API) USE_BOOST=1 USE_BOOST_PROGRAMOPTIONS_LIBS=1 +USE_NET=1 USE_RAPIDJSON=1 USE_TORCH=1 From ed8dbffe5248d511fc70ee2118ab2e33aba13316 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 1 Mar 2021 10:01:21 +0000 Subject: [PATCH 19/19] clang format --- bin/pytorch_inference/CCommandParser.cc | 6 ++++-- bin/pytorch_inference/unittest/CCommandParserTest.cc | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc index a057ff9443..77de247b41 100644 --- a/bin/pytorch_inference/CCommandParser.cc +++ b/bin/pytorch_inference/CCommandParser.cc @@ -94,7 +94,8 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { } if (checkArrayContainsUInts(tokens) == false) { - LOG_ERROR(<< "Invalid command: array [" << TOKENS << "] contains values that are not unsigned integers"); + LOG_ERROR(<< "Invalid command: array [" << TOKENS + << "] contains values that are not unsigned integers"); return false; } @@ -109,7 +110,8 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc) const { } if (checkArrayContainsUInts(value) == false) { - LOG_ERROR(<< "Invalid command: array [" << varArgName << "] contains values that are not unsigned integers"); + LOG_ERROR(<< "Invalid command: array [" << varArgName + << "] contains values that are not unsigned integers"); return false; } diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index de1e9de779..e9fee56a63 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -109,7 +109,6 @@ BOOST_AUTO_TEST_CASE(testParsingTokenVarArgsNotInts) { BOOST_REQUIRE_EQUAL(0, parsed.size()); } - BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { std::vector parsed;