Skip to content

Commit

Permalink
Fixing unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Aug 20, 2021
1 parent 941a9b0 commit 191f056
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions test/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
import unittest

import torch

from yolort.models import yolov5s, yolov5m, yolov5l, yolotr


class TorchScriptTester(unittest.TestCase):
class TestTorchScript:
def test_yolov5s_script(self):
model = yolov5s(pretrained=True)
model = yolov5s(pretrained=True, score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
Expand All @@ -18,12 +16,13 @@ def test_yolov5s_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))

torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

def test_yolov5m_script(self):
model = yolov5m(pretrained=True)
model = yolov5m(pretrained=True, score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
Expand All @@ -33,12 +32,12 @@ def test_yolov5m_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

def test_yolov5l_script(self):
model = yolov5l(pretrained=True)
model = yolov5l(pretrained=True, score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
Expand All @@ -48,12 +47,12 @@ def test_yolov5l_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

def test_yolotr_script(self):
model = yolotr(pretrained=True)
model = yolotr(pretrained=True, score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
Expand All @@ -63,6 +62,6 @@ def test_yolotr_script(self):

out = model(x)
out_script = scripted_model(x)
self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"]))
self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"]))
self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"]))
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

0 comments on commit 191f056

Please sign in to comment.