From 77f295a9695f8b124a177fd221cc89e55d56f106 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Mon, 26 Jul 2021 12:52:07 -0400 Subject: [PATCH] Adding onnx export tools --- deployment/ncnn/main.cpp | 6 +- tools/__init__.py | 0 tools/export_onnx.py | 49 +++++++++++++++ tools/yolort_deploy_friendly.py | 105 ++++++++++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 tools/__init__.py create mode 100644 tools/export_onnx.py create mode 100644 tools/yolort_deploy_friendly.py diff --git a/deployment/ncnn/main.cpp b/deployment/ncnn/main.cpp index 0c2f6cc5..cac8f031 100644 --- a/deployment/ncnn/main.cpp +++ b/deployment/ncnn/main.cpp @@ -1,6 +1,7 @@ -// This file is wirtten base on the following file: -// https://github.com/Tencent/ncnn/blob/master/examples/yolov5.cpp +// Tencent is pleased to support the open source community by making ncnn available. +// // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // @@ -10,7 +11,6 @@ // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -// ------------------------------------------------------------------------------ #include "layer.h" #include "net.h" diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/export_onnx.py b/tools/export_onnx.py new file mode 100644 index 00000000..3c9a9719 --- /dev/null +++ b/tools/export_onnx.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +import argparse +import torch +from .yolort_deploy_friendly import yolov5_deploy_friendly + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, default='./yolov5s.pt', + help='weights path') + parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], + help='image (height, width)') + parser.add_argument('--num_classes', type=int, default=80, + help='number of classes') + parser.add_argument('--batch_size', type=int, default=1, + help='batch size') + parser.add_argument('--device', default='cpu', + help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--half', action='store_true', + help='FP16 half-precision export') + parser.add_argument('--dynamic', action='store_true', + help='ONNX: dynamic axes') + parser.add_argument('--simplify', action='store_true', + help='ONNX: simplify model') + parser.add_argument('--opset', type=int, default=11, + help='ONNX: opset version') + return parser + + +def cli_main(): + parser = get_parser() + args = parser.parse_args() + print(args) + export_onnx(args) + + +def export_onnx(args): + + model = yolov5_deploy_friendly( + pretrained=True, + num_classes=args.num_classes, + ) + inputs = torch.rand(args.batch_size, 3, 320, 320) + outputs = model(inputs) + print(outputs.shape) + + +if __name__ == "__main__": + cli_main() diff --git a/tools/yolort_deploy_friendly.py b/tools/yolort_deploy_friendly.py new file mode 100644 index 00000000..1c379e24 --- /dev/null +++ b/tools/yolort_deploy_friendly.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +import torch +from torch import nn, Tensor + +from torchvision.models.utils import load_state_dict_from_url + +from yolort.models.backbone_utils import darknet_pan_backbone +from yolort.models.anchor_utils import AnchorGenerator +from yolort.models.box_head import YOLOHead + +from typing import Any, List, Optional + + +def yolov5_deploy_friendly( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 80, + **kwargs: Any, +): + r"""yolov5 small release 4.0 model from + `"ultralytics/yolov5" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + backbone_name = 'darknet_s_r4_0' + depth_multiple = 0.33 + width_multiple = 0.5 + version = 'r4.0' + backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple, version=version) + + model = YOLODeployFriendly(backbone, num_classes, **kwargs) + + if pretrained: + model_urls_root = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0' + model_url = f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt' + state_dict = load_state_dict_from_url(model_url, progress=progress) + model.load_state_dict(state_dict) + + return model + + +class YOLODeployFriendly(nn.Module): + """ + Deployment Friendly Wrapper of YOLO. + """ + def __init__( + self, + backbone: nn.Module, + num_classes: int, + # Anchor parameters + anchor_grids: Optional[List[List[float]]] = None, + anchor_generator: Optional[nn.Module] = None, + head: Optional[nn.Module] = None, + ): + super().__init__() + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + self.backbone = backbone + + strides = [8, 16, 32] + + if anchor_grids is None: + anchor_grids = [ + [10, 13, 16, 30, 33, 23], + [30, 61, 62, 45, 59, 119], + [116, 90, 156, 198, 373, 326], + ] + + if anchor_generator is None: + anchor_generator = AnchorGenerator(strides, anchor_grids) + self.anchor_generator = anchor_generator + + if head is None: + head = YOLOHead( + backbone.out_channels, + anchor_generator.num_anchors, + anchor_generator.strides, + num_classes, + ) + self.head = head + + def forward(self, samples: Tensor): + """ + Arguments: + samples (Tensor): batched images, of shape [batch_size x 3 x H x W] + """ + # get the features from the backbone + features = self.backbone(samples) + + # compute the yolo heads outputs using the features + head_outputs = self.head(features) + + all_pred_logits = [] + batch_size, _, _, _, K = head_outputs[0].shape + + for pred_logits in head_outputs: + pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(NN, HWA, K) + all_pred_logits.append(pred_logits) + + all_pred_logits = torch.cat(all_pred_logits, dim=1) + return all_pred_logits