From 5dd25c3dcf389c1eeca27ea7324313cc27f2a98a Mon Sep 17 00:00:00 2001
From: Zhiqiang Wang <zhiqwang@foxmail.com>
Date: Tue, 21 Dec 2021 05:25:54 +0800
Subject: [PATCH] Cleanup AnchorGenerator and PostProcess (#203)

* Cleanup Anchor configuration mechanism

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix the latest compatibility issues

* Fix missing outputs

* Fix docstrings

* Fix anchors in AnchorGenerator._generate_shifts

* Fix TestAnchorGenerator

* Fix test_anchor_generator

* Fix test_postprocessors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Fixing YOLOTRTModule

* Fix pylint

* Minor fix

* Fix tensor.stride in AnchorGenerator._generate_shifts

* Cleanup codes

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 test/test_models.py                   |  43 ++++++-----
 test/test_models_anchor_utils.py      |  18 ++---
 yolort/models/_utils.py               |  43 ++++++-----
 yolort/models/anchor_utils.py         | 104 ++++++++++----------------
 yolort/models/box_head.py             |  81 ++++++++++++--------
 yolort/models/yolo.py                 |  11 ++-
 yolort/runtime/yolo_tensorrt_model.py |  25 ++++++-
 yolort/v5/models/common.py            |  22 +-----
 yolort/v5/utils/general.py            |   2 +-
 9 files changed, 178 insertions(+), 171 deletions(-)

diff --git a/test/test_models.py b/test/test_models.py
index 3dc03415..4cfb6657 100644
--- a/test/test_models.py
+++ b/test/test_models.py
@@ -125,12 +125,13 @@ def _get_anchor_grids(use_p6: bool):
             ]
         return anchor_grids
 
-    def _compute_num_anchors(self, height, width, use_p6: bool):
+    def _compute_anchors(self, height, width, use_p6: bool):
         strides = self._get_strides(use_p6)
-        num_anchors = 0
+        anchors_num = len(strides)
+        anchors_shape = []
         for s in strides:
