diff --git a/tis/README.md b/tis/README.md index 6461811..4d7c6fe 100644 --- a/tis/README.md +++ b/tis/README.md @@ -23,13 +23,38 @@ $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path / $ cp -riv ./model.onnx tis/models/bisenetv2/1 ``` -#### 2. start service -We start serving with docker: +#### 2. prepare the preprocessing backend +We can use either python backend or cpp backend for preprocessing in the server side. +Firstly, we pull the docker image, and start a serving container: ``` -$ docker pull nvcr.io/nvidia/tritonserver:21.10-py3 -$ docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models +$ docker pull nvcr.io/nvidia/tritonserver:22.07-py3 +$ docker run -it --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models -v /path/to/BiSeNet/:/BiSeNet nvcr.io/nvidia/tritonserver:21.10-py3 bash ``` +From here on, we are in the container environment. Let's prepare the backends in the container: +``` +# ln -s /usr/local/bin/pip3.8 /usr/bin/pip3.8 +# /usr/bin/python3 -m pip install pillow +# apt update && apt install rapidjson-dev libopencv-dev +``` +Then we download cmake 3.22 and unzip in the container, we use this cmake 3.22 in the following operations. +We compile c++ backends: +``` +# cp -riv /BiSeNet/tis/self_backend /opt/tritonserver/backends +# chmod 777 /opt/tritonserver/backends/self_backend +# cd /opt/tritonserver/backends/self_backend +# mkdir -p build && cd build +# cmake .. && make -j4 +# mv -iuv libtriton_self_backend.so .. +``` +Utils now, we should have backends prepared. + + +#### 3. start service +We start the server in the docker container, following the above steps: +``` +# tritonserver --model-repository=/models +``` In general, the service would start now. You can check whether service has started by: ``` $ curl -v localhost:8000/v2/health/ready @@ -38,10 +63,12 @@ $ curl -v localhost:8000/v2/health/ready By default, we use gpu 0 and gpu 1, you can change configurations in the `config.pbtxt` file. -### Client +### Request with client We call the model service with both python and c++ method. +From here on, we are at the client machine, rather than the server docker container. + #### 1. python method @@ -50,10 +77,11 @@ Firstly, we need to install dependency package: $ python -m pip install tritonclient[all]==2.15.0 ``` -Then we can run the script: +Then we can run the script for both http request and grpc request: ``` $ cd BiSeNet/tis -$ python client.py +$ python client_http.py # if you want to use http client +$ python client_grpc.py # if you want to use grpc client ``` This would generate a result file named `res.jpg` in `BiSeNet/tis` directory. @@ -92,4 +120,4 @@ Finally, we run the client and see a result file named `res.jpg` generated: ### In the end -This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and model pipeline. If you have interest on this, you can learn more in the official document. +This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and dynamic batching. If you have interests on this, you can learn more in the official document. diff --git a/tis/client_backend.py b/tis/client_backend.py new file mode 100644 index 0000000..6784d29 --- /dev/null +++ b/tis/client_backend.py @@ -0,0 +1,81 @@ + + +import argparse +import sys +import numpy as np +import cv2 +import gevent.ssl + +import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException + + +np.random.seed(123) +palette = np.random.randint(0, 256, (100, 3)) + + +url = '10.128.61.8:8000' +# url = '127.0.0.1:8000' +model_name = 'preprocess_cpp' +model_version = '1' +inp_name = 'raw_img_bytes' +outp_name = 'processed_img' +inp_dtype = 'UINT8' +impth = '../example.png' +mean = [0.3257, 0.3690, 0.3223] # city, rgb +std = [0.2112, 0.2148, 0.2115] + + +## prepare image and mean/std +inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] +mean = np.array(mean, dtype=np.float32)[None, ...] +std = np.array(std, dtype=np.float32)[None, ...] +inputs = [] +inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype)) +inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32')) +inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32')) +inputs[0].set_data_from_numpy(inp_data, binary_data=True) +inputs[1].set_data_from_numpy(mean, binary_data=True) +inputs[2].set_data_from_numpy(std, binary_data=True) + +## client +triton_client = httpclient.InferenceServerClient( + url=url, verbose=False, concurrency=32) + +## infer +# sync +# results = triton_client.infer(model_name, inputs) + + +# async +# results = triton_client.async_infer( +# model_name, +# inputs, +# outputs=None, +# query_params=None, +# headers=None, +# request_compression_algorithm=None, +# response_compression_algorithm=None) +# results = results.get_result() # async infer only + + +## dynamic batching, this is not allowed, since different pictures has different raw size +results = [] +for i in range(10): + r = triton_client.async_infer( + model_name, + inputs, + outputs=None, + query_params=None, + headers=None, + request_compression_algorithm=None, + response_compression_algorithm=None) + results.append(r) +for i in range(10): + results[i].get_result() +results = results[i] + + +# get output +outp = results.as_numpy(outp_name).squeeze() +print(outp.shape) diff --git a/tis/client.py b/tis/client_grpc.py similarity index 55% rename from tis/client.py rename to tis/client_grpc.py index 33cc925..5109c1e 100644 --- a/tis/client.py +++ b/tis/client_grpc.py @@ -13,21 +13,36 @@ -# url = '10.128.61.7:8001' -url = '127.0.0.1:8001' -model_name = 'bisenetv2' +url = '10.128.61.8:8001' +# url = '127.0.0.1:8001' +model_name = 'bisenetv1' model_version = '1' -inp_name = 'input_image' +inp_name = 'raw_img_bytes' outp_name = 'preds' -inp_dtype = 'FP32' +inp_dtype = 'UINT8' outp_dtype = np.int64 -inp_shape = [1, 3, 1024, 2048] -outp_shape = [1024, 2048] impth = '../example.png' mean = [0.3257, 0.3690, 0.3223] # city, rgb std = [0.2112, 0.2148, 0.2115] +## input data and mean/std +inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] +mean = np.array(mean, dtype=np.float32)[None, ...] +std = np.array(std, dtype=np.float32)[None, ...] +inputs = [service_pb2.ModelInferRequest().InferInputTensor() for _ in range(3)] +inputs[0].name = inp_name +inputs[0].datatype = inp_dtype +inputs[0].shape.extend(inp_data.shape) +inputs[1].name = 'channel_mean' +inputs[1].datatype = 'FP32' +inputs[1].shape.extend(mean.shape) +inputs[2].name = 'channel_std' +inputs[2].datatype = 'FP32' +inputs[2].shape.extend(std.shape) +inp_bytes = [inp_data.tobytes(), mean.tobytes(), std.tobytes()] + + option = [ ('grpc.max_receive_message_length', 1073741824), ('grpc.max_send_message_length', 1073741824), @@ -52,37 +67,22 @@ request.model_name = model_name request.model_version = model_version -inp = service_pb2.ModelInferRequest().InferInputTensor() -inp.name = inp_name -inp.datatype = inp_dtype -inp.shape.extend(inp_shape) - - -mean = np.array(mean).reshape(1, 1, 3) -std = np.array(std).reshape(1, 1, 3) -im = cv2.imread(impth)[:, :, ::-1] -im = cv2.resize(im, dsize=tuple(inp_shape[-1:-3:-1])) -im = ((im / 255.) - mean) / std -im = im[None, ...].transpose(0, 3, 1, 2) -inp_bytes = im.astype(np.float32).tobytes() - request.ClearField("inputs") request.ClearField("raw_input_contents") -request.inputs.extend([inp,]) -request.raw_input_contents.extend([inp_bytes,]) - +request.inputs.extend(inputs) +request.raw_input_contents.extend(inp_bytes) -outp = service_pb2.ModelInferRequest().InferRequestedOutputTensor() -outp.name = outp_name -request.outputs.extend([outp,]) # sync -# resp = grpc_stub.ModelInfer(request).raw_output_contents[0] +# resp = grpc_stub.ModelInfer(request) # async resp = grpc_stub.ModelInfer.future(request) -resp = resp.result().raw_output_contents[0] +resp = resp.result() + +outp_bytes = resp.raw_output_contents[0] +outp_shape = resp.outputs[0].shape -out = np.frombuffer(resp, dtype=outp_dtype).reshape(*outp_shape) +out = np.frombuffer(outp_bytes, dtype=outp_dtype).reshape(*outp_shape).squeeze() out = palette[out] cv2.imwrite('res.png', out) diff --git a/tis/client_http.py b/tis/client_http.py new file mode 100644 index 0000000..d1fa501 --- /dev/null +++ b/tis/client_http.py @@ -0,0 +1,64 @@ + + +import argparse +import sys +import numpy as np +import cv2 +import gevent.ssl + +import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException + + +np.random.seed(123) +palette = np.random.randint(0, 256, (100, 3)) + + +url = '10.128.61.8:8000' +# url = '127.0.0.1:8000' +model_name = 'bisenetv2' +model_version = '1' +inp_name = 'raw_img_bytes' +outp_name = 'preds' +inp_dtype = 'UINT8' +impth = '../example.png' +mean = [0.3257, 0.3690, 0.3223] # city, rgb +std = [0.2112, 0.2148, 0.2115] + + +## prepare image and mean/std +inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...] +mean = np.array(mean, dtype=np.float32)[None, ...] +std = np.array(std, dtype=np.float32)[None, ...] +inputs = [] +inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype)) +inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32')) +inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32')) +inputs[0].set_data_from_numpy(inp_data, binary_data=True) +inputs[1].set_data_from_numpy(mean, binary_data=True) +inputs[2].set_data_from_numpy(std, binary_data=True) + + +## client +triton_client = httpclient.InferenceServerClient( + url=url, verbose=False, concurrency=32) + +## infer +# sync +# results = triton_client.infer(model_name, inputs) + +# async +results = triton_client.async_infer( + model_name, + inputs, + outputs=None, + query_params=None, + headers=None, + request_compression_algorithm=None, + response_compression_algorithm=None) +results = results.get_result() # async infer only + +# get output +outp = results.as_numpy(outp_name).squeeze() +out = palette[outp] +cv2.imwrite('res.png', out) diff --git a/tis/cpp_client/CMakeLists.txt b/tis/cpp_client/CMakeLists.txt index 92fbd6c..15eb56f 100644 --- a/tis/cpp_client/CMakeLists.txt +++ b/tis/cpp_client/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required (VERSION 3.18) project(Samples) -set(CMAKE_CXX_FLAGS "-std=c++14 -O1") +set(CMAKE_CXX_FLAGS "-std=c++14 -O2") set(CMAKE_BUILD_TYPE Release) set(CMAKE_PREFIX_PATH diff --git a/tis/cpp_client/main.cpp b/tis/cpp_client/main.cpp index 56f5ea3..39a326b 100644 --- a/tis/cpp_client/main.cpp +++ b/tis/cpp_client/main.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -27,28 +28,28 @@ namespace tc = triton::client; } -// std::string url("10.128.61.8:8001"); -std::string url("127.0.0.1:8001"); +std::string url("10.128.61.8:8001"); +// std::string url("127.0.0.1:8001"); std::string model_name("bisenetv1"); std::string model_version("1"); uint32_t client_timeout{5000000}; -bool verbose = false; std::string impth("../../../example.png"); std::string savepth("./res.jpg"); -std::vector inp_shape{1, 3, 1024, 2048}; -std::vector outp_shape{1, 1024, 2048}; -std::string inp_name("input_image"); +std::vector mean{0.3257, 0.3690, 0.3223}; // city, rgb +std::vector var{0.2112, 0.2148, 0.2115}; +std::string inp_name("raw_img_bytes"); std::string outp_name("preds"); -std::string inp_type("FP32"); +std::string inp_type("UINT8"); std::vector> get_color_map(); std::vector get_image(std::string, std::vector&); -void save_predict(std::string, int64_t*, - std::vector, std::vector); +std::vector get_image_bytes(std::string); +void save_predict(std::string, int64_t*, std::vector); void do_inference(); +void do_inference_with_bytes(std::vector&, bool); void print_infos(); void test_speed(); @@ -62,32 +63,55 @@ int main() { void do_inference() { + // create input + // std::vector inp_data = get_image(impth, inp_shape); + std::vector inp_data = get_image_bytes(impth); + std::cout << "read image: " << impth << std::endl; + do_inference_with_bytes(inp_data, true); +} + + +void do_inference_with_bytes(std::vector& inp_data, bool verbose) { // define client std::unique_ptr client; FAIL_IF_ERR( - tc::InferenceServerGrpcClient::Create(&client, url, verbose), + tc::InferenceServerGrpcClient::Create(&client, url, false), // verbose=false "unable to create grpc client"); - std::cout << "create client\n"; - - // create input - std::vector input_data = get_image(impth, inp_shape); - std::cout << "read image: " << impth << std::endl; + if (verbose) std::cout << "create client\n"; + //// raw image tc::InferInput* input; FAIL_IF_ERR( - tc::InferInput::Create(&input, inp_name, inp_shape, inp_type), - "unable to get input"); + tc::InferInput::Create(&input, inp_name, + {1, static_cast(inp_data.size())}, inp_type), + "unable to get input data"); std::shared_ptr input_ptr; input_ptr.reset(input); - FAIL_IF_ERR(input_ptr->Reset(), // reset input - "unable to reset input data"); + FAIL_IF_ERR(input_ptr->Reset(), "unable to reset input data"); + FAIL_IF_ERR(input_ptr->AppendRaw(inp_data), "unable to set data for input"); + //// mean/std + tc::InferInput *inp_mean, *inp_std; + FAIL_IF_ERR( + tc::InferInput::Create(&inp_mean, "channel_mean", {1, 3}, "FP32"), + "unable to get input mean"); + FAIL_IF_ERR( + tc::InferInput::Create(&inp_std, "channel_std", {1, 3}, "FP32"), + "unable to get input std"); + std::shared_ptr inp_mean_ptr, inp_std_ptr; + inp_mean_ptr.reset(inp_mean); + inp_std_ptr.reset(inp_std); + FAIL_IF_ERR(inp_mean_ptr->Reset(), "unable to reset input mean"); + FAIL_IF_ERR(inp_std_ptr->Reset(), "unable to reset input std"); + FAIL_IF_ERR( + inp_mean_ptr->AppendRaw(reinterpret_cast(&mean[0]), // must be uint8_t data type + mean.size() * sizeof(float)), + "unable to set data for input mean"); FAIL_IF_ERR( - input_ptr->AppendRaw( - reinterpret_cast(&input_data[0]), - input_data.size() * sizeof(float)), // NOTE: float can be others according to input type - "unable to set data for input"); - std::cout << "set input\n"; + inp_std_ptr->AppendRaw(reinterpret_cast(&var[0]), + var.size() * sizeof(float)), + "unable to set data for input std"); + if (verbose) std::cout << "set input\n"; // create output @@ -97,7 +121,7 @@ void do_inference() { "unable to get output"); std::shared_ptr output_ptr; output_ptr.reset(output); - std::cout << "set output\n"; + if (verbose) std::cout << "set output\n"; // infer options tc::InferOptions options(model_name); @@ -106,10 +130,11 @@ void do_inference() { tc::Headers http_headers; grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE; - std::cout << "set options\n"; + if (verbose) std::cout << "set options\n"; // inference - std::vector inputs = {input_ptr.get()}; + std::vector inputs = {input_ptr.get(), + inp_mean_ptr.get(), inp_std_ptr.get()}; std::vector outputs = {output_ptr.get()}; tc::InferResult* results; FAIL_IF_ERR( @@ -122,24 +147,29 @@ void do_inference() { FAIL_IF_ERR( results_ptr->RequestStatus(), "inference failed"); - std::cout << "send request and do inference\n"; + if (verbose) std::cout << "send request and do inference\n"; // parse output - int64_t* raw_oup{nullptr}; // NOTE: int64_t is used according to model + int64_t* raw_outp{nullptr}; // NOTE: int64_t is used according to model size_t n_bytes{0}; FAIL_IF_ERR( results_ptr->RawData( - outp_name, (const uint8_t**)(&raw_oup), &n_bytes), + outp_name, (const uint8_t**)(&raw_outp), &n_bytes), "fetch output failed"); - if (n_bytes != outp_shape[1] * outp_shape[2] * sizeof(int64_t)) { + std::vector outp_shape; + FAIL_IF_ERR( + results_ptr->Shape(outp_name, &outp_shape), + "get output shape failed"); + if (n_bytes != std::accumulate(outp_shape.begin(), outp_shape.end(), 1, + std::multiplies()) * sizeof(int64_t)) { std::cerr << "output shape is not set correctly\n"; exit(1); } - std::cout << "fetch output\n"; + if (verbose) std::cout << "fetch output\n"; // save colorful result - save_predict(savepth, raw_oup, inp_shape, outp_shape); - std::cout << "save inference result to:" << savepth << std::endl; + save_predict(savepth, raw_outp, outp_shape); + if (verbose) std::cout << "save inference result to:" << savepth << std::endl; } @@ -176,6 +206,24 @@ std::vector get_image(std::string impth, std::vector& shape) { } +std::vector get_image_bytes(std::string impth) { + std::ifstream fin(impth, std::ios::in|std::ios::binary); + fin.seekg(0, fin.end); + int nbytes = fin.tellg(); + if (nbytes == -1) { + std::cerr << "image file read failed: " << impth << std::endl; + exit(1); + } + fin.clear(); + fin.seekg(0); + + std::vector res(nbytes); + fin.read(reinterpret_cast(&res[0]), nbytes); + fin.close(); + + return res; +} + std::vector> get_color_map() { std::vector> color_map(256, std::vector(3)); @@ -191,12 +239,10 @@ std::vector> get_color_map() { void save_predict(std::string savename, int64_t* data, - std::vector insize, std::vector outsize) { - std::vector> color_map = get_color_map(); - int64_t oH = outsize[1]; - int64_t oW = outsize[2]; + int64_t oH = outsize[2]; // outsize is n1hw + int64_t oW = outsize[3]; cv::Mat pred(cv::Size(oW, oH), CV_8UC3); int idx{0}; for (int i{0}; i < oH; ++i) { @@ -213,11 +259,12 @@ void save_predict(std::string savename, int64_t* data, } + void print_infos() { // define client std::unique_ptr client; FAIL_IF_ERR( - tc::InferenceServerGrpcClient::Create(&client, url, verbose), + tc::InferenceServerGrpcClient::Create(&client, url, false), "unable to create grpc client"); tc::Headers http_headers; @@ -258,73 +305,22 @@ void print_infos() { void test_speed() { - // define client - std::unique_ptr client; - FAIL_IF_ERR( - tc::InferenceServerGrpcClient::Create(&client, url, verbose), - "unable to create grpc client"); - - // create input - std::vector input_data(std::accumulate( - inp_shape.begin(), inp_shape.end(), - 1, std::multiplies()) - ); - tc::InferInput* input; - FAIL_IF_ERR( - tc::InferInput::Create(&input, inp_name, inp_shape, inp_type), - "unable to get input"); - std::shared_ptr input_ptr; - input_ptr.reset(input); - FAIL_IF_ERR(input_ptr->Reset(), // reset input - "unable to reset input data"); - FAIL_IF_ERR( - input_ptr->AppendRaw( - reinterpret_cast(&input_data[0]), - input_data.size() * sizeof(float)), // NOTE: float can be others according to input type - "unable to set data for input"); - - // create output - tc::InferRequestedOutput* output; - FAIL_IF_ERR( - tc::InferRequestedOutput::Create(&output, outp_name), - "unable to get output"); - std::shared_ptr output_ptr; - output_ptr.reset(output); - - // infer options - tc::InferOptions options(model_name); - options.model_version_ = model_version; - options.client_timeout_ = client_timeout; - tc::Headers http_headers; - grpc_compression_algorithm compression_algorithm = - grpc_compression_algorithm::GRPC_COMPRESS_NONE; - // inference - std::vector inputs = {input_ptr.get()}; - std::vector outputs = {output_ptr.get()}; - tc::InferResult* results; + std::vector inp_data = get_image_bytes(impth); + // warmup + do_inference_with_bytes(inp_data, false); std::cout << "test speed ... \n"; const int n_loops{500}; auto start = std::chrono::steady_clock::now(); for (int i{0}; i < n_loops; ++i) { - FAIL_IF_ERR( - client->Infer( - &results, options, inputs, outputs, http_headers, - compression_algorithm), - "failed sending synchronous infer request"); + do_inference_with_bytes(inp_data, false); } auto end = std::chrono::steady_clock::now(); - std::shared_ptr results_ptr; - results_ptr.reset(results); - FAIL_IF_ERR( - results_ptr->RequestStatus(), - "inference failed"); - double duration = std::chrono::duration(end - start).count(); duration /= 1000.; std::cout << "running " << n_loops << " times, use time: " - << duration << "s" << std::endl; + << duration << "s" << std::endl; std::cout << "fps is: " << static_cast(n_loops) / duration << std::endl; } diff --git a/tis/models/bisenetv1/config.pbtxt b/tis/models/bisenetv1/config.pbtxt index a1d94fb..c35e2ab 100644 --- a/tis/models/bisenetv1/config.pbtxt +++ b/tis/models/bisenetv1/config.pbtxt @@ -1,30 +1,64 @@ -platform: "onnxruntime_onnx" -max_batch_size: 0 +name: "bisenetv1" +platform: "ensemble" +max_batch_size: 256 input [ -{ - name: "input_image" - data_type: TYPE_FP32 - dims: [ 1, 3, 1024, 2048 ] -} + { + name: "raw_img_bytes" + data_type: TYPE_UINT8 + dims: [ -1 ] + }, + { + name: "channel_mean" + data_type: TYPE_FP32 + dims: [ 3 ] + }, + { + name: "channel_std" + data_type: TYPE_FP32 + dims: [ 3 ] + } ] output [ -{ - name: "preds" - data_type: TYPE_INT64 - dims: [ 1, 1024, 2048 ] -} + { + name: "preds" + data_type: TYPE_INT64 + dims: [1, 1024, 2048 ] + } ] -optimization { execution_accelerators { - gpu_execution_accelerator : [ { - name : "tensorrt" - parameters { key: "precision_mode" value: "FP16" } - parameters { key: "max_workspace_size_bytes" value: "4294967296" } - }] -}} -instance_group [ -{ - count: 2 - kind: KIND_GPU - gpus: [ 0, 1 ] + +ensemble_scheduling { + step [ + { + model_name: "preprocess_py" + model_version: 1 + input_map { + key: "raw_img_bytes" + value: "raw_img_bytes" + } + input_map { + key: "channel_mean" + value: "channel_mean" + } + input_map { + key: "channel_std" + value: "channel_std" + } + output_map { + key: "processed_img" + value: "processed_img" + } + }, + { + model_name: "bisenetv1_model" + model_version: 1 + input_map { + key: "input_image" + value: "processed_img" + } + output_map { + key: "preds" + value: "preds" + } + } + ] } -] diff --git a/tis/models/bisenetv1_model/config.pbtxt b/tis/models/bisenetv1_model/config.pbtxt new file mode 100644 index 0000000..63ea2d3 --- /dev/null +++ b/tis/models/bisenetv1_model/config.pbtxt @@ -0,0 +1,31 @@ +name: "bisenetv1_model" +platform: "onnxruntime_onnx" +max_batch_size: 0 +input [ +{ + name: "input_image" + data_type: TYPE_FP32 + dims: [ 1, 3, 1024, 2048 ] +} +] +output [ +{ + name: "preds" + data_type: TYPE_INT64 + dims: [ 1, 1024, 2048 ] +} +] +optimization { execution_accelerators { # we use tensorrt backend, pure onnxruntime seems to have memory leackage problem + gpu_execution_accelerator : [ { + name : "tensorrt" + parameters { key: "precision_mode" value: "FP16" } + parameters { key: "max_workspace_size_bytes" value: "4294967296" } + }] +}} +instance_group [ +{ + count: 2 + kind: KIND_GPU + gpus: [ 0, 1 ] +} +] diff --git a/tis/models/bisenetv2/config.pbtxt b/tis/models/bisenetv2/config.pbtxt index d5501d3..6db929b 100644 --- a/tis/models/bisenetv2/config.pbtxt +++ b/tis/models/bisenetv2/config.pbtxt @@ -1,30 +1,64 @@ -platform: "onnxruntime_onnx" -max_batch_size: 0 +name: "bisenetv2" +platform: "ensemble" +max_batch_size: 256 input [ -{ - name: "input_image" - data_type: TYPE_FP32 - dims: [1, 3, 1024, 2048 ] -} + { + name: "raw_img_bytes" + data_type: TYPE_UINT8 + dims: [ -1 ] + }, + { + name: "channel_mean" + data_type: TYPE_FP32 + dims: [ 3 ] + }, + { + name: "channel_std" + data_type: TYPE_FP32 + dims: [ 3 ] + } ] output [ -{ - name: "preds" - data_type: TYPE_INT64 - dims: [1, 1024, 2048 ] -} + { + name: "preds" + data_type: TYPE_INT64 + dims: [1, 1024, 2048 ] + } ] -optimization { execution_accelerators { - gpu_execution_accelerator : [ { - name : "tensorrt" - parameters { key: "precision_mode" value: "FP16" } - parameters { key: "max_workspace_size_bytes" value: "4294967296" } - }] -}} -instance_group [ -{ - count: 2 - kind: KIND_GPU - gpus: [ 0, 1 ] + +ensemble_scheduling { + step [ + { + model_name: "preprocess_cpp" + model_version: 1 + input_map { + key: "raw_img_bytes" + value: "raw_img_bytes" + } + input_map { + key: "channel_mean" + value: "channel_mean" + } + input_map { + key: "channel_std" + value: "channel_std" + } + output_map { + key: "processed_img" + value: "processed_img" + } + }, + { + model_name: "bisenetv2_model" + model_version: 1 + input_map { + key: "input_image" + value: "processed_img" + } + output_map { + key: "preds" + value: "preds" + } + } + ] } -] diff --git a/tis/models/bisenetv2_model/config.pbtxt b/tis/models/bisenetv2_model/config.pbtxt new file mode 100644 index 0000000..a9124a9 --- /dev/null +++ b/tis/models/bisenetv2_model/config.pbtxt @@ -0,0 +1,31 @@ +name: "bisenetv2_model" +platform: "onnxruntime_onnx" +max_batch_size: 0 +input [ +{ + name: "input_image" + data_type: TYPE_FP32 + dims: [1, 3, 1024, 2048 ] +} +] +output [ +{ + name: "preds" + data_type: TYPE_INT64 + dims: [1, 1024, 2048 ] +} +] +optimization { execution_accelerators { # we use tensorrt backend, pure onnxruntime seems to have memory leackage problem + gpu_execution_accelerator : [ { + name : "tensorrt" + parameters { key: "precision_mode" value: "FP16" } + parameters { key: "max_workspace_size_bytes" value: "4294967296" } + }] +}} +instance_group [ +{ + count: 2 + kind: KIND_GPU + gpus: [ 0, 1 ] +} +] diff --git a/tis/models/preprocess_cpp/1/.gitkeep b/tis/models/preprocess_cpp/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tis/models/preprocess_cpp/config.pbtxt b/tis/models/preprocess_cpp/config.pbtxt new file mode 100644 index 0000000..6dee44a --- /dev/null +++ b/tis/models/preprocess_cpp/config.pbtxt @@ -0,0 +1,35 @@ +name: "preprocess_cpp" +backend: "self_backend" +max_batch_size: 256 +# dynamic_batching { ## this is not allowed, since we cannot know raw bytes size of each inputs from the server, as they just concat the bytes together +# max_queue_delay_microseconds: 5000000 +# } +input [ +{ + name: "raw_img_bytes" + data_type: TYPE_UINT8 + dims: [ -1 ] +}, +{ + name: "channel_mean" + data_type: TYPE_FP32 + dims: [ 3 ] +}, +{ + name: "channel_std" + data_type: TYPE_FP32 + dims: [ 3 ] +} +] +output [ + { + name: "processed_img" + data_type: TYPE_FP32 + dims: [ 1, 3, 1024, 2048 ] + } +] +instance_group [ + { + kind: KIND_CPU + } +] diff --git a/tis/models/preprocess_py/1/model.py b/tis/models/preprocess_py/1/model.py new file mode 100644 index 0000000..2af1128 --- /dev/null +++ b/tis/models/preprocess_py/1/model.py @@ -0,0 +1,122 @@ +import numpy as np +import sys +import json +import io + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils + +from PIL import Image +import os + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to intialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + + # You must parse model_config. JSON string is not parsed here + self.model_config = model_config = json.loads(args['model_config']) + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name( + model_config, "processed_img") + + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + + self.output0_shape = output0_config['dims'] + + def execute(self, requests): + """`execute` MUST be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference request is made + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + output0_dtype = self.output0_dtype + N, C, H, W = self.output0_shape + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for request in requests: + # Get INPUT0 + im_bytes = pb_utils.get_input_tensor_by_name(request, "raw_img_bytes") + im_bytes = im_bytes.as_numpy().tobytes() + im = Image.open(io.BytesIO(im_bytes)) + im = im.resize((W, H), Image.ANTIALIAS) + im = np.array(im) + + # Get mean/std + mean = pb_utils.get_input_tensor_by_name(request, "channel_mean") + std = pb_utils.get_input_tensor_by_name(request, "channel_std") + mean = mean.as_numpy().reshape(1, 1, 3) + std = std.as_numpy().reshape(1, 1, 3) + + # preprocess + im = ((im / 255.) - mean) / std + im = im[None, ...].transpose(0, 3, 1, 2).astype(np.float32) + + + out_tensor_0 = pb_utils.Tensor("processed_img", im) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occured")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is OPTIONAL. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + diff --git a/tis/models/preprocess_py/config.pbtxt b/tis/models/preprocess_py/config.pbtxt new file mode 100644 index 0000000..0ebdfd8 --- /dev/null +++ b/tis/models/preprocess_py/config.pbtxt @@ -0,0 +1,30 @@ +name: "preprocess_py" +backend: "python" +max_batch_size: 256 +input [ +{ + name: "raw_img_bytes" + data_type: TYPE_UINT8 + dims: [ -1 ] +}, +{ + name: "channel_mean" + data_type: TYPE_FP32 + dims: [ 3 ] +}, +{ + name: "channel_std" + data_type: TYPE_FP32 + dims: [ 3 ] +} +] + +output [ +{ + name: "processed_img" + data_type: TYPE_FP32 + dims: [1, 3, 1024, 2048 ] +} +] + +instance_group [{ kind: KIND_CPU }] diff --git a/tis/self_backend/CMakeLists.txt b/tis/self_backend/CMakeLists.txt new file mode 100644 index 0000000..051dc9d --- /dev/null +++ b/tis/self_backend/CMakeLists.txt @@ -0,0 +1,185 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.17) + +project(tutorialrecommendedbackend LANGUAGES C CXX) + +# +# Options +# +# Must include options required for this project as well as any +# projects included in this one by FetchContent. +# +# GPU support is disabled by default because recommended backend +# doesn't use GPUs. +# +option(TRITON_ENABLE_GPU "Enable GPU support in backend" OFF) +option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON) + +set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo") +set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo") +set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo") + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# +# Dependencies +# +# FetchContent requires us to include the transitive closure of all +# repos that we depend on so that we can override the tags. +# +include(FetchContent) + +FetchContent_Declare( + repo-common + GIT_REPOSITORY https://github.com/triton-inference-server/common.git + GIT_TAG ${TRITON_COMMON_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_TAG ${TRITON_CORE_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_Declare( + repo-backend + GIT_REPOSITORY https://github.com/triton-inference-server/backend.git + GIT_TAG ${TRITON_BACKEND_REPO_TAG} + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(repo-common repo-core repo-backend) + +find_package (OpenCV REQUIRED) + +# +# The backend must be built into a shared library. Use an ldscript to +# hide all symbols except for the TRITONBACKEND API. +# +configure_file(src/libtriton_recommended.ldscript libtriton_recommended.ldscript COPYONLY) + +add_library( + triton-self_backend-backend SHARED + src/recommended.cc +) + +add_library( + TutorialRecommendedBackend::triton-self_backend-backend ALIAS triton-self_backend-backend +) + +target_include_directories( + triton-self_backend-backend + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${OpenCV_INCLUDE_DIRS} +) + +target_compile_features(triton-self_backend-backend PRIVATE cxx_std_14) +target_compile_options( + triton-self_backend-backend PRIVATE + $<$,$,$>: + -Wall -Wextra -Wno-unused-parameter -Wno-type-limits -Werror> + $<$:/Wall /D_WIN32_WINNT=0x0A00 /EHsc> +) + +target_link_libraries( + triton-self_backend-backend + PRIVATE + triton-core-serverapi # from repo-core + triton-core-backendapi # from repo-core + triton-core-serverstub # from repo-core + triton-backend-utils # from repo-backend + ${OpenCV_LIBS} +) + +if(WIN32) + set_target_properties( + triton-self_backend-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_self_backend + ) +else() + set_target_properties( + triton-self_backend-backend PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME triton_self_backend + LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_recommended.ldscript + LINK_FLAGS "-Wl,--version-script libtriton_recommended.ldscript" + ) +endif() + +# +# Install +# +include(GNUInstallDirs) +set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TutorialRecommendedBackend) + +install( + TARGETS + triton-self_backend-backend + EXPORT + triton-self_backend-backend-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/self_backend + RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/self_backend +) + +install( + EXPORT + triton-self_backend-backend-targets + FILE + TutorialRecommendedBackendTargets.cmake + NAMESPACE + TutorialRecommendedBackend:: + DESTINATION + ${INSTALL_CONFIGDIR} +) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/cmake/TutorialRecommendedBackendConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/TutorialRecommendedBackendConfig.cmake + INSTALL_DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/TutorialRecommendedBackendConfig.cmake + DESTINATION ${INSTALL_CONFIGDIR} +) + +# +# Export from build tree +# +export( + EXPORT triton-self_backend-backend-targets + FILE ${CMAKE_CURRENT_BINARY_DIR}/TutorialRecommendedBackendTargets.cmake + NAMESPACE TutorialRecommendedBackend:: +) + +export(PACKAGE TutorialRecommendedBackend) diff --git a/tis/self_backend/cmake/TutorialRecommendedBackendConfig.cmake.in b/tis/self_backend/cmake/TutorialRecommendedBackendConfig.cmake.in new file mode 100644 index 0000000..4007f9f --- /dev/null +++ b/tis/self_backend/cmake/TutorialRecommendedBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}) + +if(NOT TARGET TutorialRecommendedBackend::triton-recommended-backend) + include("${TUTORIALRECOMMENDEDBACKEND_CMAKE_DIR}/TutorialRecommendedBackendTargets.cmake") +endif() + +set(TUTORIALRECOMMENDEDBACKEND_LIBRARIES TutorialRecommendedBackend::triton-recommended-backend) diff --git a/tis/self_backend/src/libtriton_recommended.ldscript b/tis/self_backend/src/libtriton_recommended.ldscript new file mode 100644 index 0000000..748714d --- /dev/null +++ b/tis/self_backend/src/libtriton_recommended.ldscript @@ -0,0 +1,30 @@ +# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +{ + global: + TRITONBACKEND_*; + local: *; +}; diff --git a/tis/self_backend/src/recommended.cc b/tis/self_backend/src/recommended.cc new file mode 100644 index 0000000..bb757ce --- /dev/null +++ b/tis/self_backend/src/recommended.cc @@ -0,0 +1,916 @@ +// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "triton/backend/backend_common.h" +#include "triton/backend/backend_input_collector.h" +#include "triton/backend/backend_model.h" +#include "triton/backend/backend_model_instance.h" +#include "triton/backend/backend_output_responder.h" +#include "triton/core/tritonbackend.h" +#include +#include +#include +#include + +using std::cout; +using std::endl; + + +///////////// +// +// Coin: ModelEndPoint +// +// Wraps model end point information + +struct InModelEndPoint { + // Seems tis will take care of input buffer memory, so we just + // use the pointer without caring about memory release + const char *buffer; + size_t buffer_byte_size; + TRITONSERVER_MemoryType buffer_memory_type; + int64_t buffer_memory_type_id; +}; + +struct OutModelEndPoint { + // It is up to user to allocate output buffer memory, therefore + // we should take care of memory release with shared_ptr + std::shared_ptr buffer; + size_t buffer_byte_size; + TRITONSERVER_MemoryType buffer_memory_type; + int64_t buffer_memory_type_id; + std::vector shape; +}; + + +//// call back function +// Input is raw bytes, which will be decoded into image tensors +// It is not allowed to batch raw bytes inputs together to decode and preprocess +// Therefore we should not use the field of 'dynamic_batching' in the config.pbtxt +void callback_func(std::vector& inputs, + std::vector& outputs) { + + // decode image and resize + cv::Mat buffer = cv::Mat(cv::Size(1, inputs[0].buffer_byte_size), + CV_8UC1, const_cast(inputs[0].buffer)); + cv::Mat im = cv::imdecode(buffer, cv::IMREAD_COLOR); + cout << "image size: " << im.size() << endl; + int ndims = outputs[0].shape.size(); + int H = outputs[0].shape[ndims - 2]; // last two dims as hw + int W = outputs[0].shape[ndims - 1]; + cv::Mat imresized = im; + if (im.rows != H or im.cols != W) { + cv::resize(im, imresized, cv::Size(W, H), 0., 0., cv::INTER_CUBIC); + cout << "resize image into: " << imresized.size() << endl; + } + + // obtain mean/std + std::vector channel_mean(3), channel_std(3); + for (int i{0}; i < 3; ++i) { + channel_mean[i] = reinterpret_cast( + const_cast(inputs[1].buffer))[i]; + channel_std[i] = 1. / reinterpret_cast( + const_cast(inputs[2].buffer))[i]; + } + float scale = 1. / 255; + + // allocate output buffer + outputs[0].buffer = std::shared_ptr( + new char[outputs[0].buffer_byte_size], + [](const char* p) {cout<<"release output memory\n"; delete[] p;}); + float* obuf = reinterpret_cast(const_cast(outputs[0].buffer.get())); + + // divide 255 and then normalize with channel mean/std + for (int h{0}; h < H; ++h) { + cv::Vec3b *ptr = imresized.ptr(h); + for (int w{0}; w < W; ++w) { + for (int c{0}; c < 3; ++c) { + int ind = c * H * W + h * W + w; + obuf[ind] = (ptr[w][2 - c] * scale - channel_mean[c]) * channel_std[c]; + } + } + } +} + + +namespace triton { namespace backend { namespace recommended { + +// +// Backend that demonstrates the TRITONBACKEND API. This backend works +// for any model that has 1 input with any datatype and any shape and +// 1 output with the same shape and datatype as the input. The backend +// supports both batching and non-batching models. +// +// For each batch of requests, the backend returns the input tensor +// value in the output tensor. +// + +///////////// + +extern "C" { + +// Triton calls TRITONBACKEND_Initialize when a backend is loaded into +// Triton to allow the backend to create and initialize any state that +// is intended to be shared across all models and model instances that +// use the backend. The backend should also verify version +// compatibility with Triton in this function. +// +TRITONSERVER_Error* +TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) +{ + const char* cname; + RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); + std::string name(cname); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Initialize: ") + name).c_str()); + + // Check the backend API version that Triton supports vs. what this + // backend was compiled against. Make sure that the Triton major + // version is the same and the minor version is >= what this backend + // uses. + uint32_t api_version_major, api_version_minor; + RETURN_IF_ERROR( + TRITONBACKEND_ApiVersion(&api_version_major, &api_version_minor)); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Triton TRITONBACKEND API version: ") + + std::to_string(api_version_major) + "." + + std::to_string(api_version_minor)) + .c_str()); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("'") + name + "' TRITONBACKEND API version: " + + std::to_string(TRITONBACKEND_API_VERSION_MAJOR) + "." + + std::to_string(TRITONBACKEND_API_VERSION_MINOR)) + .c_str()); + + if ((api_version_major != TRITONBACKEND_API_VERSION_MAJOR) || + (api_version_minor < TRITONBACKEND_API_VERSION_MINOR)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "triton backend API version does not support this backend"); + } + + // The backend configuration may contain information needed by the + // backend, such as tritonserver command-line arguments. This + // backend doesn't use any such configuration but for this example + // print whatever is available. + TRITONSERVER_Message* backend_config_message; + RETURN_IF_ERROR( + TRITONBACKEND_BackendConfig(backend, &backend_config_message)); + + const char* buffer; + size_t byte_size; + RETURN_IF_ERROR(TRITONSERVER_MessageSerializeToJson( + backend_config_message, &buffer, &byte_size)); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("backend configuration:\n") + buffer).c_str()); + + // This backend does not require any "global" state but as an + // example create a string to demonstrate. + std::string* state = new std::string("backend state"); + RETURN_IF_ERROR( + TRITONBACKEND_BackendSetState(backend, reinterpret_cast(state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_Finalize when a backend is no longer +// needed. +// +TRITONSERVER_Error* +TRITONBACKEND_Finalize(TRITONBACKEND_Backend* backend) +{ + // Delete the "global" state associated with the backend. + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vstate)); + std::string* state = reinterpret_cast(vstate); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("TRITONBACKEND_Finalize: state is '") + *state + "'") + .c_str()); + + delete state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +// +// ModelState +// +// State associated with a model that is using this backend. An object +// of this class is created and associated with each +// TRITONBACKEND_Model. ModelState is derived from BackendModel class +// provided in the backend utilities that provides many common +// functions. +// +class ModelState : public BackendModel { + public: + static TRITONSERVER_Error* Create( + TRITONBACKEND_Model* triton_model, ModelState** state); + virtual ~ModelState() = default; + + // Name of the input and output tensor + const std::string& InputTensorName(int ind) const { return input_names_[ind]; } + const std::string& OutputTensorName(int ind) const { return output_names_[ind]; } + + // Datatype of the input and output tensor + /// TRITONSERVER_DataType TensorDataType() const { return datatype_; } + TRITONSERVER_DataType InputTensorDataType(int ind) const { return inp_dtypes_[ind]; } + TRITONSERVER_DataType OutputTensorDataType(int ind) const { return outp_dtypes_[ind]; } + + // Shape of the input and output tensor as given in the model + // configuration file. This shape will not include the batch + // dimension (if the model has one). + const std::vector& InputTensorNonBatchShape(int ind) const { return inp_nb_shapes_[ind]; } + const std::vector& OutputTensorNonBatchShape(int ind) const { return outp_nb_shapes_[ind]; } + + // Shape of the input and output tensor, including the batch + // dimension (if the model has one). This method cannot be called + // until the model is completely loaded and initialized, including + // all instances of the model. In practice, this means that backend + // should only call it in TRITONBACKEND_ModelInstanceExecute. + TRITONSERVER_Error* InputTensorShape(std::vector& shape, int ind); + TRITONSERVER_Error* OutputTensorShape(std::vector& shape, int ind); + + // Number of inputs and outputs + int NumOfInputs() { return n_inps; } + int NumOfOutputs() { return n_outps; } + + // Validate that this model is supported by this backend. + TRITONSERVER_Error* ValidateModelConfig(); + + private: + ModelState(TRITONBACKEND_Model* triton_model); + + std::vector input_names_; + std::vector output_names_; + + std::vector inp_dtypes_; + std::vector outp_dtypes_; + + std::vector inp_shape_initialized_; + std::vector outp_shape_initialized_; + std::vector> inp_nb_shapes_; // no-batch shape + std::vector> inp_shapes_; + std::vector> outp_nb_shapes_; // no-batch shape + std::vector> outp_shapes_; + + int n_inps; + int n_outps; +}; + +ModelState::ModelState(TRITONBACKEND_Model* triton_model) + : BackendModel(triton_model) +{ + // Validate that the model's configuration matches what is supported + // by this backend. + THROW_IF_BACKEND_MODEL_ERROR(ValidateModelConfig()); +} + +TRITONSERVER_Error* +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) +{ + try { + *state = new ModelState(triton_model); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + + +TRITONSERVER_Error* +ModelState::InputTensorShape(std::vector& shape, int ind) +{ + // This backend supports models that batch along the first dimension + // and those that don't batch. For non-batch models the output shape + // will be the shape from the model configuration. For batch models + // the output shape will be the shape from the model configuration + // prepended with [ -1 ] to represent the batch dimension. The + // backend "responder" utility used below will set the appropriate + // batch dimension value for each response. The shape needs to be + // initialized lazily because the SupportsFirstDimBatching function + // cannot be used until the model is completely loaded. + if (!inp_shape_initialized_[ind]) { + bool supports_first_dim_batching; + RETURN_IF_ERROR(SupportsFirstDimBatching(&supports_first_dim_batching)); + if (supports_first_dim_batching) { + inp_shapes_[ind].push_back(-1); + } + + inp_shapes_[ind].insert( + inp_shapes_[ind].end(), + inp_nb_shapes_[ind].begin(), + inp_nb_shapes_[ind].end()); + inp_shape_initialized_[ind] = true; + } + + shape = inp_shapes_[ind]; + + return nullptr; // success +} + + +TRITONSERVER_Error* +ModelState::OutputTensorShape(std::vector& shape, int ind) +{ + // This backend supports models that batch along the first dimension + // and those that don't batch. For non-batch models the output shape + // will be the shape from the model configuration. For batch models + // the output shape will be the shape from the model configuration + // prepended with [ -1 ] to represent the batch dimension. The + // backend "responder" utility used below will set the appropriate + // batch dimension value for each response. The shape needs to be + // initialized lazily because the SupportsFirstDimBatching function + // cannot be used until the model is completely loaded. + if (!outp_shape_initialized_[ind]) { + bool supports_first_dim_batching; + RETURN_IF_ERROR(SupportsFirstDimBatching(&supports_first_dim_batching)); + if (supports_first_dim_batching) { + outp_shapes_[ind].push_back(-1); + } + + outp_shapes_[ind].insert( + outp_shapes_[ind].end(), + outp_nb_shapes_[ind].begin(), + outp_nb_shapes_[ind].end()); + outp_shape_initialized_[ind] = true; + } + + shape = outp_shapes_[ind]; + + return nullptr; // success +} + + +TRITONSERVER_Error* +ModelState::ValidateModelConfig() +{ + // If verbose logging is enabled, dump the model's configuration as + // JSON into the console output. + if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) { + common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERROR(ModelConfig().PrettyWrite(&buffer)); + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("model configuration:\n") + buffer.Contents()).c_str()); + } + + // ModelConfig is the model configuration as a TritonJson + // object. Use the TritonJson utilities to parse the JSON and + // determine if the configuration is supported by this backend. + common::TritonJson::Value inputs, outputs; + RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &inputs)); + RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); + + + //// input name, data_type and shape from config.pbtxt + n_inps = inputs.ArraySize(); // num_inputs + inp_shapes_.resize(n_inps); + for (int i{0}; i < n_inps; ++i) { + + common::TritonJson::Value input; + RETURN_IF_ERROR(inputs.IndexAsObject(i, &input)); + + // name + const char* input_name; + size_t input_name_len; + RETURN_IF_ERROR(input.MemberAsString("name", &input_name, &input_name_len)); + input_names_.push_back(std::string(input_name)); + + // data type + std::string input_dtype; + RETURN_IF_ERROR(input.MemberAsString("data_type", &input_dtype)); + inp_dtypes_.push_back(ModelConfigDataTypeToTritonServerDataType(input_dtype)); + + // dims and shape + std::vector input_shape; + RETURN_IF_ERROR(backend::ParseShape(input, "dims", &input_shape)); + inp_nb_shapes_.push_back(input_shape); + inp_shape_initialized_.push_back(false); + + } + + //// out name, data_type and shape from config.pbtxt + n_outps = outputs.ArraySize(); // num_outputs + outp_shapes_.resize(n_outps); + for (int i{0}; i < n_outps; ++i) { + + common::TritonJson::Value output; + RETURN_IF_ERROR(outputs.IndexAsObject(i, &output)); + + // name + const char* output_name; + size_t output_name_len; + RETURN_IF_ERROR( + output.MemberAsString("name", &output_name, &output_name_len)); + output_names_.push_back(std::string(output_name)); + + // data type + std::string output_dtype; + RETURN_IF_ERROR(output.MemberAsString("data_type", &output_dtype)); + outp_dtypes_.push_back(ModelConfigDataTypeToTritonServerDataType(output_dtype)); + + // dims and shape + std::vector output_shape; + RETURN_IF_ERROR(backend::ParseShape(output, "dims", &output_shape)); + outp_nb_shapes_.push_back(output_shape); + outp_shape_initialized_.push_back(false); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInitialize when a model is loaded +// to allow the backend to create any state associated with the model, +// and to also examine the model configuration to determine if the +// configuration is suitable for the backend. Any errors reported by +// this function will prevent the model from loading. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) +{ + // Create a ModelState object and associate it with the + // TRITONBACKEND_Model. If anything goes wrong with initialization + // of the model state then an error is returned and Triton will fail + // to load the model. + ModelState* model_state; + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); + RETURN_IF_ERROR( + TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelFinalize when a model is no longer +// needed. The backend should cleanup any state associated with the +// model. This function will not be called until all model instances +// of the model have been finalized. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelFinalize(TRITONBACKEND_Model* model) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vstate)); + ModelState* model_state = reinterpret_cast(vstate); + delete model_state; + + return nullptr; // success +} + +} // extern "C" + + +///////////// + +// +// ModelInstanceState +// +// State associated with a model instance. An object of this class is +// created and associated with each +// TRITONBACKEND_ModelInstance. ModelInstanceState is derived from +// BackendModelInstance class provided in the backend utilities that +// provides many common functions. +// +class ModelInstanceState : public BackendModelInstance { + public: + static TRITONSERVER_Error* Create( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state); + virtual ~ModelInstanceState() = default; + + // Get the state of the model that corresponds to this instance. + ModelState* StateForModel() const { return model_state_; } + + private: + ModelInstanceState( + ModelState* model_state, + TRITONBACKEND_ModelInstance* triton_model_instance) + : BackendModelInstance(model_state, triton_model_instance), + model_state_(model_state) + { + } + + ModelState* model_state_; +}; + +TRITONSERVER_Error* +ModelInstanceState::Create( + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + ModelInstanceState** state) +{ + try { + *state = new ModelInstanceState(model_state, triton_model_instance); + } + catch (const BackendModelInstanceException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelInstanceException")); + RETURN_IF_ERROR(ex.err_); + } + + return nullptr; // success +} + +extern "C" { + +// Triton calls TRITONBACKEND_ModelInstanceInitialize when a model +// instance is created to allow the backend to initialize any state +// associated with the instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) +{ + // Get the model state associated with this instance's model. + TRITONBACKEND_Model* model; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceModel(instance, &model)); + + void* vmodelstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelState(model, &vmodelstate)); + ModelState* model_state = reinterpret_cast(vmodelstate); + + // Create a ModelInstanceState object and associate it with the + // TRITONBACKEND_ModelInstance. + ModelInstanceState* instance_state; + RETURN_IF_ERROR( + ModelInstanceState::Create(model_state, instance, &instance_state)); + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( + instance, reinterpret_cast(instance_state))); + + return nullptr; // success +} + +// Triton calls TRITONBACKEND_ModelInstanceFinalize when a model +// instance is no longer needed. The backend should cleanup any state +// associated with the model instance. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance) +{ + void* vstate; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState(instance, &vstate)); + ModelInstanceState* instance_state = + reinterpret_cast(vstate); + delete instance_state; + + return nullptr; // success +} + +} // extern "C" + +///////////// + +extern "C" { + +// When Triton calls TRITONBACKEND_ModelInstanceExecute it is required +// that a backend create a response for each request in the batch. A +// response may be the output tensors required for that request or may +// be an error that is returned in the response. +// +TRITONSERVER_Error* +TRITONBACKEND_ModelInstanceExecute( + TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests, + const uint32_t request_count) +{ + // Collect various timestamps during the execution of this batch or + // requests. These values are reported below before returning from + // the function. + + uint64_t exec_start_ns = 0; + SET_TIMESTAMP(exec_start_ns); + + // Triton will not call this function simultaneously for the same + // 'instance'. But since this backend could be used by multiple + // instances from multiple models the implementation needs to handle + // multiple calls to this function at the same time (with different + // 'instance' objects). Best practice for a high-performance + // implementation is to avoid introducing mutex/lock and instead use + // only function-local and model-instance-specific state. + ModelInstanceState* instance_state; + RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceState( + instance, reinterpret_cast(&instance_state))); + ModelState* model_state = instance_state->StateForModel(); + + // 'responses' is initialized as a parallel array to 'requests', + // with one TRITONBACKEND_Response object for each + // TRITONBACKEND_Request object. If something goes wrong while + // creating these response objects, the backend simply returns an + // error from TRITONBACKEND_ModelInstanceExecute, indicating to + // Triton that this backend did not create or send any responses and + // so it is up to Triton to create and send an appropriate error + // response for each request. RETURN_IF_ERROR is one of several + // useful macros for error handling that can be found in + // backend_common.h. + + std::vector responses; + responses.reserve(request_count); + for (uint32_t r = 0; r < request_count; ++r) { + TRITONBACKEND_Request* request = requests[r]; + TRITONBACKEND_Response* response; + RETURN_IF_ERROR(TRITONBACKEND_ResponseNew(&response, request)); + responses.push_back(response); + } + + // At this point, the backend takes ownership of 'requests', which + // means that it is responsible for sending a response for every + // request. From here, even if something goes wrong in processing, + // the backend must return 'nullptr' from this function to indicate + // success. Any errors and failures must be communicated via the + // response objects. + // + // To simplify error handling, the backend utilities manage + // 'responses' in a specific way and it is recommended that backends + // follow this same pattern. When an error is detected in the + // processing of a request, an appropriate error response is sent + // and the corresponding TRITONBACKEND_Response object within + // 'responses' is set to nullptr to indicate that the + // request/response has already been handled and no futher processing + // should be performed for that request. Even if all responses fail, + // the backend still allows execution to flow to the end of the + // function so that statistics are correctly reported by the calls + // to TRITONBACKEND_ModelInstanceReportStatistics and + // TRITONBACKEND_ModelInstanceReportBatchStatistics. + // RESPOND_AND_SET_NULL_IF_ERROR, and + // RESPOND_ALL_AND_SET_NULL_IF_ERROR are macros from + // backend_common.h that assist in this management of response + // objects. + + // The backend could iterate over the 'requests' and process each + // one separately. But for performance reasons it is usually + // preferred to create batched input tensors that are processed + // simultaneously. This is especially true for devices like GPUs + // that are capable of exploiting the large amount parallelism + // exposed by larger data sets. + // + // The backend utilities provide a "collector" to facilitate this + // batching process. The 'collector's ProcessTensor function will + // combine a tensor's value from each request in the batch into a + // single contiguous buffer. The buffer can be provided by the + // backend or 'collector' can create and manage it. In this backend, + // there is not a specific buffer into which the batch should be + // created, so use ProcessTensor arguments that cause collector to + // manage it. + + BackendInputCollector collector( + requests, request_count, &responses, model_state->TritonMemoryManager(), + false /* pinned_enabled */, nullptr /* stream*/); + + // To instruct ProcessTensor to "gather" the entire batch of input + // tensors into a single contiguous buffer in CPU memory, set the + // "allowed input types" to be the CPU ones (see tritonserver.h in + // the triton-inference-server/core repo for allowed memory types). + std::vector> allowed_input_types = + {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + /// Coin: Here we get input buffers and meta info, such that we can process + /// them and implement the logic of the backend + int n_inps = model_state->NumOfInputs(); + std::vector inputs(n_inps); + for (int i{0}; i < n_inps; ++i) { + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + collector.ProcessTensor( + model_state->InputTensorName(i).c_str(), nullptr /* existing_buffer */, + 0 /* existing_buffer_byte_size */, allowed_input_types, + &inputs[i].buffer, + &inputs[i].buffer_byte_size, + &inputs[i].buffer_memory_type, + &inputs[i].buffer_memory_type_id)); + } + + int n_outps = model_state->NumOfOutputs(); + std::vector outputs(n_outps); + for (int i{0}; i < n_outps; ++i) { + outputs[i].buffer_memory_type = TRITONSERVER_MEMORY_CPU; + outputs[i].buffer_memory_type_id = 0; + model_state->OutputTensorShape(outputs[i].shape, i); + int64_t byte_size = TRITONSERVER_DataTypeByteSize( + model_state->OutputTensorDataType(i)); + int64_t n_pixels = std::accumulate( + outputs[i].shape.begin()+1, outputs[i].shape.end(), + 1, std::multiplies()); + outputs[i].buffer_byte_size = n_pixels * byte_size; + + /// print datatype + /// cout << TRITONSERVER_DataTypeString(model_state->OutputTensorDataType(i)) << endl; + } + + //// Coin: here we implement our logic + callback_func(inputs, outputs); + + + // Finalize the collector. If 'true' is returned, 'input_buffer' + // will not be valid until the backend synchronizes the CUDA + // stream or event that was used when creating the collector. For + // this backend, GPU is not supported and so no CUDA sync should + // be needed; so if 'true' is returned simply log an error. + const bool need_cuda_input_sync = collector.Finalize(); + if (need_cuda_input_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'recommended' backend: unexpected CUDA sync required by collector"); + } + + // 'input_buffer' contains the batched input tensor. The backend can + // implement whatever logic is necessary to produce the output + // tensor. This backend simply logs the input tensor value and then + // returns the input tensor value in the output tensor so no actual + // computation is needed. + + uint64_t compute_start_ns = 0; + SET_TIMESTAMP(compute_start_ns); + + /* + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("model ") + model_state->Name() + ": requests in batch " + + std::to_string(request_count)) + .c_str()); + std::string tstr; + IGNORE_ERROR(BufferAsTypedString( + tstr, input_buffer, input_buffer_byte_size, + model_state->TensorDataType())); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("batched " + model_state->InputTensorName() + " value: ") + + tstr) + .c_str()); + */ + + // const char* output_buffer = input_buffer; + // TRITONSERVER_MemoryType output_buffer_memory_type = input_buffer_memory_type; + // int64_t output_buffer_memory_type_id = input_buffer_memory_type_id; + + uint64_t compute_end_ns = 0; + SET_TIMESTAMP(compute_end_ns); + + bool supports_first_dim_batching; + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, + model_state->SupportsFirstDimBatching(&supports_first_dim_batching)); + + std::vector> outp_shapes(n_outps); + for (int i{0}; i < n_outps; ++i) { + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + responses, request_count, model_state->OutputTensorShape( + outp_shapes[i], i)); + } + + // Because the output tensor values are concatenated into a single + // contiguous 'output_buffer', the backend must "scatter" them out + // to the individual response output tensors. The backend utilities + // provide a "responder" to facilitate this scattering process. + + // The 'responders's ProcessTensor function will copy the portion of + // 'output_buffer' corresonding to each request's output into the + // response for that request. + + BackendOutputResponder responder( + requests, request_count, &responses, model_state->TritonMemoryManager(), + supports_first_dim_batching, false /* pinned_enabled */, + nullptr /* stream*/); + + for (int i{0}; i < n_outps; ++i) { + const char* output_buffer = outputs[i].buffer.get(); + //const char* output_buffer = new char[16]; + responder.ProcessTensor( + model_state->OutputTensorName(i).c_str(), model_state->OutputTensorDataType(i), + outp_shapes[i], output_buffer, outputs[i].buffer_memory_type, + outputs[i].buffer_memory_type_id); + + } + + // Finalize the responder. If 'true' is returned, the output + // tensors' data will not be valid until the backend synchronizes + // the CUDA stream or event that was used when creating the + // responder. For this backend, GPU is not supported and so no CUDA + // sync should be needed; so if 'true' is returned simply log an + // error. + const bool need_cuda_output_sync = responder.Finalize(); + if (need_cuda_output_sync) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "'recommended' backend: unexpected CUDA sync required by responder"); + } + + // Send all the responses that haven't already been sent because of + // an earlier error. + for (auto& response : responses) { + if (response != nullptr) { + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr), + "failed to send response"); + } + } + + uint64_t exec_end_ns = 0; + SET_TIMESTAMP(exec_end_ns); + +#ifdef TRITON_ENABLE_STATS + // For batch statistics need to know the total batch size of the + // requests. This is not necessarily just the number of requests, + // because if the model supports batching then any request can be a + // batched request itself. + size_t total_batch_size = 0; + if (!supports_first_dim_batching) { + total_batch_size = request_count; + } else { + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + TRITONBACKEND_Input* input = nullptr; + LOG_IF_ERROR( + TRITONBACKEND_RequestInputByIndex(request, 0 /* index */, &input), + "failed getting request input"); + if (input != nullptr) { + const int64_t* shape = nullptr; + LOG_IF_ERROR( + TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr), + "failed getting input properties"); + if (shape != nullptr) { + total_batch_size += shape[0]; + } + } + } + } +#else + (void)exec_start_ns; + (void)exec_end_ns; + (void)compute_start_ns; + (void)compute_end_ns; +#endif // TRITON_ENABLE_STATS + + // Report statistics for each request, and then release the request. + for (uint32_t r = 0; r < request_count; ++r) { + auto& request = requests[r]; + +#ifdef TRITON_ENABLE_STATS + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportStatistics( + instance_state->TritonModelInstance(), request, + (responses[r] != nullptr) /* success */, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting request statistics"); +#endif // TRITON_ENABLE_STATS + + LOG_IF_ERROR( + TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), + "failed releasing request"); + } + +#ifdef TRITON_ENABLE_STATS + // Report batch statistics. + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportBatchStatistics( + instance_state->TritonModelInstance(), total_batch_size, + exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); +#endif // TRITON_ENABLE_STATS + + return nullptr; // success +} + +} // extern "C" + +}}} // namespace triton::backend::recommended