Skip to content

Commit

Permalink
Refactor TensorRT utilization tools
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 12, 2022
1 parent 5440c37 commit 4283c96
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 238 deletions.
110 changes: 106 additions & 4 deletions test/test_relaying.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,113 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
# Copyright (c) 2021, yolort team. All rights reserved.

import pytest
import torch

from pathlib import Path
from torch import Tensor
from torch.jit._trace import TopLevelTracedModule
from yolort.models import yolov5s
from yolort.relaying import get_trace_module
from yolort.relaying import get_trace_module, YOLOTRTGraphSurgeon

from yolort.relaying.yolo_inference import YOLOInference
from yolort.v5 import attempt_download


def test_get_trace_module():
@pytest.mark.parametrize("h", [320, 416, 640])
@pytest.mark.parametrize("w", [320, 416, 640])
def test_get_trace_module(h, w):
model_func = yolov5s(pretrained=True)
script_module = get_trace_module(model_func, input_shape=(416, 320))
script_module = get_trace_module(model_func, input_shape=(h, w))
assert isinstance(script_module, TopLevelTracedModule)
assert script_module.code is not None


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix):

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOInference(checkpoint_path, version=version)
model.eval()
samples = torch.rand(1, 3, 320, 320)
outs = model(samples)

assert isinstance(outs, tuple)
assert len(outs) == 2
assert isinstance(outs[0], Tensor)
assert isinstance(outs[1], Tensor)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOInference(checkpoint_path, version=version)
model.eval()
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
model.to_onnx(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
yolo_gs.register_nms()
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()
100 changes: 0 additions & 100 deletions test/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,100 +0,0 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
from pathlib import Path

import pytest
import torch
from torch import Tensor
from yolort.runtime import YOLOGraphSurgeon
from yolort.runtime.trt_helper import YOLOTRTModule
from yolort.v5 import attempt_download


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix):

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOTRTModule(checkpoint_path, version=version)
model.eval()
samples = torch.rand(1, 3, 320, 320)
outs = model(samples)

assert isinstance(outs, tuple)
assert len(outs) == 2
assert isinstance(outs[0], Tensor)
assert isinstance(outs[1], Tensor)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOTRTModule(checkpoint_path, version=version)
model.eval()
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
model.to_onnx(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
yolo_gs.register_nms()
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()
6 changes: 4 additions & 2 deletions yolort/relaying/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
# Copyright (c) 2021, yolort team. All rights reserved.

from .trace_wrapper import get_trace_module
from .trt_graphsurgeon import YOLOTRTGraphSurgeon

__all__ = ["get_trace_module"]
__all__ = ["get_trace_module", "YOLOTRTGraphSurgeon"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from typing import List, Tuple

import torch
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
# Copyright (c) 2021, yolort team. All rights reserved.
#
# This source code is licensed under the GPL-3.0 license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the Apache-2.0 license found in the
# LICENSE file in the root directory of TensorRT source tree.
#

import logging
from pathlib import Path
Expand All @@ -24,16 +15,18 @@
except ImportError:
gs = None

from .trt_helper import YOLOTRTModule
from .yolo_inference import YOLOInference

logging.basicConfig(level=logging.INFO)
logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO)
logger = logging.getLogger("YOLOGraphSurgeon")
logging.getLogger("YOLOTRTGraphSurgeon").setLevel(logging.INFO)
logger = logging.getLogger("YOLOTRTGraphSurgeon")

__all__ = ["YOLOTRTGraphSurgeon"]

class YOLOGraphSurgeon:

class YOLOTRTGraphSurgeon:
"""
Constructor of the YOLOv5 Graph Surgeon object.
YOLOv5 Graph Surgeon for TensorRT inference.
Because TensorRT treat the ``torchvision::ops::nms`` as plugin, we use the a simple post-processing
module named ``LogitsDecoder`` to connect to ``EfficientNMS_TRT`` plugin in TensorRT.
Expand Down Expand Up @@ -66,8 +59,8 @@ def __init__(
checkpoint_path = Path(checkpoint_path)
assert checkpoint_path.exists()

# Use YOLOTRTModule to convert saved model to an initial ONNX graph.
model = YOLOTRTModule(checkpoint_path, version=version)
# Use YOLOInference to convert saved model to an initial ONNX graph.
model = YOLOInference(checkpoint_path, version=version)
model = model.eval()
model = model.to(device=device)
logger.info(f"Loaded saved model from {checkpoint_path}")
Expand Down
116 changes: 116 additions & 0 deletions yolort/relaying/yolo_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from pathlib import PosixPath
from typing import Optional, Tuple, Union

import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.utils import load_from_ultralytics

from .logits_decoder import LogitsDecoder

__all__ = ["YOLOInference"]


class YOLOInference(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.
Args:
checkpoint_path (string): Path of the trained YOLOv5 checkpoint.
version (string): Upstream YOLOv5 version. Default: 'r6.0'
"""

def __init__(self, checkpoint_path: str, version: str = "r6.0"):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
depth_multiple = model_info["depth_multiple"]
width_multiple = model_info["width_multiple"]
use_p6 = model_info["use_p6"]
backbone = darknet_pan_backbone(
backbone_name,
depth_multiple,
width_multiple,
version=version,
use_p6=use_p6,
)
num_classes = model_info["num_classes"]
anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"])
post_process = LogitsDecoder(model_info["strides"])
model = YOLO(
backbone,
num_classes,
anchor_generator=anchor_generator,
post_process=post_process,
)

model.load_state_dict(model_info["state_dict"])
self.model = model
self.num_classes = num_classes

@torch.no_grad()
def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# Compute the detections
outputs = self.model(inputs)

return outputs

@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, PosixPath],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
**kwargs,
):
"""
Saves the model in ONNX format.
Args:
file_path (Union[string, PosixPath]): The path of the file the onnx model should
be saved to.
input_sample (Tensor, Optional): An input for tracing. Default: None.
opset_version (int): Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
input_sample = torch.rand(1, 3, 640, 640).to(next(self.parameters()).device)

dynamic_axes = (
{
"images": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
"scores": {0: "batch", 1: "num_objects"},
}
if enable_dynamic
else None
)

input_names = ["images"]
output_names = ["boxes", "scores"]

torch.onnx.export(
self.model,
input_sample,
file_path,
do_constant_folding=True,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
**kwargs,
)
Loading

0 comments on commit 4283c96

Please sign in to comment.