diff --git a/test/test_onnx.py b/test/test_onnx.py index cacc3265..87122819 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -2,25 +2,24 @@ """ Test for exporting model to ONNX and inference with ONNXRuntime """ -import io -import unittest +from typing import List, Tuple -try: - # This import should be before that of torch if you are using PyTorch lower than 1.5.0 - # see - import onnxruntime -except ImportError: - onnxruntime = None +from pathlib import Path +import io +import pytest import torch +from torch import Tensor from torchvision.ops._register_onnx_ops import _onnx_opset_version from yolort.models import yolov5s, yolov5m, yolotr -from yolort.utils import get_image_from_url, read_image_to_tensor +# In environments without onnxruntime we prefer to +# invoke all tests in the repo and have this one skipped rather than fail. +onnxruntime = pytest.importorskip("onnxruntime") -@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') -class ONNXExporterTester(unittest.TestCase): + +class TestONNXExporter: @classmethod def setUpClass(cls): torch.manual_seed(123) @@ -53,10 +52,10 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): - if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list): + if isinstance(test_inputs, Tensor) or isinstance(test_inputs, list): test_inputs = (test_inputs,) test_ouputs = model(*test_inputs) - if isinstance(test_ouputs, torch.Tensor): + if isinstance(test_ouputs, Tensor): test_ouputs = (test_ouputs,) self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch) @@ -88,18 +87,19 @@ def to_numpy(tensor): else: raise - def get_test_images(self): - image_url = "https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg" - image = get_image_from_url(image_url) - image = read_image_to_tensor(image, is_half=False) + def get_image(self, rel_path, size) -> Tensor: + from PIL import Image + from torchvision import transforms + data_path = Path(__file__).parent.resolve() / "assets" + + img_path = data_path / rel_path + image = Image.open(img_path).convert("RGB").resize(size, Image.BILINEAR) - image_url2 = "https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg" - image2 = get_image_from_url(image_url2) - image2 = read_image_to_tensor(image2, is_half=False) + return transforms.ToTensor()(image) - images_one = [image] - images_two = [image2] - return images_one, images_two + def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]: + return ([self.get_image("bus.jpg", (416, 320))], + [self.get_image("zidane.png", (352, 480))]) def test_yolov5s_r31(self): images_one, images_two = self.get_test_images()