diff --git a/README.md b/README.md index 80903ad1..5f643fd5 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,21 @@ On the `ONNX Runtime` front you can use the [C++ example](deployment/onnxruntime ### Inference on TensorRT backend +The pipeline for TensorRT deployment is also very easy to use. + +```python +import torch +from yolort.runtime import PredictorTRT + +# Load the exported TensorRT engine +engine_path = "yolov5n6.engine" +device = torch.device("cuda") +y_runtime = PredictorTRT(engine_path, device=device) + +# Perform inference on an image file +predictions = y_runtime.predict("bus.jpg") +``` + On the `TensorRT` front you can use the [C++ example](deployment/tensorrt), and we also provide a [tutorial](https://zhiqwang.com/yolov5-rt-stack/notebooks/onnx-graphsurgeon-inference-tensorrt.html) for using the `TensorRT`. ## 🎨 Model Graph Visualization diff --git a/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb b/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb index d8752656..1bdb9307 100644 --- a/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb +++ b/notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb @@ -96,7 +96,8 @@ "from yolort.utils import cv2_imshow, get_image_from_url, read_image_to_tensor\n", "from yolort.utils.image_utils import plot_one_box, color_list\n", "from yolort.v5 import letterbox, non_max_suppression, scale_coords, attempt_download\n", - "from yolort.v5.utils.torch_utils import select_device, time_sync" + "from yolort.v5.utils.torch_utils import select_device, time_sync\n", + "from yolort.v5.utils.downloads import safe_download" ] }, { @@ -115,8 +116,9 @@ "outputs": [], "source": [ "# Define some parameters\n", + "batch_size = 1\n", "img_size = 640\n", - "stride = 64\n", + "size_divisible = 64\n", "fixed_shape = True\n", "score_thresh = 0.35\n", "iou_thresh = 0.45\n", @@ -127,94 +129,74 @@ { "cell_type": "code", "execution_count": 4, - "id": "db0a7686", - "metadata": {}, - "outputs": [], - "source": [ - "# yolov5s6.pt is downloaded from 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n6.pt'\n", - "model_path = \"yolov5n6.pt\"\n", - "\n", - "checkpoint_path = attempt_download(model_path)\n", - "onnx_path = \"yolov5n6.onnx\"\n", - "engine_path = \"yolov5n6.engine\"" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "56c0f910-3179-4e36-92a4-8daa3a234533", - "metadata": {}, - "outputs": [], - "source": [ - "img_source = \"https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg\"\n", - "# img_source = \"https://huggingface.co/spaces/zhiqwang/assets/resolve/main/zidane.jpg\"\n", - "img_raw = get_image_from_url(img_source)" - ] - }, - { - "cell_type": "markdown", - "id": "b7cf3c4b-be02-4b1e-812a-2fdd025d231b", - "metadata": {}, - "source": [ - "### Pre Processing" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "ff97c859-ef72-41ec-a54d-d29b800bc4fe", - "metadata": {}, - "outputs": [], - "source": [ - "# Preprocess\n", - "auto = not fixed_shape\n", - "image = letterbox(img_raw, new_shape=(img_size, img_size), stride=stride, auto=auto)[0]\n", - "image = read_image_to_tensor(image)\n", - "image = image[None]\n", - "image = image.to(device)\n", - "image = image.contiguous()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2b14d332-c967-4574-8744-3c408874a5fc", + "id": "3a0554d4-3cb6-4c0e-8c2a-239167ef2ce7", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg to bus.jpg...\n" + ] + }, { "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e9dce379b40c4edca86772300f78fb4e", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "torch.Size([1, 3, 640, 640])" + " 0%| | 0.00/476k [00:00" + "" ] }, "metadata": {}, @@ -800,9 +790,6 @@ } ], "source": [ - "for box, label in zip(boxes.tolist(), labels.to(dtype=torch.int64).tolist()):\n", - " img_raw = plot_one_box(box, img_raw, color=COLORS[label % len(COLORS)], label=LABELS[label])\n", - "\n", "cv2_imshow(img_raw, imshow_scale=0.5)" ] } diff --git a/yolort/models/yolo_module.py b/yolort/models/yolo_module.py index 96dbaf7c..dc884212 100644 --- a/yolort/models/yolo_module.py +++ b/yolort/models/yolo_module.py @@ -2,7 +2,7 @@ import argparse import warnings from pathlib import PosixPath -from typing import Any, List, Dict, Tuple, Optional, Union, Callable +from typing import Any, Dict, List, Callable, Optional, Tuple, Union import torch import torchvision @@ -218,11 +218,7 @@ def configure_optimizers(self): ) @torch.no_grad() - def predict( - self, - x: Any, - image_loader: Optional[Callable] = None, - ) -> List[Dict[str, Tensor]]: + def predict(self, x: Any, image_loader: Optional[Callable] = None) -> List[Dict[str, Tensor]]: """ Predict function for raw data or processed data Args: @@ -234,8 +230,7 @@ def predict( """ image_loader = image_loader or self.default_loader images = self.collate_images(x, image_loader) - outputs = self.forward(images) - return outputs + return self.forward(images) def default_loader(self, img_path: str) -> Tensor: """ diff --git a/yolort/runtime/y_tensorrt.py b/yolort/runtime/y_tensorrt.py index 4dfd0142..b2d55749 100644 --- a/yolort/runtime/y_tensorrt.py +++ b/yolort/runtime/y_tensorrt.py @@ -1,16 +1,15 @@ # Copyright (c) 2021, yolort team. All rights reserved. -# -# This source code is licensed under the GPL-3.0 license found in the -# LICENSE file in the root directory of this source tree. -# import logging from collections import OrderedDict, namedtuple -from typing import Dict, List +from typing import Any, Dict, List, Callable, Optional, Tuple import numpy as np import torch from torch import Tensor +from torchvision.io import read_image +from yolort.data import contains_any_tensor +from yolort.models.transform import YOLOTransform try: import tensorrt as trt @@ -31,25 +30,31 @@ class PredictorTRT: engine_path (string): Path of the ONNX checkpoint. device (torch.device): The CUDA device to be used for inferencing. precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. + enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: False. + size: (Tuple[int, int]): the minimum and maximum size of the image to be rescaled. + Default: (640, 640) + size_divisible (int): stride of the models. Default: 32 + fixed_shape (Tuple[int, int], optional): Padding mode for letterboxing. If set to `True`, + the image will be padded to shape `fixed_shape` if specified. Instead the image will + be padded to a minimum rectangle to match `min_size / max_size` and each of its edges + is divisible by `size_divisible` if it is not specified. Default: None + fill_color (int): fill value for padding. Default: 114 - Examples: - >>> import cv2 - >>> import numpy as np - >>> import torch - >>> from yolort.runtime import PredictorTRT - >>> - >>> engine_path = 'yolov5n6.engine' - >>> device = torch.device('cuda') - >>> runtime = PredictorTRT(engine_path, device) - >>> - >>> img_path = 'bus.jpg' - >>> image = cv2.imread(img_path) - >>> image = cv2.resize(image, (320, 320)) - >>> image = image.transpose((2, 0, 1))[::-1] # Convert HWC to CHW, BGR to RGB - >>> image = np.ascontiguousarray(image) - >>> - >>> image = runtime.preprocessing(image) - >>> detections = runtime.run_on_image(image) + Example: + + Demo pipeline for deploying TensorRT. + + .. code-block:: python + import torch + from yolort.runtime import PredictorTRT + + # Load the exported TensorRT engine + engine_path = 'yolov5n6.engine' + device = torch.device('cuda') + y_runtime = PredictorTRT(engine_path, device=device) + + # Perform inference on an image file + predictions = y_runtime.predict('bus.jpg') """ def __init__( @@ -57,51 +62,92 @@ def __init__( engine_path: str, device: torch.device = torch.device("cuda"), precision: str = "fp32", + enable_dynamic: bool = False, + size: Tuple[int, int] = (640, 640), + size_divisible: int = 32, + fixed_shape: Optional[Tuple[int, int]] = None, + fill_color: int = 114, ) -> None: - self.engine_path = engine_path - self.device = device - self.named_binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) - self.stride = 32 - self.names = [f"class{i}" for i in range(1000)] # assign defaults + self._engine_path = engine_path + self._device = device + + # Build the inference engine + self.named_binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) self.engine = self._build_engine() + self.bindings = OrderedDict() + self.binding_addrs = None + self.context = None self._set_context() if precision == "fp32": - self.half = False + self._half = False elif precision == "fp16": - self.half = True + self._half = True else: raise NotImplementedError(f"Currently not supports precision: {precision}") + self._dtype = torch.float16 if self._half else torch.float32 + + # Set pre-processing transform for TensorRT inference + self._enable_dynamic = enable_dynamic + self._size = size + self._size_divisible = size_divisible + self._fixed_shape = fixed_shape + self._fill_color = fill_color + self._img_size = None + self.transform = None + self._set_preprocessing() + + # Visualization + self._names = [f"class{i}" for i in range(1000)] # assign defaults + def _build_engine(self): - logger.info(f"Loading {self.engine_path} for TensorRT inference...") - trt_logger = trt.Logger(trt.Logger.INFO) + logger.info(f"Loading {self._engine_path} for TensorRT inference...") + if trt is not None: + trt_logger = trt.Logger(trt.Logger.INFO) + else: + trt_logger = None + raise ImportError("TensorRT is not installed, please install trt firstly.") + trt.init_libnvinfer_plugins(trt_logger, namespace="") - with open(self.engine_path, "rb") as f, trt.Runtime(trt_logger) as runtime: + with open(self._engine_path, "rb") as f, trt.Runtime(trt_logger) as runtime: engine = runtime.deserialize_cuda_engine(f.read()) return engine def _set_context(self): - self.bindings = OrderedDict() for index in range(self.engine.num_bindings): name = self.engine.get_binding_name(index) dtype = trt.nptype(self.engine.get_binding_dtype(index)) shape = tuple(self.engine.get_binding_shape(index)) - data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(self.device) + data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(self._device) self.bindings[name] = self.named_binding(name, dtype, shape, data, int(data.data_ptr())) self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) self.context = self.engine.create_execution_context() - def preprocessing(self, image): - image = torch.from_numpy(image).to(self.device) - image = image.half() if self.half else image.float() # uint8 to fp16/32 - image /= 255 # 0 - 255 to 0.0 - 1.0 - if len(image.shape) == 3: - image = image[None] # expand for batch dim - return image + def _set_preprocessing(self): + if self._enable_dynamic: + raise NotImplementedError("Currently only supports static shape inference in TensorRT.") + + export_onnx_shape = self.bindings["images"].shape + self._img_size = export_onnx_shape + + size = export_onnx_shape[-2:] + self.transform = YOLOTransform( + size[0], + size[1], + size_divisible=self._size_divisible, + fixed_shape=size, + fill_color=self._fill_color, + ) + + def warmup(self): + # Warmup model by running inference once and only warmup GPU models + if isinstance(self._device, torch.device) and self._device.type != "cpu": + image = torch.zeros(*self._img_size).to(dtype=self._dtype, device=self._device) + self(image) def __call__(self, image: Tensor): """ @@ -112,7 +158,6 @@ def __call__(self, image: Tensor): predictions (Tuple[Tensor, Tensor, Tensor, Tensor]): stands for boxes, scores, labels and number of boxes respectively. """ - assert image.shape == self.bindings["images"].shape, (image.shape, self.bindings["images"].shape) self.binding_addrs["images"] = int(image.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) num_dets = self.bindings["num_detections"].data @@ -121,20 +166,8 @@ def __call__(self, image: Tensor): labels = self.bindings["detection_classes"].data return boxes, scores, labels, num_dets - def run_on_image(self, image: Tensor): - """ - Run the TensorRT engine for one image only. - - Args: - image (Tensor): an image of shape (N, C, H, W). - """ - boxes, scores, labels, num_dets = self(image) - - detections = self.postprocessing(boxes, scores, labels, num_dets) - return detections - @staticmethod - def postprocessing(all_boxes, all_scores, all_labels, all_num_dets): + def parse_output(all_boxes, all_scores, all_labels, all_num_dets): detections: List[Dict[str, Tensor]] = [] for boxes, scores, labels, num_dets in zip(all_boxes, all_scores, all_labels, all_num_dets): @@ -144,23 +177,88 @@ def postprocessing(all_boxes, all_scores, all_labels, all_num_dets): return detections - def warmup(self, img_size=(1, 3, 320, 320)): - # Warmup model by running inference once - # only warmup GPU models - if isinstance(self.device, torch.device) and self.device.type != "cpu": - image = torch.zeros(*img_size).to(self.device).type(torch.half if self.half else torch.float) - self(image) + def forward(self, inputs: List[Tensor]): + """ + Wrapper the TensorRT inference engine with Pre-Processing Module. + + Args: + inputs (list[Tensor]): images to be processed + """ + # get the original image sizes + original_image_sizes: List[Tuple[int, int]] = [] + + for img in inputs: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # Pre-Processing + samples, _ = self.transform(inputs) + # Inference on TensorRT + boxes, scores, labels, num_dets = self(samples.tensors) + results = self.parse_output(boxes, scores, labels, num_dets) - def run_wo_postprocessing(self, image: Tensor): + # Rescale coordinate + im_shape = torch.tensor(samples.tensors.shape[-2:]) + detections = self.transform.postprocess(results, im_shape, original_image_sizes) + + return detections + + def predict(self, x: Any, image_loader: Optional[Callable] = None) -> List[Dict[str, Tensor]]: """ - Run the TensorRT engine for one image only. + Predict function for raw data or processed data + Args: + x: Input to predict. Can be raw data or processed data. + image_loader: Utility function to convert raw data to Tensor. + + Returns: + The post-processed model predictions. + """ + image_loader = image_loader or self.default_loader + images = self.collate_images(x, image_loader) + return self.forward(images) + + def default_loader(self, img_path: str) -> Tensor: + """ + Default loader of read a image path. Args: - image (Tensor): an image of shape (N, C, H, W). + img_path (str): a image path + + Returns: + Tensor, processed tensor for prediction. """ - assert image.shape == self.bindings["images"].shape, (image.shape, self.bindings["images"].shape) - self.binding_addrs["images"] = int(image.data_ptr()) - self.context.execute_v2(list(self.binding_addrs.values())) - boxes = self.bindings["boxes"].data - scores = self.bindings["scores"].data - return boxes, scores + return read_image(img_path) / 255.0 + + def collate_images(self, samples: Any, image_loader: Callable) -> List[Tensor]: + """ + Prepare source samples for inference. + + Args: + samples (Any): samples source, support the following various types: + - str or List[str]: a image path or list of image paths. + - Tensor or List[Tensor]: a tensor or list of tensors. + + Returns: + List[Tensor], The processed image samples. + """ + if isinstance(samples, Tensor): + return [samples.to(dtype=self._dtype, device=self._device)] + + if contains_any_tensor(samples): + return [sample.to(dtype=self._dtype, device=self._device) for sample in samples] + + if isinstance(samples, str): + samples = [samples] + + if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): + outputs = [] + for sample in samples: + output = image_loader(sample).to(dtype=self._dtype, device=self._device) + outputs.append(output) + return outputs + + raise NotImplementedError( + f"The type of the sample is {type(samples)}, we currently don't support it now, the " + "samples should be either a tensor, list of tensors, a image path or list of image paths." + ) diff --git a/yolort/runtime/yolo_graphsurgeon.py b/yolort/runtime/yolo_graphsurgeon.py index 8f4fdedb..24440fed 100644 --- a/yolort/runtime/yolo_graphsurgeon.py +++ b/yolort/runtime/yolo_graphsurgeon.py @@ -43,11 +43,11 @@ class YOLOGraphSurgeon: Args: checkpoint_path (string): The path pointing to the PyTorch saved model to load. + version (str): upstream version released by the ultralytics/yolov5, Possible + values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". input_sample (Tensor, optional): Specify the input shape to export ONNX, and the default shape for the sample is (1, 3, 640, 640). score_thresh (float): Score threshold used for postprocessing the detections. - version (str): upstream version released by the ultralytics/yolov5, Possible - values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0". enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: False. device (torch.device): The device to be used for importing ONNX. Default: torch.device("cpu"). precision (string): The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. @@ -57,8 +57,8 @@ def __init__( self, checkpoint_path: str, *, - input_sample: Optional[Tensor] = None, version: str = "r6.0", + input_sample: Optional[Tensor] = None, enable_dynamic: bool = False, device: torch.device = torch.device("cpu"), precision: str = "fp32",