Skip to content

Commit

Permalink
Fixing unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Aug 20, 2021
1 parent aae4fef commit 62a2fa5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
11 changes: 5 additions & 6 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from pathlib import Path
import io
import pytest
from PIL import Image

import torch
from torch import Tensor
from torchvision import transforms
from torchvision.ops._register_onnx_ops import _onnx_opset_version

from yolort.models import yolov5s, yolov5m, yolotr
Expand Down Expand Up @@ -87,19 +89,16 @@ def to_numpy(tensor):
else:
raise

def get_image(self, rel_path, size) -> Tensor:
from PIL import Image
from torchvision import transforms
data_path = Path(__file__).parent.resolve() / "assets"
def get_image(self, img_name, size) -> Tensor:

img_path = data_path / rel_path
img_path = Path(__file__).parent.resolve() / "assets" / img_name
image = Image.open(img_path).convert("RGB").resize(size, Image.BILINEAR)

return transforms.ToTensor()(image)

def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
return ([self.get_image("bus.jpg", (416, 320))],
[self.get_image("zidane.png", (352, 480))])
[self.get_image("zidane.jpg", (352, 480))])

def test_yolov5s_r31(self):
images_one, images_two = self.get_test_images()
Expand Down
92 changes: 47 additions & 45 deletions test/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,66 @@
from yolort.models import yolov5s, yolov5m, yolov5l, yolotr


class TestTorchScript:
def test_yolov5s_script(self):
model = yolov5s(pretrained=True, size=(640, 640), score_thresh=0.45)
model.eval()
def test_yolov5s_script():
model = yolov5s(pretrained=True, size=(320, 320), score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()
scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]
x = [torch.rand(3, 288, 320), torch.rand(3, 300, 256)]

out = model(x)
out_script = scripted_model(x)
out = model(x)
out_script = scripted_model(x)

torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

def test_yolov5m_script(self):
model = yolov5m(pretrained=True, size=(640, 640), score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()
def test_yolov5m_script():
model = yolov5m(pretrained=True, size=(320, 320), score_thresh=0.45)
model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]
scripted_model = torch.jit.script(model)
scripted_model.eval()

out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)
x = [torch.rand(3, 288, 320), torch.rand(3, 300, 256)]

def test_yolov5l_script(self):
model = yolov5l(pretrained=True, size=(640, 640), score_thresh=0.45)
model.eval()
out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]
def test_yolov5l_script():
model = yolov5l(pretrained=True, size=(320, 320), score_thresh=0.45)
model.eval()

out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)
scripted_model = torch.jit.script(model)
scripted_model.eval()

def test_yolotr_script(self):
model = yolotr(pretrained=True, size=(640, 640), score_thresh=0.45)
model.eval()
x = [torch.rand(3, 288, 320), torch.rand(3, 300, 256)]

scripted_model = torch.jit.script(model)
scripted_model.eval()
out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)]

out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)
def test_yolotr_script():
model = yolotr(pretrained=True, size=(320, 320), score_thresh=0.45)
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.eval()

x = [torch.rand(3, 288, 320), torch.rand(3, 300, 256)]

out = model(x)
out_script = scripted_model(x)
torch.testing.assert_allclose(out[0]["scores"], out_script[1][0]["scores"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["labels"], out_script[1][0]["labels"], rtol=0., atol=0.)
torch.testing.assert_allclose(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0., atol=0.)

0 comments on commit 62a2fa5

Please sign in to comment.