Skip to content

Commit

Permalink
Adding onnx export tools
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Jul 26, 2021
1 parent 389b949 commit e333d9e
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
Empty file added tools/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tools/export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import argparse
import torch
from .yolo_deploy_friendly import yolov5_deploy_friendly


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt',
help='weights path')
parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640],
help='image (height, width)')
parser.add_argument('--num_classes', type=int, default=80,
help='number of classes')
parser.add_argument('--batch_size', type=int, default=1,
help='batch size')
parser.add_argument('--device', default='cpu',
help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true',
help='FP16 half-precision export')
parser.add_argument('--dynamic', action='store_true',
help='ONNX: dynamic axes')
parser.add_argument('--simplify', action='store_true',
help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=11,
help='ONNX: opset version')
return parser


def cli_main():
parser = get_parser()
args = parser.parse_args()
print(args)
export_onnx(args)


def export_onnx(args):

model = yolov5_deploy_friendly(
pretrained=True,
num_classes=args.num_classes,
)
inputs = torch.rand(args.batch_size, 3, 320, 320)
outputs = model(inputs)
print(outputs.shape)


if __name__ == "__main__":
cli_main()
105 changes: 105 additions & 0 deletions tools/yolo_deploy_friendly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch
from torch import nn, Tensor

from torchvision.models.utils import load_state_dict_from_url

from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.box_head import YOLOHead

from typing import Any, List, Optional


def yolov5_deploy_friendly(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
**kwargs: Any,
):
r"""yolov5 small release 4.0 model from
`"ultralytics/yolov5" <https://zenodo.org/badge/latestdoi/264818686>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
backbone_name = 'darknet_s_r4_0'
depth_multiple = 0.33
width_multiple = 0.5
version = 'r4.0'
backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple, version=version)

model = YOLODeployFriendly(backbone, num_classes, **kwargs)

if pretrained:
model_urls_root = 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0'
model_url = f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt'
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)

return model


class YOLODeployFriendly(nn.Module):
"""
Deployment Friendly Wrapper of YOLO.
"""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
# Anchor parameters
anchor_grids: Optional[List[List[float]]] = None,
anchor_generator: Optional[nn.Module] = None,
head: Optional[nn.Module] = None,
):
super().__init__()
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
self.backbone = backbone

strides = [8, 16, 32]

if anchor_grids is None:
anchor_grids = [
[10, 13, 16, 30, 33, 23],
[30, 61, 62, 45, 59, 119],
[116, 90, 156, 198, 373, 326],
]

if anchor_generator is None:
anchor_generator = AnchorGenerator(strides, anchor_grids)
self.anchor_generator = anchor_generator

if head is None:
head = YOLOHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
self.head = head

def forward(self, samples: Tensor):
"""
Arguments:
samples (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# get the features from the backbone
features = self.backbone(samples)

# compute the yolo heads outputs using the features
head_outputs = self.head(features)

all_pred_logits = []
batch_size, _, _, _, K = head_outputs[0].shape

for pred_logits in head_outputs:
pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(NN, HWA, K)
all_pred_logits.append(pred_logits)

all_pred_logits = torch.cat(all_pred_logits, dim=1)
return all_pred_logits

0 comments on commit e333d9e

Please sign in to comment.