diff --git a/bin/pytorch_inference/CCommandParser.cc b/bin/pytorch_inference/CCommandParser.cc new file mode 100644 index 0000000000..77de247b41 --- /dev/null +++ b/bin/pytorch_inference/CCommandParser.cc @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#include "CCommandParser.h" + +#include + +#include +#include +#include +#include + +#include + +namespace rapidjson { + +std::ostream& operator<<(std::ostream& os, const rapidjson::Document& doc) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + return os << buffer.GetString(); +} +} + +namespace ml { +namespace torch { + +const std::string CCommandParser::REQUEST_ID{"request_id"}; +const std::string CCommandParser::TOKENS{"tokens"}; +const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"}; + +CCommandParser::CCommandParser(std::istream& strmIn) : m_StrmIn(strmIn) { +} + +bool CCommandParser::ioLoop(const TRequestHandlerFunc& requestHandler) { + + rapidjson::IStreamWrapper isw{m_StrmIn}; + + while (true) { + rapidjson::Document doc; + rapidjson::ParseResult parseResult = + doc.ParseStream(isw); + + if (static_cast(parseResult) == false) { + if (m_StrmIn.eof()) { + break; + } + + LOG_ERROR(<< "Error parsing command from JSON: " + << rapidjson::GetParseError_En(parseResult.Code()) + << ". At offset: " << parseResult.Offset()); + + return false; + } + + if (validateJson(doc) == false) { + continue; + } + + LOG_TRACE(<< "Inference command: " << doc); + jsonToRequest(doc); + if (requestHandler(m_Request) == false) { + LOG_ERROR(<< "Request handler forced exit"); + return false; + } + } + + return true; +} + +bool CCommandParser::validateJson(const rapidjson::Document& doc) const { + if (doc.HasMember(REQUEST_ID) == false) { + LOG_ERROR(<< "Invalid command: missing field [" << REQUEST_ID << "]"); + return false; + } + + if (doc[REQUEST_ID].IsString() == false) { + LOG_ERROR(<< "Invalid command: [" << REQUEST_ID << "] field is not a string"); + return false; + } + + if (doc.HasMember(TOKENS) == false) { + LOG_ERROR(<< "Invalid command: missing field [" << TOKENS << "]"); + return false; + } + + const rapidjson::Value& tokens = doc[TOKENS]; + if (tokens.IsArray() == false) { + LOG_ERROR(<< "Invalid command: expected an array [" << TOKENS << "]"); + return false; + } + + if (checkArrayContainsUInts(tokens) == false) { + LOG_ERROR(<< "Invalid command: array [" << TOKENS + << "] contains values that are not unsigned integers"); + return false; + } + + // check optional args + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + while (doc.HasMember(varArgName)) { + const rapidjson::Value& value = doc[varArgName]; + if (value.IsArray() == false) { + LOG_ERROR(<< "Invalid command: argument [" << varArgName << "] is not an array"); + return false; + } + + if (checkArrayContainsUInts(value) == false) { + LOG_ERROR(<< "Invalid command: array [" << varArgName + << "] contains values that are not unsigned integers"); + return false; + } + + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } + + return true; +} + +bool CCommandParser::checkArrayContainsUInts(const rapidjson::Value& arr) const { + bool allInts{true}; + + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { + allInts = allInts && itr->IsUint64(); + } + + return allInts; +} + +void CCommandParser::jsonToRequest(const rapidjson::Document& doc) { + + m_Request.s_RequestId = doc[REQUEST_ID].GetString(); + const rapidjson::Value& arr = doc[TOKENS]; + // wipe any previous + m_Request.s_Tokens.clear(); + m_Request.s_Tokens.reserve(arr.Size()); + + for (auto itr = arr.Begin(); itr != arr.End(); ++itr) { + m_Request.s_Tokens.push_back(itr->GetUint64()); + } + + std::uint64_t varCount{1}; + std::string varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + + // wipe any previous + m_Request.s_SecondaryArguments.clear(); + TUint64Vec arg; + while (doc.HasMember(varArgName)) { + const rapidjson::Value& v = doc[varArgName]; + for (auto itr = v.Begin(); itr != v.End(); ++itr) { + arg.push_back(itr->GetUint64()); + } + + m_Request.s_SecondaryArguments.push_back(arg); + arg.clear(); + ++varCount; + varArgName = VAR_ARG_PREFIX + std::to_string(varCount); + } +} +} +} diff --git a/bin/pytorch_inference/CCommandParser.h b/bin/pytorch_inference/CCommandParser.h new file mode 100644 index 0000000000..1730877e88 --- /dev/null +++ b/bin/pytorch_inference/CCommandParser.h @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#ifndef INCLUDED_ml_torch_CCommandParser_h +#define INCLUDED_ml_torch_CCommandParser_h + +#include + +#include +#include +#include +#include + +namespace ml { +namespace torch { + +//! \brief +//! Reads JSON documents from a stream calling the request handler +//! for each parsed document. +//! +//! DESCRIPTION:\n +//! +//! IMPLEMENTATION DECISIONS:\n +//! Validation exists to prevent memory violations from malicious input, +//! but no more. The caller is responsible for sending input that will +//! not result in errors from libTorch and will produce meaningful results. +//! +//! RapidJSON will natively parse a stream of rootless JSON documents +//! given the correct parse flags. The documents may be separated by +//! whitespace but no other delineator is allowed. +//! +//! The parsed request is a member of this class and will be modified when +//! a new command is parsed. The function handler passed to ioLoop must +//! not keep a reference to the request object beyond the scope of the +//! handle function as the request will change. +//! +//! The input stream is held by reference. They must outlive objects of +//! this class, which, in practice, means that the CIoManager object managing +//! them must outlive this object. +//! +class CCommandParser { +public: + static const std::string REQUEST_ID; + static const std::string TOKENS; + static const std::string VAR_ARG_PREFIX; + + using TUint64Vec = std::vector; + using TUint64VecVec = std::vector; + + struct SRequest { + std::string s_RequestId; + TUint64Vec s_Tokens; + TUint64VecVec s_SecondaryArguments; + + void clear(); + }; + + using TRequestHandlerFunc = std::function; + +public: + CCommandParser(std::istream& strmIn); + + //! Pass input to the processor until it's consumed as much as it can. + bool ioLoop(const TRequestHandlerFunc& requestHandler); + + CCommandParser(const CCommandParser&) = delete; + CCommandParser& operator=(const CCommandParser&) = delete; + +private: + bool validateJson(const rapidjson::Document& doc) const; + bool checkArrayContainsUInts(const rapidjson::Value& arr) const; + void jsonToRequest(const rapidjson::Document& doc); + +private: + std::istream& m_StrmIn; + SRequest m_Request; +}; +} +} + +#endif // INCLUDED_ml_torch_CCommandParser_h diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 55764490ca..2724a446f8 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -8,100 +8,116 @@ #include #include #include -#include + +#include #include #include -#include - #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" +#include "CCommandParser.h" #include #include -#include - #include #include -// For ntohl -#ifdef Windows -#include -#else -#include -#endif +namespace { +const std::string INFERENCE{"inference"}; +const std::string ERROR{"error"}; +} -using TFloatVec = std::vector; +torch::Tensor infer(torch::jit::script::Module& module, + ml::torch::CCommandParser::SRequest& request) { -torch::Tensor infer(torch::jit::script::Module& module, TFloatVec& data) { torch::Tensor tokensTensor = - torch::from_blob(data.data(), {1, static_cast(data.size())}) - .to(torch::kInt64); + torch::from_blob(static_cast(request.s_Tokens.data()), + {1, static_cast(request.s_Tokens.size())}, + at::dtype(torch::kInt64)); + std::vector inputs; + inputs.reserve(1 + request.s_SecondaryArguments.size()); inputs.push_back(tokensTensor); - inputs.push_back(torch::ones({1, static_cast(data.size())})); // attention mask - inputs.push_back( - torch::zeros({1, static_cast(data.size())}).to(torch::kInt64)); // token type ids - inputs.push_back(torch::arange(static_cast(data.size())).to(torch::kInt64)); // position ids + + for (auto& args : request.s_SecondaryArguments) { + inputs.emplace_back(torch::from_blob( + static_cast(args.data()), + {1, static_cast(args.size())}, at::dtype(torch::kInt64))); + } torch::NoGradGuard noGrad; auto tuple = module.forward(inputs).toTuple(); - auto predictions = tuple->elements()[0].toTensor(); - - return torch::argmax(predictions, 2); -} - -bool readUInt32(std::istream& stream, std::uint32_t& num) { - std::uint32_t netNum{0}; - stream.read(reinterpret_cast(&netNum), sizeof(std::uint32_t)); - num = ntohl(netNum); - return stream.good(); + return tuple->elements()[0].toTensor(); } -boost::optional readTokens(std::istream& inputStream) { - if (inputStream.eof()) { - LOG_ERROR(<< "Unexpected end of stream reading tokens"); - return boost::none; +void writePrediction(const torch::Tensor& prediction, + const std::string& requestId, + std::ostream& outputStream) { + + torch::Tensor view; + auto sizes = prediction.sizes(); + // Some models return a 3D tensor in which case + // the first dimension must have size == 1 + if (sizes.size() == 3 && sizes[0] == 1) { + view = prediction[0]; + } else { + view = prediction; } - // return a float vector rather than integers because - // float is needed to create the tensor - TFloatVec tokens; - std::uint32_t numTokens; - if (readUInt32(inputStream, numTokens) == false) { - LOG_ERROR(<< "Error reading the number of tokens"); - return boost::none; - } + // creating the accessor will throw if view does not + // have exactly 2 dimensions. Do this before writing + // any output so the error message isn't mingled with + // a partial result + auto accessor = view.accessor(); - for (uint32_t i = 0; i < numTokens; ++i) { - std::uint32_t token; - if (readUInt32(inputStream, token) == false) { - LOG_ERROR(<< "Error reading token"); - return boost::none; + rapidjson::OStreamWrapper writeStream(outputStream); + ml::core::CRapidJsonLineWriter jsonWriter(writeStream); + jsonWriter.StartObject(); + jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); + jsonWriter.String(requestId); + jsonWriter.Key(INFERENCE); + jsonWriter.StartArray(); + + for (int i = 0; i < accessor.size(0); ++i) { + jsonWriter.StartArray(); + for (int j = 0; j < accessor.size(1); ++j) { + jsonWriter.Double(static_cast(accessor[i][j])); } - tokens.push_back(token); + jsonWriter.EndArray(); } - return tokens; + jsonWriter.EndArray(); + jsonWriter.EndObject(); } -void writePrediction(torch::Tensor& prediction, std::ostream& outputStream) { +void writeError(const std::string& requestId, const std::string& message, std::ostream& outputStream) { rapidjson::OStreamWrapper writeStream(outputStream); ml::core::CRapidJsonLineWriter jsonWriter(writeStream); jsonWriter.StartObject(); - jsonWriter.Key("inference"); - jsonWriter.StartArray(); - auto arr = prediction.accessor(); - for (int i = 0; i < arr.size(1); i++) { - jsonWriter.Int64(arr[0][i]); - } - jsonWriter.EndArray(); + jsonWriter.Key(ml::torch::CCommandParser::REQUEST_ID); + jsonWriter.String(requestId); + jsonWriter.Key(ERROR); + jsonWriter.String(message); jsonWriter.EndObject(); } +bool handleRequest(ml::torch::CCommandParser::SRequest& request, + torch::jit::script::Module& module, + std::ostream& outputStream) { + + try { + torch::Tensor results = infer(module, request); + writePrediction(results, request.s_RequestId, outputStream); + } catch (std::runtime_error& e) { + writeError(request.s_RequestId, e.what(), outputStream); + } + + return true; +} + int main(int argc, char** argv) { // command line options std::string modelId; @@ -186,14 +202,11 @@ int main(int argc, char** argv) { return EXIT_FAILURE; } - boost::optional tokens = readTokens(ioMgr.inputStream()); - if (!tokens) { - LOG_ERROR(<< "Cannot infer, failed to read input tokens"); - return EXIT_FAILURE; - } + ml::torch::CCommandParser commandParser{ioMgr.inputStream()}; - torch::Tensor results = infer(module, *tokens); - writePrediction(results, ioMgr.outputStream()); + commandParser.ioLoop([&module, &ioMgr](ml::torch::CCommandParser::SRequest& request) { + return handleRequest(request, module, ioMgr.outputStream()); + }); LOG_DEBUG(<< "ML Torch model prototype exiting"); diff --git a/bin/pytorch_inference/Makefile b/bin/pytorch_inference/Makefile index 34e7ea5700..ea1913e756 100644 --- a/bin/pytorch_inference/Makefile +++ b/bin/pytorch_inference/Makefile @@ -25,8 +25,7 @@ SRCS= \ Main.cc \ CBufferedIStreamAdapter.cc \ CCmdLineParser.cc \ - -NO_TEST_CASES=1 + CCommandParser.cc \ include $(CPP_SRC_HOME)/mk/stdapp.mk diff --git a/bin/pytorch_inference/evaluate.py b/bin/pytorch_inference/evaluate.py index 97f03845e8..256bc3861c 100644 --- a/bin/pytorch_inference/evaluate.py +++ b/bin/pytorch_inference/evaluate.py @@ -3,17 +3,32 @@ app, together with the encoded tokens from the input_tokens file. Then it checks the model's response matches the expected. -This script first prepares the input files, then launches the C++ -pytorch_inference program which handles them in batch. - +This script reads the input files and expected outputs, then +launches the C++ pytorch_inference program which handles and +sends the request. The response is checked against the expected +defined in the test file + +The test file must have the format: +[ + { + "input": {"request_id": "foo", "tokens": [1, 2, 3]}, + "expected_output": {"request_id": "foo", "inference": [1, 2, 3]} + }, + ... +] + + +EXAMPLES +-------- Run this script with input from one of the example directories, for example: -python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/input.json examples/ner/expected_response.json +python3 evaluate.py /path/to/conll03_traced_ner.pt examples/ner/test_run.json ''' import argparse import json +import math import os import platform import stat @@ -22,8 +37,8 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('model', help='A TorchScript model with .pt extension') - parser.add_argument('input_tokens', help='JSON file with an array field "tokens"') - parser.add_argument('expected_output', help='Expected output. Another JSON file with an array field "tokens"') + parser.add_argument('test_file', help='JSON file with an array of objects each ' + 'containing "input" and "expected_output" subobjects') parser.add_argument('--restore_file', default='restore_file') parser.add_argument('--input_file', default='input_file') parser.add_argument('--output_file', default='output_file') @@ -64,12 +79,42 @@ def stream_file(source, destination) : destination.write(piece) -def write_tokens(destination, tokens): +def write_request(request, destination): + json.dump(request, destination) + + +def compare_results(expected, actual): + try: + if expected['request_id'] != actual['request_id']: + print("request_ids do not match [{}], [{}]".format(expected['request_id'], actual['request_id']), flush=True) + return False + + if len(expected['inference']) != len(actual['inference']): + print("len(inference) does not match [{}], [{}]".format(len(expected['inference']), len(actual['inference'])), flush=True) + return False + + for i in range(len(expected['inference'])): + expected_row = expected['inference'][i] + actual_row = actual['inference'][i] + + if len(expected_row) != len(actual_row): + print("row [{}] lengths are not equal [{}], [{}]".format(i, len(expected_row), len(actual_row)), flush=True) + return False + + are_close = True + for j in range(len(expected_row)): + are_close = are_close and math.isclose(expected_row[j], actual_row[j], rel_tol=1e-04) + + if are_close == False: + print("row [{}] values are not close {}, {}".format(i, expected_row, actual_row), flush=True) + return False + except KeyError as e: + print("ERROR: comparing results {}. Actual = {}".format(e, actual)) + return False + + return True + - num_tokens = len(tokens) - destination.write(num_tokens.to_bytes(4, 'big')) - for token in tokens: - destination.write(token.to_bytes(4, 'big')) def main(): @@ -90,36 +135,53 @@ def main(): with open(args.model, 'rb') as source_file: stream_file(source_file, restore_file) - with open(args.input_file, 'wb') as input_file: - with open(args.input_tokens) as token_file: - input_tokens = json.load(token_file) + with open(args.input_file, 'w') as input_file: + with open(args.test_file) as test_file: + test_evaluation = json.load(test_file) print("writing query", flush=True) - write_tokens(input_file, input_tokens['tokens']) + for doc in test_evaluation: + write_request(doc['input'], input_file) - # one shot inference launch_pytorch_app(args) - print("reading results", flush=True) - with open(args.expected_output) as expected_output_file: - expected = json.load(expected_output_file) - + print() + print("reading results...", flush=True) with open(args.output_file) as output_file: - results = json.load(output_file) - # compare to expected - if results['inference'] == expected['tokens']: - print('inference results match expected results') - else: - print('ERROR: inference results do not match expected results') - print(results) - - finally: + doc_count = 0 + results_match = True + # output is NDJSON + for jsonline in output_file: + try: + result = json.loads(jsonline) + except: + print("Error parsing json: ", jsonline) + return + + expected = test_evaluation[doc_count]['expected_output'] + + # compare to expected + if compare_results(expected, result) == False: + print() + print('ERROR: inference result [{}] does not match expected results'.format(doc_count)) + print() + results_match = False + + doc_count = doc_count +1 + + if results_match: + print() + print('SUCCESS: inference results match expected', flush=True) + print() + + finally: if os.path.isfile(args.restore_file): os.remove(args.restore_file) if os.path.isfile(args.input_file): os.remove(args.input_file) if os.path.isfile(args.output_file): - os.remove(args.output_file) + os.remove(args.output_file) + if __name__ == "__main__": diff --git a/bin/pytorch_inference/examples/ner/expected_response.json b/bin/pytorch_inference/examples/ner/expected_response.json deleted file mode 100644 index d908590edc..0000000000 --- a/bin/pytorch_inference/examples/ner/expected_response.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "tokens": [0, 6, 6, 6, 6, 0, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 8, 8, 0, 0] -} diff --git a/bin/pytorch_inference/examples/ner/input.json b/bin/pytorch_inference/examples/ner/input.json deleted file mode 100644 index a35aad7d4f..0000000000 --- a/bin/pytorch_inference/examples/ner/input.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102] -} diff --git a/bin/pytorch_inference/examples/ner/test_run.json b/bin/pytorch_inference/examples/ner/test_run.json new file mode 100644 index 0000000000..07cb92ad86 --- /dev/null +++ b/bin/pytorch_inference/examples/ner/test_run.json @@ -0,0 +1,24 @@ +[ + { + "source_text": "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge.", + "input": { + "request_id": "one", + "tokens": [101, 20164, 10932, 10289, 3561, 119, 1110, 170, 1419, 1359, 1107, 1203, 1365, 1392, 119, 2098, 3834, 1132, 1107, 141, 25810, 23904, 117, 3335, 1304, 1665, 20316, 1106, 1103, 6545, 3640, 119, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] + }, + "expected_output": {"request_id": "one", "inference": [[9.445713996887207,-2.5186009407043459,-1.6298730373382569,-2.0444109439849855,-2.2338523864746095,-1.7153222560882569,-0.43147391080856326,-2.0465660095214845,1.1865826845169068],[0.792238175868988,-2.8607232570648195,-0.7416678071022034,-3.2710094451904299,-0.8123045563697815,-1.7240391969680787,9.034605026245118,-2.5430126190185549,-0.4717179536819458],[1.9150646924972535,-2.288753032684326,0.10271981358528137,-3.2836503982543947,-0.05480371415615082,-1.2911514043807984,6.839619159698486,-2.345627546310425,-0.7598909139633179],[0.6548398733139038,-2.8548803329467775,-0.22843939065933228,-3.415384531021118,0.2715289294719696,-1.3837164640426636,7.829202175140381,-2.879647970199585,-0.38423842191696169],[1.1488219499588013,-2.967827796936035,-0.8474732041358948,-3.521238327026367,-1.603212594985962,-1.894327998161316,9.006779670715332,-2.138070821762085,-0.03918422386050224],[6.845504283905029,-2.78581166267395,-1.3969510793685914,-3.771115779876709,-0.6792833209037781,-1.7580525875091553,4.772181510925293,-2.548997640609741,-0.09716915339231491],[10.898977279663086,-2.2232580184936525,-0.8633253574371338,-2.6009225845336916,-1.4352948665618897,-1.5839565992355347,0.8288266062736511,-2.17624831199646,-0.6236083507537842],[11.136425018310547,-2.27126145362854,-0.8518160581588745,-2.6662473678588869,-1.4811440706253052,-1.636898398399353,0.7194949388504028,-2.180851697921753,-0.6165107488632202],[10.732818603515625,-2.5183522701263429,-0.7929745316505432,-2.885248899459839,-1.4805352687835694,-1.9814093112945557,1.8594862222671509,-2.201752185821533,-0.5392075777053833],[10.873189926147461,-2.306696891784668,-0.80250483751297,-2.8811705112457277,-1.4863512516021729,-1.8671237230300904,0.9172942638397217,-2.1708977222442629,-0.20861388742923737],[10.7890625,-2.5134470462799074,-0.7714695334434509,-2.6941256523132326,-1.6425325870513917,-1.7182999849319459,0.4641777276992798,-2.154723882675171,0.11145751923322678],[0.2922537624835968,-2.2648208141326906,-1.1112738847732545,-2.363692283630371,-1.1536873579025269,-2.16975474357605,0.1696488857269287,-1.7420969009399415,8.619940757751465],[-0.05837444216012955,-2.2322306632995607,-1.2456691265106202,-2.393845796585083,-1.0231869220733643,-2.1384024620056154,0.36509764194488528,-1.6257678270339966,8.386515617370606],[-0.40553534030914309,-2.0154948234558107,-0.9165335297584534,-2.3343887329101564,-1.008164644241333,-2.404677152633667,-0.25008249282836916,-0.8966042399406433,8.402207374572754],[9.445691108703614,-2.5186126232147219,-1.6298660039901734,-2.044417381286621,-2.2338428497314455,-1.7153276205062867,-0.43144431710243227,-2.0465352535247804,1.1865744590759278],[10.766592025756836,-2.483231782913208,-0.7467246651649475,-2.816481590270996,-1.5419433116912842,-1.7198463678359986,0.8594813942909241,-2.4609286785125734,-0.38469821214675906],[11.087119102478028,-2.369004726409912,-1.0071995258331299,-2.8232693672180177,-1.5680956840515137,-1.817236304283142,0.4988541603088379,-2.2202258110046388,-0.38811907172203066],[11.105840682983399,-2.1901695728302,-0.8257009387016296,-2.5521652698516847,-1.4372351169586182,-1.8002771139144898,0.3943113684654236,-2.086343288421631,-0.3007338047027588],[10.404936790466309,-2.4211761951446535,-0.6679232120513916,-2.7030136585235597,-1.702601671218872,-1.7996376752853394,0.34499865770339968,-1.9718276262283326,0.5988494753837586],[0.04125037044286728,-2.1516361236572267,-0.6540370583534241,-2.9895687103271486,-0.02026049792766571,-2.607717990875244,0.7579661011695862,-2.63039493560791,5.886351585388184],[0.41723012924194338,-1.6220901012420655,-0.6683351397514343,-3.1531808376312258,-0.7741800546646118,-2.0307929515838625,1.8664846420288087,-2.6268656253814699,4.985666275024414],[1.1279197931289673,-2.7916486263275148,-1.4165221452713013,-3.2945871353149416,-0.24917204678058625,-2.006779432296753,2.0364503860473635,-1.8233060836791993,4.838916301727295],[10.858016967773438,-2.210500478744507,-0.7385379672050476,-2.6571877002716066,-1.7322543859481812,-1.5595816373825074,0.38503119349479678,-2.160937547683716,-0.2505747079849243],[11.213296890258789,-2.2317054271698,-0.7556230425834656,-2.6072943210601808,-1.5168147087097169,-1.735519289970398,-0.2140665054321289,-1.9741536378860474,-0.3461124897003174],[11.128496170043946,-2.286959409713745,-0.6008748412132263,-2.680781841278076,-1.637223482131958,-1.8050802946090699,0.04603640362620354,-2.0937740802764894,-0.39641517400741579],[5.765664100646973,-2.3779571056365969,0.5095824599266052,-3.020923137664795,-0.835684597492218,-2.7159736156463625,1.3586626052856446,-2.0834007263183595,1.0498584508895875],[10.664918899536133,-2.2987592220306398,-0.45505377650260928,-2.8677072525024416,-1.7201433181762696,-1.8049241304397584,0.18862102925777436,-2.447751760482788,-0.1371004730463028],[10.836210250854493,-1.9615545272827149,-0.39490482211112978,-2.8520052433013918,-1.5771089792251588,-1.796066403388977,-0.2166888266801834,-2.1943840980529787,-0.14454032480716706],[11.126060485839844,-2.092292308807373,-0.7535718679428101,-2.7322869300842287,-1.737945318222046,-1.915206789970398,-0.05707954242825508,-2.2171850204467775,0.04100846126675606],[-0.9580754637718201,-1.4011685848236085,0.3523239195346832,-2.691606044769287,-0.3759573698043823,-1.531064510345459,0.32374972105026247,-2.8274593353271486,5.610057830810547],[-0.3176293671131134,-1.7536091804504395,0.2829216420650482,-2.8517212867736818,-0.298392117023468,-2.5902841091156008,0.05927262827754021,-2.667252779006958,6.318171977996826],[9.445697784423829,-2.5185813903808595,-1.629878282546997,-2.0444087982177736,-2.233858108520508,-1.7153252363204957,-0.4314804971218109,-2.0465519428253176,1.1865817308425904],[9.445714950561524,-2.5186026096343996,-1.6298726797103882,-2.0444114208221437,-2.2338526248931886,-1.7153226137161255,-0.431472510099411,-2.0465657711029054,1.1865825653076172]]} + }, + { + "source_text": "Jim bought 300 shares of Acme Corp. in 2006", + "input": { + "request_id": "two", + "tokens": [101, 3104, 3306, 3127, 6117, 1104, 138, 1665, 3263, 13619, 119, 1107, 1386, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "arg_2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "arg_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + }, + "expected_output": {"request_id": "two", "inference": [[9.265243530273438,-2.515533685684204,-1.3969738483428956,-2.351750135421753,-1.803475260734558,-1.8310259580612183,-0.09857147932052612,-1.8631094694137574,0.39207878708839419],[1.2384196519851685,-1.9405957460403443,-1.775153398513794,-3.3273582458496095,6.286568641662598,-2.179285764694214,2.0314316749572756,-2.7113730907440187,-0.45493167638778689],[11.42744255065918,-2.162276268005371,-0.843875527381897,-2.559823513031006,-1.0762289762496949,-1.8819851875305176,-0.05885419622063637,-2.0086324214935304,-1.1396600008010865],[11.531952857971192,-2.2821600437164308,-0.5459890961647034,-2.5652623176574709,-1.4509466886520386,-2.0525619983673097,-0.25311681628227236,-2.0138297080993654,-1.1384060382843018],[11.342474937438965,-2.217491388320923,-0.8007739186286926,-2.817258834838867,-1.4931764602661133,-2.072322368621826,0.4753860533237457,-2.071307897567749,-1.1917411088943482],[11.184587478637696,-2.2941224575042726,-1.1137279272079468,-2.842801570892334,-1.3545138835906983,-2.021228075027466,0.747635006904602,-2.145432472229004,-0.8665369749069214],[0.21825885772705079,-2.547029733657837,-0.33246660232543948,-2.996917724609375,-1.280576229095459,-1.7315617799758912,9.094033241271973,-2.1498544216156008,-0.9704827666282654],[1.5527149438858033,-2.109332323074341,0.9170875549316406,-2.693459987640381,0.07074408233165741,-1.8306814432144166,4.762491226196289,-1.9375808238983155,-0.8089945316314697],[0.49461689591407778,-2.785339593887329,-0.35393503308296206,-3.1697468757629396,-1.1621716022491456,-1.5924493074417115,8.098134994506836,-2.1945927143096926,-0.2739206850528717],[0.8808980584144592,-2.962925910949707,-0.7409814596176148,-3.4761435985565187,-1.8204796314239503,-1.6338881254196168,8.950613975524903,-1.90260910987854,0.25867289304733279],[3.3241844177246095,-2.2221860885620119,0.010276625864207745,-3.562873125076294,-0.6271736025810242,-1.7994204759597779,3.548079252243042,-2.3455047607421877,0.8453909754753113],[11.597189903259278,-2.2594125270843508,-0.805946409702301,-2.462068796157837,-1.2340333461761475,-1.948832392692566,-0.4201577305793762,-1.6598224639892579,-1.0664936304092408],[11.329129219055176,-1.987001657485962,-0.5631923079490662,-2.7865068912506105,-1.0649499893188477,-2.278918981552124,-0.09693071246147156,-2.0265250205993654,-0.8837844729423523],[9.265239715576172,-2.5155346393585207,-1.3969736099243165,-2.351749897003174,-1.8034745454788209,-1.8310260772705079,-0.09857088327407837,-1.863108515739441,0.39208000898361208]]} + } +] \ No newline at end of file diff --git a/bin/pytorch_inference/examples/sentiment_analysis/test_run.json b/bin/pytorch_inference/examples/sentiment_analysis/test_run.json new file mode 100644 index 0000000000..76aaa0ac51 --- /dev/null +++ b/bin/pytorch_inference/examples/sentiment_analysis/test_run.json @@ -0,0 +1,20 @@ +[ + { + "source_text": "The cat was sick on the bed", + "input": { + "request_id": "one", + "tokens": [101, 1996, 4937, 2001, 5305, 2006, 1996, 2793, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1, 1] + }, + "expected_output": {"request_id": "one", "inference": [[3.9489, -3.2416]]} + }, + { + "source_text": "The movie was awesome!!", + "input": { + "request_id": "two", + "tokens": [101, 1996, 3185, 2001, 12476, 999, 999, 102], + "arg_1": [1, 1, 1, 1, 1, 1, 1, 1] + }, + "expected_output": {"request_id": "two", "inference": [[-4.2720, 4.6515]]} + } +] \ No newline at end of file diff --git a/bin/pytorch_inference/unittest/.gitignore b/bin/pytorch_inference/unittest/.gitignore new file mode 100644 index 0000000000..2eb0aff1ca --- /dev/null +++ b/bin/pytorch_inference/unittest/.gitignore @@ -0,0 +1,2 @@ +# Ignore temporary test files +test.log diff --git a/bin/pytorch_inference/unittest/.objs/.gitignore b/bin/pytorch_inference/unittest/.objs/.gitignore new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc new file mode 100644 index 0000000000..e9fee56a63 --- /dev/null +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -0,0 +1,213 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#include "../CCommandParser.h" + +#include + +BOOST_AUTO_TEST_SUITE(CCommandParserTest) + +BOOST_AUTO_TEST_CASE(testParsingStream) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}" + "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(2, parsed.size()); + { + BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); + ml::torch::CCommandParser::TUint64Vec expected{1, 2, 3}; + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[0].s_Tokens.begin(), + parsed[0].s_Tokens.end(), + expected.begin(), expected.end()); + } + { + BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); + ml::torch::CCommandParser::TUint64Vec expected{4, 5}; + BOOST_REQUIRE_EQUAL_COLLECTIONS(parsed[1].s_Tokens.begin(), + parsed[1].s_Tokens.end(), + expected.begin(), expected.end()); + } +} + +BOOST_AUTO_TEST_CASE(testParsingInvalidDoc) { + + std::vector parsed; + + std::string command{"{\"foo\": 1, }"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + }) == false); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingInvalidRequestId) { + + std::vector parsed; + + std::string command{"{\"request_id\": 1}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingTokenArrayNotInts) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"tokens_should_be_uints\", \"tokens\": [\"a\", \"b\", \"c\"]}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingTokenVarArgsNotInts) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"bad\", \"tokens\": [1, 2], \"arg_1\": [\"a\", \"b\"]}"}; + + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testParsingWhitespaceSeparatedDocs) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}\t" + "{\"request_id\": \"bar\", \"tokens\": [1, 2, 3]}\n" + "{\"request_id\": \"foo2\", \"tokens\": [1, 2, 3]} " + "{\"request_id\": \"bar2\", \"tokens\": [1, 2, 3]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(4, parsed.size()); + BOOST_REQUIRE_EQUAL("foo", parsed[0].s_RequestId); + BOOST_REQUIRE_EQUAL("bar", parsed[1].s_RequestId); + BOOST_REQUIRE_EQUAL("foo2", parsed[2].s_RequestId); + BOOST_REQUIRE_EQUAL("bar2", parsed[3].s_RequestId); +} + +BOOST_AUTO_TEST_CASE(testParsingVariableArguments) { + + std::vector parsed; + + std::string command{ + "{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": [0, 0], \"arg_2\": [0, 1]}" + "{\"request_id\": \"bar\", \"tokens\": [3, 4], \"arg_1\": [1, 0], \"arg_2\": [1, 1]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(2, parsed.size()); + { + ml::torch::CCommandParser::TUint64Vec expectedArg1{0, 0}; + ml::torch::CCommandParser::TUint64Vec expectedArg2{0, 1}; + + ml::torch::CCommandParser::TUint64VecVec extraArgs = parsed[0].s_SecondaryArguments; + BOOST_REQUIRE_EQUAL(2, extraArgs.size()); + + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } + { + ml::torch::CCommandParser::TUint64Vec expectedArg1{1, 0}; + ml::torch::CCommandParser::TUint64Vec expectedArg2{1, 1}; + + ml::torch::CCommandParser::TUint64VecVec extraArgs = parsed[1].s_SecondaryArguments; + BOOST_REQUIRE_EQUAL(2, extraArgs.size()); + + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[0].begin(), extraArgs[0].end(), + expectedArg1.begin(), expectedArg1.end()); + BOOST_REQUIRE_EQUAL_COLLECTIONS(extraArgs[1].begin(), extraArgs[1].end(), + expectedArg2.begin(), expectedArg2.end()); + } +} + +BOOST_AUTO_TEST_CASE(testParsingInvalidVarArg) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2], \"arg_1\": \"not_an_array\"}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + BOOST_TEST_REQUIRE(processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return true; + })); + + BOOST_REQUIRE_EQUAL(0, parsed.size()); +} + +BOOST_AUTO_TEST_CASE(testRequestHandlerExitsLoop) { + + std::vector parsed; + + std::string command{"{\"request_id\": \"foo\", \"tokens\": [1, 2, 3]}" + "{\"request_id\": \"bar\", \"tokens\": [4, 5]}"}; + std::istringstream commandStream{command}; + + ml::torch::CCommandParser processor{commandStream}; + // handler returns false + BOOST_TEST_REQUIRE( + false == processor.ioLoop([&parsed](const ml::torch::CCommandParser::SRequest& request) { + parsed.push_back(request); + return false; + })); + + // ioloop should exit after the first call to the handler + BOOST_REQUIRE_EQUAL(1, parsed.size()); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/Main.cc b/bin/pytorch_inference/unittest/Main.cc new file mode 100644 index 0000000000..ed5004a50e --- /dev/null +++ b/bin/pytorch_inference/unittest/Main.cc @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +#define BOOST_TEST_MODULE bin.pytorch +// Defining BOOST_TEST_MODULE usually auto-generates main(), but we don't want +// this as we need custom initialisation to allow for output in both console and +// Boost.Test XML formats +#define BOOST_TEST_NO_MAIN + +#include +#include + +#include + +int main(int argc, char** argv) { + ml::test::CTestObserver observer; + boost::unit_test::framework::register_observer(observer); + return boost::unit_test::unit_test_main(&ml::test::CBoostTestXmlOutput::init, argc, argv); + boost::unit_test::framework::deregister_observer(observer); +} diff --git a/bin/pytorch_inference/unittest/Makefile b/bin/pytorch_inference/unittest/Makefile new file mode 100644 index 0000000000..4916872652 --- /dev/null +++ b/bin/pytorch_inference/unittest/Makefile @@ -0,0 +1,25 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License; +# you may not use this file except in compliance with the Elastic License. +# +include $(CPP_SRC_HOME)/mk/defines.mk + +TARGET=ml_test$(EXE_EXT) + +USE_BOOST=1 +USE_BOOST_PROGRAMOPTIONS_LIBS=1 +USE_BOOST_TEST_LIBS=1 +USE_RAPIDJSON=1 +USE_TORCH=1 + +all: build + +LIBS=\ + ../.objs/C*$(OBJECT_FILE_EXT) \ + +SRCS=\ + Main.cc \ + CCommandParserTest.cc \ + +include $(CPP_SRC_HOME)/mk/stdboosttest.mk \ No newline at end of file