diff --git a/test/test_torchscript.py b/test/test_torchscript.py index 7fe397863..b03ba5239 100644 --- a/test/test_torchscript.py +++ b/test/test_torchscript.py @@ -1,14 +1,12 @@ # Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. -import unittest - import torch from yolort.models import yolov5s, yolov5m, yolov5l, yolotr -class TorchScriptTester(unittest.TestCase): +class TestTorchScript: def test_yolov5s_script(self): - model = yolov5s(pretrained=True) + model = yolov5s(pretrained=True, score_thresh=0.45) model.eval() scripted_model = torch.jit.script(model) @@ -18,12 +16,13 @@ def test_yolov5s_script(self): out = model(x) out_script = scripted_model(x) - self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"])) - self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"])) - self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"])) + + torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.) def test_yolov5m_script(self): - model = yolov5m(pretrained=True) + model = yolov5m(pretrained=True, score_thresh=0.45) model.eval() scripted_model = torch.jit.script(model) @@ -33,12 +32,12 @@ def test_yolov5m_script(self): out = model(x) out_script = scripted_model(x) - self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"])) - self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"])) - self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"])) + torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.) def test_yolov5l_script(self): - model = yolov5l(pretrained=True) + model = yolov5l(pretrained=True, score_thresh=0.45) model.eval() scripted_model = torch.jit.script(model) @@ -48,12 +47,12 @@ def test_yolov5l_script(self): out = model(x) out_script = scripted_model(x) - self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"])) - self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"])) - self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"])) + torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.) def test_yolotr_script(self): - model = yolotr(pretrained=True) + model = yolotr(pretrained=True, score_thresh=0.45) model.eval() scripted_model = torch.jit.script(model) @@ -63,6 +62,6 @@ def test_yolotr_script(self): out = model(x) out_script = scripted_model(x) - self.assertTrue(out[0]["scores"].equal(out_script[1][0]["scores"])) - self.assertTrue(out[0]["labels"].equal(out_script[1][0]["labels"])) - self.assertTrue(out[0]["boxes"].equal(out_script[1][0]["boxes"])) + torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.) + torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)