-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add torchscript export and model unittest (#8)
* init commit * Add type annotations and candidate models * Simplify model unittest * Add type annotations * Modify default loading ways of the models * Add torchscript export unittest * Modify the model unittest names * Set pretrained to be Ture when test torchscript exporting * Fix AnchorGenerator unittest
Showing
8 changed files
with
578 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# Optional list of dependencies required by the package | ||
dependencies = ['yaml', 'torch', 'torchvision'] | ||
|
||
from models import yolov5 | ||
from models import yolov5, yolov5s, yolov5m, yolov5l | ||
from models import yolov5_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
|
||
from models.anchor_utils import AnchorGenerator | ||
from .common_utils import TestCase | ||
|
||
|
||
class ModelTester(TestCase): | ||
def _init_test_anchor_generator(self): | ||
strides = [4] | ||
anchor_grids = [[6, 14]] | ||
anchor_generator = AnchorGenerator(strides, anchor_grids) | ||
|
||
return anchor_generator | ||
|
||
def get_features(self, images): | ||
s0, s1 = images.shape[-2:] | ||
features = [torch.rand(2, 8, s0 // 5, s1 // 5)] | ||
return features | ||
|
||
def test_anchor_generator(self): | ||
images = torch.randn(2, 3, 10, 10) | ||
features = self.get_features(images) | ||
|
||
model = self._init_test_anchor_generator() | ||
model.eval() | ||
anchors = model(features) | ||
|
||
anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]]) | ||
wh_output = torch.tensor([[4.], [4.], [4.], [4.]]) | ||
xy_output = torch.tensor([[6., 14.], [6., 14.], [6., 14.], [6., 14.]]) | ||
|
||
self.assertEqual(len(anchors), 3) | ||
self.assertEqual(tuple(anchors[0].shape), (4, 2)) | ||
self.assertEqual(tuple(anchors[1].shape), (4, 1)) | ||
self.assertEqual(tuple(anchors[2].shape), (4, 2)) | ||
self.assertEqual(anchors[0], anchor_output) | ||
self.assertEqual(anchors[1], wh_output) | ||
self.assertEqual(anchors[2], xy_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from models.backbone import darknet | ||
from models.anchor_utils import AnchorGenerator | ||
from models.box_head import YoloHead, PostProcess, SetCriterion | ||
from models import yolov5s, yolov5m, yolov5l | ||
|
||
|
||
class TorchScriptTester(unittest.TestCase): | ||
|
||
def _init_test_backbone(self): | ||
backbone = darknet() | ||
return backbone | ||
|
||
def test_yolo_backbone_script(self): | ||
model, _ = self._init_test_backbone() | ||
torch.jit.script(model) | ||
|
||
def _init_test_anchor_generator(self): | ||
strides = [8, 16, 32] | ||
anchor_grids = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]] | ||
anchor_generator = AnchorGenerator(strides, anchor_grids) | ||
return anchor_generator | ||
|
||
def test_anchor_generator_script(self): | ||
model = self._init_test_anchor_generator() | ||
scripted_model = torch.jit.script(model) # noqa | ||
|
||
def _init_test_yolo_head(self): | ||
in_channels = [128, 256, 512] | ||
num_anchors = 3 | ||
num_classes = 80 | ||
box_head = YoloHead(in_channels, num_anchors, num_classes) | ||
return box_head | ||
|
||
def test_yolo_head_script(self): | ||
model = self._init_test_yolo_head() | ||
scripted_model = torch.jit.script(model) # noqa | ||
|
||
def _init_test_postprocessors(self): | ||
score_thresh = 0.5 | ||
nms_thresh = 0.45 | ||
detections_per_img = 100 | ||
postprocessors = PostProcess(score_thresh, nms_thresh, detections_per_img) | ||
return postprocessors | ||
|
||
def test_postprocessors_script(self): | ||
model = self._init_test_postprocessors() | ||
scripted_model = torch.jit.script(model) # noqa | ||
|
||
def _init_test_criterion(self): | ||
weights = (1.0, 1.0, 1.0, 1.0) | ||
fg_iou_thresh = 0.5 | ||
bg_iou_thresh = 0.4 | ||
allow_low_quality_matches = True | ||
criterion = SetCriterion(weights, fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches) | ||
return criterion | ||
|
||
@unittest.skip("Current it isn't well implemented") | ||
def test_criterion_script(self): | ||
model = self._init_test_criterion() | ||
scripted_model = torch.jit.script(model) # noqa | ||
|
||
def test_yolov5s_script(self): | ||
model = yolov5s(pretrained=True) | ||
model.eval() | ||
|
||
scripted_model = torch.jit.script(model) | ||
scripted_model.eval() | ||
|
||
x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] | ||
|
||
out = model(x) | ||
out_script = scripted_model(x)[1] | ||
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"])) | ||
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) | ||
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) | ||
|
||
def test_yolov5m_script(self): | ||
model = yolov5m(pretrained=True) | ||
model.eval() | ||
|
||
scripted_model = torch.jit.script(model) | ||
scripted_model.eval() | ||
|
||
x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] | ||
|
||
out = model(x) | ||
out_script = scripted_model(x)[1] | ||
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"])) | ||
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) | ||
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) | ||
|
||
def test_yolov5l_script(self): | ||
model = yolov5l(pretrained=True) | ||
model.eval() | ||
|
||
scripted_model = torch.jit.script(model) | ||
scripted_model.eval() | ||
|
||
x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] | ||
|
||
out = model(x) | ||
out_script = scripted_model(x)[1] | ||
self.assertTrue(out[0]["scores"].equal(out_script[0]["scores"])) | ||
self.assertTrue(out[0]["labels"].equal(out_script[0]["labels"])) | ||
self.assertTrue(out[0]["boxes"].equal(out_script[0]["boxes"])) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters