-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
242 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,113 @@ | ||
# Copyright (c) 2021, yolort team. All Rights Reserved. | ||
# Copyright (c) 2021, yolort team. All rights reserved. | ||
|
||
import pytest | ||
import torch | ||
|
||
from pathlib import Path | ||
from torch import Tensor | ||
from torch.jit._trace import TopLevelTracedModule | ||
from yolort.models import yolov5s | ||
from yolort.relaying import get_trace_module | ||
from yolort.relaying import get_trace_module, YOLOTRTGraphSurgeon | ||
|
||
from yolort.relaying.yolo_inference import YOLOInference | ||
from yolort.v5 import attempt_download | ||
|
||
|
||
def test_get_trace_module(): | ||
@pytest.mark.parametrize("h", [320, 416, 640]) | ||
@pytest.mark.parametrize("w", [320, 416, 640]) | ||
def test_get_trace_module(h, w): | ||
model_func = yolov5s(pretrained=True) | ||
script_module = get_trace_module(model_func, input_shape=(416, 320)) | ||
script_module = get_trace_module(model_func, input_shape=(h, w)) | ||
assert isinstance(script_module, TopLevelTracedModule) | ||
assert script_module.code is not None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix): | ||
|
||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
model = YOLOInference(checkpoint_path, version=version) | ||
model.eval() | ||
samples = torch.rand(1, 3, 320, 320) | ||
outs = model(samples) | ||
|
||
assert isinstance(outs, tuple) | ||
assert len(outs) == 2 | ||
assert isinstance(outs[0], Tensor) | ||
assert isinstance(outs[1], Tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
model = YOLOInference(checkpoint_path, version=version) | ||
model.eval() | ||
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
model.to_onnx(onnx_file_path) | ||
assert Path(onnx_file_path).exists() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) | ||
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
yolo_gs.save(onnx_file_path) | ||
assert Path(onnx_file_path).exists() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) | ||
yolo_gs.register_nms() | ||
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
yolo_gs.save(onnx_file_path) | ||
assert Path(onnx_file_path).exists() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +0,0 @@ | ||
# Copyright (c) 2021, yolort team. All Rights Reserved. | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from torch import Tensor | ||
from yolort.runtime import YOLOGraphSurgeon | ||
from yolort.runtime.trt_helper import YOLOTRTModule | ||
from yolort.v5 import attempt_download | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix): | ||
|
||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
model = YOLOTRTModule(checkpoint_path, version=version) | ||
model.eval() | ||
samples = torch.rand(1, 3, 320, 320) | ||
outs = model(samples) | ||
|
||
assert isinstance(outs, tuple) | ||
assert len(outs) == 2 | ||
assert isinstance(outs[0], Tensor) | ||
assert isinstance(outs[1], Tensor) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
model = YOLOTRTModule(checkpoint_path, version=version) | ||
model.eval() | ||
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
model.to_onnx(onnx_file_path) | ||
assert Path(onnx_file_path).exists() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) | ||
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
yolo_gs.save(onnx_file_path) | ||
assert Path(onnx_file_path).exists() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"arch, version, upstream_version, hash_prefix", | ||
[ | ||
("yolov5s", "r4.0", "v4.0", "9ca9a642"), | ||
("yolov5n", "r6.0", "v6.0", "649e089f"), | ||
("yolov5s", "r6.0", "v6.0", "c3b140f3"), | ||
("yolov5n6", "r6.0", "v6.0", "beecbbae"), | ||
], | ||
) | ||
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix): | ||
base_url = "https://github.com/ultralytics/yolov5/releases/download/" | ||
model_url = f"{base_url}/{upstream_version}/{arch}.pt" | ||
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) | ||
|
||
yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False) | ||
yolo_gs.register_nms() | ||
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx" | ||
assert not Path(onnx_file_path).exists() | ||
yolo_gs.save(onnx_file_path) | ||
assert Path(onnx_file_path).exists() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
# Copyright (c) 2021, yolort team. All Rights Reserved. | ||
# Copyright (c) 2021, yolort team. All rights reserved. | ||
|
||
from .trace_wrapper import get_trace_module | ||
from .trt_graphsurgeon import YOLOTRTGraphSurgeon | ||
|
||
__all__ = ["get_trace_module"] | ||
__all__ = ["get_trace_module", "YOLOTRTGraphSurgeon"] |
1 change: 1 addition & 0 deletions
1
yolort/runtime/logits_decoder.py → yolort/relaying/logits_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) 2021, yolort team. All rights reserved. | ||
|
||
from typing import List, Tuple | ||
|
||
import torch | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) 2021, yolort team. All rights reserved. | ||
|
||
from pathlib import PosixPath | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from yolort.models import YOLO | ||
from yolort.models.anchor_utils import AnchorGenerator | ||
from yolort.models.backbone_utils import darknet_pan_backbone | ||
from yolort.utils import load_from_ultralytics | ||
|
||
from .logits_decoder import LogitsDecoder | ||
|
||
__all__ = ["YOLOInference"] | ||
|
||
|
||
class YOLOInference(nn.Module): | ||
""" | ||
TensorRT deployment friendly wrapper for YOLO. | ||
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party | ||
inference frameworks currently do not support this operator very well. | ||
Args: | ||
checkpoint_path (string): Path of the trained YOLOv5 checkpoint. | ||
version (string): Upstream YOLOv5 version. Default: 'r6.0' | ||
""" | ||
|
||
def __init__(self, checkpoint_path: str, version: str = "r6.0"): | ||
super().__init__() | ||
model_info = load_from_ultralytics(checkpoint_path, version=version) | ||
|
||
backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" | ||
depth_multiple = model_info["depth_multiple"] | ||
width_multiple = model_info["width_multiple"] | ||
use_p6 = model_info["use_p6"] | ||
backbone = darknet_pan_backbone( | ||
backbone_name, | ||
depth_multiple, | ||
width_multiple, | ||
version=version, | ||
use_p6=use_p6, | ||
) | ||
num_classes = model_info["num_classes"] | ||
anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"]) | ||
post_process = LogitsDecoder(model_info["strides"]) | ||
model = YOLO( | ||
backbone, | ||
num_classes, | ||
anchor_generator=anchor_generator, | ||
post_process=post_process, | ||
) | ||
|
||
model.load_state_dict(model_info["state_dict"]) | ||
self.model = model | ||
self.num_classes = num_classes | ||
|
||
@torch.no_grad() | ||
def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W] | ||
""" | ||
# Compute the detections | ||
outputs = self.model(inputs) | ||
|
||
return outputs | ||
|
||
@torch.no_grad() | ||
def to_onnx( | ||
self, | ||
file_path: Union[str, PosixPath], | ||
input_sample: Optional[Tensor] = None, | ||
opset_version: int = 11, | ||
enable_dynamic: bool = True, | ||
**kwargs, | ||
): | ||
""" | ||
Saves the model in ONNX format. | ||
Args: | ||
file_path (Union[string, PosixPath]): The path of the file the onnx model should | ||
be saved to. | ||
input_sample (Tensor, Optional): An input for tracing. Default: None. | ||
opset_version (int): Opset version we export the model to the onnx submodule. Default: 11. | ||
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: True. | ||
**kwargs: Will be passed to torch.onnx.export function. | ||
""" | ||
if input_sample is None: | ||
input_sample = torch.rand(1, 3, 640, 640).to(next(self.parameters()).device) | ||
|
||
dynamic_axes = ( | ||
{ | ||
"images": {0: "batch", 2: "height", 3: "width"}, | ||
"boxes": {0: "batch", 1: "num_objects"}, | ||
"scores": {0: "batch", 1: "num_objects"}, | ||
} | ||
if enable_dynamic | ||
else None | ||
) | ||
|
||
input_names = ["images"] | ||
output_names = ["boxes", "scores"] | ||
|
||
torch.onnx.export( | ||
self.model, | ||
input_sample, | ||
file_path, | ||
do_constant_folding=True, | ||
opset_version=opset_version, | ||
input_names=input_names, | ||
output_names=output_names, | ||
dynamic_axes=dynamic_axes, | ||
**kwargs, | ||
) |
Oops, something went wrong.