diff --git a/test/test_models.py b/test/test_models.py index 1b55ecb2..0ef183d5 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -12,6 +12,7 @@ from yolort.models.backbone_utils import darknet_pan_backbone from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion from yolort.models.transformer import darknet_tan_backbone +from yolort.models.yolo_lite import yolov5_mobilenet_v3_small_fpn from yolort.v5 import get_yolov5_size, attempt_download @@ -420,3 +421,18 @@ def test_load_from_yolov5_torchscript(arch, size_divisible, version, upstream_ve torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0) torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0) torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0) + + +def test_yolov5_mobilenet_v3_small_fpn(): + + model = yolov5_mobilenet_v3_small_fpn() + model = model.eval() + + images = torch.rand(4, 3, 320, 320) + out = model(images) + assert isinstance(out, list) + assert len(out) == 4 + assert isinstance(out[0], dict) + assert isinstance(out[0]["boxes"], Tensor) + assert isinstance(out[0]["labels"], Tensor) + assert isinstance(out[0]["scores"], Tensor)