Skip to content

Commit

Permalink
This is a combination of 5 commits.
Browse files Browse the repository at this point in the history
New PR for "ultralytics#7736"

Remove not use

Format onnxruntime and tensorrt onnx outputs

fix

unified outputs
  • Loading branch information
triple-Mu committed Dec 27, 2022
1 parent 8ca1826 commit b2234c4
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 5 deletions.
83 changes: 78 additions & 5 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,74 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
return f, model_onnx


@try_export
def export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
check_requirements('onnx>=1.12.0')
import onnx

LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

from models.common import End2End
model = End2End(model, *nms_cfg, device=im.device)
b, topk, backend = 'batch', nms_cfg[0], nms_cfg[-1]
output_names = ['num_dets', 'boxes', 'scores', 'labels']
output_shapes = {n: {0: b} for n in output_names}
if dynamic == 'batch':
dynamic_cfg = {'images': {0: b}, **output_shapes}
elif dynamic == 'all':
dynamic_cfg = {'images': {0: b, 2: 'height', 3: 'width'}, **output_shapes}
else:
dynamic_cfg, b = {}, im.shape[0]

torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
f,
verbose=False,
opset_version=opset,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_cfg)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model

# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)

# Fix shape info for onnx using by TensorRT
if backend == 'trt':
shapes = [b, 1, b, topk, 4, b, topk, b, topk]
else:
shapes = [b, 1, b, 'topk', 4, b, 'topk', b, 'topk']
for i in model_onnx.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
onnx.save(model_onnx, f)

# Simplify
if simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim

LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
return f, model_onnx


@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
Expand Down Expand Up @@ -505,7 +573,7 @@ def run(
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
nms=False, # ONNX/TF/TensorRT: NMS config for model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
Expand Down Expand Up @@ -560,9 +628,9 @@ def run(
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
if not nms and onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if xml: # OpenVINO
if not nms and xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
f[4], _ = export_coreml(model, im, file, int8, half)
Expand Down Expand Up @@ -592,6 +660,11 @@ def run(
if paddle: # PaddlePaddle
f[10], _ = export_paddle(model, im, file, metadata)

if nms and (onnx or xml):
nms_cfg = [topk_all, iou_thres, conf_thres, nms]
f.append(export_onnx_with_nms(model, im, file, opset, nms_cfg, dynamic, simplify)[0])
if xml:
f.append(export_openvino(file.with_suffix('.pt'), metadata, half)[0])
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
Expand Down Expand Up @@ -622,12 +695,12 @@ def parse_opt():
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--dynamic', nargs='?', const='all', default=False, help='ONNX/TF/TensorRT: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
parser.add_argument('--nms', nargs='?', const=True, default=False, help='ONNX/TF/TensorRT: NMS config for model')
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
Expand Down
179 changes: 179 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import math
import platform
import random
import warnings
import zipfile
from collections import OrderedDict, namedtuple
Expand Down Expand Up @@ -858,3 +859,181 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))


class ORT_NMS(torch.autograd.Function):

@staticmethod
def forward(ctx,
boxes,
scores,
max_output_boxes_per_class=torch.tensor([100]),
iou_threshold=torch.tensor([0.45]),
score_threshold=torch.tensor([0.25])):
device = boxes.device
batch = scores.shape[0]
num_det = random.randint(0, 100)
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices

@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)


class TRT_NMS(torch.autograd.Function):

@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)

return num_det, det_boxes, det_scores, det_classes

@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes


class ONNX_ORT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
self.iou_threshold = torch.tensor([iou_thres]).to(device)
self.score_threshold = torch.tensor([score_thres]).to(device)
self.max_wh = 7680
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)

def forward(self, x):
batch, anchors, _ = x.shape
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf

nms_box = boxes @ self.convert_matrix
nms_score = scores.transpose(1, 2).contiguous()

selected_indices = ORT_NMS.apply(nms_box, nms_score, self.max_obj, self.iou_threshold, self.score_threshold)
batch_inds, cls_inds, box_inds = selected_indices.unbind(1)
selected_score = nms_score[batch_inds, cls_inds, box_inds].unsqueeze(1)
selected_box = nms_box[batch_inds, box_inds, ...]

dets = torch.cat([selected_box, selected_score], dim=1)

batched_dets = dets.unsqueeze(0).repeat(batch, 1, 1)
batch_template = torch.arange(0, batch, dtype=batch_inds.dtype, device=batch_inds.device)
batched_dets = batched_dets.where((batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
batched_dets.new_zeros(1))

batched_labels = cls_inds.unsqueeze(0).repeat(batch, 1)
batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)),
batched_labels.new_ones(1) * -1)

N = batched_dets.shape[0]

batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), 1)
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones((N, 1))), 1)

_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)

topk_batch_inds = torch.arange(batch, dtype=topk_inds.dtype, device=topk_inds.device).view(-1, 1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
labels = batched_labels[topk_batch_inds, topk_inds, ...]
boxes, scores = batched_dets.split((4, 1), -1)
scores = scores.squeeze(-1)
num_dets = (scores > 0).sum(1, keepdim=True)
return num_dets, boxes, scores, labels


class ONNX_TRT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres

def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
num_dets, boxes, scores, labels = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj, self.plugin_version,
self.score_activation, self.score_threshold)
return num_dets, boxes, scores, labels


class End2End(nn.Module):

def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, backend='ort', device=None):
super().__init__()
device = device if device else torch.device('cpu')
self.model = model.to(device)

if backend == 'trt':
self.patch_model = ONNX_TRT
elif backend == 'ort':
self.patch_model = ONNX_ORT
elif backend == 'ovo':
self.patch_model = ONNX_ORT
else:
raise NotImplementedError
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, device)
self.end2end.eval()
self.stride = self.model.stride
self.names = self.model.names

def forward(self, x):
x = self.model(x)[0]
x = self.end2end(x)
return x

0 comments on commit b2234c4

Please sign in to comment.