Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix the latest compatibility issues
Browse files Browse the repository at this point in the history
zhiqwang committed Dec 20, 2021
1 parent aea2b0d commit 595d758
Showing 3 changed files with 76 additions and 46 deletions.
31 changes: 21 additions & 10 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,8 @@


class AnchorGenerator(nn.Module):
def __init__(
self,
strides: List[int],
anchor_grids: List[List[float]],
):
def __init__(self, strides: List[int], anchor_grids: List[List[float]]):

super().__init__()
assert len(strides) == len(anchor_grids)
self.strides = strides
@@ -26,7 +23,7 @@ def _generate_grids(
) -> List[Tensor]:

grids = []
for (height, width) in grid_sizes:
for height, width in grid_sizes:
# For output anchor, compute [x_center, y_center, x_center, y_center]
widths = torch.arange(width, dtype=torch.int32, device=device).to(dtype=dtype)
heights = torch.arange(height, dtype=torch.int32, device=device).to(dtype=dtype)
@@ -38,12 +35,26 @@ def _generate_grids(

return grids

def _generate_shifts(self) -> List[Tensor]:
return self.anchors.clone().view(self.num_layers, 1, -1, 1, 1, 2)
def _generate_shifts(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
) -> List[Tensor]:

shifts = []
for i, (height, width) in enumerate(grid_sizes):
shift = (
(self.anchors[i].clone() * self.strides[i])
.view((1, self.num_anchors, 1, 1, 2))
.expand((1, self.num_anchors, height, width, 2))
.to(dtype=dtype)
)
shifts.append(shift)
return shifts

def forward(self, feature_maps: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
dtype, device = feature_maps[0].dtype, feature_maps[0].device
grids = self._generate_grids(grid_sizes, dtype, device)
shifts = self._generate_shifts()
grids = self._generate_grids(grid_sizes, dtype=dtype, device=device)
shifts = self._generate_shifts(grid_sizes, dtype=dtype)
return grids, shifts
78 changes: 46 additions & 32 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops
from torchvision.ops import box_convert, boxes as box_ops

from . import _utils as det_utils

@@ -322,45 +322,58 @@ class LogitsDecoder(nn.Module):
This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
"""

def __init__(self) -> None:
def __init__(self, strides: List[int]) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
"""

super().__init__()
self.strides = strides

@staticmethod
def _concat_pred_logits(head_outputs: List[Tensor]) -> Tensor:
def _concat_pred_logits(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape

# Decode bounding box with the shifts and grids
all_pred_logits = []
for pred_logits in head_outputs:
pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(N, HWA, K)
all_pred_logits.append(pred_logits)

for i, head_output in enumerate(head_outputs):
head_feature = head_output.sigmoid()
pred_xy, pred_wh = det_utils.decode_single(
head_feature[..., :4],
grids[i],
shifts[i],
self.strides[i],
)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.reshape(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

return all_pred_logits

@staticmethod
def _decode_pred_logits(
pred_logits: Tensor,
idx: int,
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
):
def _decode_pred_logits(self, pred_logits: Tensor):
"""
Decode the prediction logit from the Post_precess
Decode the prediction logit from the PostPrecess.
"""
pred_logits = torch.sigmoid(pred_logits[idx])

# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]

boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple)
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores

def forward(
self,
head_outputs: List[Tensor],
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tuple[Tensor, Tensor]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
@@ -370,16 +383,18 @@ def forward(
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs)
all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple)
boxes, scores = self._decode_pred_logits(all_pred_logits[idx])
bbox_regression.append(boxes)
pred_scores.append(scores)

@@ -393,25 +408,28 @@ class PostProcess(LogitsDecoder):

def __init__(
self,
strides: List[int],
score_thresh: float,
nms_thresh: float,
detections_per_img: int,
) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""
super().__init__()
super().__init__(strides)
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img

def forward(
self,
head_outputs: List[Tensor],
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> List[Dict[str, Tensor]]:
"""
Perform the computation. At test time, postprocess_detections is the final layer of YOLO.
@@ -422,28 +440,24 @@ def forward(
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
grids (List[Tensor]): Anchor grids.
shifts (List[Tensor]): Anchor shifts.
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs)

all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
# Decode the predict logits
boxes, scores = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple)

boxes, scores = self._decode_pred_logits(all_pred_logits[idx])
# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]

# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# Keep only topk scoring head_outputs
# keep only topk scoring head_outputs
keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

detections.append({"scores": scores, "labels": labels, "boxes": boxes})

return detections
13 changes: 9 additions & 4 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@ def __init__(
criterion = SetCriterion(
anchor_generator.num_anchors,
anchor_generator.strides,
anchor_generator.anchor_grids,
anchor_generator.anchors,
num_classes,
)
self.compute_loss = criterion
@@ -122,7 +122,12 @@ def __init__(
self.head = head

if post_process is None:
post_process = PostProcess(score_thresh, nms_thresh, detections_per_img)
post_process = PostProcess(
anchor_generator.strides,
score_thresh,
nms_thresh,
detections_per_img,
)
self.post_process = post_process

# used only on torchscript mode
@@ -163,7 +168,7 @@ def forward(
head_outputs = self.head(features)

# create the set of anchors
anchors_tuple = self.anchor_generator(features)
grids, shifts = self.anchor_generator(features)
losses = {}
detections: List[Dict[str, Tensor]] = []

@@ -173,7 +178,7 @@ def forward(
losses = self.compute_loss(targets, head_outputs)
else:
# compute the detections
detections = self.post_process(head_outputs, anchors_tuple)
detections = self.post_process(head_outputs, grids, shifts)

if torch.jit.is_scripting():
if not self._has_warned:

0 comments on commit 595d758

Please sign in to comment.