-            num_anchors += (height // s) * (width // s)
-        return num_anchors * 3
+            anchors_shape.append((height // s, width // s))
+        return anchors_num, anchors_shape
 
     def _get_feature_shapes(self, height, width, width_multiple=0.5, use_p6=False):
         in_channels = self._get_in_channels(width_multiple, use_p6)
@@ -238,12 +239,14 @@ def test_anchor_generator(self, width_multiple, use_p6, batch_size, height, widt
         )
         model = self._init_test_anchor_generator(use_p6)
         anchors = model(feature_maps)
-        expected_num_anchors = self._compute_num_anchors(height, width, use_p6)
+        expected_anchors_num, expected_anchors_shape = self._compute_anchors(height, width, use_p6)
+
+        assert len(anchors) == 2
+        assert len(anchors[0]) == len(anchors[1]) == expected_anchors_num
+        for i in range(expected_anchors_num):
+            assert tuple(anchors[0][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2)
+            assert tuple(anchors[1][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2)
 
-        assert len(anchors) == 3
-        assert tuple(anchors[0].shape) == (expected_num_anchors, 2)
-        assert tuple(anchors[1].shape) == (expected_num_anchors, 1)
-        assert tuple(anchors[2].shape) == (expected_num_anchors, 2)
         _check_jit_scriptable(model, (feature_maps,))
 
     def _init_test_yolo_head(self, width_multiple=0.5, use_p6=False):
@@ -269,29 +272,31 @@ def test_yolo_head(self):
         assert head_outputs[2].shape == target_head_outputs[2].shape
         _check_jit_scriptable(model, (feature_maps,))
 
-    def _init_test_postprocessors(self):
+    def _init_test_postprocessors(self, strides):
         score_thresh = 0.5
         nms_thresh = 0.45
         detections_per_img = 100
-        postprocessors = PostProcess(score_thresh, nms_thresh, detections_per_img)
+        postprocessors = PostProcess(strides, score_thresh, nms_thresh, detections_per_img)
         return postprocessors
 
-    def test_postprocessors(self):
+    @pytest.mark.parametrize("use_p6", [False, True])
+    def test_postprocessors(self, use_p6):
         N, H, W = 4, 416, 352
-        feature_maps = self._get_feature_maps(N, H, W)
-        head_outputs = self._get_head_outputs(N, H, W)
+        strides = self._get_strides(use_p6)
+        feature_maps = self._get_feature_maps(N, H, W, use_p6=use_p6)
+        head_outputs = self._get_head_outputs(N, H, W, use_p6=use_p6)
 
-        anchor_generator = self._init_test_anchor_generator()
-        anchors_tuple = anchor_generator(feature_maps)
-        model = self._init_test_postprocessors()
-        out = model(head_outputs, anchors_tuple)
+        anchor_generator = self._init_test_anchor_generator(use_p6=use_p6)
+        grids, shifts = anchor_generator(feature_maps)
+        model = self._init_test_postprocessors(strides)
+        out = model(head_outputs, grids, shifts)
 
         assert len(out) == N
         assert isinstance(out[0], dict)
         assert isinstance(out[0]["boxes"], Tensor)
         assert isinstance(out[0]["labels"], Tensor)
         assert isinstance(out[0]["scores"], Tensor)
-        _check_jit_scriptable(model, (head_outputs, anchors_tuple))
+        _check_jit_scriptable(model, (head_outputs, grids, shifts))
 
     def test_criterion(self, use_p6=False):
         N, H, W = 4, 640, 640
diff --git a/test/test_models_anchor_utils.py b/test/test_models_anchor_utils.py
index 39e4609f..803de860 100644
--- a/test/test_models_anchor_utils.py
+++ b/test/test_models_anchor_utils.py
@@ -18,15 +18,13 @@ def test_anchor_generator(self):
         model.eval()
         anchors = model(features)
 
-        expected_anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])
-        expected_wh_output = torch.tensor([[4.0], [4.0], [4.0], [4.0]])
-        expected_xy_output = torch.tensor([[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]])
+        expected_grids = torch.tensor([[[[[0.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]]]])
+        expected_shifts = torch.tensor([[[[[6.0, 14.0], [6.0, 14.0]], [[6.0, 14.0], [6.0, 14.0]]]]])
 
-        assert len(anchors) == 3
-        assert tuple(anchors[0].shape) == (4, 2)
-        assert tuple(anchors[1].shape) == (4, 1)
-        assert tuple(anchors[2].shape) == (4, 2)
+        assert len(anchors) == 2
+        assert len(anchors[0]) == len(anchors[1]) == 1
+        assert tuple(anchors[0][0].shape) == (1, 1, 2, 2, 2)
+        assert tuple(anchors[1][0].shape) == (1, 1, 2, 2, 2)
 
-        torch.testing.assert_close(anchors[0], expected_anchor_output, rtol=0, atol=0)
-        torch.testing.assert_close(anchors[1], expected_wh_output, rtol=0, atol=0)
-        torch.testing.assert_close(anchors[2], expected_xy_output, rtol=0, atol=0)
+        torch.testing.assert_close(anchors[0][0], expected_grids)
+        torch.testing.assert_close(anchors[1][0], expected_shifts)
diff --git a/yolort/models/_utils.py b/yolort/models/_utils.py
index a53c06e0..77a97544 100644
--- a/yolort/models/_utils.py
+++ b/yolort/models/_utils.py
@@ -3,12 +3,13 @@
 
 import torch
 from torch import nn, Tensor
-from torchvision.ops import box_convert, box_iou
+from torchvision.ops import box_iou
 
 
 def _evaluate_iou(target, pred):
     """
-    Evaluate intersection over union (IOU) for target from dataset and output prediction from model
+    Evaluate intersection over union (IOU) for target from dataset and
+    output prediction from model
     """
     if pred["boxes"].shape[0] == 0:
         # no box detected, 0 IOU
@@ -34,8 +35,7 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
 
 def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor:
     """
-    Encode a set of anchors with respect to some
-    reference boxes
+    Encode a set of anchors with respect to some reference boxes
 
     Args:
         reference_boxes (Tensor): reference boxes
@@ -52,23 +52,24 @@ def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor:
 
 def decode_single(
     rel_codes: Tensor,
-    anchors_tuple: Tuple[Tensor, Tensor, Tensor],
-) -> Tensor:
+    grid: Tensor,
+    shift: Tensor,
+    stride: int,
+) -> Tuple[Tensor, Tensor]:
     """
     From a set of original boxes and encoded relative box offsets,
     get the decoded boxes.
 
-    Arguments:
-        rel_codes (Tensor): encoded boxes
-        anchors_tupe (Tensor, Tensor, Tensor): reference boxes.
+    Args:
+        rel_codes (Tensor): Encoded boxes
+        grid (Tensor): Anchor grids
+        shift (Tensor): Anchor shifts
+        stride (int): Stride
     """
+    pred_xy = (rel_codes[..., 0:2] * 2.0 - 0.5 + grid) * stride
+    pred_wh = (rel_codes[..., 2:4] * 2.0) ** 2 * shift
 
-    pred_wh = (rel_codes[..., 0:2] * 2.0 + anchors_tuple[0]) * anchors_tuple[1]  # wh
-    pred_xy = (rel_codes[..., 2:4] * 2) ** 2 * anchors_tuple[2]  # xy
-    pred_boxes = torch.cat([pred_wh, pred_xy], dim=1)
-    pred_boxes = box_convert(pred_boxes, in_fmt="cxcywh", out_fmt="xyxy")
-
-    return pred_boxes
+    return pred_xy, pred_wh
 
 
 def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-7):
@@ -99,8 +100,10 @@ def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-
 
     iou = inter / union
 
-    cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width
-    ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
+    # convex (smallest enclosing box) width
+    cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
+    # convex height
+    ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
     # Complete IoU https://arxiv.org/abs/1911.08287v1
     c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
     rho2 = (
@@ -149,7 +152,7 @@ def forward(self, pred, logit):
 
         if self.reduction == "mean":
             return loss.mean()
-        elif self.reduction == "sum":
+        if self.reduction == "sum":
             return loss.sum()
-        else:  # 'none'
-            return loss
+        # 'none'
+        return loss
diff --git a/yolort/models/anchor_utils.py b/yolort/models/anchor_utils.py
index ca049498..23dafa4e 100644
--- a/yolort/models/anchor_utils.py
+++ b/yolort/models/anchor_utils.py
@@ -6,89 +6,61 @@
 
 
 class AnchorGenerator(nn.Module):
-    def __init__(
-        self,
-        strides: List[int],
-        anchor_grids: List[List[float]],
-    ):
+    def __init__(self, strides: List[int], anchor_grids: List[List[float]]):
+
         super().__init__()
         assert len(strides) == len(anchor_grids)
-        self.num_anchors = len(anchor_grids[0]) // 2
         self.strides = strides
         self.anchor_grids = anchor_grids
+        self.num_layers = len(anchor_grids)
+        self.num_anchors = len(anchor_grids[0]) // 2
 
-    def set_wh_weights(
+    def _generate_grids(
         self,
         grid_sizes: List[List[int]],
         dtype: torch.dtype = torch.float32,
         device: torch.device = torch.device("cpu"),
-    ) -> Tensor:
-
-        wh_weights = []
-
-        for size, stride in zip(grid_sizes, self.strides):
-            grid_height, grid_width = size
-            stride = torch.as_tensor([stride], dtype=dtype, device=device)
-            stride = stride.view(-1, 1)
-            stride = stride.repeat(1, grid_height * grid_width * self.num_anchors)
-            stride = stride.reshape(-1, 1)
-            wh_weights.append(stride)
+    ) -> List[Tensor]:
 
-        return torch.cat(wh_weights)
-
-    def set_xy_weights(
-        self,
-        grid_sizes: List[List[int]],
-        dtype: torch.dtype = torch.float32,
-        device: torch.device = torch.device("cpu"),
-    ) -> Tensor:
+        grids = []
+        for height, width in grid_sizes:
+            # For output anchor, compute [x_center, y_center, x_center, y_center]
+            widths = torch.arange(width, dtype=torch.int32, device=device).to(dtype=dtype)
+            heights = torch.arange(height, dtype=torch.int32, device=device).to(dtype=dtype)
 
-        xy_weights = []
+            shift_y, shift_x = torch.meshgrid(heights, widths)
 
-        for size, anchor_grid in zip(grid_sizes, self.anchor_grids):
-            grid_height, grid_width = size
-            anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)
-            anchor_grid = anchor_grid.view(-1, 2)
-            anchor_grid = anchor_grid.repeat(1, grid_height * grid_width)
-            anchor_grid = anchor_grid.reshape(-1, 2)
-            xy_weights.append(anchor_grid)
+            grid = torch.stack((shift_x, shift_y), 2).expand((1, self.num_anchors, height, width, 2))
+            grids.append(grid)
 
-        return torch.cat(xy_weights)
+        return grids
 
-    def grid_anchors(
+    def _generate_shifts(
         self,
         grid_sizes: List[List[int]],
         dtype: torch.dtype = torch.float32,
         device: torch.device = torch.device("cpu"),
-    ) -> Tensor:
-
-        anchors = []
-
-        for size in grid_sizes:
-            grid_height, grid_width = size
-
-            # For output anchor, compute [x_center, y_center, x_center, y_center]
-            shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device).to(dtype=dtype)
-            shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device).to(dtype=dtype)
-
-            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
-
-            shifts = torch.stack((shift_x, shift_y), dim=2)
-            shifts = shifts.view(1, grid_height, grid_width, 2)
-            shifts = shifts.repeat(self.num_anchors, 1, 1, 1)
-            shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)
-            shifts = shifts.reshape(-1, 2)
-
-            anchors.append(shifts)
-
-        return torch.cat(anchors)
-
-    def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
+    ) -> List[Tensor]:
+
+        anchors = torch.tensor(self.anchor_grids, dtype=dtype, device=device)
+        strides = torch.tensor(self.strides, dtype=dtype, device=device)
+        anchors = anchors.view(self.num_layers, -1, 2) / strides.view(-1, 1, 1)
+
+        shifts = []
+        for i, (height, width) in enumerate(grid_sizes):
+            shift = (
+                (anchors[i].clone() * self.strides[i])
+                .view((1, self.num_anchors, 1, 1, 2))
+                .expand((1, self.num_anchors, height, width, 2))
+                .contiguous()
+                .to(dtype=dtype)
+            )
+            shifts.append(shift)
+        return shifts
+
+    def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
         grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
         dtype, device = feature_maps[0].dtype, feature_maps[0].device
-
-        wh_weights = self.set_wh_weights(grid_sizes, dtype, device)
-        xy_weights = self.set_xy_weights(grid_sizes, dtype, device)
-        anchors = self.grid_anchors(grid_sizes, dtype, device)
-
-        return anchors, wh_weights, xy_weights
+        grids = self._generate_grids(grid_sizes, dtype=dtype, device=device)
+        shifts = self._generate_shifts(grid_sizes, dtype=dtype, device=device)
+        return grids, shifts
diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py
index 7da0cc7c..b044f616 100644
--- a/yolort/models/box_head.py
+++ b/yolort/models/box_head.py
@@ -5,7 +5,7 @@
 import torch
 import torch.nn.functional as F
 from torch import nn, Tensor
-from torchvision.ops import boxes as box_ops
+from torchvision.ops import box_convert, boxes as box_ops
 
 from . import _utils as det_utils
 
@@ -70,7 +70,8 @@ def forward(self, x: List[Tensor]) -> List[Tensor]:
             # Permute output from (N, A * K, H, W) to (N, A, H, W, K)
             N, _, H, W = pred_logits.shape
             pred_logits = pred_logits.view(N, self.num_anchors, -1, H, W)
-            pred_logits = pred_logits.permute(0, 1, 3, 4, 2)  # Size=(N, A, H, W, K)
+            # Size=(N, A, H, W, K)
+            pred_logits = pred_logits.permute(0, 1, 3, 4, 2).contiguous()
 
             all_pred_logits.append(pred_logits)
 
@@ -322,45 +323,58 @@ class LogitsDecoder(nn.Module):
     This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
     """
 
-    def __init__(self) -> None:
+    def __init__(self, strides: List[int]) -> None:
+        """
+        Args:
+            strides (List[int]): Strides of the AnchorGenerator.
+        """
+
         super().__init__()
+        self.strides = strides
 
-    @staticmethod
-    def _concat_pred_logits(head_outputs: List[Tensor]) -> Tensor:
+    def _concat_pred_logits(
+        self,
+        head_outputs: List[Tensor],
+        grids: List[Tensor],
+        shifts: List[Tensor],
+    ) -> Tensor:
         # Concat all pred logits
         batch_size, _, _, _, K = head_outputs[0].shape
 
+        # Decode bounding box with the shifts and grids
         all_pred_logits = []
-        for pred_logits in head_outputs:
-            pred_logits = pred_logits.reshape(batch_size, -1, K)  # Size=(N, HWA, K)
-            all_pred_logits.append(pred_logits)
+
+        for i, head_output in enumerate(head_outputs):
+            head_feature = torch.sigmoid(head_output)
+            pred_xy, pred_wh = det_utils.decode_single(
+                head_feature[..., :4],
+                grids[i],
+                shifts[i],
+                self.strides[i],
+            )
+            pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
+            all_pred_logits.append(pred_logits.view(batch_size, -1, K))
 
         all_pred_logits = torch.cat(all_pred_logits, dim=1)
+
         return all_pred_logits
 
-    @staticmethod
-    def _decode_pred_logits(
-        pred_logits: Tensor,
-        idx: int,
-        anchors_tuple: Tuple[Tensor, Tensor, Tensor],
-    ):
+    def _decode_pred_logits(self, pred_logits: Tensor):
         """
-        Decode the prediction logit from the Post_precess
+        Decode the prediction logit from the PostPrecess.
         """
-        pred_logits = torch.sigmoid(pred_logits[idx])
-
         # Compute conf
         # box_conf x class_conf, w/ shape: num_anchors x num_classes
         scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
-
-        boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple)
+        boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")
 
         return boxes, scores
 
     def forward(
         self,
         head_outputs: List[Tensor],
-        anchors_tuple: Tuple[Tensor, Tensor, Tensor],
+        grids: List[Tensor],
+        shifts: List[Tensor],
     ) -> Tuple[Tensor, Tensor]:
         """
         Just concat the predict logits, ignore the original ``torchvision::nms`` module
@@ -370,16 +384,19 @@ def forward(
             head_outputs (List[Tensor]): The predicted locations and class/object confidence,
                 shape of the element is (N, A, H, W, K).
             anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
+            grids (List[Tensor]): Anchor grids.
+            shifts (List[Tensor]): Anchor shifts.
         """
         batch_size = len(head_outputs[0])
 
-        all_pred_logits = self._concat_pred_logits(head_outputs)
+        all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
 
         bbox_regression = []
         pred_scores = []
 
         for idx in range(batch_size):  # image idx, image inference
-            boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple)
+            pred_logits = all_pred_logits[idx]
+            boxes, scores = self._decode_pred_logits(pred_logits)
             bbox_regression.append(boxes)
             pred_scores.append(scores)
 
@@ -393,17 +410,19 @@ class PostProcess(LogitsDecoder):
 
     def __init__(
         self,
+        strides: List[int],
         score_thresh: float,
         nms_thresh: float,
         detections_per_img: int,
     ) -> None:
         """
         Args:
+            strides (List[int]): Strides of the AnchorGenerator.
             score_thresh (float): Score threshold used for postprocessing the detections.
             nms_thresh (float): NMS threshold used for postprocessing the detections.
             detections_per_img (int): Number of best detections to keep after NMS.
         """
-        super().__init__()
+        super().__init__(strides)
         self.score_thresh = score_thresh
         self.nms_thresh = nms_thresh
         self.detections_per_img = detections_per_img
@@ -411,7 +430,8 @@ def __init__(
     def forward(
         self,
         head_outputs: List[Tensor],
-        anchors_tuple: Tuple[Tensor, Tensor, Tensor],
+        grids: List[Tensor],
+        shifts: List[Tensor],
     ) -> List[Dict[str, Tensor]]:
         """
         Perform the computation. At test time, postprocess_detections is the final layer of YOLO.
@@ -422,25 +442,24 @@ def forward(
         Args:
             head_outputs (List[Tensor]): The predicted locations and class/object confidence,
                 shape of the element is (N, A, H, W, K).
-            anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
+            grids (List[Tensor]): Anchor grids.
+            shifts (List[Tensor]): Anchor shifts.
         """
         batch_size = len(head_outputs[0])
 
-        all_pred_logits = self._concat_pred_logits(head_outputs)
-
+        all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
         detections: List[Dict[str, Tensor]] = []
 
         for idx in range(batch_size):  # image idx, image inference
-            # Decode the predict logits
-            boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple)
-
+            pred_logits = all_pred_logits[idx]
+            boxes, scores = self._decode_pred_logits(pred_logits)
             # remove low scoring boxes
             inds, labels = torch.where(scores > self.score_thresh)
             boxes, scores = boxes[inds], scores[inds, labels]
 
             # non-maximum suppression, independently done per level
             keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
-            # Keep only topk scoring head_outputs
+            # keep only topk scoring head_outputs
             keep = keep[: self.detections_per_img]
             boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
 
diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py
index 34ba6ef8..80cbc965 100644
--- a/yolort/models/yolo.py
+++ b/yolort/models/yolo.py
@@ -122,7 +122,12 @@ def __init__(
         self.head = head
 
         if post_process is None:
-            post_process = PostProcess(score_thresh, nms_thresh, detections_per_img)
+            post_process = PostProcess(
+                anchor_generator.strides,
+                score_thresh,
+                nms_thresh,
+                detections_per_img,
+            )
         self.post_process = post_process
 
         # used only on torchscript mode
@@ -163,7 +168,7 @@ def forward(
         head_outputs = self.head(features)
 
         # create the set of anchors
-        anchors_tuple = self.anchor_generator(features)
+        grids, shifts = self.anchor_generator(features)
         losses = {}
         detections: List[Dict[str, Tensor]] = []
 
@@ -173,7 +178,7 @@ def forward(
             losses = self.compute_loss(targets, head_outputs)
         else:
             # compute the detections
-            detections = self.post_process(head_outputs, anchors_tuple)
+            detections = self.post_process(head_outputs, grids, shifts)
 
         if torch.jit.is_scripting():
             if not self._has_warned:
diff --git a/yolort/runtime/yolo_tensorrt_model.py b/yolort/runtime/yolo_tensorrt_model.py
index 87c69153..c7238b3e 100644
--- a/yolort/runtime/yolo_tensorrt_model.py
+++ b/yolort/runtime/yolo_tensorrt_model.py
@@ -5,7 +5,9 @@
 import torch
 from torch import nn, Tensor
 from yolort.models import YOLO
+from yolort.models.backbone_utils import darknet_pan_backbone
 from yolort.models.box_head import LogitsDecoder
+from yolort.utils import load_from_ultralytics
 
 __all__ = ["YOLOTRTModule"]
 
@@ -24,14 +26,31 @@ def __init__(
         version: str = "r6.0",
     ):
         super().__init__()
-        post_process = LogitsDecoder()
+        model_info = load_from_ultralytics(checkpoint_path, version=version)
 
-        self.model = YOLO.load_from_yolov5(
-            checkpoint_path,
+        backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
+        depth_multiple = model_info["depth_multiple"]
+        width_multiple = model_info["width_multiple"]
+        use_p6 = model_info["use_p6"]
+        backbone = darknet_pan_backbone(
+            backbone_name,
+            depth_multiple,
+            width_multiple,
             version=version,
+            use_p6=use_p6,
+        )
+        post_process = LogitsDecoder(model_info["strides"])
+        model = YOLO(
+            backbone,
+            model_info["num_classes"],
+            strides=model_info["strides"],
+            anchor_grids=model_info["anchor_grids"],
             post_process=post_process,
         )
 
+        model.load_state_dict(model_info["state_dict"])
+        self.model = model
+
     @torch.no_grad()
     def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
         """
diff --git a/yolort/v5/models/common.py b/yolort/v5/models/common.py
index 5021d14a..1d981b07 100644
--- a/yolort/v5/models/common.py
+++ b/yolort/v5/models/common.py
@@ -620,24 +620,10 @@ def render(self):
     def pandas(self):
         # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
         new = copy(self)  # return copy
-        ca = (
-            "xmin",
-            "ymin",
-            "xmax",
-            "ymax",
-            "confidence",
-            "class",
-            "name",
-        )  # xyxy columns
-        cb = (
-            "xcenter",
-            "ycenter",
-            "width",
-            "height",
-            "confidence",
-            "class",
-            "name",
-        )  # xywh columns
+        # xyxy columns
+        ca = ("xmin", "ymin", "xmax", "ymax", "confidence", "class", "name")
+        # xywh columns
+        cb = ("xcenter", "ycenter", "width", "height", "confidence", "class", "name")
         for k, c in zip(["xyxy", "xyxyn", "xywh", "xywhn"], [ca, ca, cb, cb]):
             # update
             a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)]
diff --git a/yolort/v5/utils/general.py b/yolort/v5/utils/general.py
index 3e68799f..6f23d613 100644
--- a/yolort/v5/utils/general.py
+++ b/yolort/v5/utils/general.py
@@ -511,7 +511,7 @@ def non_max_suppression(
     Runs Non-Maximum Suppression (NMS) on inference results
 
     Returns:
-         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
+        list of detections, on (n,6) tensor per image [xyxy, conf, cls]
     """
 
     nc = prediction.shape[2] - 5  # number of classes