diff --git a/test/test_runtime.py b/test/test_runtime.py index 29407b4e..a404efda 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -4,6 +4,7 @@ import pytest import torch from torch import Tensor +from yolort.runtime import YOLOGraphSurgeon from yolort.runtime.trt_helper import YOLOTRTModule from yolort.v5 import attempt_download @@ -43,14 +44,57 @@ def test_yolo_trt_module(arch, version, upstream_version, hash_prefix): ("yolov5n6", "r6.0", "v6.0", "beecbbae"), ], ) -def test_trt_model_onnx_saves(arch, version, upstream_version, hash_prefix): +def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix): base_url = "https://github.com/ultralytics/yolov5/releases/download/" model_url = f"{base_url}/{upstream_version}/{arch}.pt" checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) model = YOLOTRTModule(checkpoint_path, version=version) model.eval() - onnx_file_path = f"trt_model_onnx_saves_{arch}_{hash_prefix}.onnx" + onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx" assert not Path(onnx_file_path).exists() model.to_onnx(onnx_file_path) assert Path(onnx_file_path).exists() + + +@pytest.mark.parametrize( + "arch, version, upstream_version, hash_prefix", + [ + ("yolov5s", "r4.0", "v4.0", "9ca9a642"), + ("yolov5n", "r6.0", "v6.0", "649e089f"), + ("yolov5s", "r6.0", "v6.0", "c3b140f3"), + ("yolov5n6", "r6.0", "v6.0", "beecbbae"), + ], +) +def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix): + base_url = "https://github.com/ultralytics/yolov5/releases/download/" + model_url = f"{base_url}/{upstream_version}/{arch}.pt" + checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) + + yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) + onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx" + assert not Path(onnx_file_path).exists() + yolo_gs.save(onnx_file_path) + assert Path(onnx_file_path).exists() + + +@pytest.mark.parametrize( + "arch, version, upstream_version, hash_prefix", + [ + ("yolov5s", "r4.0", "v4.0", "9ca9a642"), + ("yolov5n", "r6.0", "v6.0", "649e089f"), + ("yolov5s", "r6.0", "v6.0", "c3b140f3"), + ("yolov5n6", "r6.0", "v6.0", "beecbbae"), + ], +) +def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix): + base_url = "https://github.com/ultralytics/yolov5/releases/download/" + model_url = f"{base_url}/{upstream_version}/{arch}.pt" + checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) + + yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) + yolo_gs.register_nms() + onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx" + assert not Path(onnx_file_path).exists() + yolo_gs.save(onnx_file_path) + assert Path(onnx_file_path).exists()