diff --git a/test/test_models.py b/test/test_models.py index 3dc03415..4cfb6657 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -125,12 +125,13 @@ def _get_anchor_grids(use_p6: bool): ] return anchor_grids - def _compute_num_anchors(self, height, width, use_p6: bool): + def _compute_anchors(self, height, width, use_p6: bool): strides = self._get_strides(use_p6) - num_anchors = 0 + anchors_num = len(strides) + anchors_shape = [] for s in strides: - num_anchors += (height // s) * (width // s) - return num_anchors * 3 + anchors_shape.append((height // s, width // s)) + return anchors_num, anchors_shape def _get_feature_shapes(self, height, width, width_multiple=0.5, use_p6=False): in_channels = self._get_in_channels(width_multiple, use_p6) @@ -238,12 +239,14 @@ def test_anchor_generator(self, width_multiple, use_p6, batch_size, height, widt ) model = self._init_test_anchor_generator(use_p6) anchors = model(feature_maps) - expected_num_anchors = self._compute_num_anchors(height, width, use_p6) + expected_anchors_num, expected_anchors_shape = self._compute_anchors(height, width, use_p6) + + assert len(anchors) == 2 + assert len(anchors[0]) == len(anchors[1]) == expected_anchors_num + for i in range(expected_anchors_num): + assert tuple(anchors[0][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2) + assert tuple(anchors[1][i].shape) == (1, 3, *(expected_anchors_shape[i]), 2) - assert len(anchors) == 3 - assert tuple(anchors[0].shape) == (expected_num_anchors, 2) - assert tuple(anchors[1].shape) == (expected_num_anchors, 1) - assert tuple(anchors[2].shape) == (expected_num_anchors, 2) _check_jit_scriptable(model, (feature_maps,)) def _init_test_yolo_head(self, width_multiple=0.5, use_p6=False): @@ -269,29 +272,31 @@ def test_yolo_head(self): assert head_outputs[2].shape == target_head_outputs[2].shape _check_jit_scriptable(model, (feature_maps,)) - def _init_test_postprocessors(self): + def _init_test_postprocessors(self, strides): score_thresh = 0.5 nms_thresh = 0.45 detections_per_img = 100 - postprocessors = PostProcess(score_thresh, nms_thresh, detections_per_img) + postprocessors = PostProcess(strides, score_thresh, nms_thresh, detections_per_img) return postprocessors - def test_postprocessors(self): + @pytest.mark.parametrize("use_p6", [False, True]) + def test_postprocessors(self, use_p6): N, H, W = 4, 416, 352 - feature_maps = self._get_feature_maps(N, H, W) - head_outputs = self._get_head_outputs(N, H, W) + strides = self._get_strides(use_p6) + feature_maps = self._get_feature_maps(N, H, W, use_p6=use_p6) + head_outputs = self._get_head_outputs(N, H, W, use_p6=use_p6) - anchor_generator = self._init_test_anchor_generator() - anchors_tuple = anchor_generator(feature_maps) - model = self._init_test_postprocessors() - out = model(head_outputs, anchors_tuple) + anchor_generator = self._init_test_anchor_generator(use_p6=use_p6) + grids, shifts = anchor_generator(feature_maps) + model = self._init_test_postprocessors(strides) + out = model(head_outputs, grids, shifts) assert len(out) == N assert isinstance(out[0], dict) assert isinstance(out[0]["boxes"], Tensor) assert isinstance(out[0]["labels"], Tensor) assert isinstance(out[0]["scores"], Tensor) - _check_jit_scriptable(model, (head_outputs, anchors_tuple)) + _check_jit_scriptable(model, (head_outputs, grids, shifts)) def test_criterion(self, use_p6=False): N, H, W = 4, 640, 640 diff --git a/test/test_models_anchor_utils.py b/test/test_models_anchor_utils.py index 39e4609f..803de860 100644 --- a/test/test_models_anchor_utils.py +++ b/test/test_models_anchor_utils.py @@ -18,15 +18,13 @@ def test_anchor_generator(self): model.eval() anchors = model(features) - expected_anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]]) - expected_wh_output = torch.tensor([[4.0], [4.0], [4.0], [4.0]]) - expected_xy_output = torch.tensor([[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]]) + expected_grids = torch.tensor([[[[[0.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]]]]) + expected_shifts = torch.tensor([[[[[6.0, 14.0], [6.0, 14.0]], [[6.0, 14.0], [6.0, 14.0]]]]]) - assert len(anchors) == 3 - assert tuple(anchors[0].shape) == (4, 2) - assert tuple(anchors[1].shape) == (4, 1) - assert tuple(anchors[2].shape) == (4, 2) + assert len(anchors) == 2 + assert len(anchors[0]) == len(anchors[1]) == 1 + assert tuple(anchors[0][0].shape) == (1, 1, 2, 2, 2) + assert tuple(anchors[1][0].shape) == (1, 1, 2, 2, 2) - torch.testing.assert_close(anchors[0], expected_anchor_output, rtol=0, atol=0) - torch.testing.assert_close(anchors[1], expected_wh_output, rtol=0, atol=0) - torch.testing.assert_close(anchors[2], expected_xy_output, rtol=0, atol=0) + torch.testing.assert_close(anchors[0][0], expected_grids) + torch.testing.assert_close(anchors[1][0], expected_shifts) diff --git a/yolort/models/_utils.py b/yolort/models/_utils.py index a53c06e0..77a97544 100644 --- a/yolort/models/_utils.py +++ b/yolort/models/_utils.py @@ -3,12 +3,13 @@ import torch from torch import nn, Tensor -from torchvision.ops import box_convert, box_iou +from torchvision.ops import box_iou def _evaluate_iou(target, pred): """ - Evaluate intersection over union (IOU) for target from dataset and output prediction from model + Evaluate intersection over union (IOU) for target from dataset and + output prediction from model """ if pred["boxes"].shape[0] == 0: # no box detected, 0 IOU @@ -34,8 +35,7 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor: """ - Encode a set of anchors with respect to some - reference boxes + Encode a set of anchors with respect to some reference boxes Args: reference_boxes (Tensor): reference boxes @@ -52,23 +52,24 @@ def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor: def decode_single( rel_codes: Tensor, - anchors_tuple: Tuple[Tensor, Tensor, Tensor], -) -> Tensor: + grid: Tensor, + shift: Tensor, + stride: int, +) -> Tuple[Tensor, Tensor]: """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. - Arguments: - rel_codes (Tensor): encoded boxes - anchors_tupe (Tensor, Tensor, Tensor): reference boxes. + Args: + rel_codes (Tensor): Encoded boxes + grid (Tensor): Anchor grids + shift (Tensor): Anchor shifts + stride (int): Stride """ + pred_xy = (rel_codes[..., 0:2] * 2.0 - 0.5 + grid) * stride + pred_wh = (rel_codes[..., 2:4] * 2.0) ** 2 * shift - pred_wh = (rel_codes[..., 0:2] * 2.0 + anchors_tuple[0]) * anchors_tuple[1] # wh - pred_xy = (rel_codes[..., 2:4] * 2) ** 2 * anchors_tuple[2] # xy - pred_boxes = torch.cat([pred_wh, pred_xy], dim=1) - pred_boxes = box_convert(pred_boxes, in_fmt="cxcywh", out_fmt="xyxy") - - return pred_boxes + return pred_xy, pred_wh def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-7): @@ -99,8 +100,10 @@ def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e- iou = inter / union - cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width - ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + # convex (smallest enclosing box) width + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) + # convex height + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared rho2 = ( @@ -149,7 +152,7 @@ def forward(self, pred, logit): if self.reduction == "mean": return loss.mean() - elif self.reduction == "sum": + if self.reduction == "sum": return loss.sum() - else: # 'none' - return loss + # 'none' + return loss diff --git a/yolort/models/anchor_utils.py b/yolort/models/anchor_utils.py index ca049498..23dafa4e 100644 --- a/yolort/models/anchor_utils.py +++ b/yolort/models/anchor_utils.py @@ -6,89 +6,61 @@ class AnchorGenerator(nn.Module): - def __init__( - self, - strides: List[int], - anchor_grids: List[List[float]], - ): + def __init__(self, strides: List[int], anchor_grids: List[List[float]]): + super().__init__() assert len(strides) == len(anchor_grids) - self.num_anchors = len(anchor_grids[0]) // 2 self.strides = strides self.anchor_grids = anchor_grids + self.num_layers = len(anchor_grids) + self.num_anchors = len(anchor_grids[0]) // 2 - def set_wh_weights( + def _generate_grids( self, grid_sizes: List[List[int]], dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), - ) -> Tensor: - - wh_weights = [] - - for size, stride in zip(grid_sizes, self.strides): - grid_height, grid_width = size - stride = torch.as_tensor([stride], dtype=dtype, device=device) - stride = stride.view(-1, 1) - stride = stride.repeat(1, grid_height * grid_width * self.num_anchors) - stride = stride.reshape(-1, 1) - wh_weights.append(stride) + ) -> List[Tensor]: - return torch.cat(wh_weights) - - def set_xy_weights( - self, - grid_sizes: List[List[int]], - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), - ) -> Tensor: + grids = [] + for height, width in grid_sizes: + # For output anchor, compute [x_center, y_center, x_center, y_center] + widths = torch.arange(width, dtype=torch.int32, device=device).to(dtype=dtype) + heights = torch.arange(height, dtype=torch.int32, device=device).to(dtype=dtype) - xy_weights = [] + shift_y, shift_x = torch.meshgrid(heights, widths) - for size, anchor_grid in zip(grid_sizes, self.anchor_grids): - grid_height, grid_width = size - anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device) - anchor_grid = anchor_grid.view(-1, 2) - anchor_grid = anchor_grid.repeat(1, grid_height * grid_width) - anchor_grid = anchor_grid.reshape(-1, 2) - xy_weights.append(anchor_grid) + grid = torch.stack((shift_x, shift_y), 2).expand((1, self.num_anchors, height, width, 2)) + grids.append(grid) - return torch.cat(xy_weights) + return grids - def grid_anchors( + def _generate_shifts( self, grid_sizes: List[List[int]], dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), - ) -> Tensor: - - anchors = [] - - for size in grid_sizes: - grid_height, grid_width = size - - # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device).to(dtype=dtype) - shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device).to(dtype=dtype) - - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) - - shifts = torch.stack((shift_x, shift_y), dim=2) - shifts = shifts.view(1, grid_height, grid_width, 2) - shifts = shifts.repeat(self.num_anchors, 1, 1, 1) - shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device) - shifts = shifts.reshape(-1, 2) - - anchors.append(shifts) - - return torch.cat(anchors) - - def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]: + ) -> List[Tensor]: + + anchors = torch.tensor(self.anchor_grids, dtype=dtype, device=device) + strides = torch.tensor(self.strides, dtype=dtype, device=device) + anchors = anchors.view(self.num_layers, -1, 2) / strides.view(-1, 1, 1) + + shifts = [] + for i, (height, width) in enumerate(grid_sizes): + shift = ( + (anchors[i].clone() * self.strides[i]) + .view((1, self.num_anchors, 1, 1, 2)) + .expand((1, self.num_anchors, height, width, 2)) + .contiguous() + .to(dtype=dtype) + ) + shifts.append(shift) + return shifts + + def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) dtype, device = feature_maps[0].dtype, feature_maps[0].device - - wh_weights = self.set_wh_weights(grid_sizes, dtype, device) - xy_weights = self.set_xy_weights(grid_sizes, dtype, device) - anchors = self.grid_anchors(grid_sizes, dtype, device) - - return anchors, wh_weights, xy_weights + grids = self._generate_grids(grid_sizes, dtype=dtype, device=device) + shifts = self._generate_shifts(grid_sizes, dtype=dtype, device=device) + return grids, shifts diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 7da0cc7c..b044f616 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from torch import nn, Tensor -from torchvision.ops import boxes as box_ops +from torchvision.ops import box_convert, boxes as box_ops from . import _utils as det_utils @@ -70,7 +70,8 @@ def forward(self, x: List[Tensor]) -> List[Tensor]: # Permute output from (N, A * K, H, W) to (N, A, H, W, K) N, _, H, W = pred_logits.shape pred_logits = pred_logits.view(N, self.num_anchors, -1, H, W) - pred_logits = pred_logits.permute(0, 1, 3, 4, 2) # Size=(N, A, H, W, K) + # Size=(N, A, H, W, K) + pred_logits = pred_logits.permute(0, 1, 3, 4, 2).contiguous() all_pred_logits.append(pred_logits) @@ -322,45 +323,58 @@ class LogitsDecoder(nn.Module): This is a simplified version of PostProcess to remove the ``torchvision::nms`` module. """ - def __init__(self) -> None: + def __init__(self, strides: List[int]) -> None: + """ + Args: + strides (List[int]): Strides of the AnchorGenerator. + """ + super().__init__() + self.strides = strides - @staticmethod - def _concat_pred_logits(head_outputs: List[Tensor]) -> Tensor: + def _concat_pred_logits( + self, + head_outputs: List[Tensor], + grids: List[Tensor], + shifts: List[Tensor], + ) -> Tensor: # Concat all pred logits batch_size, _, _, _, K = head_outputs[0].shape + # Decode bounding box with the shifts and grids all_pred_logits = [] - for pred_logits in head_outputs: - pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(N, HWA, K) - all_pred_logits.append(pred_logits) + + for i, head_output in enumerate(head_outputs): + head_feature = torch.sigmoid(head_output) + pred_xy, pred_wh = det_utils.decode_single( + head_feature[..., :4], + grids[i], + shifts[i], + self.strides[i], + ) + pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1) + all_pred_logits.append(pred_logits.view(batch_size, -1, K)) all_pred_logits = torch.cat(all_pred_logits, dim=1) + return all_pred_logits - @staticmethod - def _decode_pred_logits( - pred_logits: Tensor, - idx: int, - anchors_tuple: Tuple[Tensor, Tensor, Tensor], - ): + def _decode_pred_logits(self, pred_logits: Tensor): """ - Decode the prediction logit from the Post_precess + Decode the prediction logit from the PostPrecess. """ - pred_logits = torch.sigmoid(pred_logits[idx]) - # Compute conf # box_conf x class_conf, w/ shape: num_anchors x num_classes scores = pred_logits[:, 5:] * pred_logits[:, 4:5] - - boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple) + boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy") return boxes, scores def forward( self, head_outputs: List[Tensor], - anchors_tuple: Tuple[Tensor, Tensor, Tensor], + grids: List[Tensor], + shifts: List[Tensor], ) -> Tuple[Tensor, Tensor]: """ Just concat the predict logits, ignore the original ``torchvision::nms`` module @@ -370,16 +384,19 @@ def forward( head_outputs (List[Tensor]): The predicted locations and class/object confidence, shape of the element is (N, A, H, W, K). anchors_tuple (Tuple[Tensor, Tensor, Tensor]): + grids (List[Tensor]): Anchor grids. + shifts (List[Tensor]): Anchor shifts. """ batch_size = len(head_outputs[0]) - all_pred_logits = self._concat_pred_logits(head_outputs) + all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts) bbox_regression = [] pred_scores = [] for idx in range(batch_size): # image idx, image inference - boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple) + pred_logits = all_pred_logits[idx] + boxes, scores = self._decode_pred_logits(pred_logits) bbox_regression.append(boxes) pred_scores.append(scores) @@ -393,17 +410,19 @@ class PostProcess(LogitsDecoder): def __init__( self, + strides: List[int], score_thresh: float, nms_thresh: float, detections_per_img: int, ) -> None: """ Args: + strides (List[int]): Strides of the AnchorGenerator. score_thresh (float): Score threshold used for postprocessing the detections. nms_thresh (float): NMS threshold used for postprocessing the detections. detections_per_img (int): Number of best detections to keep after NMS. """ - super().__init__() + super().__init__(strides) self.score_thresh = score_thresh self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img @@ -411,7 +430,8 @@ def __init__( def forward( self, head_outputs: List[Tensor], - anchors_tuple: Tuple[Tensor, Tensor, Tensor], + grids: List[Tensor], + shifts: List[Tensor], ) -> List[Dict[str, Tensor]]: """ Perform the computation. At test time, postprocess_detections is the final layer of YOLO. @@ -422,25 +442,24 @@ def forward( Args: head_outputs (List[Tensor]): The predicted locations and class/object confidence, shape of the element is (N, A, H, W, K). - anchors_tuple (Tuple[Tensor, Tensor, Tensor]): + grids (List[Tensor]): Anchor grids. + shifts (List[Tensor]): Anchor shifts. """ batch_size = len(head_outputs[0]) - all_pred_logits = self._concat_pred_logits(head_outputs) - + all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts) detections: List[Dict[str, Tensor]] = [] for idx in range(batch_size): # image idx, image inference - # Decode the predict logits - boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple) - + pred_logits = all_pred_logits[idx] + boxes, scores = self._decode_pred_logits(pred_logits) # remove low scoring boxes inds, labels = torch.where(scores > self.score_thresh) boxes, scores = boxes[inds], scores[inds, labels] # non-maximum suppression, independently done per level keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) - # Keep only topk scoring head_outputs + # keep only topk scoring head_outputs keep = keep[: self.detections_per_img] boxes, scores, labels = boxes[keep], scores[keep], labels[keep] diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 34ba6ef8..80cbc965 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -122,7 +122,12 @@ def __init__( self.head = head if post_process is None: - post_process = PostProcess(score_thresh, nms_thresh, detections_per_img) + post_process = PostProcess( + anchor_generator.strides, + score_thresh, + nms_thresh, + detections_per_img, + ) self.post_process = post_process # used only on torchscript mode @@ -163,7 +168,7 @@ def forward( head_outputs = self.head(features) # create the set of anchors - anchors_tuple = self.anchor_generator(features) + grids, shifts = self.anchor_generator(features) losses = {} detections: List[Dict[str, Tensor]] = [] @@ -173,7 +178,7 @@ def forward( losses = self.compute_loss(targets, head_outputs) else: # compute the detections - detections = self.post_process(head_outputs, anchors_tuple) + detections = self.post_process(head_outputs, grids, shifts) if torch.jit.is_scripting(): if not self._has_warned: diff --git a/yolort/runtime/yolo_tensorrt_model.py b/yolort/runtime/yolo_tensorrt_model.py index 87c69153..c7238b3e 100644 --- a/yolort/runtime/yolo_tensorrt_model.py +++ b/yolort/runtime/yolo_tensorrt_model.py @@ -5,7 +5,9 @@ import torch from torch import nn, Tensor from yolort.models import YOLO +from yolort.models.backbone_utils import darknet_pan_backbone from yolort.models.box_head import LogitsDecoder +from yolort.utils import load_from_ultralytics __all__ = ["YOLOTRTModule"] @@ -24,14 +26,31 @@ def __init__( version: str = "r6.0", ): super().__init__() - post_process = LogitsDecoder() + model_info = load_from_ultralytics(checkpoint_path, version=version) - self.model = YOLO.load_from_yolov5( - checkpoint_path, + backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" + depth_multiple = model_info["depth_multiple"] + width_multiple = model_info["width_multiple"] + use_p6 = model_info["use_p6"] + backbone = darknet_pan_backbone( + backbone_name, + depth_multiple, + width_multiple, version=version, + use_p6=use_p6, + ) + post_process = LogitsDecoder(model_info["strides"]) + model = YOLO( + backbone, + model_info["num_classes"], + strides=model_info["strides"], + anchor_grids=model_info["anchor_grids"], post_process=post_process, ) + model.load_state_dict(model_info["state_dict"]) + self.model = model + @torch.no_grad() def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: """ diff --git a/yolort/v5/models/common.py b/yolort/v5/models/common.py index 5021d14a..1d981b07 100644 --- a/yolort/v5/models/common.py +++ b/yolort/v5/models/common.py @@ -620,24 +620,10 @@ def render(self): def pandas(self): # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) new = copy(self) # return copy - ca = ( - "xmin", - "ymin", - "xmax", - "ymax", - "confidence", - "class", - "name", - ) # xyxy columns - cb = ( - "xcenter", - "ycenter", - "width", - "height", - "confidence", - "class", - "name", - ) # xywh columns + # xyxy columns + ca = ("xmin", "ymin", "xmax", "ymax", "confidence", "class", "name") + # xywh columns + cb = ("xcenter", "ycenter", "width", "height", "confidence", "class", "name") for k, c in zip(["xyxy", "xyxyn", "xywh", "xywhn"], [ca, ca, cb, cb]): # update a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] diff --git a/yolort/v5/utils/general.py b/yolort/v5/utils/general.py index 3e68799f..6f23d613 100644 --- a/yolort/v5/utils/general.py +++ b/yolort/v5/utils/general.py @@ -511,7 +511,7 @@ def non_max_suppression( Runs Non-Maximum Suppression (NMS) on inference results Returns: - list of detections, on (n,6) tensor per image [xyxy, conf, cls] + list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ nc = prediction.shape[2] - 5 # number of classes