forked from PaddlePaddle/PaddleDetection
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TensorRT nms plugin for end2end ppyoloe detection (PaddlePaddle#6348
- Loading branch information
Showing
4 changed files
with
488 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Export ONNX Model | ||
## Download pretrain paddle models | ||
|
||
* [ppyoloe-s](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams) | ||
* [ppyoloe-m](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_m_300e_coco.pdparams) | ||
* [ppyoloe-l](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams) | ||
* [ppyoloe-x](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_x_300e_coco.pdparams) | ||
* [ppyoloe-s-400e](https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_400e_coco.pdparams) | ||
|
||
|
||
## Export paddle model for deploying | ||
|
||
```shell | ||
python ./tools/export_model.py \ | ||
-c configs/ppyoloe/ppyoloe_crn_s_300e_coco.yml \ | ||
-o weights=ppyoloe_crn_s_300e_coco.pdparams \ | ||
trt=True \ | ||
exclude_nms=True \ | ||
TestReader.inputs_def.image_shape=[3,640,640] \ | ||
--output_dir ./ | ||
|
||
# if you want to try ppyoloe-s-400e model | ||
python ./tools/export_model.py \ | ||
-c configs/ppyoloe/ppyoloe_crn_s_400e_coco.yml \ | ||
-o weights=ppyoloe_crn_s_400e_coco.pdparams \ | ||
trt=True \ | ||
exclude_nms=True \ | ||
TestReader.inputs_def.image_shape=[3,640,640] \ | ||
--output_dir ./ | ||
``` | ||
|
||
## Check requirements | ||
```shell | ||
pip install onnx>=1.10.0 | ||
pip install paddle2onnx | ||
pip install onnx-simplifier | ||
pip install onnx-graphsurgeon --index-url https://pypi.ngc.nvidia.com | ||
# if use cuda-python infer, please install it | ||
pip install cuda-python | ||
# if use cupy infer, please install it | ||
pip install cupy-cuda117 # cuda110-cuda117 are all available | ||
``` | ||
|
||
## Export script | ||
```shell | ||
python ./deploy/end2end_ppyoloe/end2end.py \ | ||
--model-dir ppyoloe_crn_s_300e_coco \ | ||
--save-file ppyoloe_crn_s_300e_coco.onnx \ | ||
--opset 11 \ | ||
--batch-size 1 \ | ||
--topk-all 100 \ | ||
--iou-thres 0.6 \ | ||
--conf-thres 0.4 | ||
# if you want to try ppyoloe-s-400e model | ||
python ./deploy/end2end_ppyoloe/end2end.py \ | ||
--model-dir ppyoloe_crn_s_400e_coco \ | ||
--save-file ppyoloe_crn_s_400e_coco.onnx \ | ||
--opset 11 \ | ||
--batch-size 1 \ | ||
--topk-all 100 \ | ||
--iou-thres 0.6 \ | ||
--conf-thres 0.4 | ||
``` | ||
#### Description of all arguments | ||
|
||
- `--model-dir` : the path of ppyoloe export dir. | ||
- `--save-file` : the path of export onnx. | ||
- `--opset` : onnx opset version. | ||
- `--img-size` : image size for exporting ppyoloe. | ||
- `--batch-size` : batch size for exporting ppyoloe. | ||
- `--topk-all` : topk objects for every image. | ||
- `--iou-thres` : iou threshold for NMS algorithm. | ||
- `--conf-thres` : confidence threshold for NMS algorithm. | ||
|
||
### TensorRT backend (TensorRT version>= 8.0.0) | ||
#### TensorRT engine export | ||
``` shell | ||
/path/to/trtexec \ | ||
--onnx=ppyoloe_crn_s_300e_coco.onnx \ | ||
--saveEngine=ppyoloe_crn_s_300e_coco.engine \ | ||
--fp16 # if export TensorRT fp16 model | ||
# if you want to try ppyoloe-s-400e model | ||
/path/to/trtexec \ | ||
--onnx=ppyoloe_crn_s_400e_coco.onnx \ | ||
--saveEngine=ppyoloe_crn_s_400e_coco.engine \ | ||
--fp16 # if export TensorRT fp16 model | ||
``` | ||
#### TensorRT image infer | ||
|
||
``` shell | ||
# cuda-python infer script | ||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_300e_coco.engine | ||
# cupy infer script | ||
python ./deploy/end2end_ppyoloe/cupy-python.py ppyoloe_crn_s_300e_coco.engine | ||
# if you want to try ppyoloe-s-400e model | ||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine | ||
# or | ||
python ./deploy/end2end_ppyoloe/cuda-python.py ppyoloe_crn_s_400e_coco.engine | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import sys | ||
import requests | ||
import cv2 | ||
import random | ||
import time | ||
import numpy as np | ||
import tensorrt as trt | ||
from cuda import cudart | ||
from pathlib import Path | ||
from collections import OrderedDict, namedtuple | ||
|
||
|
||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32): | ||
# Resize and pad image while meeting stride-multiple constraints | ||
shape = im.shape[:2] # current shape [height, width] | ||
if isinstance(new_shape, int): | ||
new_shape = (new_shape, new_shape) | ||
|
||
# Scale ratio (new / old) | ||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) | ||
if not scaleup: # only scale down, do not scale up (for better val mAP) | ||
r = min(r, 1.0) | ||
|
||
# Compute padding | ||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) | ||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding | ||
|
||
if auto: # minimum rectangle | ||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding | ||
|
||
dw /= 2 # divide padding into 2 sides | ||
dh /= 2 | ||
|
||
if shape[::-1] != new_unpad: # resize | ||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) | ||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) | ||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) | ||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border | ||
return im, r, (dw, dh) | ||
|
||
|
||
w = Path(sys.argv[1]) | ||
|
||
assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path' | ||
|
||
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | ||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | ||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | ||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | ||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | ||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | ||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | ||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | ||
'hair drier', 'toothbrush'] | ||
colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)} | ||
|
||
url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg' | ||
file = requests.get(url) | ||
img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1) | ||
|
||
_, stream = cudart.cudaStreamCreate() | ||
|
||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1) | ||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1) | ||
|
||
# Infer TensorRT Engine | ||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | ||
logger = trt.Logger(trt.Logger.ERROR) | ||
trt.init_libnvinfer_plugins(logger, namespace="") | ||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | ||
model = runtime.deserialize_cuda_engine(f.read()) | ||
bindings = OrderedDict() | ||
fp16 = False # default updated below | ||
for index in range(model.num_bindings): | ||
name = model.get_binding_name(index) | ||
dtype = trt.nptype(model.get_binding_dtype(index)) | ||
shape = tuple(model.get_binding_shape(index)) | ||
data = np.empty(shape, dtype=np.dtype(dtype)) | ||
_, data_ptr = cudart.cudaMallocAsync(data.nbytes, stream) | ||
bindings[name] = Binding(name, dtype, shape, data, data_ptr) | ||
if model.binding_is_input(index) and dtype == np.float16: | ||
fp16 = True | ||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | ||
context = model.create_execution_context() | ||
|
||
image = img.copy() | ||
image, ratio, dwdh = letterbox(image, auto=False) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
image_copy = image.copy() | ||
|
||
image = image.transpose((2, 0, 1)) | ||
image = np.expand_dims(image, 0) | ||
image = np.ascontiguousarray(image) | ||
|
||
im = image.astype(np.float32) | ||
im /= 255 | ||
im -= mean | ||
im /= std | ||
|
||
_, image_ptr = cudart.cudaMallocAsync(im.nbytes, stream) | ||
cudart.cudaMemcpyAsync(image_ptr, im.ctypes.data, im.nbytes, | ||
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream) | ||
|
||
# warmup for 10 times | ||
for _ in range(10): | ||
tmp = np.random.randn(1, 3, 640, 640).astype(np.float32) | ||
_, tmp_ptr = cudart.cudaMallocAsync(tmp.nbytes, stream) | ||
binding_addrs['image'] = tmp_ptr | ||
context.execute_v2(list(binding_addrs.values())) | ||
|
||
start = time.perf_counter() | ||
binding_addrs['image'] = image_ptr | ||
context.execute_v2(list(binding_addrs.values())) | ||
print(f'Cost {(time.perf_counter() - start) * 1000}ms') | ||
|
||
nums = bindings['num_dets'].data | ||
boxes = bindings['det_boxes'].data | ||
scores = bindings['det_scores'].data | ||
classes = bindings['det_classes'].data | ||
|
||
cudart.cudaMemcpyAsync(nums.ctypes.data, | ||
bindings['num_dets'].ptr, | ||
nums.nbytes, | ||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, | ||
stream) | ||
cudart.cudaMemcpyAsync(boxes.ctypes.data, | ||
bindings['det_boxes'].ptr, | ||
boxes.nbytes, | ||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, | ||
stream) | ||
cudart.cudaMemcpyAsync(scores.ctypes.data, | ||
bindings['det_scores'].ptr, | ||
scores.nbytes, | ||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, | ||
stream) | ||
cudart.cudaMemcpyAsync(classes.ctypes.data, | ||
bindings['det_classes'].ptr, | ||
classes.data.nbytes, | ||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, | ||
stream) | ||
|
||
cudart.cudaStreamSynchronize(stream) | ||
cudart.cudaStreamDestroy(stream) | ||
|
||
for i in binding_addrs.values(): | ||
cudart.cudaFree(i) | ||
|
||
num = int(nums[0][0]) | ||
box_img = boxes[0, :num].round().astype(np.int32) | ||
score_img = scores[0, :num] | ||
clss_img = classes[0, :num] | ||
for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)): | ||
name = names[int(clss)] | ||
color = colors[name] | ||
cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2) | ||
cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX, | ||
0.75, [225, 255, 255], thickness=2) | ||
|
||
cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR)) | ||
cv2.waitKey(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import sys | ||
import requests | ||
import cv2 | ||
import random | ||
import time | ||
import numpy as np | ||
import cupy as cp | ||
import tensorrt as trt | ||
from PIL import Image | ||
from collections import OrderedDict, namedtuple | ||
from pathlib import Path | ||
|
||
|
||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32): | ||
# Resize and pad image while meeting stride-multiple constraints | ||
shape = im.shape[:2] # current shape [height, width] | ||
if isinstance(new_shape, int): | ||
new_shape = (new_shape, new_shape) | ||
|
||
# Scale ratio (new / old) | ||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) | ||
if not scaleup: # only scale down, do not scale up (for better val mAP) | ||
r = min(r, 1.0) | ||
|
||
# Compute padding | ||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) | ||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding | ||
|
||
if auto: # minimum rectangle | ||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding | ||
|
||
dw /= 2 # divide padding into 2 sides | ||
dh /= 2 | ||
|
||
if shape[::-1] != new_unpad: # resize | ||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) | ||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) | ||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) | ||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border | ||
return im, r, (dw, dh) | ||
|
||
|
||
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | ||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | ||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | ||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | ||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | ||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | ||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | ||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | ||
'hair drier', 'toothbrush'] | ||
colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)} | ||
|
||
url = 'https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg' | ||
file = requests.get(url) | ||
img = cv2.imdecode(np.frombuffer(file.content, np.uint8), 1) | ||
|
||
w = Path(sys.argv[1]) | ||
|
||
assert w.exists() and w.suffix in ('.engine', '.plan'), 'Wrong engine path' | ||
|
||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1) | ||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1) | ||
|
||
mean = cp.asarray(mean) | ||
std = cp.asarray(std) | ||
|
||
# Infer TensorRT Engine | ||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | ||
logger = trt.Logger(trt.Logger.INFO) | ||
trt.init_libnvinfer_plugins(logger, namespace="") | ||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | ||
model = runtime.deserialize_cuda_engine(f.read()) | ||
bindings = OrderedDict() | ||
fp16 = False # default updated below | ||
for index in range(model.num_bindings): | ||
name = model.get_binding_name(index) | ||
dtype = trt.nptype(model.get_binding_dtype(index)) | ||
shape = tuple(model.get_binding_shape(index)) | ||
data = cp.empty(shape, dtype=cp.dtype(dtype)) | ||
bindings[name] = Binding(name, dtype, shape, data, int(data.data.ptr)) | ||
if model.binding_is_input(index) and dtype == np.float16: | ||
fp16 = True | ||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | ||
context = model.create_execution_context() | ||
|
||
image = img.copy() | ||
image, ratio, dwdh = letterbox(image, auto=False) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
image_copy = image.copy() | ||
|
||
image = image.transpose((2, 0, 1)) | ||
image = np.expand_dims(image, 0) | ||
image = np.ascontiguousarray(image) | ||
|
||
im = cp.asarray(image) | ||
im = im.astype(cp.float32) | ||
im /= 255 | ||
im -= mean | ||
im /= std | ||
|
||
# warmup for 10 times | ||
for _ in range(10): | ||
tmp = cp.random.randn(1, 3, 640, 640).astype(cp.float32) | ||
binding_addrs['image'] = int(tmp.data.ptr) | ||
context.execute_v2(list(binding_addrs.values())) | ||
|
||
start = time.perf_counter() | ||
binding_addrs['image'] = int(im.data.ptr) | ||
context.execute_v2(list(binding_addrs.values())) | ||
print(f'Cost {(time.perf_counter() - start) * 1000}ms') | ||
|
||
nums = bindings['num_dets'].data | ||
boxes = bindings['det_boxes'].data | ||
scores = bindings['det_scores'].data | ||
classes = bindings['det_classes'].data | ||
|
||
num = int(nums[0][0]) | ||
box_img = boxes[0, :num].round().astype(cp.int32) | ||
score_img = scores[0, :num] | ||
clss_img = classes[0, :num] | ||
for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)): | ||
name = names[int(clss)] | ||
color = colors[name] | ||
cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2) | ||
cv2.putText(image_copy, name, (int(box[0]), int(box[1]) - 2), cv2.FONT_HERSHEY_SIMPLEX, | ||
0.75, [225, 255, 255], thickness=2) | ||
|
||
cv2.imshow('Result', cv2.cvtColor(image_copy, cv2.COLOR_RGB2BGR)) | ||
cv2.waitKey(0) |
Oops, something went wrong.