Skip to content

Commit

Permalink
Complete TensorRT C++ example! (#257)
Browse files Browse the repository at this point in the history
* Update TRT inference examples

* Save image instead of show image

* Add readme
  • Loading branch information
ShiquanYu authored Dec 27, 2021
1 parent 6d10759 commit 7d08c9c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 41 deletions.
33 changes: 33 additions & 0 deletions deployment/tensorrt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# TensorRT Inference

The TensorRT inference for `yolort`, support GPU only.

## Dependencies

- TensorRT 8.x

## Usage

1. Create build director and cmake config.

```bash
mkdir -p build/ && cd build/
cmake .. -DTENSORRT_DIR=${your_trt_install_director}
```

1. Build project

```bash
make
```

1. Export your custom model to ONNX(see [onnx-graphsurgeon-inference-tensorrt](https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb)).

1. Now, you can infer your own images.

```bash
./yolort_trt [--image ../../../test/assets/zidane.jpg]
[--model_path ../../../notebooks/yolov5s.onnx]
[--class_names ../../../notebooks/assets/coco.names]
[--fp16] # Enable it if your GPU support fp16 inference
```
144 changes: 103 additions & 41 deletions deployment/tensorrt/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,47 @@ inline size_t getElementSize(nvinfer1::DataType t) noexcept {
return 0;
}

void visualizeDetection(
cv::Mat& image,
std::vector<Detection>& detections,
const std::vector<std::string>& classNames) {
for (const Detection& detection : detections) {
cv::rectangle(image, detection.box, cv::Scalar(229, 160, 21), 2);

int x = detection.box.x;
int y = detection.box.y;

int conf = (int)(detection.conf * 100);
int classId = detection.classId;
std::string label = classNames[classId] + " 0." + std::to_string(conf);

int baseline = 0;
cv::Size size = cv::getTextSize(label, cv::FONT_ITALIC, 0.8, 2, &baseline);
cv::rectangle(
image, cv::Point(x, y - 25), cv::Point(x + size.width, y), cv::Scalar(229, 160, 21), -1);

cv::putText(
image, label, cv::Point(x, y - 3), cv::FONT_ITALIC, 0.8, cv::Scalar(255, 255, 255), 2);
}
}

std::vector<std::string> loadNames(const std::string& path) {
// load class names
std::vector<std::string> classNames;
std::ifstream infile(path);
if (infile.good()) {
std::string line;
while (getline(infile, line)) {
classNames.emplace_back(line);
}
infile.close();
} else {
std::cerr << "ERROR: Failed to access class name path: " << path << std::endl;
}

return classNames;
}

ICudaEngine* CreateCudaEngineFromOnnx(
MyLogger& logger,
const char* onnx_path,
Expand Down Expand Up @@ -183,7 +224,7 @@ class YOLOv5Detector {
int max_batch_size = 1,
bool enable_int8 = false,
bool enable_fp16 = false);
virtual ~YOLOv5Detector() = default;
virtual ~YOLOv5Detector();
YOLOv5Detector(const YOLOv5Detector&) = delete;
YOLOv5Detector& operator=(const YOLOv5Detector&) = delete;

Expand All @@ -194,6 +235,7 @@ class YOLOv5Detector {
MyLogger logger;
std::unique_ptr<ICudaEngine> engine;
std::unique_ptr<IExecutionContext> context;
cudaStream_t stream;
}; /* class YOLOv5Detector */

YOLOv5Detector::YOLOv5Detector(
Expand All @@ -203,7 +245,15 @@ YOLOv5Detector::YOLOv5Detector(
bool enable_fp16)
: engine(
{CreateCudaEngineFromOnnx(logger, model_path, max_batch_size, enable_int8, enable_fp16)}),
context({engine->createExecutionContext()}) {}
context({engine->createExecutionContext()}) {
CHECK(cudaStreamCreate(&stream));
}

YOLOv5Detector::~YOLOv5Detector() {
if (stream) {
CHECK(cudaStreamDestroy(stream));
}
}

std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
std::vector<Detection> result;
Expand All @@ -212,12 +262,16 @@ std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
int num_detections_index = engine->getBindingIndex("num_detections");
int detection_boxes_index = engine->getBindingIndex("detection_boxes");
int detection_scores_index = engine->getBindingIndex("detection_scores");
int detection_labels_index = engine->getBindingIndex("detection_labels");
int detection_labels_index = engine->getBindingIndex("detection_classes");

int32_t num_detections = 0;
float* detection_boxes = nullptr;
float* detection_scores = nullptr;
int32_t* detection_labels = nullptr;
std::vector<float> detection_boxes;
std::vector<float> detection_scores;
std::vector<float> detection_labels;

Dims dim = engine->getBindingDimensions(0);
dim.d[0] = batch_size;
context->setBindingDimensions(0, dim);

for (int32_t i = 0; i < engine->getNbBindings(); i++) {
{
Expand All @@ -231,90 +285,92 @@ std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
std::cout << ")"
<< "}" << std::endl;
}
Dims dim = engine->getBindingDimensions(i);
size_t buffer_size = batch_size;
// FIXME: 此处如果为 dynamic input,部分形状为 -1
for (int j = 1; j < engine->getBindingDimensions(i).nbDims; j++) {

size_t buffer_size = 1;
for (int j = 0; j < engine->getBindingDimensions(i).nbDims; j++) {
buffer_size *= engine->getBindingDimensions(i).d[j];
}
CHECK(cudaMalloc(&buffers[i], buffer_size * getElementSize(engine->getBindingDataType(i))));
if (i == detection_boxes_index) {
detection_boxes = new float[buffer_size];
detection_boxes.resize(buffer_size);
} else if (i == detection_scores_index) {
detection_scores = new float[buffer_size];
detection_scores.resize(buffer_size);
} else if (i == detection_labels_index) {
detection_labels = new int32_t[buffer_size];
detection_labels.resize(buffer_size);
}
}

/* Dims == > NCHW */
int32_t input_h = engine->getBindingDimensions(0).d[2];
int32_t input_w = engine->getBindingDimensions(0).d[3];
cudaStream_t stream; /* XXX: 此处应该可以直接声明为类成员变量? */
CHECK(cudaStreamCreate(&stream));
cv::resize(image, image, cv::Size(input_w, input_h));
image.convertTo(image, CV_32FC3);
CHECK(cudaMemcpyAsync(buffers[0], image.data, image.total(), cudaMemcpyHostToDevice, stream));
cv::Mat tmp;
cv::resize(image, tmp, cv::Size(input_w, input_h));
tmp.convertTo(tmp, CV_32FC3, 1 / 255.0);
{
/* HWC ==> CHW */
int offset = 0;
std::vector<cv::Mat> split_images;
cv::split(tmp, split_images);
for (auto split_image : split_images) {
CHECK(cudaMemcpyAsync(
(float*)(buffers[0]) + offset,
split_image.data,
split_image.total() * sizeof(float),
cudaMemcpyHostToDevice,
stream));
offset = split_image.total();
}
}
context->enqueueV2(buffers, stream, nullptr);

for (int32_t i = 1; i < engine->getNbBindings(); i++) {
size_t buffer_size = batch_size;
// FIXME: 此处如果为 dynamic input,部分形状为 -1
for (int j = 1; j < engine->getBindingDimensions(i).nbDims; j++) {
buffer_size *= engine->getBindingDimensions(i).d[j];
}
if (i == detection_boxes_index) {
CHECK(cudaMemcpyAsync(
detection_boxes,
detection_boxes.data(),
buffers[detection_boxes_index],
buffer_size * getElementSize(engine->getBindingDataType(i)),
detection_boxes.size() * getElementSize(engine->getBindingDataType(i)),
cudaMemcpyDeviceToHost,
stream));
} else if (i == detection_scores_index) {
CHECK(cudaMemcpyAsync(
detection_scores,
detection_scores.data(),
buffers[detection_scores_index],
buffer_size * getElementSize(engine->getBindingDataType(i)),
detection_scores.size() * getElementSize(engine->getBindingDataType(i)),
cudaMemcpyDeviceToHost,
stream));
} else if (i == detection_labels_index) {
CHECK(cudaMemcpyAsync(
detection_labels,
detection_labels.data(),
buffers[detection_labels_index],
buffer_size * getElementSize(engine->getBindingDataType(i)),
detection_labels.size() * getElementSize(engine->getBindingDataType(i)),
cudaMemcpyDeviceToHost,
stream));
} else if (i == num_detections_index) {
CHECK(cudaMemcpyAsync(
&num_detections,
buffers[num_detections_index],
buffer_size * getElementSize(engine->getBindingDataType(i)),
getElementSize(engine->getBindingDataType(i)),
cudaMemcpyDeviceToHost,
stream));
}
}

cudaStreamDestroy(stream);

for (int i = 0; i < engine->getNbBindings(); ++i) {
CHECK(cudaFree(buffers[i]));
}

/* Convert box fromat from LTRB to LTWH */
for (int32_t i = 0; i < num_detections; i++) {
Detection detection;
detection.box.x = detection_boxes[4 * i];
detection.box.y = detection_boxes[4 * i + 1];
detection.box.width = detection_boxes[4 * i + 2];
detection.box.height = detection_boxes[4 * i + 3];
detection.box.x = detection_boxes[4 * i] * image.cols / input_w;
detection.box.y = detection_boxes[4 * i + 1] * image.rows / input_h;
detection.box.width = detection_boxes[4 * i + 2] * image.cols / input_w - detection.box.x;
detection.box.height = detection_boxes[4 * i + 3] * image.rows / input_h - detection.box.y;
detection.classId = detection_labels[i];
detection.conf = detection_scores[i];
result.push_back(detection);
}

delete[] detection_boxes;
delete[] detection_scores;
delete[] detection_labels;

return result;
}

Expand All @@ -329,12 +385,18 @@ int main(int argc, char* argv[]) {
cmd.parse_check(argc, argv);
std::string imagePath = cmd.get<std::string>("image");
std::string modelPath = cmd.get<std::string>("model_path");
std::string classNamesPath = cmd.get<std::string>("class_names");
std::vector<std::string> classNames = loadNames(classNamesPath);

cv::Mat image = cv::imread(modelPath);
cv::Mat image = cv::imread(imagePath);
YOLOv5Detector yolo_detector(modelPath.c_str(), cmd.exist("int8"), cmd.exist("fp16"));
std::vector<Detection> result = yolo_detector.detect(image);

std::cout << "Detected " << result.size() << " objects." << std::endl;

visualizeDetection(image, result, classNames);

cv::imwrite("result.jpg", image);

return 0;
}

0 comments on commit 7d08c9c

Please sign in to comment.