Skip to content

Commit

Permalink
Add torchscript export and model unittest (#8)
Browse files Browse the repository at this point in the history
* init commit

* Add type annotations and candidate models

* Simplify model unittest

* Add type annotations

* Modify default loading ways of the models

* Add torchscript export unittest

* Modify the model unittest names

* Set pretrained to be Ture when test torchscript exporting

* Fix AnchorGenerator unittest
zhiqwang authored Dec 2, 2020
1 parent 703d69e commit ff8f506
Showing 8 changed files with 578 additions and 16 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Optional list of dependencies required by the package
dependencies = ['yaml', 'torch', 'torchvision']

from models import yolov5
from models import yolov5, yolov5s, yolov5m, yolov5l
from models import yolov5_onnx
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn

from .common import Conv
from .yolo import yolov5
from .yolo import yolov5, yolov5s, yolov5m, yolov5l # noqa

from utils.activations import Hardswish

2 changes: 1 addition & 1 deletion models/box_head.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ def _sum(x: List[Tensor]) -> Tensor:


class YoloHead(nn.Module):
def __init__(self, in_channels, num_anchors, num_classes): # detection layer
def __init__(self, in_channels: List[int], num_anchors: int, num_classes: int):
super().__init__()
self.num_anchors = num_anchors # anchors
self.num_outputs = num_classes + 5 # number of outputs per anchor
61 changes: 50 additions & 11 deletions models/yolo.py
Original file line number Diff line number Diff line change
@@ -9,12 +9,22 @@
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from torch.jit.annotations import Tuple, List, Dict, Optional

from .backbone import darknet
from .box_head import YoloHead, SetCriterion, PostProcess
from .anchor_utils import AnchorGenerator

from typing import Tuple, Any, List, Dict, Optional


__all__ = ['yolov5', 'yolov5s', 'yolov5m', 'yolov5l']


model_urls = {
'yolov5s': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.1/yolov5s.pt',
'yolov5m': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5m.pt',
'yolov5l': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5l.pt',
}


class YOLO(nn.Module):
def __init__(
@@ -151,15 +161,14 @@ def forward(
return self.eager_outputs(losses, detections)


model_urls = {
'yolov5s': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.1/yolov5s.pt',
'yolov5m': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5m.pt',
'yolov5l': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5l.pt',
}


def yolov5(cfg_path='yolov5s.yaml', pretrained=False, progress=True,
num_classes=80, pretrained_backbone=True, **kwargs):
def yolov5(
cfg_path: str = 'yolov5s.yaml',
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> YOLO:
"""
Constructs a YOLO model.
@@ -206,3 +215,33 @@ def yolov5(cfg_path='yolov5s.yaml', pretrained=False, progress=True,
state_dict = load_state_dict_from_url(model_urls[Path(cfg_path).stem], progress=progress)
model.load_state_dict(state_dict)
return model


def yolov5s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> YOLO:
r"""yolov5s model from
`"ultralytics/yolov5" <https://zenodo.org/badge/latestdoi/264818686>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return yolov5('yolov5s.yaml', pretrained, progress, **kwargs)


def yolov5m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> YOLO:
r"""yolov5m model from
`"ultralytics/yolov5" <https://zenodo.org/badge/latestdoi/264818686>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return yolov5('yolov5m.yaml', pretrained, progress, **kwargs)


def yolov5l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> YOLO:
r"""yolov5l model from
`"ultralytics/yolov5" <https://zenodo.org/badge/latestdoi/264818686>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return yolov5('yolov5l.yaml', pretrained, progress, **kwargs)
372 changes: 372 additions & 0 deletions test/common_utils.py

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch

from models.anchor_utils import AnchorGenerator
from .common_utils import TestCase


class ModelTester(TestCase):
def _init_test_anchor_generator(self):
strides = [4]
anchor_grids = [[6, 14]]
anchor_generator = AnchorGenerator(strides, anchor_grids)

return anchor_generator

def get_features(self, images):
s0, s1 = images.shape[-2:]
features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
return features

def test_anchor_generator(self):
images = torch.randn(2, 3, 10, 10)
features = self.get_features(images)

model = self._init_test_anchor_generator()
model.eval()
anchors = model(features)

anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])
wh_output = torch.tensor([[4.], [4.], [4.], [4.]])
xy_output = torch.tensor([[6., 14.], [6., 14.], [6., 14.], [6., 14.]])

self.assertEqual(len(anchors), 3)
self.assertEqual(tuple(anchors[0].shape), (4, 2))
self.assertEqual(tuple(anchors[1].shape), (4, 1))
self.assertEqual(tuple(anchors[2].shape), (4, 2))
self.assertEqual(anchors[0], anchor_output)
self.assertEqual(anchors[1], wh_output)
self.assertEqual(anchors[2], xy_output)
113 changes: 113 additions & 0 deletions test/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest

import torch

from models.backbone import darknet
from models.anchor_utils import AnchorGenerator
from models.box_head import YoloHead, PostProcess, SetCriterion
from models import yolov5s, yolov5m, yolov5l


class TorchScriptTester(unittest.TestCase):

def _init_test_backbone(self):
backbone = darknet()
return backbone

def test_yolo_backbone_script(self):
model, _ = self._init_test_backbone()
torch.jit.script(model)

def _init_test_anchor_generator(self):
strides = [8, 16, 32]
anchor_grids = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
anchor_generator = AnchorGenerator(strides, anchor_grids)
return anchor_generator

def test_anchor_generator_script(self):
model = self._init_test_anchor_generator()
scripted_model = torch.jit.script(model) # noqa

def _init_test_yolo_head(self):
in_channels = [128, 256, 512]
num_anchors = 3
num_classes = 80
box_head = YoloHead(in_channels, num_anchors, num_classes)
return box_head

def test_yolo_head_script(self):
model = self._init_test_yolo_head()
scripted_model = torch.jit.script(model) # noqa

def _init_test_postprocessors(self):
score_thresh = 0.5
nms_thresh = 0.45
detections_per_img = 100
postprocessors = PostProcess(score_thresh, nms_thresh, detections_per_img)
return postprocessors

def test_postprocessors_script(self):
model = self._init_test_postprocessors()
scripted_model = torch.jit.script(model) # noqa

def _init_test_criterion(self):
weights = (1.0, 1.0, 1.0, 1.0)
fg_iou_thresh = 0.5
bg_iou_thresh = 0.4
allow_low_quality_matches = True
criterion = SetCriterion(weights, fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches)
return criterion

@unittest.skip("Current it isn't well implemented")
def test_criterion_script(self):
model = self._init_test_criterion()
scripted_model = torch.jit.script(model) # noqa

def test_yolov5s_script(self):
model = yolov5s(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]

out = model(x)
out_script = scripted_model(x)[1]
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))

def test_yolov5m_script(self):
model = yolov5m(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]

out = model(x)
out_script = scripted_model(x)[1]
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))

def test_yolov5l_script(self):
model = yolov5l(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]

out = model(x)
out_script = scripted_model(x)[1]
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"]))


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/tracing/trace_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch

from hubconf import yolov5
from hubconf import yolov5s


if __name__ == "__main__":

model = yolov5(pretrained=True)
model = yolov5s(pretrained=True)
model.eval()

traced_model = torch.jit.script(model)

0 comments on commit ff8f506

Please sign in to comment.