From 7d08c9caf1d34f931922c23312065ecb3a68f0e2 Mon Sep 17 00:00:00 2001 From: YuShiquan Date: Mon, 27 Dec 2021 14:08:38 +0800 Subject: [PATCH] Complete TensorRT C++ example! (#257) * Update TRT inference examples * Save image instead of show image * Add readme --- deployment/tensorrt/README.md | 33 ++++++++ deployment/tensorrt/main.cpp | 144 ++++++++++++++++++++++++---------- 2 files changed, 136 insertions(+), 41 deletions(-) create mode 100644 deployment/tensorrt/README.md diff --git a/deployment/tensorrt/README.md b/deployment/tensorrt/README.md new file mode 100644 index 00000000..6c40c783 --- /dev/null +++ b/deployment/tensorrt/README.md @@ -0,0 +1,33 @@ +# TensorRT Inference + +The TensorRT inference for `yolort`, support GPU only. + +## Dependencies + +- TensorRT 8.x + +## Usage + +1. Create build director and cmake config. + + ```bash + mkdir -p build/ && cd build/ + cmake .. -DTENSORRT_DIR=${your_trt_install_director} + ``` + +1. Build project + + ```bash + make + ``` + +1. Export your custom model to ONNX(see [onnx-graphsurgeon-inference-tensorrt](https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb)). + +1. Now, you can infer your own images. + + ```bash + ./yolort_trt [--image ../../../test/assets/zidane.jpg] + [--model_path ../../../notebooks/yolov5s.onnx] + [--class_names ../../../notebooks/assets/coco.names] + [--fp16] # Enable it if your GPU support fp16 inference + ``` diff --git a/deployment/tensorrt/main.cpp b/deployment/tensorrt/main.cpp index fb87555e..31cd0080 100644 --- a/deployment/tensorrt/main.cpp +++ b/deployment/tensorrt/main.cpp @@ -58,6 +58,47 @@ inline size_t getElementSize(nvinfer1::DataType t) noexcept { return 0; } +void visualizeDetection( + cv::Mat& image, + std::vector& detections, + const std::vector& classNames) { + for (const Detection& detection : detections) { + cv::rectangle(image, detection.box, cv::Scalar(229, 160, 21), 2); + + int x = detection.box.x; + int y = detection.box.y; + + int conf = (int)(detection.conf * 100); + int classId = detection.classId; + std::string label = classNames[classId] + " 0." + std::to_string(conf); + + int baseline = 0; + cv::Size size = cv::getTextSize(label, cv::FONT_ITALIC, 0.8, 2, &baseline); + cv::rectangle( + image, cv::Point(x, y - 25), cv::Point(x + size.width, y), cv::Scalar(229, 160, 21), -1); + + cv::putText( + image, label, cv::Point(x, y - 3), cv::FONT_ITALIC, 0.8, cv::Scalar(255, 255, 255), 2); + } +} + +std::vector loadNames(const std::string& path) { + // load class names + std::vector classNames; + std::ifstream infile(path); + if (infile.good()) { + std::string line; + while (getline(infile, line)) { + classNames.emplace_back(line); + } + infile.close(); + } else { + std::cerr << "ERROR: Failed to access class name path: " << path << std::endl; + } + + return classNames; +} + ICudaEngine* CreateCudaEngineFromOnnx( MyLogger& logger, const char* onnx_path, @@ -183,7 +224,7 @@ class YOLOv5Detector { int max_batch_size = 1, bool enable_int8 = false, bool enable_fp16 = false); - virtual ~YOLOv5Detector() = default; + virtual ~YOLOv5Detector(); YOLOv5Detector(const YOLOv5Detector&) = delete; YOLOv5Detector& operator=(const YOLOv5Detector&) = delete; @@ -194,6 +235,7 @@ class YOLOv5Detector { MyLogger logger; std::unique_ptr engine; std::unique_ptr context; + cudaStream_t stream; }; /* class YOLOv5Detector */ YOLOv5Detector::YOLOv5Detector( @@ -203,7 +245,15 @@ YOLOv5Detector::YOLOv5Detector( bool enable_fp16) : engine( {CreateCudaEngineFromOnnx(logger, model_path, max_batch_size, enable_int8, enable_fp16)}), - context({engine->createExecutionContext()}) {} + context({engine->createExecutionContext()}) { + CHECK(cudaStreamCreate(&stream)); +} + +YOLOv5Detector::~YOLOv5Detector() { + if (stream) { + CHECK(cudaStreamDestroy(stream)); + } +} std::vector YOLOv5Detector::detect(cv::Mat& image) { std::vector result; @@ -212,12 +262,16 @@ std::vector YOLOv5Detector::detect(cv::Mat& image) { int num_detections_index = engine->getBindingIndex("num_detections"); int detection_boxes_index = engine->getBindingIndex("detection_boxes"); int detection_scores_index = engine->getBindingIndex("detection_scores"); - int detection_labels_index = engine->getBindingIndex("detection_labels"); + int detection_labels_index = engine->getBindingIndex("detection_classes"); int32_t num_detections = 0; - float* detection_boxes = nullptr; - float* detection_scores = nullptr; - int32_t* detection_labels = nullptr; + std::vector detection_boxes; + std::vector detection_scores; + std::vector detection_labels; + + Dims dim = engine->getBindingDimensions(0); + dim.d[0] = batch_size; + context->setBindingDimensions(0, dim); for (int32_t i = 0; i < engine->getNbBindings(); i++) { { @@ -231,90 +285,92 @@ std::vector YOLOv5Detector::detect(cv::Mat& image) { std::cout << ")" << "}" << std::endl; } - Dims dim = engine->getBindingDimensions(i); - size_t buffer_size = batch_size; - // FIXME: 此处如果为 dynamic input,部分形状为 -1 - for (int j = 1; j < engine->getBindingDimensions(i).nbDims; j++) { + + size_t buffer_size = 1; + for (int j = 0; j < engine->getBindingDimensions(i).nbDims; j++) { buffer_size *= engine->getBindingDimensions(i).d[j]; } CHECK(cudaMalloc(&buffers[i], buffer_size * getElementSize(engine->getBindingDataType(i)))); if (i == detection_boxes_index) { - detection_boxes = new float[buffer_size]; + detection_boxes.resize(buffer_size); } else if (i == detection_scores_index) { - detection_scores = new float[buffer_size]; + detection_scores.resize(buffer_size); } else if (i == detection_labels_index) { - detection_labels = new int32_t[buffer_size]; + detection_labels.resize(buffer_size); } } /* Dims == > NCHW */ int32_t input_h = engine->getBindingDimensions(0).d[2]; int32_t input_w = engine->getBindingDimensions(0).d[3]; - cudaStream_t stream; /* XXX: 此处应该可以直接声明为类成员变量? */ - CHECK(cudaStreamCreate(&stream)); - cv::resize(image, image, cv::Size(input_w, input_h)); - image.convertTo(image, CV_32FC3); - CHECK(cudaMemcpyAsync(buffers[0], image.data, image.total(), cudaMemcpyHostToDevice, stream)); + cv::Mat tmp; + cv::resize(image, tmp, cv::Size(input_w, input_h)); + tmp.convertTo(tmp, CV_32FC3, 1 / 255.0); + { + /* HWC ==> CHW */ + int offset = 0; + std::vector split_images; + cv::split(tmp, split_images); + for (auto split_image : split_images) { + CHECK(cudaMemcpyAsync( + (float*)(buffers[0]) + offset, + split_image.data, + split_image.total() * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + offset = split_image.total(); + } + } context->enqueueV2(buffers, stream, nullptr); for (int32_t i = 1; i < engine->getNbBindings(); i++) { - size_t buffer_size = batch_size; - // FIXME: 此处如果为 dynamic input,部分形状为 -1 - for (int j = 1; j < engine->getBindingDimensions(i).nbDims; j++) { - buffer_size *= engine->getBindingDimensions(i).d[j]; - } if (i == detection_boxes_index) { CHECK(cudaMemcpyAsync( - detection_boxes, + detection_boxes.data(), buffers[detection_boxes_index], - buffer_size * getElementSize(engine->getBindingDataType(i)), + detection_boxes.size() * getElementSize(engine->getBindingDataType(i)), cudaMemcpyDeviceToHost, stream)); } else if (i == detection_scores_index) { CHECK(cudaMemcpyAsync( - detection_scores, + detection_scores.data(), buffers[detection_scores_index], - buffer_size * getElementSize(engine->getBindingDataType(i)), + detection_scores.size() * getElementSize(engine->getBindingDataType(i)), cudaMemcpyDeviceToHost, stream)); } else if (i == detection_labels_index) { CHECK(cudaMemcpyAsync( - detection_labels, + detection_labels.data(), buffers[detection_labels_index], - buffer_size * getElementSize(engine->getBindingDataType(i)), + detection_labels.size() * getElementSize(engine->getBindingDataType(i)), cudaMemcpyDeviceToHost, stream)); } else if (i == num_detections_index) { CHECK(cudaMemcpyAsync( &num_detections, buffers[num_detections_index], - buffer_size * getElementSize(engine->getBindingDataType(i)), + getElementSize(engine->getBindingDataType(i)), cudaMemcpyDeviceToHost, stream)); } } - cudaStreamDestroy(stream); - for (int i = 0; i < engine->getNbBindings(); ++i) { CHECK(cudaFree(buffers[i])); } + /* Convert box fromat from LTRB to LTWH */ for (int32_t i = 0; i < num_detections; i++) { Detection detection; - detection.box.x = detection_boxes[4 * i]; - detection.box.y = detection_boxes[4 * i + 1]; - detection.box.width = detection_boxes[4 * i + 2]; - detection.box.height = detection_boxes[4 * i + 3]; + detection.box.x = detection_boxes[4 * i] * image.cols / input_w; + detection.box.y = detection_boxes[4 * i + 1] * image.rows / input_h; + detection.box.width = detection_boxes[4 * i + 2] * image.cols / input_w - detection.box.x; + detection.box.height = detection_boxes[4 * i + 3] * image.rows / input_h - detection.box.y; detection.classId = detection_labels[i]; detection.conf = detection_scores[i]; result.push_back(detection); } - delete[] detection_boxes; - delete[] detection_scores; - delete[] detection_labels; - return result; } @@ -329,12 +385,18 @@ int main(int argc, char* argv[]) { cmd.parse_check(argc, argv); std::string imagePath = cmd.get("image"); std::string modelPath = cmd.get("model_path"); + std::string classNamesPath = cmd.get("class_names"); + std::vector classNames = loadNames(classNamesPath); - cv::Mat image = cv::imread(modelPath); + cv::Mat image = cv::imread(imagePath); YOLOv5Detector yolo_detector(modelPath.c_str(), cmd.exist("int8"), cmd.exist("fp16")); std::vector result = yolo_detector.detect(image); std::cout << "Detected " << result.size() << " objects." << std::endl; + visualizeDetection(image, result, classNames); + + cv::imwrite("result.jpg", image); + return 0; }