Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring TensorRT Python interface #307

Merged
merged 2 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 71 additions & 71 deletions notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb

Large diffs are not rendered by default.

82 changes: 50 additions & 32 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#

import logging
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Optional, Tuple, Union

try:
Expand Down Expand Up @@ -41,13 +41,13 @@ class YOLOTRTModule(nn.Module):

Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.

Args:
checkpoint_path (string): Path of the trained YOLOv5 checkpoint.
version (string): Upstream YOLOv5 version. Default: 'r6.0'
"""

def __init__(
self,
checkpoint_path: str,
version: str = "r6.0",
):
def __init__(self, checkpoint_path: str, version: str = "r6.0"):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

Expand Down Expand Up @@ -90,7 +90,7 @@ def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
file_path: Union[str, PosixPath],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
Expand All @@ -100,10 +100,11 @@ def to_onnx(
Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None.
opset_version: Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True.
file_path (Union[string, PosixPath]): The path of the file the onnx model should
be saved to.
input_sample (Tensor, Optional): An input for tracing. Default: None.
opset_version (int): Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
Expand Down Expand Up @@ -140,12 +141,33 @@ class EngineBuilder:
Parses an ONNX graph and builds a TensorRT engine from it.
"""

def __init__(self, verbose=False, workspace=4):
def __init__(
self,
verbose: bool = False,
workspace: int = 4,
precision: str = "fp32",
enable_dynamic: bool = False,
max_batch_size: int = 16,
calib_input: Optional[str] = None,
calib_cache: Optional[str] = None,
calib_num_images: int = 5000,
calib_batch_size: int = 8,
):
"""
Args:
verbose: If enabled, a higher verbosity level will be
set on the TensorRT logger.
workspace: Max memory workspace to allow, in Gb.
verbose (bool): If enabled, a higher verbosity level will be
set on the TensorRT logger. Default: False
workspace (int): Max memory workspace to allow, in Gb.
precision (string): The datatype to use for the engine inference, either 'fp32',
'fp16' or 'int8'. Default: 'fp32'
enable_dynamic (bool): Whether to enable dynamic shapes. Default: False
max_batch_size (int): Maximum batch size reserved for dynamic shape inference. Default: 16
calib_input (string, optinal): The path to a directory holding the calibration images.
Default: None
calib_cache (string, optinal): The path where to write the calibration cache to,
or if it already exists, load it from. Default: None
calib_num_images (int): The maximum number of images to use for calibration. Default: 5000
calib_batch_size (int): The batch size to use for the calibration process. Default: 8
"""
self.logger = trt.Logger(trt.Logger.INFO)
if verbose:
Expand All @@ -161,6 +183,16 @@ def __init__(self, verbose=False, workspace=4):
self.network = None
self.parser = None

# Leaving some interfaces and parameters for subsequent use, but we have not yet
# implemented the following functionality
self.precision = precision
self.enable_dynamic = enable_dynamic
self.max_batch_size = max_batch_size
self.calib_input = calib_input
self.calib_cache = calib_cache
self.calib_num_images = calib_num_images
self.calib_batch_size = calib_batch_size

def create_network(self, onnx_path: str):
"""
Parse the ONNX graph and create the corresponding TensorRT network definition.
Expand All @@ -185,31 +217,17 @@ def create_network(self, onnx_path: str):
for output in outputs:
logger.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")

def create_engine(
self,
engine_path: str,
*,
precision: str = "fp32",
max_batch_size: int = 32,
calib_input: Optional[str] = None,
calib_cache: Optional[str] = None,
calib_num_images: int = 5000,
calib_batch_size: int = 8,
):
def create_engine(self, engine_path: str):
"""
Build the TensorRT engine and serialize it to disk.

Args:
engine_path: The path where to serialize the engine to.
precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
calib_input: The path to a directory holding the calibration images.
calib_cache: The path where to write the calibration cache to, or if it already
exists, load it from.
calib_num_images: The maximum number of images to use for calibration.
calib_batch_size: The batch size to use for the calibration process.
"""
engine_path = Path(engine_path)
engine_path.parent.mkdir(parents=True, exist_ok=True)

precision = self.precision
logger.info(f"Building {precision} Engine in {engine_path}")

# Process the batch size and profile
Expand Down
16 changes: 12 additions & 4 deletions yolort/runtime/y_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class PredictorTRT:
single device for a single input image.

Args:
engine_path (str): Path of the ONNX checkpoint.
engine_path (string): Path of the ONNX checkpoint.
device (torch.device): The CUDA device to be used for inferencing.
precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.

Examples:
>>> import cv2
Expand All @@ -55,6 +56,7 @@ def __init__(
self,
engine_path: str,
device: torch.device = torch.device("cuda"),
precision: str = "fp32",
) -> None:
self.engine_path = engine_path
self.device = device
Expand All @@ -64,7 +66,13 @@ def __init__(

self.engine = self._build_engine()
self._set_context()
self.half = False

if precision == "fp32":
self.half = False
elif precision == "fp16":
self.half = True
else:
raise NotImplementedError(f"Currently not supports precision: {precision}")

def _build_engine(self):
logger.info(f"Loading {self.engine_path} for TensorRT inference...")
Expand Down Expand Up @@ -136,11 +144,11 @@ def postprocessing(all_boxes, all_scores, all_labels, all_num_dets):

return detections

def warmup(self, img_size=(1, 3, 320, 320), half=False):
def warmup(self, img_size=(1, 3, 320, 320)):
# Warmup model by running inference once
# only warmup GPU models
if isinstance(self.device, torch.device) and self.device.type != "cpu":
image = torch.zeros(*img_size).to(self.device).type(torch.half if half else torch.float)
image = torch.zeros(*img_size).to(self.device).type(torch.half if self.half else torch.float)
self(image)

def run_wo_postprocessing(self, image: Tensor):
Expand Down
14 changes: 12 additions & 2 deletions yolort/runtime/yolo_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class YOLOGraphSurgeon:
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: False.
device (torch.device): The device to be used for importing ONNX. Default: torch.device("cpu").
precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
"""

def __init__(
Expand All @@ -60,6 +61,7 @@ def __init__(
version: str = "r6.0",
enable_dynamic: bool = False,
device: torch.device = torch.device("cpu"),
precision: str = "fp32",
):
checkpoint_path = Path(checkpoint_path)
assert checkpoint_path.exists()
Expand All @@ -82,6 +84,7 @@ def __init__(
self.graph.fold_constants()
self.num_classes = model.num_classes
self.batch_size = 1
self.precision = precision

def infer(self):
"""
Expand Down Expand Up @@ -165,6 +168,13 @@ def register_nms(
"box_coding": 0,
}

if self.precision == "fp32":
dtype_output = np.float32
elif self.precision == "fp16":
dtype_output = np.float16
else:
raise NotImplementedError(f"Currently not supports precision: {self.precision}")

# NMS Outputs
output_num_detections = gs.Variable(
name="num_detections",
Expand All @@ -173,12 +183,12 @@ def register_nms(
) # A scalar indicating the number of valid detections per batch image.
output_boxes = gs.Variable(
name="detection_boxes",
dtype=np.float32,
dtype=dtype_output,
shape=[self.batch_size, detections_per_img, 4],
)
output_scores = gs.Variable(
name="detection_scores",
dtype=np.float32,
dtype=dtype_output,
shape=[self.batch_size, detections_per_img],
)
output_labels = gs.Variable(
Expand Down