diff --git a/.gitignore b/.gitignore index 3a7a15e..139c8f7 100644 --- a/.gitignore +++ b/.gitignore @@ -117,5 +117,7 @@ pretrained/* dist_train.sh openvino/build/* openvino/output* +*.onnx +tis/cpp_client/build/* tvm/ diff --git a/README.md b/README.md index 57b042c..7ce5b87 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ You can go to [ncnn](./ncnn) for details. 3. openvino You can go to [openvino](./openvino) for details. +4. tis +Triton Inference Server(TIS) provides a service solution of deployment. You can go to [tis](./tis) for details. + ## platform @@ -163,3 +166,4 @@ $ python tools/evaluate.py --config configs/bisenetv1_city.py --weight-path /pat ### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation if you need, though I believe you will not need it. +#let me see see dev branch \ No newline at end of file diff --git a/openvino/README.md b/openvino/README.md index ca9123e..331ce8d 100644 --- a/openvino/README.md +++ b/openvino/README.md @@ -12,7 +12,7 @@ My cpu is Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz. 1.Train the model and export it to onnx ``` $ cd BiSeNet/ -$ python tools/export_onnx.py --aux-mode eval --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model_v2.onnx +$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model_v2.onnx ``` (Optional) 2.Install 'onnx-simplifier' to simplify the generated onnx model: ``` diff --git a/tis/README.md b/tis/README.md new file mode 100644 index 0000000..6461811 --- /dev/null +++ b/tis/README.md @@ -0,0 +1,95 @@ + + +## A simple demo of using trition-inference-serving + +### Platform + +* ubuntu 18.04 +* cmake-3.22.0 +* 8 Tesla T4 gpu + + +### Serving Model + +#### 1. prepare model repository + +We need to export our model to onnx and copy it to model repository: +``` +$ cd BiSeNet +$ python tools/export_onnx.py --config configs/bisenetv1_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx +$ cp -riv ./model.onnx tis/models/bisenetv1/1 + +$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx +$ cp -riv ./model.onnx tis/models/bisenetv2/1 +``` + +#### 2. start service +We start serving with docker: +``` +$ docker pull nvcr.io/nvidia/tritonserver:21.10-py3 +$ docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models +``` + +In general, the service would start now. You can check whether service has started by: +``` +$ curl -v localhost:8000/v2/health/ready +``` + +By default, we use gpu 0 and gpu 1, you can change configurations in the `config.pbtxt` file. + + +### Client + +We call the model service with both python and c++ method. + + +#### 1. python method + +Firstly, we need to install dependency package: +``` +$ python -m pip install tritonclient[all]==2.15.0 +``` + +Then we can run the script: +``` +$ cd BiSeNet/tis +$ python client.py +``` + +This would generate a result file named `res.jpg` in `BiSeNet/tis` directory. + + +#### 2. c++ method + +We need to compile c++ client library from source: +``` +$ apt install rapidjson-dev +$ mkdir -p /data/ $$ cd /data/ +$ git clone https://github.com/triton-inference-server/client.git +$ cd client && git reset --hard da04158bc094925a56b +$ mkdir -p build && cd build +$ cmake -DCMAKE_INSTALL_PREFIX=/opt/triton_client -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_CC_GRPC=ON -DTRITON_ENABLE_PERF_ANALYZER=OFF -DTRITON_ENABLE_PYTHON_HTTP=OFF -DTRITON_ENABLE_PYTHON_GRPC=OFF -DTRITON_ENABLE_JAVA_HTTP=OFF -DTRITON_ENABLE_GPU=ON -DTRITON_ENABLE_EXAMPLES=OFF -DTRITON_ENABLE_TESTS=ON .. +$ make cc-clients +``` +The above commands are exactly what I used to compile the library. I learned these commands from the official document. + +Also, We need to install `cmake` with version `3.22`. + +Optionally, I compiled opencv from source and install it to `/opt/opencv`. You can first skip this and see whether you meet problems. If you have problems about opencv in the following steps, you can compile opencv as what I do. + +After installing the dependencies, we can compile our c++ client: +``` +$ cd BiSeNet/tis/cpp_client +$ mkdir -p build && cd build +$ cmake .. && make +``` + +Finally, we run the client and see a result file named `res.jpg` generated: +``` + ./client +``` + + +### In the end + +This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and model pipeline. If you have interest on this, you can learn more in the official document. diff --git a/tis/client.py b/tis/client.py new file mode 100644 index 0000000..33cc925 --- /dev/null +++ b/tis/client.py @@ -0,0 +1,88 @@ + +import numpy as np +import cv2 + +import grpc + +from tritonclient.grpc import service_pb2, service_pb2_grpc +import tritonclient.grpc.model_config_pb2 as mc + + +np.random.seed(123) +palette = np.random.randint(0, 256, (100, 3)) + + + +# url = '10.128.61.7:8001' +url = '127.0.0.1:8001' +model_name = 'bisenetv2' +model_version = '1' +inp_name = 'input_image' +outp_name = 'preds' +inp_dtype = 'FP32' +outp_dtype = np.int64 +inp_shape = [1, 3, 1024, 2048] +outp_shape = [1024, 2048] +impth = '../example.png' +mean = [0.3257, 0.3690, 0.3223] # city, rgb +std = [0.2112, 0.2148, 0.2115] + + +option = [ + ('grpc.max_receive_message_length', 1073741824), + ('grpc.max_send_message_length', 1073741824), + ] +channel = grpc.insecure_channel(url, options=option) +grpc_stub = service_pb2_grpc.GRPCInferenceServiceStub(channel) + + +metadata_request = service_pb2.ModelMetadataRequest( + name=model_name, version=model_version) +metadata_response = grpc_stub.ModelMetadata(metadata_request) +print(metadata_response) + +config_request = service_pb2.ModelConfigRequest( + name=model_name, + version=model_version) +config_response = grpc_stub.ModelConfig(config_request) +print(config_response) + + +request = service_pb2.ModelInferRequest() +request.model_name = model_name +request.model_version = model_version + +inp = service_pb2.ModelInferRequest().InferInputTensor() +inp.name = inp_name +inp.datatype = inp_dtype +inp.shape.extend(inp_shape) + + +mean = np.array(mean).reshape(1, 1, 3) +std = np.array(std).reshape(1, 1, 3) +im = cv2.imread(impth)[:, :, ::-1] +im = cv2.resize(im, dsize=tuple(inp_shape[-1:-3:-1])) +im = ((im / 255.) - mean) / std +im = im[None, ...].transpose(0, 3, 1, 2) +inp_bytes = im.astype(np.float32).tobytes() + +request.ClearField("inputs") +request.ClearField("raw_input_contents") +request.inputs.extend([inp,]) +request.raw_input_contents.extend([inp_bytes,]) + + +outp = service_pb2.ModelInferRequest().InferRequestedOutputTensor() +outp.name = outp_name +request.outputs.extend([outp,]) + +# sync +# resp = grpc_stub.ModelInfer(request).raw_output_contents[0] +# async +resp = grpc_stub.ModelInfer.future(request) +resp = resp.result().raw_output_contents[0] + +out = np.frombuffer(resp, dtype=outp_dtype).reshape(*outp_shape) + +out = palette[out] +cv2.imwrite('res.png', out) diff --git a/tis/cpp_client/CMakeLists.txt b/tis/cpp_client/CMakeLists.txt new file mode 100644 index 0000000..92fbd6c --- /dev/null +++ b/tis/cpp_client/CMakeLists.txt @@ -0,0 +1,29 @@ +cmake_minimum_required (VERSION 3.18) + +project(Samples) + +set(CMAKE_CXX_FLAGS "-std=c++14 -O1") +set(CMAKE_BUILD_TYPE Release) + +set(CMAKE_PREFIX_PATH + /opt/triton_client/ + /opt/opencv/lib/cmake/opencv4) +find_package(OpenCV REQUIRED) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${OpenCV_INCLUDE_DIRS} + /opt/triton_client/include +) +link_directories( + /opt/triton_client/lib + ) + + +add_executable(client main.cpp) +target_link_libraries(client PRIVATE + grpcclient + ${OpenCV_LIBS} + -lpthread + ) diff --git a/tis/cpp_client/main.cpp b/tis/cpp_client/main.cpp new file mode 100644 index 0000000..56f5ea3 --- /dev/null +++ b/tis/cpp_client/main.cpp @@ -0,0 +1,330 @@ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "grpc_client.h" +#include "common.h" + +#include + + +namespace tc = triton::client; + + +#define FAIL_IF_ERR(X, MSG) \ + { \ + tc::Error err = (X); \ + if (!err.IsOk()) { \ + std::cerr << "error: " << (MSG) << ": " << err << std::endl; \ + exit(1); \ + } \ + } + + +// std::string url("10.128.61.8:8001"); +std::string url("127.0.0.1:8001"); +std::string model_name("bisenetv1"); +std::string model_version("1"); +uint32_t client_timeout{5000000}; +bool verbose = false; + +std::string impth("../../../example.png"); +std::string savepth("./res.jpg"); +std::vector inp_shape{1, 3, 1024, 2048}; +std::vector outp_shape{1, 1024, 2048}; +std::string inp_name("input_image"); +std::string outp_name("preds"); +std::string inp_type("FP32"); + + + +std::vector> get_color_map(); +std::vector get_image(std::string, std::vector&); +void save_predict(std::string, int64_t*, + std::vector, std::vector); +void do_inference(); +void print_infos(); +void test_speed(); + + +int main() { + // print_infos(); + do_inference(); + // test_speed(); + return 0; +} + + +void do_inference() { + + // define client + std::unique_ptr client; + FAIL_IF_ERR( + tc::InferenceServerGrpcClient::Create(&client, url, verbose), + "unable to create grpc client"); + std::cout << "create client\n"; + + // create input + std::vector input_data = get_image(impth, inp_shape); + std::cout << "read image: " << impth << std::endl; + + tc::InferInput* input; + FAIL_IF_ERR( + tc::InferInput::Create(&input, inp_name, inp_shape, inp_type), + "unable to get input"); + std::shared_ptr input_ptr; + input_ptr.reset(input); + FAIL_IF_ERR(input_ptr->Reset(), // reset input + "unable to reset input data"); + FAIL_IF_ERR( + input_ptr->AppendRaw( + reinterpret_cast(&input_data[0]), + input_data.size() * sizeof(float)), // NOTE: float can be others according to input type + "unable to set data for input"); + std::cout << "set input\n"; + + + // create output + tc::InferRequestedOutput* output; + FAIL_IF_ERR( + tc::InferRequestedOutput::Create(&output, outp_name), + "unable to get output"); + std::shared_ptr output_ptr; + output_ptr.reset(output); + std::cout << "set output\n"; + + // infer options + tc::InferOptions options(model_name); + options.model_version_ = model_version; + options.client_timeout_ = client_timeout; + tc::Headers http_headers; + grpc_compression_algorithm compression_algorithm = + grpc_compression_algorithm::GRPC_COMPRESS_NONE; + std::cout << "set options\n"; + + // inference + std::vector inputs = {input_ptr.get()}; + std::vector outputs = {output_ptr.get()}; + tc::InferResult* results; + FAIL_IF_ERR( + client->Infer( + &results, options, inputs, outputs, http_headers, + compression_algorithm), + "failed sending synchronous infer request"); + std::shared_ptr results_ptr; + results_ptr.reset(results); + FAIL_IF_ERR( + results_ptr->RequestStatus(), + "inference failed"); + std::cout << "send request and do inference\n"; + + // parse output + int64_t* raw_oup{nullptr}; // NOTE: int64_t is used according to model + size_t n_bytes{0}; + FAIL_IF_ERR( + results_ptr->RawData( + outp_name, (const uint8_t**)(&raw_oup), &n_bytes), + "fetch output failed"); + if (n_bytes != outp_shape[1] * outp_shape[2] * sizeof(int64_t)) { + std::cerr << "output shape is not set correctly\n"; + exit(1); + } + std::cout << "fetch output\n"; + + // save colorful result + save_predict(savepth, raw_oup, inp_shape, outp_shape); + std::cout << "save inference result to:" << savepth << std::endl; +} + + +std::vector get_image(std::string impth, std::vector& shape) { + int64_t iH = shape[2]; + int64_t iW = shape[3]; + cv::Mat im = cv::imread(impth); + if (im.empty()) { + std::cerr << "cv::imread failed: " << impth << std::endl; + exit(1); + } + int64_t orgH{im.rows}, orgW{im.cols}; + if ((orgH != iH) || orgW != iW) { + std::cout << "resize orignal image of (" << orgH << "," << orgW + << ") to (" << iH << ", " << iW << ") according to model requirement\n"; + cv::resize(im, im, cv::Size(iW, iH), cv::INTER_CUBIC); + } + + std::vector data(iH * iW * 3); + float mean[3] = {0.3257f, 0.3690f, 0.3223f}; + float var[3] = {0.2112f, 0.2148f, 0.2115f}; + float scale = 1.f / 255.f; + for (float &el : var) el = 1.f / el; + for (int h{0}; h < iH; ++h) { + cv::Vec3b *p = im.ptr(h); + for (int w{0}; w < iW; ++w) { + for (int c{0}; c < 3; ++c) { + int idx = (2 - c) * iH * iW + h * iW + w; // to rgb order + data[idx] = (p[w][c] * scale - mean[c]) * var[c]; + } + } + } + return data; +} + + +std::vector> get_color_map() { + std::vector> color_map(256, + std::vector(3)); + std::minstd_rand rand_eng(123); + std::uniform_int_distribution u(0, 255); + for (int i{0}; i < 256; ++i) { + for (int j{0}; j < 3; ++j) { + color_map[i][j] = u(rand_eng); + } + } + return color_map; +} + + +void save_predict(std::string savename, int64_t* data, + std::vector insize, + std::vector outsize) { + + std::vector> color_map = get_color_map(); + int64_t oH = outsize[1]; + int64_t oW = outsize[2]; + cv::Mat pred(cv::Size(oW, oH), CV_8UC3); + int idx{0}; + for (int i{0}; i < oH; ++i) { + uint8_t *ptr = pred.ptr(i); + for (int j{0}; j < oW; ++j) { + ptr[0] = color_map[data[idx]][0]; + ptr[1] = color_map[data[idx]][1]; + ptr[2] = color_map[data[idx]][2]; + ptr += 3; + ++idx; + } + } + cv::imwrite(savename, pred); +} + + +void print_infos() { + // define client + std::unique_ptr client; + FAIL_IF_ERR( + tc::InferenceServerGrpcClient::Create(&client, url, verbose), + "unable to create grpc client"); + + tc::Headers http_headers; + inference::ModelConfigResponse model_config; + FAIL_IF_ERR( + client->ModelConfig( + &model_config, model_name, model_version, http_headers), + "unable to get config"); + + inference::ModelMetadataResponse model_metadata; + FAIL_IF_ERR( + client->ModelMetadata( + &model_metadata, model_name, model_version, http_headers), + "unable to get meta data"); + + std::cout << "---- model info ----" << std::endl; + auto input = model_metadata.inputs(0); + auto output = model_metadata.outputs(0); + std::cout << "name: " << model_metadata.name() << std::endl; + std::cout << "platform: " << model_metadata.platform() << std::endl; + std::cout << "max_batch_size: " << model_config.config().max_batch_size() << std::endl; + + int size; + size = input.shape().size(); + std::cout << input.name() << ": \n size: ("; + for (int i{0}; i < size; ++i) { + std::cout << input.shape()[i] << ", "; + } + std::cout << ")\n data_type: " << input.datatype() << std::endl;; + size = output.shape().size(); + std::cout << output.name() << ": \n size: ("; + for (int i{0}; i < size; ++i) { + std::cout << output.shape()[i] << ", "; + } + std::cout << ")\n data_type: " << output.datatype() << std::endl;; + std::cout << "--------------------" << std::endl; +} + + +void test_speed() { + // define client + std::unique_ptr client; + FAIL_IF_ERR( + tc::InferenceServerGrpcClient::Create(&client, url, verbose), + "unable to create grpc client"); + + // create input + std::vector input_data(std::accumulate( + inp_shape.begin(), inp_shape.end(), + 1, std::multiplies()) + ); + tc::InferInput* input; + FAIL_IF_ERR( + tc::InferInput::Create(&input, inp_name, inp_shape, inp_type), + "unable to get input"); + std::shared_ptr input_ptr; + input_ptr.reset(input); + FAIL_IF_ERR(input_ptr->Reset(), // reset input + "unable to reset input data"); + FAIL_IF_ERR( + input_ptr->AppendRaw( + reinterpret_cast(&input_data[0]), + input_data.size() * sizeof(float)), // NOTE: float can be others according to input type + "unable to set data for input"); + + // create output + tc::InferRequestedOutput* output; + FAIL_IF_ERR( + tc::InferRequestedOutput::Create(&output, outp_name), + "unable to get output"); + std::shared_ptr output_ptr; + output_ptr.reset(output); + + // infer options + tc::InferOptions options(model_name); + options.model_version_ = model_version; + options.client_timeout_ = client_timeout; + tc::Headers http_headers; + grpc_compression_algorithm compression_algorithm = + grpc_compression_algorithm::GRPC_COMPRESS_NONE; + + // inference + std::vector inputs = {input_ptr.get()}; + std::vector outputs = {output_ptr.get()}; + tc::InferResult* results; + + std::cout << "test speed ... \n"; + const int n_loops{500}; + auto start = std::chrono::steady_clock::now(); + for (int i{0}; i < n_loops; ++i) { + FAIL_IF_ERR( + client->Infer( + &results, options, inputs, outputs, http_headers, + compression_algorithm), + "failed sending synchronous infer request"); + } + auto end = std::chrono::steady_clock::now(); + + std::shared_ptr results_ptr; + results_ptr.reset(results); + FAIL_IF_ERR( + results_ptr->RequestStatus(), + "inference failed"); + + double duration = std::chrono::duration(end - start).count(); + duration /= 1000.; + std::cout << "running " << n_loops << " times, use time: " + << duration << "s" << std::endl; + std::cout << "fps is: " << static_cast(n_loops) / duration << std::endl; +} diff --git a/tis/models/bisenetv1/config.pbtxt b/tis/models/bisenetv1/config.pbtxt new file mode 100644 index 0000000..a1d94fb --- /dev/null +++ b/tis/models/bisenetv1/config.pbtxt @@ -0,0 +1,30 @@ +platform: "onnxruntime_onnx" +max_batch_size: 0 +input [ +{ + name: "input_image" + data_type: TYPE_FP32 + dims: [ 1, 3, 1024, 2048 ] +} +] +output [ +{ + name: "preds" + data_type: TYPE_INT64 + dims: [ 1, 1024, 2048 ] +} +] +optimization { execution_accelerators { + gpu_execution_accelerator : [ { + name : "tensorrt" + parameters { key: "precision_mode" value: "FP16" } + parameters { key: "max_workspace_size_bytes" value: "4294967296" } + }] +}} +instance_group [ +{ + count: 2 + kind: KIND_GPU + gpus: [ 0, 1 ] +} +] diff --git a/tis/models/bisenetv2/config.pbtxt b/tis/models/bisenetv2/config.pbtxt new file mode 100644 index 0000000..d5501d3 --- /dev/null +++ b/tis/models/bisenetv2/config.pbtxt @@ -0,0 +1,30 @@ +platform: "onnxruntime_onnx" +max_batch_size: 0 +input [ +{ + name: "input_image" + data_type: TYPE_FP32 + dims: [1, 3, 1024, 2048 ] +} +] +output [ +{ + name: "preds" + data_type: TYPE_INT64 + dims: [1, 1024, 2048 ] +} +] +optimization { execution_accelerators { + gpu_execution_accelerator : [ { + name : "tensorrt" + parameters { key: "precision_mode" value: "FP16" } + parameters { key: "max_workspace_size_bytes" value: "4294967296" } + }] +}} +instance_group [ +{ + count: 2 + kind: KIND_GPU + gpus: [ 0, 1 ] +} +] diff --git a/tools/demo_video.py b/tools/demo_video.py index 6946862..59cc12d 100644 --- a/tools/demo_video.py +++ b/tools/demo_video.py @@ -5,10 +5,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.multiprocessing import Process, Queue +import time from PIL import Image import numpy as np import cv2 -from torch.multiprocessing import Process, Queue import lib.transform_cv2 as T from lib.models import model_factory @@ -60,6 +61,7 @@ def get_func(inpth, in_q): in_q.put('quit') while not in_q.empty(): continue cap.release() + time.sleep(1) print('input queue done')