Skip to content

Commit

Permalink
Use test images from repo rather than internet
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Aug 20, 2021
1 parent c4ba869 commit aae4fef
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,24 @@
"""
Test for exporting model to ONNX and inference with ONNXRuntime
"""
import io
import unittest
from typing import List, Tuple

try:
# This import should be before that of torch if you are using PyTorch lower than 1.5.0
# see <https://github.com/onnx/onnx/issues/2394#issuecomment-581638840>
import onnxruntime
except ImportError:
onnxruntime = None
from pathlib import Path
import io
import pytest

import torch
from torch import Tensor
from torchvision.ops._register_onnx_ops import _onnx_opset_version

from yolort.models import yolov5s, yolov5m, yolotr
from yolort.utils import get_image_from_url, read_image_to_tensor

# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
onnxruntime = pytest.importorskip("onnxruntime")

@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
class ONNXExporterTester(unittest.TestCase):

class TestONNXExporter:
@classmethod
def setUpClass(cls):
torch.manual_seed(123)
Expand Down Expand Up @@ -53,10 +52,10 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False,
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
if isinstance(test_inputs, Tensor) or isinstance(test_inputs, list):
test_inputs = (test_inputs,)
test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor):
if isinstance(test_ouputs, Tensor):
test_ouputs = (test_ouputs,)
self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)

Expand Down Expand Up @@ -88,18 +87,19 @@ def to_numpy(tensor):
else:
raise

def get_test_images(self):
image_url = "https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg"
image = get_image_from_url(image_url)
image = read_image_to_tensor(image, is_half=False)
def get_image(self, rel_path, size) -> Tensor:
from PIL import Image
from torchvision import transforms
data_path = Path(__file__).parent.resolve() / "assets"

img_path = data_path / rel_path
image = Image.open(img_path).convert("RGB").resize(size, Image.BILINEAR)

image_url2 = "https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg"
image2 = get_image_from_url(image_url2)
image2 = read_image_to_tensor(image2, is_half=False)
return transforms.ToTensor()(image)

images_one = [image]
images_two = [image2]
return images_one, images_two
def get_test_images(self) -> Tuple[List[Tensor], List[Tensor]]:
return ([self.get_image("bus.jpg", (416, 320))],
[self.get_image("zidane.png", (352, 480))])

def test_yolov5s_r31(self):
images_one, images_two = self.get_test_images()
Expand Down

0 comments on commit aae4fef

Please sign in to comment.