Skip to content

Commit

Permalink
Move LogitsDecoder into trt_helper.py and fix docstrings (#256)
Browse files Browse the repository at this point in the history
* Fix example in PredictorTRT

* Fix docstrings for PredictorTRT

* Move LogitsDecoder into yolort.runtime

* Minor fix
  • Loading branch information
zhiqwang authored Dec 26, 2021
1 parent 902984d commit 6d10759
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 64 deletions.
54 changes: 0 additions & 54 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,60 +353,6 @@ def _decode_pred_logits(pred_logits: Tensor):
return boxes, scores


class LogitsDecoder(nn.Module):
"""
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure of
exporting the ONNX graph for TensorRT.
"""

def __init__(self, strides: List[int]) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
"""

super().__init__()
self.strides = strides

def forward(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tuple[Tensor, Tensor]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
from original ``yolort.models.box_head.PostProcess``.
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

# The default boxes tensor has shape [batch_size, number_boxes, 4].
# This will insert a "1" dimension in the second axis, to become
# [batch_size, number_boxes, 1, 4], the shape that plugin/BatchedNMS expects.
boxes = torch.stack(bbox_regression).unsqueeze_(2)
scores = torch.stack(pred_scores)
return boxes, scores


class PostProcess(nn.Module):
"""
Performs Non-Maximum Suppression (NMS) on inference results
Expand Down
57 changes: 55 additions & 2 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Optional, List, Tuple, Union

try:
import tensorrt as trt
Expand All @@ -23,7 +23,7 @@
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import LogitsDecoder
from yolort.models.box_head import _concat_pred_logits, _decode_pred_logits
from yolort.utils import load_from_ultralytics

logging.basicConfig(level=logging.INFO)
Expand All @@ -34,6 +34,59 @@
__all__ = ["YOLOTRTModule", "EngineBuilder"]


class LogitsDecoder(nn.Module):
"""
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure of
exporting the ONNX graph for YOLOTRTModule.
"""

def __init__(self, strides: List[int]) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
"""

super().__init__()
self.strides = strides

def forward(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tuple[Tensor, Tensor]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
from original ``yolort.models.box_head.PostProcess``.
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

# The default boxes tensor has shape [batch_size, number_boxes, 4].
# This will insert a "1" dimension in the second axis, to become
# [batch_size, number_boxes, 1, 4], the shape that plugin/BatchedNMS expects.
boxes = torch.stack(bbox_regression).unsqueeze_(2)
scores = torch.stack(pred_scores)
return boxes, scores


class YOLOTRTModule(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.
Expand Down
26 changes: 18 additions & 8 deletions yolort/runtime/y_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,35 @@ class PredictorTRT:
Args:
engine_path (str): Path of the ONNX checkpoint.
device (torch.device): The CUDA device to be used for inferencing.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
Examples:
>>> import cv2
>>> import numpy as np
>>> import torch
>>> from yolort.runtime import PredictorTRT
>>>
>>> engine_path = 'yolov5s.engine'
>>> device = torch.device("cuda")
>>> detector = PredictorTRT(engine_path, device)
>>> runtime = PredictorTRT(engine_path, device)
>>>
>>> img_path = 'bus.jpg'
>>> detections = detector.run_on_image(img_path)
>>> image = cv2.imread(img_path)
>>> image = cv2.resize(image, (320, 320))
>>> image = image.transpose((2, 0, 1))[::-1] # Convert HWC to CHW, BGR to RGB
>>> image = np.ascontiguousarray(image)
>>>
>>> image = runtime.preprocessing(image)
>>> detections = runtime.run_on_image(image)
"""

def __init__(
self,
engine_path: str,
device: torch.device = torch.device("cuda"),
score_thresh: float = 0.25,
iou_thresh: float = 0.45,
nms_thresh: float = 0.45,
detections_per_img: int = 100,
) -> None:
self.engine_path = engine_path
Expand All @@ -56,7 +66,7 @@ def __init__(
self.stride = 32
self.names = [f"class{i}" for i in range(1000)] # assign defaults
self.score_thresh = score_thresh
self.iou_thresh = iou_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img

self.engine = self._build_engine()
Expand Down Expand Up @@ -98,8 +108,8 @@ def __call__(self, image: Tensor):
image (Tensor): an image of shape (C, N, H, W).
Returns:
predictions (Tuple[List[float], List[int], List[float, float]]):
stands for scores, labels and boxes respectively.
predictions (Tuple[Tensor, Tensor, Tensor, Tensor]):
stands for boxes, scores, labels and number of boxes respectively.
"""
assert image.shape == self.bindings["images"].shape, (image.shape, self.bindings["images"].shape)
self.binding_addrs["images"] = int(image.data_ptr())
Expand Down

0 comments on commit 6d10759

Please sign in to comment.