Skip to content

Commit

Permalink
Separate out module LogitsDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Jan 24, 2022
1 parent 7b015ed commit 4a0c599
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 56 deletions.
59 changes: 59 additions & 0 deletions yolort/runtime/logits_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021, yolort team. All rights reserved.
from typing import List, Tuple
import torch
from torch import nn, Tensor
from yolort.models.box_head import _concat_pred_logits, _decode_pred_logits


class LogitsDecoder(nn.Module):
"""
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure for
exporting the ONNX Graph to YOLOTRTModule or others.
"""

def __init__(self, strides: List[int]) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
"""

super().__init__()
self.strides = strides

def forward(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tuple[Tensor, Tensor]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
from original ``yolort.models.box_head.PostProcess``.
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])
device = head_outputs[0].device
dtype = head_outputs[0].dtype
strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

# The default boxes tensor has shape [batch_size, number_boxes, 4].
boxes = torch.stack(bbox_regression)
scores = torch.stack(pred_scores)
return boxes, scores
59 changes: 3 additions & 56 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
from pathlib import Path
from typing import Optional, List, Tuple, Union
from typing import Optional, Tuple, Union

try:
import tensorrt as trt
Expand All @@ -23,9 +23,10 @@
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import _concat_pred_logits, _decode_pred_logits
from yolort.utils import load_from_ultralytics

from .logits_decoder import LogitsDecoder

logging.basicConfig(level=logging.INFO)
logging.getLogger("TRTHelper").setLevel(logging.INFO)
logger = logging.getLogger("TRTHelper")
Expand All @@ -34,60 +35,6 @@
__all__ = ["YOLOTRTModule", "EngineBuilder"]


class LogitsDecoder(nn.Module):
"""
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure for
exporting the ONNX Graph to YOLOTRTModule or others.
"""

def __init__(self, strides: List[int]) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
"""

super().__init__()
self.strides = strides

def forward(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tuple[Tensor, Tensor]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
from original ``yolort.models.box_head.PostProcess``.
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])
device = head_outputs[0].device
dtype = head_outputs[0].dtype
strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)

all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

# The default boxes tensor has shape [batch_size, number_boxes, 4].
boxes = torch.stack(bbox_regression)
scores = torch.stack(pred_scores)
return boxes, scores


class YOLOTRTModule(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.
Expand Down

0 comments on commit 4a0c599

Please sign in to comment.