Skip to content

Commit

Permalink
Add unit-test for #170
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 20, 2021
1 parent fab48d7 commit 1b9047e
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import cv2
import numpy as np
import pytest
import torch
from torch import Tensor
from yolort import models
from yolort.models import YOLOv5
from yolort.utils import (
FeatureExtractor,
get_image_from_url,
load_from_ultralytics,
read_image_to_tensor,
)
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords
from yolort.v5 import (
letterbox,
load_yolov5_model,
scale_coords,
non_max_suppression,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,6 +53,67 @@ def test_load_from_ultralytics(
assert len(model_info["strides"]) == 4 if use_p6 else 3


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s-VOC", "r4.0", "v5.0", "23818cff")],
)
def test_load_from_ultralytics_voc(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)

# Preprocess
img_raw = cv2.imread(img_path)
img = letterbox(img_raw, new_shape=(320, 320))[0]
img = read_image_to_tensor(img)

conf = 0.25
iou = 0.45

# Define YOLOv5 model
model_yolov5 = load_yolov5_model(checkpoint_path)
model_yolov5.conf = conf # confidence threshold (0-1)
model_yolov5.iou = iou # NMS IoU threshold (0-1)
model_yolov5.eval()
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou, agnostic=True)
out_from_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLOv5.load_from_yolov5(
checkpoint_path,
score_thresh=conf,
version=version,
)
model_yolort.eval()
with torch.no_grad():
out_from_yolort = model_yolort(img[None])

torch.testing.assert_allclose(
out_from_yolort[0]['boxes'], out_from_yolov5[:, :4]
)
torch.testing.assert_allclose(
out_from_yolort[0]['scores'], out_from_yolov5[:, 4]
)
torch.testing.assert_allclose(
out_from_yolort[0]['labels'], out_from_yolov5[:, 5].to(dtype=torch.int64)
)


def test_read_image_to_tensor():
N, H, W = 3, 720, 360
img = np.random.randint(0, 255, (H, W, N), dtype="uint8") # As a dummy image
Expand Down

0 comments on commit 1b9047e

Please sign in to comment.