From 43bb3e5b390a9d5e7c75733bb9ce79f4b7721f39 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 18 Feb 2021 02:53:44 +0800 Subject: [PATCH 1/5] Bump version to 0.4.0 (#64) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ed2aede2..816bc983 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ from setuptools import setup, find_packages PATH_ROOT = Path(__file__).parent.resolve() -VERSION = "0.3.0rc1" +VERSION = "0.4.0a0" PACKAGE_NAME = 'yolort' sha = 'Unknown' From 70a4165137378750c81cb9dff91384748c708602 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 19 Feb 2021 15:56:45 +0800 Subject: [PATCH 2/5] Add slack channel link (#65) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4876063f..0a9c14a5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![CI testing](https://github.com/zhiqwang/yolov5-rt-stack/workflows/CI%20testing/badge.svg)](https://github.com/zhiqwang/yolov5-rt-stack/actions?query=workflow%3A%22CI+testing%22) [![PyPI version](https://badge.fury.io/py/yolort.svg)](https://badge.fury.io/py/yolort) [![codecov](https://codecov.io/gh/zhiqwang/yolov5-rt-stack/branch/master/graph/badge.svg?token=1GX96EA72Y)](https://codecov.io/gh/zhiqwang/yolov5-rt-stack) +[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/yolort/shared_invite/zt-mqwc7235-940aAh8IaKYeWclrJx10SA) **What it is.** Yet another implementation of Ultralytics's [yolov5](https://github.com/ultralytics/yolov5), and with modules refactoring to make it available in deployment backends such as `libtorch`, `onnxruntime`, `tvm` and so on. @@ -41,7 +42,7 @@ There are no extra compiled components in `yolort` and package dependencies are Or from Source ```bash - # clone flash repository locally + # clone yolort repository locally git clone https://github.com/zhiqwang/yolov5-rt-stack.git cd yolov5-rt-stack # install in editable mode From e5e901260de3395036520c3931ff645bf7eb5522 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Mon, 22 Feb 2021 13:27:25 +0800 Subject: [PATCH 3/5] Support ultralytics released v4.0 stacks (#66) * Refactor ultralytics checkpoint updating scripts * Add module state updating scripts and unittest * Skip test for torch 1.8+ * Follow Backbone to update block in PAN * Update docs * Fix activation function in PAN * Adjust the position of the parameter * Cleanup parameters * Fix state_dict obtatining method * Add ultralytics checkpoint updater unittest * Avoid multiple loading of the same model * Cleanup unittest * Add missing packages for ultralytics * Update README.md * Add ultralytics releases v4.0 to yolort.models and hubconf * Adjust the position of the parameter * Fix typo * Upload updated checkpoints --- .github/workflows/ci_test.yml | 2 + README.md | 40 ++----- hubconf.py | 2 +- test/test_utils.py | 12 +++ yolort/models/__init__.py | 15 +++ yolort/models/backbone_utils.py | 14 ++- yolort/models/common.py | 112 +++++++++++++++++--- yolort/models/darknet.py | 57 +++++----- yolort/models/experimental.py | 19 ---- yolort/models/path_aggregation_network.py | 20 ++-- yolort/models/yolo.py | 27 +++-- yolort/utils/__init__.py | 2 + yolort/utils/update_module_state.py | 122 ++++++++++++++++++++++ yolort/utils/updated_checkpoint.py | 77 -------------- 14 files changed, 326 insertions(+), 195 deletions(-) create mode 100644 test/test_utils.py create mode 100644 yolort/utils/update_module_state.py delete mode 100644 yolort/utils/updated_checkpoint.py diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml index 9c63545d..8afbea85 100644 --- a/.github/workflows/ci_test.yml +++ b/.github/workflows/ci_test.yml @@ -42,6 +42,8 @@ jobs: pip install -U opencv-python pip install -U pycocotools>=2.0.2 pip install -U onnxruntime + # required by ultralytics + pip install -U pandas seaborn PyYAML thop tensorboard - name: Install PyTorch ${{ matrix.torch }} Version run: | pip install ${{ matrix.pip_address }} diff --git a/README.md b/README.md index 0a9c14a5..5b4abdb2 100644 --- a/README.md +++ b/README.md @@ -74,40 +74,16 @@ model = torch.hub.load('zhiqwang/yolov5-rt-stack', 'yolov5s', pretrained=True) ### Updating checkpoint from ultralytics/yolov5 -The module state of `yolort` has some differences comparing to `ultralytics/yolov5`. We can load ultralytics's trained model checkpoint with minor changes, and we have converted ultralytics's lastest release [v3.1](https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt) checkpoint [here](https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.1/yolov5s.pt). +The module state of `yolort` has some differences comparing to `ultralytics/yolov5`. We can load ultralytics's trained model checkpoint with minor changes, and we have converted ultralytics's release [v3.1](https://github.com/ultralytics/yolov5/releases/tag/v3.1) and [v4.0](https://github.com/ultralytics/yolov5/releases/tag/v4.0). For example, if you want to convert a `yolov5s` (release 4.0) model, you can just run the following script: -
Expand to see more information of how to update ultralytics's trained (or your own) model checkpoint.
- -- If you train your model using ultralytics's repo, you should update the model checkpoint first. ultralytics's trained model has a limitation that their model must load in the root path of ultralytics, so a important thing is to desensitize the path dependence as follows: - - ```python - # Noted that current path is the root of ultralytics/yolov5, and the checkpoint is - # downloaded from - ultralytics_weights = 'https://github.com/ultralytics/yolov5/releases/download/v3.1/yolov5s.pt' - checkpoints_ = torch.load(ultralytics_weights, map_location='cpu')['model'] - torch.save(checkpoints_.state_dict(), desensitize_ultralytics_weights) - ``` - -- Load `yolort` model as follows: - - ```python - from hubconf import yolov5s - - model = yolov5s() - model.eval() - ``` - -- Now let's update ultralytics/yolov5 trained checkpoint, see the [conversion script](utils/updated_checkpoint.py) for more information: - - ```python - from utils.updated_checkpoint import update_ultralytics_checkpoints - - model = update_ultralytics_checkpoints(model, desensitize_ultralytics_weights) - # updated checkpint is saved to checkpoint_path_rt_stack - torch.save(model.state_dict(), checkpoint_path_rt_stack) - ``` +```python +from yolort.utils import update_module_state_from_ultralytics -
+# Update module state from ultralytics +model = update_module_state_from_ultralytics(arch='yolov5s', version='v4.0') +# Save updated module +torch.save(model.state_dict(), 'yolov5s_updated.pt') +``` ### Inference on `LibTorch` backend 🚀 diff --git a/hubconf.py b/hubconf.py index d90d4cf6..2c722790 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,4 @@ # Optional list of dependencies required by the package dependencies = ['torch', 'torchvision'] -from yolort.models import yolov5s, yolov5m, yolov5l +from yolort.models import yolov5s, yolov5m, yolov5l, yolov5s_r40, yolov5m_r40, yolov5l_r40 diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..8942dbe8 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,12 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +import unittest +import torch +from torch import nn + +from yolort.utils import update_module_state_from_ultralytics + + +class UtilsTester(unittest.TestCase): + def test_update_module_state_from_ultralytics(self): + model = update_module_state_from_ultralytics(arch='yolov5s', version='v4.0') + self.assertIsInstance(model, nn.Module) diff --git a/yolort/models/__init__.py b/yolort/models/__init__.py index d57882dd..3a97ecc0 100644 --- a/yolort/models/__init__.py +++ b/yolort/models/__init__.py @@ -22,6 +22,21 @@ def yolov5l(**kwargs): return model +def yolov5s_r40(**kwargs): + model = YOLOModule(arch="yolov5_darknet_pan_s_r40", **kwargs) + return model + + +def yolov5m_r40(**kwargs): + model = YOLOModule(arch="yolov5_darknet_pan_m_r40", **kwargs) + return model + + +def yolov5l_r40(**kwargs): + model = YOLOModule(arch="yolov5_darknet_pan_l_r40", **kwargs) + return model + + def yolov5_onnx(pretrained=False, progress=True, num_classes=80, **kwargs): model = yolov5s(pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) diff --git a/yolort/models/backbone_utils.py b/yolort/models/backbone_utils.py index e37b9ccb..4b1b87ea 100644 --- a/yolort/models/backbone_utils.py +++ b/yolort/models/backbone_utils.py @@ -4,6 +4,7 @@ from . import darknet from .path_aggregation_network import PathAggregationNetwork +from .common import BottleneckCSP, C3 from typing import List, Optional @@ -22,14 +23,19 @@ class BackboneWithPAN(nn.Module): of the returned activation (which the user can specify). in_channels_list (List[int]): number of channels for each feature map that is returned, in the order they are present in the OrderedDict + version (str): ultralytics release version: v3.1 or v4.0 Attributes: out_channels (int): the number of channels in the PAN """ - def __init__(self, backbone, return_layers, in_channels_list, depth_multiple): + def __init__(self, backbone, return_layers, in_channels_list, depth_multiple, version): super().__init__() self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) - self.pan = PathAggregationNetwork(in_channels_list, depth_multiple) + self.pan = PathAggregationNetwork( + in_channels_list, + depth_multiple, + version=version, + ) self.out_channels = in_channels_list def forward(self, x): @@ -44,6 +50,7 @@ def darknet_pan_backbone( width_multiple: float, pretrained: Optional[bool] = False, returned_layers: Optional[List[int]] = None, + version: str = 'v4.0', ): """ Constructs a specified ResNet backbone with PAN on top. Freezes the specified number of @@ -71,6 +78,7 @@ def darknet_pan_backbone( pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. + version (str): ultralytics release version: v3.1 or v4.0 """ backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features @@ -81,4 +89,4 @@ def darknet_pan_backbone( in_channels_list = [int(gw * width_multiple) for gw in [256, 512, 1024]] - return BackboneWithPAN(backbone, return_layers, in_channels_list, depth_multiple) + return BackboneWithPAN(backbone, return_layers, in_channels_list, depth_multiple, version) diff --git a/yolort/models/common.py b/yolort/models/common.py index bd4b5b04..fad32ecc 100644 --- a/yolort/models/common.py +++ b/yolort/models/common.py @@ -21,11 +21,27 @@ def DWConv(c1, c2, k=1, s=1, act=True): class Conv(nn.Module): # Standard convolution - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version='v4.0'): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + k (int): kernel + s (int): stride + p (Optional[int]): padding + g (int): groups + act (bool): determine the activation function + version (str): ultralytics release version: v3.1 or v4.0 + """ super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.Hardswish() if act else nn.Identity() + if version == 'v4.0': + self.act = nn.SiLU() if act else nn.Identity() + elif version == 'v3.1': + self.act = nn.Hardswish() if act else nn.Identity() + else: + raise NotImplementedError("Currently only support version v3.1 and v4.0") def forward(self, x: Tensor) -> Tensor: return self.act(self.bn(self.conv(x))) @@ -36,11 +52,20 @@ def fuseforward(self, x): class Bottleneck(nn.Module): # Standard bottleneck - def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, version='v4.0'): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + shortcut (bool): shortcut + g (int): groups + e (float): expansion + version (str): ultralytics release version: v3.1 or v4.0 + """ super().__init__() c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.cv1 = Conv(c1, c_, 1, 1, version=version) + self.cv2 = Conv(c_, c2, 3, 1, g=g, version=version) self.add = shortcut and c1 == c2 def forward(self, x): @@ -49,16 +74,25 @@ def forward(self, x): class BottleneckCSP(nn.Module): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + n (int): number + shortcut (bool): shortcut + g (int): groups + e (float): expansion + """ super().__init__() c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) + self.cv1 = Conv(c1, c_, 1, 1, version='v3.1') self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) - self.cv4 = Conv(2 * c_, c2, 1, 1) + self.cv4 = Conv(2 * c_, c2, 1, 1, version='v3.1') self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) self.act = nn.LeakyReLU(0.1, inplace=True) - self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0, version='v3.1') for _ in range(n)]) def forward(self, x): y1 = self.cv3(self.m(self.cv1(x))) @@ -66,13 +100,36 @@ def forward(self, x): return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + n (int): number + shortcut (bool): shortcut + g (int): groups + e (float): expansion + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP - def __init__(self, c1, c2, k=(5, 9, 13)): + def __init__(self, c1, c2, k=(5, 9, 13), version='v4.0'): super().__init__() c_ = c1 // 2 # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.cv1 = Conv(c1, c_, 1, 1, version=version) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1, version=version) self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) def forward(self, x): @@ -82,9 +139,20 @@ def forward(self, x): class Focus(nn.Module): # Focus wh information into c-space - def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version='v4.0'): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + k (int): kernel + s (int): stride + p (Optional[int]): padding + g (int): groups + act (bool): determine the activation function + version (str): ultralytics release version: v3.1 or v4.0 + """ super().__init__() - self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + self.conv = Conv(c1 * 4, c2, k, s, p, g, act, version=version) def forward(self, x: Tensor) -> Tensor: y = focus_transform(x) @@ -95,7 +163,10 @@ def forward(self, x: Tensor) -> Tensor: def focus_transform(x: Tensor) -> Tensor: '''x(b,c,w,h) -> y(b,4c,w/2,h/2)''' - y = torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1) + y = torch.cat([x[..., ::2, ::2], + x[..., 1::2, ::2], + x[..., ::2, 1::2], + x[..., 1::2, 1::2]], 1) return y @@ -124,7 +195,16 @@ def forward(x): class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) - def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): + """ + Args: + c1 (int): ch_in + c2 (int): ch_out + k (int): kernel + s (int): stride + p (Optional[int]): padding + g (int): groups + """ super().__init__() self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1) diff --git a/yolort/models/darknet.py b/yolort/models/darknet.py index 21b828ed..21996277 100644 --- a/yolort/models/darknet.py +++ b/yolort/models/darknet.py @@ -3,8 +3,7 @@ from torch import nn, Tensor from torch.hub import load_state_dict_from_url -from .common import Conv, SPP, Focus, BottleneckCSP -from .experimental import C3 +from .common import Conv, SPP, Focus, BottleneckCSP, C3 from typing import Callable, List, Optional, Any @@ -13,12 +12,12 @@ 'darknet_s_r4_0', 'darknet_m_r4_0', 'darknet_l_r4_0'] model_urls = { - "darknet_s_r3.1": None, - "darknet_m_r3.1": None, - "darknet_l_r3.1": None, - "darknet_s_r4.0": None, - "darknet_m_r4.0": None, - "darknet_l_r4.0": None, + "darknet_s_v3.1": None, + "darknet_m_v3.1": None, + "darknet_l_v3.1": None, + "darknet_s_v4.0": None, + "darknet_m_v4.0": None, + "darknet_l_v4.0": None, } # TODO: add checkpoint weights @@ -47,6 +46,7 @@ def __init__( self, depth_multiple: float, width_multiple: float, + version: str, block: Optional[Callable[..., nn.Module]] = None, stages_repeats: Optional[List[int]] = None, stages_out_channels: Optional[List[int]] = None, @@ -60,6 +60,7 @@ def __init__( num_classes (int): Number of classes depth_multiple (float): Depth multiplier width_multiple (float): Width multiplier - adjusts number of channels in each layer by this amount + version (str): ultralytics release version: v3.1 or v4.0 round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for darknet @@ -67,7 +68,7 @@ def __init__( super().__init__() if block is None: - block = BottleneckCSP + block = _block[version] input_channel = 64 last_channel = 1024 @@ -83,21 +84,21 @@ def __init__( # building first layer out_channel = _make_divisible(input_channel * width_multiple, round_nearest) - layers.append(Focus(3, out_channel, k=3)) + layers.append(Focus(3, out_channel, k=3, version=version)) input_channel = out_channel # building CSP blocks for depth_gain, out_channel in zip(stages_repeats, stages_out_channels): depth_gain = max(round(depth_gain * depth_multiple), 1) out_channel = _make_divisible(out_channel * width_multiple, round_nearest) - layers.append(Conv(input_channel, out_channel, k=3, s=2)) + layers.append(Conv(input_channel, out_channel, k=3, s=2, version=version)) layers.append(block(out_channel, out_channel, n=depth_gain)) input_channel = out_channel # building last CSP blocks last_channel = _make_divisible(last_channel * width_multiple, round_nearest) - layers.append(Conv(input_channel, last_channel, k=3, s=2)) - layers.append(SPP(last_channel, last_channel, k=(5, 9, 13))) + layers.append(Conv(input_channel, last_channel, k=3, s=2, version=version)) + layers.append(SPP(last_channel, last_channel, k=(5, 9, 13), version=version)) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -131,6 +132,12 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) +_block = { + "v3.1": BottleneckCSP, + "v4.0": C3, +} + + def _darknet(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> DarkNet: """ Constructs a DarkNet architecture from @@ -150,12 +157,6 @@ def _darknet(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: return model -_block = { - "r3.1": BottleneckCSP, - "r4.0": C3, -} - - def darknet_s_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: """ Constructs a DarkNet with small channels, as described in release 3.1 @@ -165,8 +166,7 @@ def darknet_s_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_s_r3.1", pretrained, progress, - 0.33, 0.5, block=_block["r3.1"], **kwargs) + return _darknet("darknet_s_r3.1", pretrained, progress, 0.33, 0.5, "v3.1", **kwargs) def darknet_m_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: @@ -178,8 +178,7 @@ def darknet_m_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_m_r3.1", pretrained, progress, - 0.67, 0.75, block=_block["r3.1"], **kwargs) + return _darknet("darknet_m_r3.1", pretrained, progress, 0.67, 0.75, "v3.1", **kwargs) def darknet_l_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: @@ -191,8 +190,7 @@ def darknet_l_r3_1(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_l_r3.1", pretrained, progress, - 1.0, 1.0, block=_block["r3.1"], **kwargs) + return _darknet("darknet_l_r3.1", pretrained, progress, 1.0, 1.0, "v3.1", **kwargs) def darknet_s_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: @@ -204,8 +202,7 @@ def darknet_s_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_s_r4.0", pretrained, progress, - 0.33, 0.5, block=_block["r4.0"], **kwargs) + return _darknet("darknet_s_r4.0", pretrained, progress, 0.33, 0.5, "v4.0", **kwargs) def darknet_m_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: @@ -217,8 +214,7 @@ def darknet_m_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_m_r4.0", pretrained, progress, - 0.67, 0.75, block=_block["r4.0"], **kwargs) + return _darknet("darknet_m_r4.0", pretrained, progress, 0.67, 0.75, "v4.0", **kwargs) def darknet_l_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarkNet: @@ -230,5 +226,4 @@ def darknet_l_r4_0(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _darknet("darknet_l_r4.0", pretrained, progress, - 1.0, 1.0, block=_block["r4.0"], **kwargs) + return _darknet("darknet_l_r4.0", pretrained, progress, 1.0, 1.0, "v4.0", **kwargs) diff --git a/yolort/models/experimental.py b/yolort/models/experimental.py index 3d63391f..4e1d87cd 100644 --- a/yolort/models/experimental.py +++ b/yolort/models/experimental.py @@ -22,25 +22,6 @@ def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) -class C3(nn.Module): - # Cross Convolution CSP - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion - super().__init__() - c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) - self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) - self.cv4 = Conv(2 * c_, c2, 1, 1) - self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) - self.act = nn.LeakyReLU(0.1, inplace=True) - self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) - - def forward(self, x): - y1 = self.cv3(self.m(self.cv1(x))) - y2 = self.cv2(x) - return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) - - class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs diff --git a/yolort/models/path_aggregation_network.py b/yolort/models/path_aggregation_network.py index 5b88e45b..0ca3748d 100644 --- a/yolort/models/path_aggregation_network.py +++ b/yolort/models/path_aggregation_network.py @@ -2,7 +2,7 @@ import torch from torch import nn, Tensor -from .common import Conv, BottleneckCSP +from .common import Conv, BottleneckCSP, C3 from typing import Callable, List, Dict, Optional @@ -22,6 +22,7 @@ class PathAggregationNetwork(nn.Module): in_channels_list (list[int]): number of channels for each feature map that is passed to the module out_channels (int): number of channels of the PAN representation + version (str): ultralytics release version: v3.1 or v4.0 Examples:: @@ -44,22 +45,23 @@ def __init__( self, in_channels_list: List[int], depth_multiple: float, + version: str = 'v4.0', block: Optional[Callable[..., nn.Module]] = None, ): super().__init__() assert len(in_channels_list) == 3, "currently only support length 3." if block is None: - block = BottleneckCSP + block = _block[version] depth_gain = max(round(3 * depth_multiple), 1) inner_blocks = [ block(in_channels_list[2], in_channels_list[2], n=depth_gain, shortcut=False), - Conv(in_channels_list[2], in_channels_list[1], 1, 1), + Conv(in_channels_list[2], in_channels_list[1], 1, 1, version=version), nn.Upsample(scale_factor=2), block(in_channels_list[2], in_channels_list[1], n=depth_gain, shortcut=False), - Conv(in_channels_list[1], in_channels_list[0], 1, 1), + Conv(in_channels_list[1], in_channels_list[0], 1, 1, version=version), nn.Upsample(scale_factor=2), ] @@ -67,9 +69,9 @@ def __init__( layer_blocks = [ block(in_channels_list[1], in_channels_list[0], n=depth_gain, shortcut=False), - Conv(in_channels_list[0], in_channels_list[0], 3, 2), + Conv(in_channels_list[0], in_channels_list[0], 3, 2, version=version), block(in_channels_list[1], in_channels_list[1], n=depth_gain, shortcut=False), - Conv(in_channels_list[1], in_channels_list[1], 3, 2), + Conv(in_channels_list[1], in_channels_list[1], 3, 2, version=version), block(in_channels_list[2], in_channels_list[2], n=depth_gain, shortcut=False), ] self.layer_blocks = nn.ModuleList(layer_blocks) @@ -155,3 +157,9 @@ def forward(self, x: Dict[str, Tensor]) -> List[Tensor]: results.append(last_inner) return results + + +_block = { + "v3.1": BottleneckCSP, + "v4.0": C3, +} diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 64df9633..166b86dc 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -129,9 +129,9 @@ def forward( 'yolov5_darknet_pan_s_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r31_coco-eb728698.pt', 'yolov5_darknet_pan_m_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r31_coco-670dc553.pt', 'yolov5_darknet_pan_l_r31_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r31_coco-4dcc8209.pt', - 'yolov5_darknet_pan_s_r40_coco': None, - 'yolov5_darknet_pan_m_r40_coco': None, - 'yolov5_darknet_pan_l_r40_coco': None, + 'yolov5_darknet_pan_s_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_s_r40_coco-e3fd213d.pt', + 'yolov5_darknet_pan_m_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_m_r40_coco-d295cb02.pt', + 'yolov5_darknet_pan_l_r40_coco': f'{model_urls_root}/yolov5_darknet_pan_l_r40_coco-4416841f.pt', } @@ -139,6 +139,7 @@ def _yolov5_darknet_pan( backbone_name: str, depth_multiple: float, width_multiple: float, + version: str, weights_name: str, pretrained: bool = False, progress: bool = True, @@ -181,7 +182,7 @@ def _yolov5_darknet_pan( pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr """ - backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple) + backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple, version=version) anchor_grids = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], @@ -209,7 +210,8 @@ def yolov5_darknet_pan_s_r31(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_s_r31_coco' depth_multiple = 0.33 width_multiple = 0.5 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v3.1' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) @@ -225,7 +227,8 @@ def yolov5_darknet_pan_m_r31(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_m_r31_coco' depth_multiple = 0.67 width_multiple = 0.75 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v3.1' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) @@ -241,7 +244,8 @@ def yolov5_darknet_pan_l_r31(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_l_r31_coco' depth_multiple = 1.0 width_multiple = 1.0 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v3.1' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) @@ -257,7 +261,8 @@ def yolov5_darknet_pan_s_r40(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_s_r40_coco' depth_multiple = 0.33 width_multiple = 0.5 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v4.0' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) @@ -273,7 +278,8 @@ def yolov5_darknet_pan_m_r40(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_m_r40_coco' depth_multiple = 0.67 width_multiple = 0.75 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v4.0' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) @@ -289,5 +295,6 @@ def yolov5_darknet_pan_l_r40(pretrained: bool = False, progress: bool = True, nu weights_name = 'yolov5_darknet_pan_l_r40_coco' depth_multiple = 1.0 width_multiple = 1.0 - return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, weights_name, + version = 'v4.0' + return _yolov5_darknet_pan(backbone_name, depth_multiple, width_multiple, version, weights_name, pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) diff --git a/yolort/utils/__init__.py b/yolort/utils/__init__.py index d160b6bf..6bff56f9 100644 --- a/yolort/utils/__init__.py +++ b/yolort/utils/__init__.py @@ -1 +1,3 @@ from .flash_utils import get_callable_dict +from .image_utils import cv2_imshow +from .update_module_state import update_module_state_from_ultralytics diff --git a/yolort/utils/update_module_state.py b/yolort/utils/update_module_state.py new file mode 100644 index 00000000..116a3098 --- /dev/null +++ b/yolort/utils/update_module_state.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +from functools import reduce +import torch +from torch import nn + +from ..models import yolo + +from typing import Any + + +def update_module_state_from_ultralytics( + arch: str = 'yolov5s', + version: str = 'v4.0', + num_classes: int = 80, + **kwargs: Any, +): + architecture_maps = { + 'yolov5s_v3.1': 'yolov5_darknet_pan_s_r31', + 'yolov5m_v3.1': 'yolov5_darknet_pan_m_r31', + 'yolov5l_v3.1': 'yolov5_darknet_pan_l_r31', + 'yolov5s_v4.0': 'yolov5_darknet_pan_s_r40', + 'yolov5m_v4.0': 'yolov5_darknet_pan_m_r40', + 'yolov5l_v4.0': 'yolov5_darknet_pan_l_r40', + } + + model = torch.hub.load(f'ultralytics/yolov5:{version}', arch, pretrained=True) + + module_state_updater = ModuleStateUpdate(arch=architecture_maps[f'{arch}_{version}'], + num_classes=num_classes, **kwargs) + + module_state_updater.updating(model) + + return module_state_updater.model.half() + + +class ModuleStateUpdate: + """ + Update checkpoint from ultralytics yolov5 + """ + def __init__( + self, + arch: str = 'yolov5_darknet_pan_s_r31', + num_classes: int = 80, + inner_block_maps: dict = {'0': '9', '1': '10', '3': '13', '4': '14'}, + layer_block_maps: dict = {'0': '17', '1': '18', '2': '20', '3': '21', '4': '23'}, + head_ind: int = 24, + head_name: str = 'm', + ) -> None: + # Configuration for making the keys consistent + self.inner_block_maps = inner_block_maps + self.layer_block_maps = layer_block_maps + self.head_ind = head_ind + self.head_name = head_name + # Set model + self.model = yolo.__dict__[arch](num_classes=num_classes) + + def updating(self, state_dict): + # Obtain module state + state_dict = obtain_module_sequential(state_dict) + + # Update backbone features + for name, params in self.model.backbone.body.named_parameters(): + params.data.copy_( + self.attach_parameters_block(state_dict, name, None)) + + for name, buffers in self.model.backbone.body.named_buffers(): + buffers.copy_( + self.attach_parameters_block(state_dict, name, None)) + + # Update PAN features + for name, params in self.model.backbone.pan.inner_blocks.named_parameters(): + params.data.copy_( + self.attach_parameters_block(state_dict, name, self.inner_block_maps)) + + for name, buffers in self.model.backbone.pan.inner_blocks.named_buffers(): + buffers.copy_( + self.attach_parameters_block(state_dict, name, self.inner_block_maps)) + + for name, params in self.model.backbone.pan.layer_blocks.named_parameters(): + params.data.copy_( + self.attach_parameters_block(state_dict, name, self.layer_block_maps)) + + for name, buffers in self.model.backbone.pan.layer_blocks.named_buffers(): + buffers.copy_( + self.attach_parameters_block(state_dict, name, self.layer_block_maps)) + + # Update box heads + for name, params in self.model.head.named_parameters(): + params.data.copy_( + self.attach_parameters_heads(state_dict, name)) + + for name, buffers in self.model.head.named_buffers(): + buffers.copy_( + self.attach_parameters_heads(state_dict, name)) + + @staticmethod + def attach_parameters_block(state_dict, name, block_maps=None): + keys = name.split('.') + ind = int(block_maps[keys[0]]) if block_maps else int(keys[0]) + return rgetattr(state_dict[ind], keys[1:]) + + def attach_parameters_heads(self, state_dict, name): + keys = name.split('.') + ind = int(keys[1]) + return rgetattr(getattr(state_dict[self.head_ind], self.head_name)[ind], keys[2:]) + + +def rgetattr(obj, attr, *args): + """ + Nested version of getattr. + See + """ + def _getattr(obj, attr): + return getattr(obj, attr, *args) + return reduce(_getattr, [obj] + attr) + + +def obtain_module_sequential(state_dict): + if isinstance(state_dict, nn.Sequential): + return state_dict + else: + return obtain_module_sequential(state_dict.model) diff --git a/yolort/utils/updated_checkpoint.py b/yolort/utils/updated_checkpoint.py deleted file mode 100644 index 47eb0844..00000000 --- a/yolort/utils/updated_checkpoint.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. -import argparse -import torch - -from ..models import yolov5m - - -def update_ultralytics_checkpoints(model, checkpoint_path_ultralytics): - """ - It's limited that ultralytics saved model must load in their root path. - So a very important thing is to desensitize the path befor updating - ultralytics's trained model as following: - - >>> checkpoints_ = torch.load(weights, map_location='cpu')['model'] - >>> torch.save(checkpoints_.state_dict(), './checkpoints/yolov5s_ultralytics.pt') - """ - state_dict = torch.load(checkpoint_path_ultralytics, map_location="cpu") - - # Update backbone features - for name, params in model.backbone.body.named_parameters(prefix='model'): - params.data.copy_(state_dict[name]) - - for name, buffers in model.backbone.body.named_buffers(prefix='model'): - buffers.copy_(state_dict[name]) - - inner_block_maps = {'0': '9', '1': '10', '3': '13', '4': '14'} - layer_block_maps = {'0': '17', '1': '18', '2': '20', '3': '21', '4': '23'} - - # Update PAN features - for name, params in model.backbone.pan.inner_blocks.named_parameters(): - state_key = name.split('.') - params.data.copy_(state_dict[f"model.{'.'.join([inner_block_maps[state_key[0]]] + state_key[1:])}"]) - - for name, buffers in model.backbone.pan.inner_blocks.named_buffers(): - state_key = name.split('.') - buffers.copy_(state_dict[f"model.{'.'.join([inner_block_maps[state_key[0]]] + state_key[1:])}"]) - - for name, params in model.backbone.pan.layer_blocks.named_parameters(): - state_key = name.split('.') - params.data.copy_(state_dict[f"model.{'.'.join([layer_block_maps[state_key[0]]] + state_key[1:])}"]) - - for name, buffers in model.backbone.pan.layer_blocks.named_buffers(): - state_key = name.split('.') - buffers.copy_(state_dict[f"model.{'.'.join([layer_block_maps[state_key[0]]] + state_key[1:])}"]) - - # Update box heads - for name, params in model.head.named_parameters(prefix='model.24'): - params.data.copy_(state_dict[name.replace('head', 'm')]) - - for name, buffers in model.head.named_buffers(prefix='model.24'): - buffers.copy_(state_dict[name.replace('head', 'm')]) - - return model - - -def main(args): - model = yolov5m(pretrained=False, score_thresh=0.25) - model = update_ultralytics_checkpoints(model, args.checkpoint_path_ultralytics) - model = model.half() - torch.save(model.state_dict(), args.checkpoint_path_rt_stack) - - -def get_args_parser(): - parser = argparse.ArgumentParser('YOLO checkpoint configures', add_help=False) - parser.add_argument('--checkpoint_path_ultralytics', default='.checkpoints/yolov5s_ultralytics.pt', - help='Path of ultralytics trained yolov5 checkpoint model') - parser.add_argument('--checkpoint_path_rt_stack', default='./checkpoints/yolov5s_rt.pt', - help='Path of updated yolov5 checkpoint model') - - return parser - - -if __name__ == "__main__": - parser = argparse.ArgumentParser('Update checkpoint from ultralytics yolov5', parents=[get_args_parser()]) - args = parser.parse_args() - - main(args) From 5825161bac87e80d5de3f2004b943ffd6ac89ef0 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Mon, 22 Feb 2021 22:26:37 +0800 Subject: [PATCH 4/5] Initialize bias into YoloHead (#67) * Initialize weights and bias in YoloHead * Fix format in Docs * Fix unittest * Minor fixes --- test/test_models.py | 3 ++- yolort/models/box_head.py | 23 +++++++++++++++++++++-- yolort/models/yolo.py | 1 + 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 3ec3ea9c..16379af3 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -108,7 +108,8 @@ def _init_test_yolo_head(self): in_channels = self._get_in_channels() num_anchors = self._get_num_anchors() num_classes = self._get_num_classes() - box_head = YoloHead(in_channels, num_anchors, num_classes) + strides = self._get_strides() + box_head = YoloHead(in_channels, num_anchors, strides, num_classes) return box_head def test_yolo_head(self): diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 5742a991..f4d64886 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -1,4 +1,5 @@ # Modified from ultralytics/yolov5 by Zhiqiang Wang +import math import torch from torch import nn, Tensor @@ -10,14 +11,31 @@ class YoloHead(nn.Module): - def __init__(self, in_channels: List[int], num_anchors: int, num_classes: int): + def __init__(self, in_channels: List[int], num_anchors: int, strides: List[int], num_classes: int): super().__init__() self.num_anchors = num_anchors # anchors + self.num_classes = num_classes self.num_outputs = num_classes + 5 # number of outputs per anchor + self.strides = strides self.head = nn.ModuleList( nn.Conv2d(ch, self.num_outputs * self.num_anchors, 1) for ch in in_channels) # output conv + self._initialize_biases() # Init weights, biases + + def _initialize_biases(self, cf=None): + """ + Initialize biases into YoloHead, cf is class frequency + Check section 3.3 in + """ + for mi, s in zip(self.head, self.strides): + b = mi.bias.view(self.num_anchors, -1) # conv.bias(255) to (3,85) + # obj (8 objects per 640 image) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) + # classes + b.data[:, 5:] += torch.log(cf / cf.sum()) if cf else math.log(0.6 / (self.num_classes - 0.99)) + mi.bias = nn.Parameter(b.view(-1), requires_grad=True) + def get_result_from_head(self, features: Tensor, idx: int) -> Tensor: """ This is equivalent to self.head[idx](features), @@ -199,7 +217,8 @@ def assign_targets_to_anchors( # Append a = targets_with_gain[:, 6].long() # anchor indices # image, anchor, grid indices - indices.append((bc[0], a, grid_ij[:, 1].clamp_(0, gain[3] - 1), grid_ij[:, 0].clamp_(0, gain[2] - 1))) + indices.append((bc[0], a, grid_ij[:, 1].clamp_(0, gain[3] - 1), + grid_ij[:, 0].clamp_(0, gain[2] - 1))) targets_box.append(torch.cat((grid_xy - grid_ij, grid_wh), 1)) # box anchors_encode.append(anchors_per_layer[a]) # anchors targets_cls.append(bc[1]) # class diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 166b86dc..0f00d5cd 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -56,6 +56,7 @@ def __init__( head = YoloHead( backbone.out_channels, anchor_generator.num_anchors, + anchor_generator.strides, num_classes, ) self.head = head From 76f5a5d33b99ffde162b85cfe443f3142348dd2e Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 24 Feb 2021 00:28:24 +0800 Subject: [PATCH 5/5] Add export friendly substitutions of SiLU (#69) * Refactor module importing * Add onnx friendly institutions of nn.SiLU * Update notebook for ONNX exporting * Update notebook for TVM exporting * Add docs --- hubconf.py | 2 +- .../export-onnx-inference-onnxruntime.ipynb | 54 ++++++++------ notebooks/export-relay-inference-tvm.ipynb | 34 ++++----- test/test_onnx.py | 55 ++++++++++---- yolort/models/__init__.py | 74 +++++++++++++------ yolort/utils/activations.py | 19 +++-- 6 files changed, 154 insertions(+), 84 deletions(-) diff --git a/hubconf.py b/hubconf.py index 2c722790..d90d4cf6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,4 @@ # Optional list of dependencies required by the package dependencies = ['torch', 'torchvision'] -from yolort.models import yolov5s, yolov5m, yolov5l, yolov5s_r40, yolov5m_r40, yolov5l_r40 +from yolort.models import yolov5s, yolov5m, yolov5l diff --git a/notebooks/export-onnx-inference-onnxruntime.ipynb b/notebooks/export-onnx-inference-onnxruntime.ipynb index 43b70306..25e8c784 100644 --- a/notebooks/export-onnx-inference-onnxruntime.ipynb +++ b/notebooks/export-onnx-inference-onnxruntime.ipynb @@ -12,7 +12,7 @@ "import onnx\n", "import onnxruntime\n", "\n", - "from yolort.models import yolov5_onnx\n", + "from yolort.models import yolov5s\n", "\n", "from yolort.utils.image_utils import read_image" ] @@ -26,7 +26,7 @@ "import os\n", "\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"5\"\n", "\n", "device = torch.device('cuda')" ] @@ -44,7 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = yolov5_onnx(pretrained=True, score_thresh=0.45)\n", + "model = yolov5s(upstream_version='v4.0', export_friendly=True, pretrained=True, score_thresh=0.45)\n", "\n", "model = model.eval()\n", "model = model.to(device)" @@ -100,8 +100,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 20 ms, sys: 0 ns, total: 20 ms\n", - "Wall time: 20.1 ms\n" + "CPU times: user 16 ms, sys: 0 ns, total: 16 ms\n", + "Wall time: 16.8 ms\n" ] } ], @@ -119,10 +119,10 @@ { "data": { "text/plain": [ - "tensor([[ 48.4231, 401.9458, 237.0045, 897.8144],\n", - " [215.4538, 407.8977, 344.6994, 857.3773],\n", - " [ 13.1457, 225.1691, 801.7442, 736.7350],\n", - " [675.6570, 409.5675, 812.7283, 868.2495]], device='cuda:0')" + "tensor([[ 52.1687, 384.9377, 235.4150, 899.1040],\n", + " [223.6789, 406.9857, 346.8747, 862.1425],\n", + " [677.8205, 390.5674, 811.9033, 871.8314],\n", + " [ 9.4887, 227.6140, 799.6432, 766.6011]], device='cuda:0')" ] }, "execution_count": 7, @@ -142,7 +142,7 @@ { "data": { "text/plain": [ - "tensor([0.8941, 0.8636, 0.8621, 0.7490], device='cuda:0')" + "tensor([0.8995, 0.8665, 0.8193, 0.8094], device='cuda:0')" ] }, "execution_count": 8, @@ -162,7 +162,7 @@ { "data": { "text/plain": [ - "tensor([0, 0, 5, 0], device='cuda:0')" + "tensor([0, 0, 0, 5], device='cuda:0')" ] }, "execution_count": 9, @@ -224,17 +224,17 @@ " 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n", "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " dtype=torch.float32)).float())) for i in range(dim)]\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:31: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + "/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:31: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " stride = torch.as_tensor([stride], dtype=dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:50: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + "/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:50: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + "/mnt/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:344: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + "/mnt/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " for idx in range(batch_size): # image idx, image inference\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + "/mnt/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " for s, s_orig in zip(new_size, original_size)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "/mnt/yolov5-rt-stack/yolort/models/transform.py:287: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " for s, s_orig in zip(new_size, original_size)\n", "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py:2378: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n", " \"If indices include negative values, the exported graph will produce incorrect results.\")\n", @@ -264,6 +264,17 @@ "## Simplifier exported `ONNX` model" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Don't forget to install `onnx-simplifier`\n", + "\n", + "```bash\n", + "!pip -U install onnx-simplifier\n", + "```" + ] + }, { "cell_type": "code", "execution_count": 13, @@ -273,7 +284,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Starting simplifing with onnxsim 0.3.1\n" + "Starting simplifing with onnxsim 0.3.2\n" ] } ], @@ -361,6 +372,7 @@ "metadata": {}, "outputs": [], "source": [ + "# ort_session = onnxruntime.InferenceSession(export_onnx_name)\n", "ort_session = onnxruntime.InferenceSession(onnx_simp_name)" ] }, @@ -384,8 +396,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 16 ms, sys: 8 ms, total: 24 ms\n", - "Wall time: 22.4 ms\n" + "CPU times: user 2.33 s, sys: 0 ns, total: 2.33 s\n", + "Wall time: 77.9 ms\n" ] } ], @@ -411,7 +423,7 @@ ], "source": [ "for i in range(0, len(outputs)):\n", - " torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-05, atol=1e-07)\n", + " torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-04, atol=1e-07)\n", "\n", "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" ] diff --git a/notebooks/export-relay-inference-tvm.ipynb b/notebooks/export-relay-inference-tvm.ipynb index b6fd856f..b2a386f9 100644 --- a/notebooks/export-relay-inference-tvm.ipynb +++ b/notebooks/export-relay-inference-tvm.ipynb @@ -4,10 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Compile PyTorch Object Detection Models\n", + "# Compile YOLOv5 Models\n", "\n", - "This article is an introductory tutorial to deploy PyTorch object\n", - "detection models with Relay VM.\n", + "This article is an introductory tutorial to deploy PyTorch YOLOv5 models with Relay VM.\n", "\n", "For us to begin with, PyTorch should be installed.\n", "TorchVision is also required since we will be using it as our model zoo.\n", @@ -75,7 +74,7 @@ }, "outputs": [], "source": [ - "in_size = 300\n", + "in_size = 416\n", "\n", "input_shape = (1, 3, in_size, in_size)\n", "\n", @@ -110,7 +109,7 @@ "source": [ "from yolort.models import yolov5s\n", "\n", - "model_func = yolov5s(pretrained=True)" + "model_func = yolov5s(upstream_version='v4.0', export_friendly=True, pretrained=True)" ] }, { @@ -142,7 +141,7 @@ " anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n", "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:77: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:344: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + "/data/wangzq/yolov5-rt-stack/yolort/models/box_head.py:363: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " for idx in range(batch_size): # image idx, image inference\n", "/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:287: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " for s, s_orig in zip(new_size, original_size)\n", @@ -171,12 +170,12 @@ "data": { "text/plain": [ "graph(%self.1 : __torch__.TraceWrapper,\n", - " %images : Float(1:270000, 3:90000, 300:300, 300:1, requires_grad=0, device=cpu)):\n", - " %4620 : __torch__.yolort.models.yolo_module.YOLOModule = prim::GetAttr[name=\"model\"](%self.1)\n", - " %4999 : (Tensor, Tensor, Tensor) = prim::CallMethod[name=\"forward\"](%4620, %images)\n", - " %4996 : Float(300:4, 4:1, requires_grad=0, device=cpu), %4997 : Float(300:1, requires_grad=0, device=cpu), %4998 : Long(300:1, requires_grad=0, device=cpu) = prim::TupleUnpack(%4999)\n", - " %3728 : (Float(300:4, 4:1, requires_grad=0, device=cpu), Float(300:1, requires_grad=0, device=cpu), Long(300:1, requires_grad=0, device=cpu)) = prim::TupleConstruct(%4996, %4997, %4998)\n", - " return (%3728)" + " %images : Float(1:519168, 3:173056, 416:416, 416:1, requires_grad=0, device=cpu)):\n", + " %4495 : __torch__.yolort.models.yolo_module.YOLOModule = prim::GetAttr[name=\"model\"](%self.1)\n", + " %4874 : (Tensor, Tensor, Tensor) = prim::CallMethod[name=\"forward\"](%4495, %images)\n", + " %4871 : Float(300:4, 4:1, requires_grad=0, device=cpu), %4872 : Float(300:1, requires_grad=0, device=cpu), %4873 : Long(300:1, requires_grad=0, device=cpu) = prim::TupleUnpack(%4874)\n", + " %3611 : (Float(300:4, 4:1, requires_grad=0, device=cpu), Float(300:1, requires_grad=0, device=cpu), Long(300:1, requires_grad=0, device=cpu)) = prim::TupleConstruct(%4871, %4872, %4873)\n", + " return (%3611)" ] }, "execution_count": 6, @@ -201,7 +200,7 @@ "metadata": {}, "outputs": [], "source": [ - "img_path = 'test/assets/bus.jpg'\n", + "img_path = './test/assets/bus.jpg'\n", "\n", "img = cv2.imread(img_path).astype(\"float32\")\n", "img = cv2.resize(img, (in_size, in_size))\n", @@ -360,7 +359,6 @@ }, "outputs": [], "source": [ - "# Dummy run\n", "ctx = tvm.cpu()\n", "vm = VirtualMachine(vm_exec, ctx)\n", "vm.set_input(\"main\", **{input_name: img})\n", @@ -381,15 +379,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 528 ms, sys: 364 ms, total: 892 ms\n", - "Wall time: 22.3 ms\n" + "CPU times: user 684 ms, sys: 832 ms, total: 1.52 s\n", + "Wall time: 39.2 ms\n" ] } ], "source": [ "%%time\n", - "ctx = tvm.cpu()\n", - "vm = VirtualMachine(vm_exec, ctx)\n", "vm.set_input(\"main\", **{input_name: img})\n", "tvm_res = vm.run()" ] @@ -454,4 +450,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/test/test_onnx.py b/test/test_onnx.py index c6832f3d..f1ebca32 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -10,7 +10,7 @@ import unittest from torchvision.ops._register_onnx_ops import _onnx_opset_version -from yolort.models import yolov5_onnx +from yolort.models import yolov5s, yolov5m @unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') @@ -19,15 +19,23 @@ class ONNXExporterTester(unittest.TestCase): def setUpClass(cls): torch.manual_seed(123) - def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, + def run_model(self, model, inputs_list, tolerate_small_mismatch=False, + do_constant_folding=True, dynamic_axes=None, output_names=None, input_names=None): model.eval() onnx_io = io.BytesIO() # export to onnx with the first input - torch.onnx.export(model, inputs_list[0], onnx_io, - do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, - dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) + torch.onnx.export( + model, + inputs_list[0], + onnx_io, + do_constant_folding=do_constant_folding, + opset_version=_onnx_opset_version, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + ) # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): @@ -89,23 +97,40 @@ def get_test_images(self): image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" image2 = self.get_image_from_url(url=image_url2, size=(250, 380)) - images = [image] - test_images = [image2] - return images, test_images + images_one = [image] + images_two = [image2] + return images_one, images_two - def test_yolov5s(self): - images, test_images = self.get_test_images() - dummy_image = [torch.ones(3, 100, 100) * 0.3] - model = yolov5_onnx(pretrained=True) + def test_yolov5s_r31(self): + images_one, images_two = self.get_test_images() + images_dummy = [torch.ones(3, 100, 100) * 0.3] + model = yolov5s(upstream_version='v3.1', export_friendly=True, pretrained=True) model.eval() - model(images) + model(images_one) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"], + self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"], output_names=["outputs"], dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, tolerate_small_mismatch=True) # Test exported model for an image with no detections on other images - self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"], + self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + + def test_yolov5m_r40(self): + images_one, images_two = self.get_test_images() + images_dummy = [torch.ones(3, 100, 100) * 0.3] + model = yolov5m(upstream_version='v4.0', export_friendly=True, pretrained=True) + model.eval() + model(images_one) + # Test exported model on images of different size, or dummy input + self.run_model(model, [(images_one,), (images_two,), (images_dummy,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True) + # Test exported model for an image with no detections on other images + self.run_model(model, [(images_dummy,), (images_one,)], input_names=["images_tensors"], output_names=["outputs"], dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, tolerate_small_mismatch=True) diff --git a/yolort/models/__init__.py b/yolort/models/__init__.py index 3a97ecc0..9abc75b9 100644 --- a/yolort/models/__init__.py +++ b/yolort/models/__init__.py @@ -4,45 +4,73 @@ from .common import Conv from .yolo_module import YOLOModule -from ..utils.activations import Hardswish +from ..utils.activations import Hardswish, SiLU +from typing import Any -def yolov5s(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_s_r31", **kwargs) - return model +def yolov5s(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): + """ + Args: + upstream_version (str): Determine the upstream YOLOv5 version. + export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode. + """ + if upstream_version == 'v3.1': + model = YOLOModule(arch="yolov5_darknet_pan_s_r31", **kwargs) + elif upstream_version == 'v4.0': + model = YOLOModule(arch="yolov5_darknet_pan_s_r40", **kwargs) + else: + raise NotImplementedError("Currently only supports v3.1 and v4.0 versions") + + if export_friendly: + _export_module_friendly(model) -def yolov5m(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_m_r31", **kwargs) return model -def yolov5l(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_l_r31", **kwargs) - return model +def yolov5m(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): + """ + Args: + upstream_version (str): Determine the upstream YOLOv5 version. + export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode. + """ + if upstream_version == 'v3.1': + model = YOLOModule(arch="yolov5_darknet_pan_m_r31", **kwargs) + elif upstream_version == 'v4.0': + model = YOLOModule(arch="yolov5_darknet_pan_m_r40", **kwargs) + else: + raise NotImplementedError("Currently only supports v3.1 and v4.0 versions") + if export_friendly: + _export_module_friendly(model) -def yolov5s_r40(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_s_r40", **kwargs) return model -def yolov5m_r40(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_m_r40", **kwargs) - return model +def yolov5l(upstream_version: str ='v3.1', export_friendly: bool = False, **kwargs: Any): + """ + Args: + upstream_version (str): Determine the upstream YOLOv5 version. + export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode. + """ + if upstream_version == 'v3.1': + model = YOLOModule(arch="yolov5_darknet_pan_l_r31", **kwargs) + elif upstream_version == 'v4.0': + model = YOLOModule(arch="yolov5_darknet_pan_l_r40", **kwargs) + else: + raise NotImplementedError("Currently only supports v3.1 and v4.0 versions") + if export_friendly: + _export_module_friendly(model) -def yolov5l_r40(**kwargs): - model = YOLOModule(arch="yolov5_darknet_pan_l_r40", **kwargs) return model -def yolov5_onnx(pretrained=False, progress=True, num_classes=80, **kwargs): - - model = yolov5s(pretrained=pretrained, progress=progress, num_classes=num_classes, **kwargs) +def _export_module_friendly(model): for m in model.modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility - if isinstance(m, Conv) and isinstance(m.act, nn.Hardswish): - m.act = Hardswish() # assign activation - - return model + if isinstance(m, Conv): + if isinstance(m.act, nn.Hardswish): + m.act = Hardswish() # assign activation + if isinstance(m.act, nn.SiLU): + m.act = SiLU() diff --git a/yolort/utils/activations.py b/yolort/utils/activations.py index 9308dcef..9ec9b95b 100644 --- a/yolort/utils/activations.py +++ b/yolort/utils/activations.py @@ -3,14 +3,23 @@ import torch.nn.functional as F -# Swish -class Swish(nn.Module): # - @staticmethod - def forward(x): +class SiLU(nn.Module): + """ + Export-friendly version of nn.SiLU() + + Ref: + """ + def __init__(self) -> None: + super().__init__() + + def forward(self, x): return x * torch.sigmoid(x) -class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() +class Hardswish(nn.Module): + """ + Export-friendly version of nn.Hardswish() + """ def __init__(self): super().__init__()