From 791d9384a80cd15cab488e666008384e202154ea Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 16 Nov 2021 19:49:48 +0800 Subject: [PATCH 01/60] add fcos --- torchvision/models/detection/fcos.py | 111 +++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 torchvision/models/detection/fcos.py diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py new file mode 100644 index 00000000000..4e6021a410f --- /dev/null +++ b/torchvision/models/detection/fcos.py @@ -0,0 +1,111 @@ +import math +import warnings +from collections import OrderedDict +from typing import Dict, List, Tuple, Optional + +import torch +from torch import nn, Tensor + +from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import sigmoid_focal_loss +from ...ops import boxes as box_ops +from ...ops import misc as misc_nn_ops +from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...utils import _log_api_usage_once +from ..resnet import resnet50 +from . import _utils as det_utils +from ._utils import overwrite_eps +from .anchor_utils import AnchorGenerator +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .transform import GeneralizedRCNNTransform + + + +class FCOSClassificationHead(nn.Module): + """ + A classification head for use in FCOS. + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + num_convs (int): number of conv layer + """ + + def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_probability=0.01, norm_layer=None): + super().__init__() + + if norm_layer is None: + norm_layer = lambda channels: nn.GroupNorm(32, channels) + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + self.num_classes = num_classes + self.num_anchors = num_anchors + + def compute_loss(self, targets, head_outputs, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor + losses = [] + + cls_logits = head_outputs["cls_logits"] + + for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): + # determine only the foreground + foreground_idxs_per_image = matched_idxs_per_image >= 0 + num_foreground = foreground_idxs_per_image.sum() + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # compute the classification loss + losses.append( + sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction="sum", + ) + / max(1, num_foreground) + ) + + return _sum(losses) / len(targets) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class FCOS(nn.Module): + pass From c9a00c14471b22178babb6be53c9f342966f2607 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 16 Nov 2021 21:34:19 +0800 Subject: [PATCH 02/60] update fcos --- torchvision/models/detection/fcos.py | 558 ++++++++++++++++++++++++++- 1 file changed, 543 insertions(+), 15 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 4e6021a410f..90449d527c1 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -1,7 +1,7 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Optional, Any import torch from torch import nn, Tensor @@ -18,8 +18,51 @@ from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform +from .retinanet import _sum +def pairwise_point_box_distance(points: torch.Tensor, boxes: torch.Tensor): + """ + Pairwise distance between N points and M boxes. The distance between a + point and a box is represented by the distance from the point to 4 edges + of the box. Distances are all positive when the point is inside the box. + Args: + points: Nx2 coordinates. Each row is (x, y) + boxes: M boxes + Returns: + Tensor: distances of size (N, M, 4). The 4 values are distances from + the point to the left, top, right, bottom of the box. + """ + x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) + return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) + + +class FCOSHead(nn.Module): + """ + A regression and classification head for use in RetinaNet. + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes): + super().__init__() + self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = FCOSRegressionHead(in_channels, num_anchors) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] + return { + "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), + "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + } + + def forward(self, x): + # type: (List[Tensor]) -> Dict[str, Tensor] + return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} + class FCOSClassificationHead(nn.Module): """ @@ -33,13 +76,17 @@ class FCOSClassificationHead(nn.Module): def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_probability=0.01, norm_layer=None): super().__init__() + + self.num_classes = num_classes + self.num_anchors = num_anchors + assert self.num_anchors == 1 # FCOS is anchor-free if norm_layer is None: norm_layer = lambda channels: nn.GroupNorm(32, channels) conv = [] for _ in range(num_convs): conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) - conv.append(norm_layer(in_channels) + conv.append(norm_layer(in_channels)) conv.append(nn.ReLU()) self.conv = nn.Sequential(*conv) @@ -52,9 +99,6 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro torch.nn.init.normal_(self.cls_logits.weight, std=0.01) torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) - self.num_classes = num_classes - self.num_anchors = num_anchors - def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor losses = [] @@ -73,17 +117,10 @@ def compute_loss(self, targets, head_outputs, matched_idxs): targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], ] = 1.0 - # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS - # compute the classification loss losses.append( - sigmoid_focal_loss( - cls_logits_per_image[valid_idxs_per_image], - gt_classes_target[valid_idxs_per_image], - reduction="sum", - ) - / max(1, num_foreground) + sigmoid_focal_loss(cls_logits_per_image, gt_classes_target, + reduction="sum") / max(1, num_foreground) ) return _sum(losses) / len(targets) @@ -107,5 +144,496 @@ def forward(self, x): return torch.cat(all_cls_logits, dim=1) +class FCOSRegressionHead(nn.Module): + """ + A regression head for use in FCOS. + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + + __annotations__ = { + "box_coder": det_utils.BoxCoder, + } + + def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): + super().__init__() + + if norm_layer is None: + norm_layer = lambda channels: nn.GroupNorm(32, channels) + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + self.ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) + for layer in [self.bbox_reg, self.ctrness]: + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor + losses = [] + + bbox_regression = head_outputs["bbox_regression"] + + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip( + targets, bbox_regression, anchors, matched_idxs + ): + # determine only the foreground indices, ignore the rest + foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] + num_foreground = foreground_idxs_per_image.numel() + + # select only the foreground boxes + matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]] + bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] + anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] + + # compute the regression targets + target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) + + # compute the loss + losses.append( + torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + / max(1, num_foreground) + ) + + return _sum(losses) / max(1, len(targets)) + + def forward(self, x): + # type: (List[Tensor]) -> Tensor + all_bbox_regression = [] + + for features in x: + bbox_regression = self.conv(features) + bbox_regression = self.bbox_reg(bbox_regression) + + # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + + all_bbox_regression.append(bbox_regression) + + return torch.cat(all_bbox_regression, dim=1) + + class FCOS(nn.Module): - pass + """ + Implements FCOS. + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + The behavior of the model changes depending if it is in training or evaluation mode. + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + center_sampling_radius (int): radius of the "center" of a groundtruth box, + within which all anchor points are labeled positive. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + topk_candidates (int): Number of best detections to keep before NMS. + Example: + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> # RetinaNet needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator( + >>> sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),) + >>> ) + >>> + >>> # put the pieces together inside a RetinaNet model + >>> model = RetinaNet(backbone, + >>> num_classes=2, + >>> anchor_generator=anchor_generator) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + __annotations__ = { + "box_coder": det_utils.BoxCoder + } + + def __init__( + self, + backbone, + num_classes, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # Anchor parameters + anchor_generator=None, + head=None, + center_sampling_radius=1.5, + score_thresh=0.05, + nms_thresh=0.5, + detections_per_img=300, + topk_candidates=1000, + ): + super().__init__() + _log_api_usage_once(self) + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)" + ) + self.backbone = backbone + + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + + if anchor_generator is None: + anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) + aspect_ratios = ((1.0),) * len(anchor_sizes) + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + self.anchor_generator = anchor_generator + + if head is None: + head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + self.head = head + + self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + self.center_sampling_radius = center_sampling_radius + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): + # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[int]) -> Dict[str, Tensor] + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) + continue + + gt_boxes = targets_per_image["boxes"] + gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2 + anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 + anchor_sizes = anchors_per_image.tensor[:, 2] - anchors_per_image.tensor[:, 0] + # center sampling: anchor point must be close enough to gt center. + pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( + dim=2 + ).values < self.center_sampling_radius * anchor_sizes[:, None] + pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes) + + # anchor point must be inside gt + pairwise_match &= pairwise_dist.min(dim=2).values > 0 + + # each anchor is only responsible for certain scale range. + lower_bound = anchor_sizes * 4 + lower_bound[: num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1] :] = float("inf") + pairwise_dist = pairwise_dist.max(dim=2).values + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & ( + pairwise_dist < upper_bound[:, None] + ) + + # Match the GT box with minimum area, if there are multiple GT matches + gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N + pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) + min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match + matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1 + + matched_idxs.append(matched_idx) + + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + + def postprocess_detections(self, head_outputs, anchors, image_shapes): + # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] + + num_images = len(image_shapes) + + detections: List[Dict[str, Tensor]] = [] + + for index in range(num_images): + box_regression_per_image = [br[index] for br in box_regression] + logits_per_image = [cl[index] for cl in class_logits] + anchors_per_image, image_shape = anchors[index], image_shapes[index] + + image_boxes = [] + image_scores = [] + image_labels = [] + + for box_regression_per_level, logits_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, anchors_per_image + ): + num_classes = logits_per_level.shape[-1] + + # remove low scoring boxes + scores_per_level = torch.sigmoid(logits_per_level).flatten() + keep_idxs = scores_per_level > self.score_thresh + scores_per_level = scores_per_level[keep_idxs] + topk_idxs = torch.where(keep_idxs)[0] + + # keep only topk scoring predictions + num_topk = min(self.topk_candidates, topk_idxs.size(0)) + scores_per_level, idxs = scores_per_level.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") + labels_per_level = topk_idxs % num_classes + + boxes_per_level = self.box_coder.decode_single( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) + boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) + + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) + + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + """ + Args: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + if self.training: + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") + else: + raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + + # get the original image sizes + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError( + "All bounding boxes should have positive height and width." + f" Found invalid box {degen_bb} for target at index {target_idx}." + ) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + + # TODO: Do we want a list or a dict? + features = list(features.values()) + + # compute the retinanet heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + # recover level sizes + num_anchors_per_level = [x.size(2) * x.size(3) for x in features] + + losses = {} + detections: List[Dict[str, Tensor]] = [] + if self.training: + assert targets is not None + + # compute the losses + losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level) + else: + # split outputs per level + split_head_outputs: Dict[str, List[Tensor]] = {} + for k in head_outputs: + split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) + split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] + + # compute the detections + detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + return self.eager_outputs(losses, detections) + + +model_urls = { + "fcos_resnet50_fpn_coco": "", +} + + +def fcos_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): + """ + Constructs a FCOS model with a ResNet-50-FPN backbone. + Reference: `"Focal Loss for Dense Object Detection" `_. + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + The behavior of the model changes depending if it is in training or evaluation mode. + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detections: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each detection + - scores (``Tensor[N]``): the scores of each detection + For more details on the output, you may refer to :ref:`instance_seg_output`. + Example:: + >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is + passed (the default) this value is set to 3. + """ + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) + + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + # skip P2 because it generates too many anchors (according to their paper) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = FCOS(backbone, num_classes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) + model.load_state_dict(state_dict) + overwrite_eps(model, 0.0) + return model + From e82f825e0caee32306ec1447a66e96b4bdbbed57 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 19:53:23 +0800 Subject: [PATCH 03/60] add giou_loss --- torchvision/ops/giou_loss.py | 59 ++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 torchvision/ops/giou_loss.py diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py new file mode 100644 index 00000000000..04f8e7b7b48 --- /dev/null +++ b/torchvision/ops/giou_loss.py @@ -0,0 +1,59 @@ +import torch + +# copy from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/giou_loss.py +def giou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) + https://arxiv.org/abs/1902.09630 + Gradient-friendly IoU loss with an additional penalty that is non-zero when the + boxes do not overlap and scales with the size of their smallest enclosing box. + This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + iouk = intsctk / (unionk + eps) + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + miouk = iouk - ((area_c - unionk) / (area_c + eps)) + + loss = 1 - miouk + + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss From 391b8c9ecbb16a0c344d6f03986b42286bb9d2df Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 19:55:02 +0800 Subject: [PATCH 04/60] add BoxLinearCoder for FCOS --- torchvision/models/detection/_utils.py | 89 ++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b870e6a2456..5c9e988fd68 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -216,6 +216,95 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) return pred_boxes + +class BoxLinearCoder: + """ + The linear box-to-box transform defined in FCOS. The transformation is parameterized + by the distance from the center of (square) src box to 4 edges of the target box. + """ + + def __init__(self, normalize_by_size=True) -> None: + """ + Args: + normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes. + """ + self.normalize_by_size = normalize_by_size + + def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]: + boxes_per_image = [len(b) for b in reference_boxes] + reference_boxes = torch.cat(reference_boxes, dim=0) + proposals = torch.cat(proposals, dim=0) + targets = self.encode_single(reference_boxes, proposals) + return targets.split(boxes_per_image, 0) + + def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: + """ + Encode a set of proposals with respect to some + reference boxes + Args: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + """ + # get the center of reference_boxes + reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2]) + reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3]) + + # get box regression transformation deltas + target_l = reference_boxes_ctr_x - proposals[:, 0] + target_t = reference_boxes_ctr_y - proposals[:, 1] + target_r = proposals[:, 2] - reference_boxes_ctr_x + target_b = proposals[:, 3] - reference_boxes_ctr_y + + targets = torch.stack((target_l, target_t, target_r, target_b), dim=1) + if self.normalize_by_size: + stride_w = reference_boxes[:, 2] - reference_boxes[:, 0] + stride_h = reference_boxes[:, 3] - reference_boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + targets = targets / strides + + return targets + + def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: + assert isinstance(boxes, (list, tuple)) + assert isinstance(rel_codes, torch.Tensor) + boxes_per_image = [b.size(0) for b in boxes] + concat_boxes = torch.cat(boxes, dim=0) + box_sum = 0 + for val in boxes_per_image: + box_sum += val + if box_sum > 0: + rel_codes = rel_codes.reshape(box_sum, -1) + pred_boxes = self.decode_single(rel_codes, concat_boxes) + if box_sum > 0: + pred_boxes = pred_boxes.reshape(box_sum, -1, 4) + return pred_boxes + + def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + Args: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + """ + + boxes = boxes.to(rel_codes.dtype) + + ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) + ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) + if self.normalize_by_size: + stride_w = boxes[:, 2] - boxes[:, 0] + stride_h = boxes[:, 3] - boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + rel_codes = rel_codes * strides + + pred_boxes1 = ctr_x - rel_codes[:, 0] + pred_boxes2 = ctr_y - rel_codes[:, 1] + pred_boxes3 = ctr_x + rel_codes[:, 2] + pred_boxes4 = ctr_y + rel_codes[:, 3] + pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1) + return pred_boxes + class Matcher: """ From b01f145930190b21f0b9af68bb28c71c5d4d3a47 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 20:07:14 +0800 Subject: [PATCH 05/60] add full code for FCOS --- torchvision/models/detection/fcos.py | 237 ++++++++++++--------------- 1 file changed, 109 insertions(+), 128 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 90449d527c1..1428ae169b0 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops import sigmoid_focal_loss +from ...ops import sigmoid_focal_loss, giou_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 @@ -18,50 +18,82 @@ from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform -from .retinanet import _sum - - -def pairwise_point_box_distance(points: torch.Tensor, boxes: torch.Tensor): - """ - Pairwise distance between N points and M boxes. The distance between a - point and a box is represented by the distance from the point to 4 edges - of the box. Distances are all positive when the point is inside the box. - Args: - points: Nx2 coordinates. Each row is (x, y) - boxes: M boxes - Returns: - Tensor: distances of size (N, M, 4). The 4 values are distances from - the point to the left, top, right, bottom of the box. - """ - x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1) - x0, y0, x1, y1 = boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) - return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) class FCOSHead(nn.Module): """ - A regression and classification head for use in RetinaNet. + A regression and classification head for use in FCOS. Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted num_classes (int): number of classes to be predicted + num_convs (int): number of conv layer of head """ - def __init__(self, in_channels, num_anchors, num_classes): + def __init__(self, in_channels, num_anchors, num_classes, num_convs=4): super().__init__() - self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes) - self.regression_head = FCOSRegressionHead(in_channels, num_anchors) + self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) + self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs) + + def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): + + cls_logits = head_outputs["cls_logits"] # [N, K, C] + bbox_regression = head_outputs["bbox_regression"] # [N, K, 4] + bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1] + + all_gt_classes_targets = [] + all_gt_boxes_targets = [] + for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs): + gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] + gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud + gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] + all_gt_classes_targets.append(gt_classes_targets) + all_gt_boxes_targets.append(gt_boxes_targets) + + all_gt_classes_targets = torch.stack(all_gt_classes_targets) + # compute foregroud + foregroud_mask = all_gt_classes_targets >= 0 + num_foreground = foregroud_mask.sum().item() + + # classification loss + gt_classes_targets = torch.zeros_like(cls_logits) + gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0 + loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") + + # regression loss: GIoU loss + pred_boxes = [box_coder.decode_single(bbox_regression_per_image, anchors_per_image) \ + for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)] + # amp issue: pred_boxes need to convert float + loss_bbox_reg = giou_loss(torch.stack(pred_boxes)[foregroud_mask].float(), + torch.stack(all_gt_boxes_targets)[foregroud_mask], reduction='sum') + + # ctrness loss + bbox_reg_targets = [box_coder.encode_single(anchors_per_image, boxes_targets_per_image) \ + for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)] + bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) + if len(bbox_reg_targets) == 0: + bbox_reg_targets.new_zeros(len(bbox_reg_targets)) + left_right = bbox_reg_targets[:, :, [0, 2]] + top_bottom = bbox_reg_targets[:, :, [1, 3]] + gt_ctrness_targets = torch.sqrt((left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] + )) + pred_centerness = bbox_ctrness.squeeze(dim=2) + loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits( + pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum" + ) - def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] return { - "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), - "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + "classification": loss_cls / max(1, num_foreground), + "bbox_regression": loss_bbox_reg / max(1, num_foreground), + "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground) } def forward(self, x): # type: (List[Tensor]) -> Dict[str, Tensor] - return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} + cls_logits = self.classification_head(x) + bbox_regression, bbox_ctrness = self.regression_head(x) + return {"cls_logits": cls_logits, "bbox_regression": bbox_regression, "bbox_ctrness": bbox_ctrness} class FCOSClassificationHead(nn.Module): @@ -79,7 +111,6 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro self.num_classes = num_classes self.num_anchors = num_anchors - assert self.num_anchors == 1 # FCOS is anchor-free if norm_layer is None: norm_layer = lambda channels: nn.GroupNorm(32, channels) @@ -99,32 +130,6 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro torch.nn.init.normal_(self.cls_logits.weight, std=0.01) torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) - def compute_loss(self, targets, head_outputs, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor - losses = [] - - cls_logits = head_outputs["cls_logits"] - - for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): - # determine only the foreground - foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() - - # create the target classification - gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[ - foreground_idxs_per_image, - targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], - ] = 1.0 - - # compute the classification loss - losses.append( - sigmoid_focal_loss(cls_logits_per_image, gt_classes_target, - reduction="sum") / max(1, num_foreground) - ) - - return _sum(losses) / len(targets) - def forward(self, x): # type: (List[Tensor]) -> Tensor all_cls_logits = [] @@ -150,12 +155,9 @@ class FCOSRegressionHead(nn.Module): Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted + num_convs (int): number of conv layer """ - __annotations__ = { - "box_coder": det_utils.BoxCoder, - } - def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): super().__init__() @@ -169,8 +171,8 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): self.conv = nn.Sequential(*conv) self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) - self.ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) - for layer in [self.bbox_reg, self.ctrness]: + self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) + for layer in [self.bbox_reg, self.bbox_ctrness]: torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.zeros_(layer.bias) @@ -179,54 +181,30 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.zeros_(layer.bias) - self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - - def compute_loss(self, targets, head_outputs, anchors, matched_idxs): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor - losses = [] - - bbox_regression = head_outputs["bbox_regression"] - - for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip( - targets, bbox_regression, anchors, matched_idxs - ): - # determine only the foreground indices, ignore the rest - foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] - num_foreground = foreground_idxs_per_image.numel() - - # select only the foreground boxes - matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]] - bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] - anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] - - # compute the regression targets - target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) - - # compute the loss - losses.append( - torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") - / max(1, num_foreground) - ) - - return _sum(losses) / max(1, len(targets)) - def forward(self, x): # type: (List[Tensor]) -> Tensor all_bbox_regression = [] + all_bbox_ctrness = [] for features in x: - bbox_regression = self.conv(features) - bbox_regression = self.bbox_reg(bbox_regression) + bbox_feature = self.conv(features) + bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature)) + bbox_ctrness = self.bbox_ctrness(bbox_feature) - # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). N, _, H, W = bbox_regression.shape bbox_regression = bbox_regression.view(N, -1, 4, H, W) bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) - all_bbox_regression.append(bbox_regression) - return torch.cat(all_bbox_regression, dim=1) + # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1). + bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W) + bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2) + bbox_ctrness = bbox_ctrness.reshape(N, -1, 1) + all_bbox_ctrness.append(bbox_ctrness) + + return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1) class FCOS(nn.Module): @@ -263,7 +241,9 @@ class FCOS(nn.Module): image_std (Tuple[float, float, float]): std values used for input normalization. They are generally the std values of the dataset on which the backbone has been trained on anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature - maps. + maps. For FCOS, only set one anchor for per position of each level, the width and height equal to + the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point + in FCOS paper. head (nn.Module): Module run on top of the feature pyramid. Defaults to a module containing a classification and regression module. center_sampling_radius (int): radius of the "center" of a groundtruth box, @@ -275,12 +255,12 @@ class FCOS(nn.Module): Example: >>> import torch >>> import torchvision - >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection import FCOS >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features - >>> # RetinaNet needs to know the number of + >>> # FCOS needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here >>> backbone.out_channels = 1280 @@ -291,23 +271,19 @@ class FCOS(nn.Module): >>> # map could potentially have different sizes and >>> # aspect ratios >>> anchor_generator = AnchorGenerator( - >>> sizes=((32, 64, 128, 256, 512),), - >>> aspect_ratios=((0.5, 1.0, 2.0),) + >>> sizes=((8,), (16,), (32,), (64,), (128,)), + >>> aspect_ratios=((1.0,),) >>> ) >>> >>> # put the pieces together inside a RetinaNet model - >>> model = RetinaNet(backbone, - >>> num_classes=2, + >>> model = FCOS(backbone, + >>> num_classes=80, >>> anchor_generator=anchor_generator) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - __annotations__ = { - "box_coder": det_utils.BoxCoder - } - def __init__( self, backbone, @@ -321,7 +297,7 @@ def __init__( anchor_generator=None, head=None, center_sampling_radius=1.5, - score_thresh=0.05, + score_thresh=0.2, nms_thresh=0.5, detections_per_img=300, topk_candidates=1000, @@ -340,16 +316,17 @@ def __init__( assert isinstance(anchor_generator, (AnchorGenerator, type(None))) if anchor_generator is None: - anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) - aspect_ratios = ((1.0),) * len(anchor_sizes) + anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map + aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) self.anchor_generator = anchor_generator + assert self.anchor_generator.num_anchors_per_location()[0] == 1 if head is None: head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) self.head = head - self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) if image_mean is None: image_mean = [0.485, 0.456, 0.406] @@ -386,13 +363,16 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): gt_boxes = targets_per_image["boxes"] gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2 - anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 - anchor_sizes = anchors_per_image.tensor[:, 2] - anchors_per_image.tensor[:, 0] + anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N + anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0] # center sampling: anchor point must be close enough to gt center. pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( dim=2 ).values < self.center_sampling_radius * anchor_sizes[:, None] - pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes) + # compute pairwise distance between N points and M boxes + x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) + pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M) # anchor point must be inside gt pairwise_match &= pairwise_dist.min(dim=2).values > 0 @@ -407,21 +387,21 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): pairwise_dist < upper_bound[:, None] ) - # Match the GT box with minimum area, if there are multiple GT matches + # match the GT box with minimum area, if there are multiple GT matches gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match - matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1 + matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1 matched_idxs.append(matched_idx) - - return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder) def postprocess_detections(self, head_outputs, anchors, image_shapes): # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] class_logits = head_outputs["cls_logits"] box_regression = head_outputs["bbox_regression"] + box_ctrness = head_outputs["bbox_ctrness"] num_images = len(image_shapes) @@ -430,19 +410,22 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): for index in range(num_images): box_regression_per_image = [br[index] for br in box_regression] logits_per_image = [cl[index] for cl in class_logits] + box_ctrness_per_image = [bc[index] for bc in box_ctrness] anchors_per_image, image_shape = anchors[index], image_shapes[index] image_boxes = [] image_scores = [] image_labels = [] - for box_regression_per_level, logits_per_level, anchors_per_level in zip( - box_regression_per_image, logits_per_image, anchors_per_image + for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image ): num_classes = logits_per_level.shape[-1] # remove low scoring boxes - scores_per_level = torch.sigmoid(logits_per_level).flatten() + scores_per_level = torch.sqrt(torch.sigmoid(logits_per_level) * \ + torch.sigmoid(box_ctrness_per_level) + ).flatten() keep_idxs = scores_per_level > self.score_thresh scores_per_level = scores_per_level[keep_idxs] topk_idxs = torch.where(keep_idxs)[0] @@ -584,7 +567,7 @@ def fcos_resnet50_fpn( ): """ Constructs a FCOS model with a ResNet-50-FPN backbone. - Reference: `"Focal Loss for Dense Object Detection" `_. + Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" `_. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. The behavior of the model changes depending if it is in training or evaluation mode. @@ -626,14 +609,12 @@ def fcos_resnet50_fpn( pretrained_backbone = False backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) - # skip P2 because it generates too many anchors (according to their paper) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + backbone = resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5 ) model = FCOS(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) + state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model - From a4eeb591b6e1542d1f1bba91aa0c735b04e97839 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 20:08:39 +0800 Subject: [PATCH 06/60] add giou loss --- torchvision/ops/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index d5cdf39d20f..ae87a977f29 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -20,6 +20,7 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth +from .giou_loss import giou_loss _register_custom_op() @@ -52,4 +53,5 @@ "FrozenBatchNorm2d", "ConvNormActivation", "SqueezeExcitation", + 'giou_loss' ] From 61e946d26cb7c394502b7657a87d7fc9eb2b7986 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 20:09:16 +0800 Subject: [PATCH 07/60] add fcos --- torchvision/models/detection/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 4772415b3b1..be46f950a61 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -4,3 +4,4 @@ from .retinanet import * from .ssd import * from .ssdlite import * +from .fcos import * From 997f1897c3feb18dbb31daa4dfb7204c27bdc58b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 18 Nov 2021 20:11:50 +0800 Subject: [PATCH 08/60] add __all__ --- torchvision/models/detection/fcos.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 1428ae169b0..a850e50b8fd 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -20,6 +20,9 @@ from .transform import GeneralizedRCNNTransform +__all__ = ["FCOS", "fcos_resnet50_fpn"] + + class FCOSHead(nn.Module): """ A regression and classification head for use in FCOS. From 793b1db0acb9d8942912375744c80f6865ea83ad Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Thu, 18 Nov 2021 21:33:41 +0800 Subject: [PATCH 09/60] Fixing lint --- torchvision/models/detection/fcos.py | 66 +++++++++++++++++----------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index a850e50b8fd..95ff7b1f390 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -1,7 +1,7 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional, Any +from typing import Dict, List, Tuple, Optional import torch from torch import nn, Tensor @@ -40,15 +40,15 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4): def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): - cls_logits = head_outputs["cls_logits"] # [N, K, C] - bbox_regression = head_outputs["bbox_regression"] # [N, K, 4] - bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1] + cls_logits = head_outputs["cls_logits"] # [N, K, C] + bbox_regression = head_outputs["bbox_regression"] # [N, K, 4] + bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1] all_gt_classes_targets = [] all_gt_boxes_targets = [] for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs): gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] - gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud + gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] all_gt_classes_targets.append(gt_classes_targets) all_gt_boxes_targets.append(gt_boxes_targets) @@ -64,15 +64,26 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") # regression loss: GIoU loss - pred_boxes = [box_coder.decode_single(bbox_regression_per_image, anchors_per_image) \ - for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)] + pred_boxes = [ + box_coder.decode_single( + bbox_regression_per_image, anchors_per_image + ) + for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) + ] # amp issue: pred_boxes need to convert float - loss_bbox_reg = giou_loss(torch.stack(pred_boxes)[foregroud_mask].float(), - torch.stack(all_gt_boxes_targets)[foregroud_mask], reduction='sum') + loss_bbox_reg = giou_loss( + torch.stack(pred_boxes)[foregroud_mask].float(), + torch.stack(all_gt_boxes_targets)[foregroud_mask], + reduction='sum', + ) # ctrness loss - bbox_reg_targets = [box_coder.encode_single(anchors_per_image, boxes_targets_per_image) \ - for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)] + bbox_reg_targets = [ + box_coder.encode_single( + anchors_per_image, boxes_targets_per_image + ) + for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets) + ] bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) if len(bbox_reg_targets) == 0: bbox_reg_targets.new_zeros(len(bbox_reg_targets)) @@ -96,7 +107,11 @@ def forward(self, x): # type: (List[Tensor]) -> Dict[str, Tensor] cls_logits = self.classification_head(x) bbox_regression, bbox_ctrness = self.regression_head(x) - return {"cls_logits": cls_logits, "bbox_regression": bbox_regression, "bbox_ctrness": bbox_ctrness} + return { + "cls_logits": cls_logits, + "bbox_regression": bbox_regression, + "bbox_ctrness": bbox_ctrness, + } class FCOSClassificationHead(nn.Module): @@ -114,9 +129,10 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro self.num_classes = num_classes self.num_anchors = num_anchors - + if norm_layer is None: norm_layer = lambda channels: nn.GroupNorm(32, channels) + conv = [] for _ in range(num_convs): conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) @@ -166,6 +182,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): if norm_layer is None: norm_layer = lambda channels: nn.GroupNorm(32, channels) + conv = [] for _ in range(num_convs): conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) @@ -178,7 +195,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): for layer in [self.bbox_reg, self.bbox_ctrness]: torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.zeros_(layer.bias) - + for layer in self.conv.children(): if isinstance(layer, nn.Conv2d): torch.nn.init.normal_(layer.weight, std=0.01) @@ -319,8 +336,8 @@ def __init__( assert isinstance(anchor_generator, (AnchorGenerator, type(None))) if anchor_generator is None: - anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map - aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor + anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map + aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) self.anchor_generator = anchor_generator assert self.anchor_generator.num_anchors_per_location()[0] == 1 @@ -366,7 +383,7 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): gt_boxes = targets_per_image["boxes"] gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2 - anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N + anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0] # center sampling: anchor point must be close enough to gt center. pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( @@ -375,7 +392,7 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): # compute pairwise distance between N points and M boxes x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1) x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) - pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M) + pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M) # anchor point must be inside gt pairwise_match &= pairwise_dist.min(dim=2).values > 0 @@ -384,14 +401,14 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): lower_bound = anchor_sizes * 4 lower_bound[: num_anchors_per_level[0]] = 0 upper_bound = anchor_sizes * 8 - upper_bound[-num_anchors_per_level[-1] :] = float("inf") + upper_bound[-num_anchors_per_level[-1]:] = float("inf") pairwise_dist = pairwise_dist.max(dim=2).values pairwise_match &= (pairwise_dist > lower_bound[:, None]) & ( pairwise_dist < upper_bound[:, None] ) # match the GT box with minimum area, if there are multiple GT matches - gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N + gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1 @@ -426,9 +443,8 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): num_classes = logits_per_level.shape[-1] # remove low scoring boxes - scores_per_level = torch.sqrt(torch.sigmoid(logits_per_level) * \ - torch.sigmoid(box_ctrness_per_level) - ).flatten() + scores_per_level = torch.sqrt( + torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)).flatten() keep_idxs = scores_per_level > self.score_thresh scores_per_level = scores_per_level[keep_idxs] topk_idxs = torch.where(keep_idxs)[0] @@ -612,8 +628,8 @@ def fcos_resnet50_fpn( pretrained_backbone = False backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) - backbone = resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5 + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5 ) model = FCOS(backbone, num_classes, **kwargs) if pretrained: From 3cf91e3e06bd3a2bf6b9e2ec9b5d10f80d31b042 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Thu, 18 Nov 2021 23:22:22 +0800 Subject: [PATCH 10/60] Fixing lint in giou_loss.py --- torchvision/ops/giou_loss.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 04f8e7b7b48..16d82c34c15 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,6 +1,6 @@ import torch -# copy from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/giou_loss.py + def giou_loss( boxes1: torch.Tensor, boxes2: torch.Tensor, @@ -8,13 +8,22 @@ def giou_loss( eps: float = 1e-7, ) -> torch.Tensor: """ + Original implementation from + https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) https://arxiv.org/abs/1902.09630 Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap and scales with the size of their smallest enclosing box. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the + same dimensions. + Args: - boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes + boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes reduction: 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. From c7a7d52b96cc31078c984d651644991fdd0514b7 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 19 Nov 2021 10:28:17 +0800 Subject: [PATCH 11/60] Add typing annotation to fcos --- torchvision/models/detection/fcos.py | 40 +++++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 95ff7b1f390..4159f78d4ee 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -149,8 +149,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro torch.nn.init.normal_(self.cls_logits.weight, std=0.01) torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) - def forward(self, x): - # type: (List[Tensor]) -> Tensor + def forward(self, x: List[Tensor]) -> Tensor: all_cls_logits = [] for features in x: @@ -201,8 +200,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.zeros_(layer.bias) - def forward(self, x): - # type: (List[Tensor]) -> Tensor + def forward(self, x: List[Tensor]) -> Tensor: all_bbox_regression = [] all_bbox_ctrness = [] @@ -364,15 +362,23 @@ def __init__( self._has_warned = False @torch.jit.unused - def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def eager_outputs( + self, + losses: Dict[str, Tensor], + detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training: return losses return detections - def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): - # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[int]) -> Dict[str, Tensor] + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + num_anchors_per_level: List[int], + ) -> Dict[str, Tensor]: matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image["boxes"].numel() == 0: @@ -417,8 +423,12 @@ def compute_loss(self, targets, head_outputs, anchors, num_anchors_per_level): return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder) - def postprocess_detections(self, head_outputs, anchors, image_shapes): - # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] + def postprocess_detections( + self, + head_outputs: Dict[str, List[Tensor]], + anchors: List[List[Tensor]], + image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: class_logits = head_outputs["cls_logits"] box_regression = head_outputs["bbox_regression"] box_ctrness = head_outputs["bbox_ctrness"] @@ -484,12 +494,16 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): return detections - def forward(self, images, targets=None): - # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: """ Args: images (list[Tensor]): images to be processed targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. @@ -570,7 +584,7 @@ def forward(self, images, targets=None): if torch.jit.is_scripting(): if not self._has_warned: - warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting") self._has_warned = True return losses, detections return self.eager_outputs(losses, detections) From 42b147df516b51c029700029d4bf0bd7c496f468 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 19 Nov 2021 10:30:20 +0800 Subject: [PATCH 12/60] Add trained checkpoints --- torchvision/models/detection/fcos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 4159f78d4ee..505c7aa43fa 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -591,7 +591,8 @@ def forward( model_urls = { - "fcos_resnet50_fpn_coco": "", + "fcos_resnet50_fpn_coco": + "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth", } From ff2d78a29bb63a99f335f817d2309c6df989c6cf Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 19 Nov 2021 11:04:02 +0800 Subject: [PATCH 13/60] Use partial to replace lambda --- torchvision/models/detection/fcos.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 505c7aa43fa..b640933f7dc 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -1,6 +1,7 @@ import math import warnings from collections import OrderedDict +from functools import partial from typing import Dict, List, Tuple, Optional import torch @@ -131,7 +132,7 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_pro self.num_anchors = num_anchors if norm_layer is None: - norm_layer = lambda channels: nn.GroupNorm(32, channels) + norm_layer = partial(nn.GroupNorm, 32) conv = [] for _ in range(num_convs): @@ -180,7 +181,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): super().__init__() if norm_layer is None: - norm_layer = lambda channels: nn.GroupNorm(32, channels) + norm_layer = partial(nn.GroupNorm, 32) conv = [] for _ in range(num_convs): From d4c08d31eb29e05bd0564a42a5091926c85be9d9 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 19 Nov 2021 11:22:54 +0800 Subject: [PATCH 14/60] Minor fixes to docstrings --- torchvision/models/detection/fcos.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index b640933f7dc..27d34b24f15 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -27,6 +27,7 @@ class FCOSHead(nn.Module): """ A regression and classification head for use in FCOS. + Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -118,6 +119,7 @@ def forward(self, x): class FCOSClassificationHead(nn.Module): """ A classification head for use in FCOS. + Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -171,6 +173,7 @@ def forward(self, x: List[Tensor]) -> Tensor: class FCOSRegressionHead(nn.Module): """ A regression head for use in FCOS. + Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -229,16 +232,21 @@ def forward(self, x: List[Tensor]) -> Tensor: class FCOS(nn.Module): """ Implements FCOS. + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes. + The behavior of the model changes depending if it is in training or evaluation mode. + During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the class label for each ground-truth box + The model returns a Dict[Tensor] during training, containing the classification and regression losses. + During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows: @@ -246,6 +254,7 @@ class FCOS(nn.Module): ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the predicted labels for each image - scores (Tensor[N]): the scores for each prediction + Args: backbone (nn.Module): the network used to compute the features for the model. It should contain an out_channels attribute, which indicates the number of output @@ -271,7 +280,9 @@ class FCOS(nn.Module): nms_thresh (float): NMS threshold used for postprocessing the detections. detections_per_img (int): Number of best detections to keep after NMS. topk_candidates (int): Number of best detections to keep before NMS. + Example: + >>> import torch >>> import torchvision >>> from torchvision.models.detection import FCOS @@ -603,9 +614,12 @@ def fcos_resnet50_fpn( """ Constructs a FCOS model with a ResNet-50-FPN backbone. Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" `_. + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each image, and should be in ``0-1`` range. Different images can have different sizes. + The behavior of the model changes depending if it is in training or evaluation mode. + During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with @@ -613,6 +627,7 @@ def fcos_resnet50_fpn( - labels (``Int64Tensor[N]``): the class label for each ground-truth box The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses. + During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows, where ``N`` is the number of detections: @@ -621,11 +636,14 @@ def fcos_resnet50_fpn( - labels (``Int64Tensor[N]``): the predicted labels for each detection - scores (``Tensor[N]``): the scores of each detection For more details on the output, you may refer to :ref:`instance_seg_output`. - Example:: + + Example: + >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) + Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr From c464249c9661c6a06e72487c7f0766ccc345b1b0 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 12:25:44 +0800 Subject: [PATCH 15/60] Apply ufmt format --- torchvision/models/detection/_utils.py | 2 +- torchvision/models/detection/fcos.py | 40 ++++++++++---------------- torchvision/ops/__init__.py | 4 +-- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 5c9e988fd68..78a6b2c09c0 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -216,7 +216,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) return pred_boxes - + class BoxLinearCoder: """ The linear box-to-box transform defined in FCOS. The transformation is parameterized diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 27d34b24f15..423478961ed 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -67,23 +67,19 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): # regression loss: GIoU loss pred_boxes = [ - box_coder.decode_single( - bbox_regression_per_image, anchors_per_image - ) + box_coder.decode_single(bbox_regression_per_image, anchors_per_image) for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) ] # amp issue: pred_boxes need to convert float loss_bbox_reg = giou_loss( torch.stack(pred_boxes)[foregroud_mask].float(), torch.stack(all_gt_boxes_targets)[foregroud_mask], - reduction='sum', + reduction="sum", ) # ctrness loss bbox_reg_targets = [ - box_coder.encode_single( - anchors_per_image, boxes_targets_per_image - ) + box_coder.encode_single(anchors_per_image, boxes_targets_per_image) for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets) ] bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) @@ -91,9 +87,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): bbox_reg_targets.new_zeros(len(bbox_reg_targets)) left_right = bbox_reg_targets[:, :, [0, 2]] top_bottom = bbox_reg_targets[:, :, [1, 3]] - gt_ctrness_targets = torch.sqrt((left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( - top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] - )) + gt_ctrness_targets = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) + * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + ) pred_centerness = bbox_ctrness.squeeze(dim=2) loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits( pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum" @@ -102,7 +99,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): return { "classification": loss_cls / max(1, num_foreground), "bbox_regression": loss_bbox_reg / max(1, num_foreground), - "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground) + "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground), } def forward(self, x): @@ -375,9 +372,7 @@ def __init__( @torch.jit.unused def eager_outputs( - self, - losses: Dict[str, Tensor], - detections: List[Dict[str, Tensor]] + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training: return losses @@ -419,11 +414,9 @@ def compute_loss( lower_bound = anchor_sizes * 4 lower_bound[: num_anchors_per_level[0]] = 0 upper_bound = anchor_sizes * 8 - upper_bound[-num_anchors_per_level[-1]:] = float("inf") + upper_bound[-num_anchors_per_level[-1] :] = float("inf") pairwise_dist = pairwise_dist.max(dim=2).values - pairwise_match &= (pairwise_dist > lower_bound[:, None]) & ( - pairwise_dist < upper_bound[:, None] - ) + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None]) # match the GT box with minimum area, if there are multiple GT matches gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N @@ -436,10 +429,7 @@ def compute_loss( return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder) def postprocess_detections( - self, - head_outputs: Dict[str, List[Tensor]], - anchors: List[List[Tensor]], - image_shapes: List[Tuple[int, int]] + self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]] ) -> List[Dict[str, Tensor]]: class_logits = head_outputs["cls_logits"] box_regression = head_outputs["bbox_regression"] @@ -466,7 +456,8 @@ def postprocess_detections( # remove low scoring boxes scores_per_level = torch.sqrt( - torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)).flatten() + torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level) + ).flatten() keep_idxs = scores_per_level > self.score_thresh scores_per_level = scores_per_level[keep_idxs] topk_idxs = torch.where(keep_idxs)[0] @@ -603,8 +594,7 @@ def forward( model_urls = { - "fcos_resnet50_fpn_coco": - "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth", + "fcos_resnet50_fpn_coco": "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth", } diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index ae87a977f29..b27e19c9381 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -13,6 +13,7 @@ from .deform_conv import deform_conv2d, DeformConv2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .giou_loss import giou_loss from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign @@ -20,7 +21,6 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth -from .giou_loss import giou_loss _register_custom_op() @@ -53,5 +53,5 @@ "FrozenBatchNorm2d", "ConvNormActivation", "SqueezeExcitation", - 'giou_loss' + "giou_loss", ] From 293e6b907a29ba67c4851b29e3ab06ebc4840d6c Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 14:21:16 +0800 Subject: [PATCH 16/60] Fixing docstrings --- torchvision/models/detection/fcos.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 423478961ed..0cd23856bb1 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -302,10 +302,12 @@ class FCOS(nn.Module): >>> aspect_ratios=((1.0,),) >>> ) >>> - >>> # put the pieces together inside a RetinaNet model - >>> model = FCOS(backbone, - >>> num_classes=80, - >>> anchor_generator=anchor_generator) + >>> # put the pieces together inside a FCOS model + >>> model = FCOS( + >>> backbone, + >>> num_classes=80, + >>> anchor_generator=anchor_generator, + >>> ) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) From 4cbce224e136b8f5593d03cbb4bb8291eb930938 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 15:01:07 +0800 Subject: [PATCH 17/60] Fixing jit scripting --- torchvision/models/detection/fcos.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 0cd23856bb1..9844cb434d9 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -40,7 +40,14 @@ def __init__(self, in_channels, num_anchors, num_classes, num_convs=4): self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs) - def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + box_coder, + ): cls_logits = head_outputs["cls_logits"] # [N, K, C] bbox_regression = head_outputs["bbox_regression"] # [N, K, 4] @@ -102,8 +109,7 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs, box_coder): "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground), } - def forward(self, x): - # type: (List[Tensor]) -> Dict[str, Tensor] + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: cls_logits = self.classification_head(x) bbox_regression, bbox_ctrness = self.regression_head(x) return { @@ -201,7 +207,7 @@ def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.zeros_(layer.bias) - def forward(self, x: List[Tensor]) -> Tensor: + def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]: all_bbox_regression = [] all_bbox_ctrness = [] From b444d21e69cadd7718acf5d34a94ae73b6e1af7d Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 15:08:15 +0800 Subject: [PATCH 18/60] Minor fixes to docstrings --- torchvision/models/detection/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 78a6b2c09c0..b788c16680b 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -241,6 +241,7 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ Encode a set of proposals with respect to some reference boxes + Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded @@ -283,6 +284,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. + Args: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. From ac8d062d27f66661447fce412176759abf4d893b Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 15:29:37 +0800 Subject: [PATCH 19/60] Fixing jit scripting --- torchvision/models/detection/_utils.py | 6 +++--- torchvision/models/detection/fcos.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b788c16680b..b442eeae33b 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -223,7 +223,7 @@ class BoxLinearCoder: by the distance from the center of (square) src box to 4 edges of the target box. """ - def __init__(self, normalize_by_size=True) -> None: + def __init__(self, normalize_by_size: bool = True) -> None: """ Args: normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes. @@ -260,7 +260,7 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: if self.normalize_by_size: stride_w = reference_boxes[:, 2] - reference_boxes[:, 0] stride_h = reference_boxes[:, 3] - reference_boxes[:, 1] - strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1) targets = targets / strides return targets @@ -297,7 +297,7 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: if self.normalize_by_size: stride_w = boxes[:, 2] - boxes[:, 0] stride_h = boxes[:, 3] - boxes[:, 1] - strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1) rel_codes = rel_codes * strides pred_boxes1 = ctr_x - rel_codes[:, 0] diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 9844cb434d9..4dd2b9fa7b2 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -35,8 +35,13 @@ class FCOSHead(nn.Module): num_convs (int): number of conv layer of head """ + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + def __init__(self, in_channels, num_anchors, num_classes, num_convs=4): super().__init__() + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs) @@ -46,7 +51,6 @@ def compute_loss( head_outputs: Dict[str, Tensor], anchors: List[Tensor], matched_idxs: List[Tensor], - box_coder, ): cls_logits = head_outputs["cls_logits"] # [N, K, C] @@ -74,7 +78,7 @@ def compute_loss( # regression loss: GIoU loss pred_boxes = [ - box_coder.decode_single(bbox_regression_per_image, anchors_per_image) + self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image) for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) ] # amp issue: pred_boxes need to convert float @@ -86,7 +90,7 @@ def compute_loss( # ctrness loss bbox_reg_targets = [ - box_coder.encode_single(anchors_per_image, boxes_targets_per_image) + self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image) for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets) ] bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) @@ -319,6 +323,10 @@ class FCOS(nn.Module): >>> predictions = model(x) """ + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + def __init__( self, backbone, @@ -434,7 +442,7 @@ def compute_loss( matched_idxs.append(matched_idx) - return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs, self.box_coder) + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections( self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]] From 7b1c73fc759e60ed76f800e6efac5d1574ee84d4 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 16:13:15 +0800 Subject: [PATCH 20/60] Ignore mypy in fcos --- mypy.ini | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy.ini b/mypy.ini index a6000f8a9d5..799298d2168 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,6 +70,10 @@ ignore_errors = True ignore_errors = True +[mypy-torchvision.models.detection.fcos] + +ignore_errors = True + [mypy-torchvision.ops.*] ignore_errors = True From 06e7ea8d8ce2cfb5364e43639fdb9235d5cf5f1f Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 18:19:20 +0800 Subject: [PATCH 21/60] Fixing trained checkpoints --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 4dd2b9fa7b2..17dbde7e1ba 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -610,7 +610,7 @@ def forward( model_urls = { - "fcos_resnet50_fpn_coco": "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-46080c1a.pth", + "fcos_resnet50_fpn_coco": "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-7c2e6686.pth", } From d348da3e355b25c921ab1bbbed4df84a6f68dedd Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 20:29:15 +0800 Subject: [PATCH 22/60] Fixing unit-test of jit script --- test/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_models.py b/test/test_models.py index 150b813b0cb..3da2ab0ec96 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -220,6 +220,7 @@ def _check_input_backprop(model, inputs): "retinanet_resnet50_fpn": lambda x: x[1], "ssd300_vgg16": lambda x: x[1], "ssdlite320_mobilenet_v3_large": lambda x: x[1], + "fcos_resnet50_fpn": lambda x: x[1], } From 1802ca2f197491651ce536de3015e21b57d0f293 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 19 Nov 2021 21:49:44 +0800 Subject: [PATCH 23/60] Fixing docstrings --- torchvision/models/detection/fcos.py | 44 +++++++++++++++++++++------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 17dbde7e1ba..151730d8387 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,7 +2,7 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Dict, List, Tuple, Optional +from typing import Callable, Dict, List, Tuple, Optional import torch from torch import nn, Tensor @@ -32,14 +32,14 @@ class FCOSHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted num_classes (int): number of classes to be predicted - num_convs (int): number of conv layer of head + num_convs (Optional[int]): number of conv layer of head. Default: 4. """ __annotations__ = { "box_coder": det_utils.BoxLinearCoder, } - def __init__(self, in_channels, num_anchors, num_classes, num_convs=4): + def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4): super().__init__() self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) @@ -128,13 +128,23 @@ class FCOSClassificationHead(nn.Module): A classification head for use in FCOS. Args: - in_channels (int): number of channels of the input feature - num_anchors (int): number of anchors to be predicted - num_classes (int): number of classes to be predicted - num_convs (int): number of conv layer + in_channels (int): number of channels of the input feature. + num_anchors (int): number of anchors to be predicted. + num_classes (int): number of classes to be predicted. + num_convs (Optional[int]): number of conv layer. Default: 4. + prior_probability (Optional[float]): probability of prior. Default: 0.01. + norm_layer: Module specifying the normalization layer to use. """ - def __init__(self, in_channels, num_anchors, num_classes, num_convs=4, prior_probability=0.01, norm_layer=None): + def __init__( + self, + in_channels: int, + num_anchors: int, + num_classes: int, + num_convs: int = 4, + prior_probability: float = 0.01, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): super().__init__() self.num_classes = num_classes @@ -184,10 +194,17 @@ class FCOSRegressionHead(nn.Module): Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted - num_convs (int): number of conv layer + num_convs (Optional[int]): number of conv layer. Default: 4. + norm_layer: Module specifying the normalization layer to use. """ - def __init__(self, in_channels, num_anchors, num_convs=4, norm_layer=None): + def __init__( + self, + in_channels: int, + num_anchors: int, + num_convs: int = 4, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): super().__init__() if norm_layer is None: @@ -615,7 +632,12 @@ def forward( def fcos_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs, ): """ Constructs a FCOS model with a ResNet-50-FPN backbone. From 9f5034d384395604d15ff6ddaffdac29187baa47 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Fri, 26 Nov 2021 13:47:27 +0800 Subject: [PATCH 24/60] Add test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl --- ...ModelTester.test_fcos_resnet50_fpn_expect.pkl | Bin 0 -> 1123 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl diff --git a/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9cddc763a8faa5015bf3d679b1592ff5e310bec4 GIT binary patch literal 1123 zcmb7E&2G~`5MC#alNv}uTUve#y`@r{KW;stsPbizFhQsh5{evq6r4qYkRkq2j(009vXgHR%>gQ*^ z(J8Yp8Jp(50=odE)+->xAowOC$k{SLV7)(nwgcZ6 zw)BEgq<;t<5enzTv!xCX4!ZDA!lN#fW1`3oLQJ#bVivCICQ({YO-Ff5Q8wc!GuYzF z47RUaB6=cWrwhBWuWt>6AHmavXnutE!DzTV!L#@TZFrvMenH$X=iIM|d+&S9{aV7C zF6@u3y2*>B>%Tt+rzij|GcFTbQ>HRI%9N?mG&%<>Maonx`|vN~?JRN2=u~eJ{|CN( zH-26JRoK=t{zQZk*45}#Rug!ulT^Qd4nOCxMQE8^p>5gIX&I~ZA!Dl@zCf>+It~4W zWl5tO&HH01(G0eAf9$(eYT|{ cVsS}kjRu%0`1^==yqUZ}gNm2}Bmczr26AQA_W%F@ literal 0 HcmV?d00001 From cfe224c309f21ab48a39406f71a12ee3f067648e Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 26 Nov 2021 13:59:55 +0800 Subject: [PATCH 25/60] Fixing test_detection_model_trainable_backbone_layers --- test/test_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 98de895cdb4..6dbf73bf744 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -321,6 +321,10 @@ def _check_input_backprop(model, inputs): "max_trainable": 6, "n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266], }, + "fcos_resnet50_fpn": { + "max_trainable": 5, + "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107], + }, } From a729e57f69c6c9be4d139e85f56c092b74929577 Mon Sep 17 00:00:00 2001 From: zhiqiang Date: Fri, 26 Nov 2021 19:46:42 +0800 Subject: [PATCH 26/60] Update test_fcos_resnet50_fpn_expect.pkl --- ...elTester.test_fcos_resnet50_fpn_expect.pkl | Bin 1123 -> 9571 bytes test/test_models.py | 7 +++++++ 2 files changed, 7 insertions(+) diff --git a/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl index 9cddc763a8faa5015bf3d679b1592ff5e310bec4..0657261d96cefe0d09b93efcd81c21b4cb56b3da 100644 GIT binary patch literal 9571 zcmeG?XH-Z5HFa`aWOVH8n91{e9CXc8s8##1PSl1WICM#gcOgYcK==N4MRv0DdWc*SiCT~Ann zrX7PxR)O~tk+4B5X}nMK1_Fvua>ej}G2a!tFJrzE-?M!}JSB~pA4X#{L74EYdn889 zXa7vIf5rwe+P@Qx4K!BbZhTwD#FufpbV}TJYiZ#t2&>SvpZ%#Yr!)IA3+1prm>kNs z)VOj-AU>Jhp5}uHTnX$5*q0q{YBhwg5`VtH`bvy`CJ2X-p6Uf|qviNG7S=Hc<#iu% zbW)=*w?>+iaPomb?4I68W415h=vjdn8{$M`rWKf&;f0@f{ek#uqLT?5MEBa>=rb)v zo^x2%2Ztm!h?w8JW<VdgFEf3Q@5C9ill0vrzaylrh^c+Z~7#TG$eg?c1TE*b77d zkuBSDo>L&YP8fvW?qz;cj!r)X;%7H*5})Pk#apL+ggU{Q`83D%lSI)qPZ@KJR>!^Y zy2A&1;@Pid?tkZ|i2XM8q{ z1b{c#YL)dMa>;xexBx;e-QC%x8OM(Gbh=nZ@jQ#EjiviROQF!k7+wD2659wM46N@w2;PZNnJt`IUW$7q2z6_rav}-gt4RJ8tU# zHqOZ99Joh#?UIFc5%;5nbA$TIm~E`NK>PFL!51=(zi{kaYu3?=(&RwQT3SUk*Uv11 z?>Xa7+gjqFZk!8)gd=WqJ`z5#;r^E}`6;iDB#&7yxw0FLSzm<}v)MlH8@A8*FJErW z5VaF~;6*R)7ZsKaG_W19>8nZ^vwx1c#K`+ujkAhrKM%SuDwlCTu|IpR&*f@l&%NcsBA&S$_9#df^)4g^>@KR^#e} zyv7D`z@=QqP5C5>WgEGEiQ=&x_p2{1wg(yCJ)uR1ltg%}AJa{GL7d=V(|AIk4Y@QhUi!3j8AuQzckq&lyFBozeg*gB{_$P zMq8}UX1*Okll`PIjpB>w6~vclE?!JZ>W#7K>{l(ztiJu5xHXMs4K(*MY{G{T(<471j*m^`O3J?T2Ts`eL7( ze+YHQ0Gg9<->-{A2Uk81*dLu)C%tf{mitYO*ROH@uK4|Qj*s`M3df!C!l4`Zyj5e2 z&l`k&v9gl)gD) z^Q-;r2}`)U*(zPzcXIZYlFHWkxK|Eg^?s!+mw?_Ub z<3G=n*!_Iwvp-(Hc(Ey?O0>0Se`c|HTMrD2;(DrZ>peBfY zD}2#eQ6;qYyq;$9Ac1Y^MNKHjrWZ5U`-`~3(;}sw+0Ph++rD2$H0KvD3Slt5xw;39nO33pfeWsyE7$`6+Wr zk8_C^X)A59XCkjZ%c}^~*n1bptH$-A98bOIF^uJZzYHN46C60YgZ`uuR*lGM4yH8J_>4UTahjA_dg9p>AU@mHh>{2wyn2A|OT7N<(;%jP#qsl4PoTz@3wZy~y>tubtU$M4 zD`ebM7hg;Y-bJ)8eyr1r`mXhw@`t=0lGe7UCvN%P8H4)q`;2qfi=hU0JT{Z}3-^OW zpu}bpuRs0XvEmpM=$Ogx1C}?5(h^6lXX+!OWTMq*Jot&2c#7|5@kFzaAaq;0kj6psT)Rtw zc zeld$akA1Pr!bUOt2KTFk8{gudvBM;f7Bc4X(Gis*rg1#k6(R?0q{K?|EZl^#Gt_SDBIx6fx zj`yny9pACTpU?4mq#;@&K4S%TZ^_bA;{CW5;?KH|=o+e$;r9XkJ@`Sbp=KW>_bVN7CN?*c*qt4aOs$ z@_B+CX+H-{s@J@={)B8VBrTP&5)!m_U<^32>I=l|9 z=-Z0>jq7bD9E2yk@}3AnA2ZiTiEkHmqj`=;t`~*MdXluZKle|f_&t;RDUoF8{=M54 z7x+z}F~=T8n0{BiQEORWNz}F+FHU-My_u&(`;WL+*|t)~R&n{FY!<)YxfY3H=e~N) z-csIgk~ZA)bKzciTMYSx`+@UK79S1if&FveBYlo3h(Is0?{j_i7Es{VhXBKtwcy&9c6C{%!GLUw%KazeF*qI@yv7*O&W` z-VX`PmuSp2kg(wZ_XpGT`;q>9Hlf7%ilH@81pn>NXErbY`OG_jtU&!8=(^z!to-y2 zjC!LU-g#69eXrI*v&D6g_g)=zRn>vxAJ?I`={h)mcpb*OT!)sw)q+hxEvTRt8a!&j zI`4Ogthoxu)?S5Q8?Hdei7Vi{?Fy9JUV)nrE<@F|%iw$NGBmz*8H!6U!SofEAYF9{ zPCmW}wrekfe_0LK=heWh=o+}>Q3DCiH8B4dGaT7xhAu755OS*;x))S~w4@pydRN1b zPStQgs)poyRdBer3OZh?f_~?#z$T>%j3=uAS69Iik1ELMR0$@RO0aUMgqD{o;O&wM z*ppuYNA(q8zpDZ+e^CL^t14jA!V0LHSph{kdDt9q9(JBAfxHh&VD0-Q;OJdKwu<5N z%3>IsT@0N!7sHZG#jt8sF{mSo;nR`DFze|#xcTrL^sYSz6S_J(F#vLjG*Q6r&?!zLO{B9ArsEXiJbs@wZDuhu4B|jE|_p(ChuPKD0E_8Fe zSO8lK3Lxj30w{z8cdrR-SDC;#(F9SxCiunP1fMqM!ym=@ zus$gtT;}D2`#2v_4O?omVOe%Ibp9b5e%p`@%hzVZ(nZ;DY(X}71ZG2Fc@~6Z zWkJK?EI9a87O1|=g6z#%@OxwyxP6ob#xYqibVL^H^Us2qmRT^sDhqt8&cLapGtfNY z3V8fK>+jN`CLta6%uk2VN$Jp9nGW|mro*Gh zY0&>n8hoCb20Mr!Fel2lmuaViw{PlbP} zQz5)-D(rMig|+QdVO+aZ7}z=$TpNuLaMuX!>x>XmYlQSFBgCW`A@{Tq^1nC2@~@3x z-f0AeTX-5sf=iSI-bCxC3|JpG8xAU&; z%jca>fKdF~e;3d>`ucfYf!7syU4hpXcwK?l6?k2N*A@7GtN`EFdL0hl6=qMrehL2b zea%_w^5VWWFES>UzGMjy?*DdSd*!2!w;k9+!B^A9SX0ytp}|&+UVPK@_m4{y%y{W} z^$VX@wg3Lx+UX_DjaJf&c=`JvFZHXsf4`@`tZQofimvyox___P%esa)U(xk>Rrl}J zaDFM~IbHtJwbVE?xJ8RT79qE0eP4+4=S=+kiM=RYehSG1KcvSWFiBS3_+P>Vp~21h O3nS8LF0p)*?EeDnO=}PU literal 1123 zcmb7E&2G~`5MC#alNv}uTUve#y`@r{KW;stsPbizFhQsh5{evq6r4qYkRkq2j(009vXgHR%>gQ*^ z(J8Yp8Jp(50=odE)+->xAowOC$k{SLV7)(nwgcZ6 zw)BEgq<;t<5enzTv!xCX4!ZDA!lN#fW1`3oLQJ#bVivCICQ({YO-Ff5Q8wc!GuYzF z47RUaB6=cWrwhBWuWt>6AHmavXnutE!DzTV!L#@TZFrvMenH$X=iIM|d+&S9{aV7C zF6@u3y2*>B>%Tt+rzij|GcFTbQ>HRI%9N?mG&%<>Maonx`|vN~?JRN2=u~eJ{|CN( zH-26JRoK=t{zQZk*45}#Rug!ulT^Qd4nOCxMQE8^p>5gIX&I~ZA!Dl@zCf>+It~4W zWl5tO&HH01(G0eAf9$(eYT|{ cVsS}kjRu%0`1^==yqUZ}gNm2}Bmczr26AQA_W%F@ diff --git a/test/test_models.py b/test/test_models.py index 6dbf73bf744..7237bd9a11b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -270,6 +270,13 @@ def _check_input_backprop(model, inputs): "max_size": 224, "input_shape": (3, 224, 224), }, + "fcos_resnet50_fpn": { + "num_classes": 2, + "score_thresh": 0.05, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), + }, "maskrcnn_resnet50_fpn": { "num_classes": 10, "min_size": 224, From c2c7a7e077f3086f617f9c09d85e92674df91ce7 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 9 Jan 2022 18:08:02 +0800 Subject: [PATCH 27/60] rename stride to box size --- torchvision/models/detection/_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b442eeae33b..0f4ba5aa0d3 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -258,10 +258,10 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: targets = torch.stack((target_l, target_t, target_r, target_b), dim=1) if self.normalize_by_size: - stride_w = reference_boxes[:, 2] - reference_boxes[:, 0] - stride_h = reference_boxes[:, 3] - reference_boxes[:, 1] - strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1) - targets = targets / strides + reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0] + reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1] + reference_boxes_size = torch.stack((reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1) + targets = targets / reference_boxes_size return targets @@ -295,10 +295,10 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) if self.normalize_by_size: - stride_w = boxes[:, 2] - boxes[:, 0] - stride_h = boxes[:, 3] - boxes[:, 1] - strides = torch.stack((stride_w, stride_h, stride_w, stride_h), dim=1) - rel_codes = rel_codes * strides + boxes_w = boxes[:, 2] - boxes[:, 0] + boxes_h = boxes[:, 3] - boxes[:, 1] + boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1) + rel_codes = rel_codes * boxes_size pred_boxes1 = ctr_x - rel_codes[:, 0] pred_boxes2 = ctr_y - rel_codes[:, 1] From 7a830e16f2552426855bce7f0e8eb44eac5bae7f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 10 Jan 2022 19:03:39 +0800 Subject: [PATCH 28/60] remove TODO and fix some typo --- torchvision/models/detection/fcos.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 151730d8387..e2442f6ebbd 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -570,7 +570,6 @@ def forward( images, targets = self.transform(images, targets) # Check for degenerate boxes - # TODO: Move this to a function if targets is not None: for target_idx, target in enumerate(targets): boxes = target["boxes"] @@ -589,10 +588,9 @@ def forward( if isinstance(features, torch.Tensor): features = OrderedDict([("0", features)]) - # TODO: Do we want a list or a dict? features = list(features.values()) - # compute the retinanet heads outputs using the features + # compute the fcos heads outputs using the features head_outputs = self.head(features) # create the set of anchors From 90efd29bd2f1ba5a7baafc3e6d7eede281f48bb6 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 10 Jan 2022 19:25:30 +0800 Subject: [PATCH 29/60] merge some code for better --- torchvision/models/detection/fcos.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index e2442f6ebbd..977312c8ab2 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -546,11 +546,9 @@ def forward( During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ - if self.training and targets is None: - raise ValueError("In training mode, targets should be passed") - if self.training: - assert targets is not None + if targets is None: + raise ValueError("In training mode, targets should be passed") for target in targets: boxes = target["boxes"] if isinstance(boxes, torch.Tensor): From 0e6039bf785c24e22d589352bda00d23c7100747 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 10 Jan 2022 20:15:01 +0800 Subject: [PATCH 30/60] impove the comments --- torchvision/models/detection/fcos.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 977312c8ab2..60097c2db21 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -53,9 +53,9 @@ def compute_loss( matched_idxs: List[Tensor], ): - cls_logits = head_outputs["cls_logits"] # [N, K, C] - bbox_regression = head_outputs["bbox_regression"] # [N, K, 4] - bbox_ctrness = head_outputs["bbox_ctrness"] # [N, K, 1] + cls_logits = head_outputs["cls_logits"] # [N, HWA, C] + bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4] + bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1] all_gt_classes_targets = [] all_gt_boxes_targets = [] From 4f9f3eff1e9b3ec9c30ba4acc0e58a74235ab7f0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 11 Jan 2022 14:50:43 +0800 Subject: [PATCH 31/60] remove decode and encode of BoxLinearCoder --- torchvision/models/detection/_utils.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 0f4ba5aa0d3..2c6f9a2e937 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -230,13 +230,6 @@ def __init__(self, normalize_by_size: bool = True) -> None: """ self.normalize_by_size = normalize_by_size - def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]: - boxes_per_image = [len(b) for b in reference_boxes] - reference_boxes = torch.cat(reference_boxes, dim=0) - proposals = torch.cat(proposals, dim=0) - targets = self.encode_single(reference_boxes, proposals) - return targets.split(boxes_per_image, 0) - def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ Encode a set of proposals with respect to some @@ -260,26 +253,12 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: if self.normalize_by_size: reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0] reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1] - reference_boxes_size = torch.stack((reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1) + reference_boxes_size = torch.stack((reference_boxes_w, reference_boxes_h, + reference_boxes_w, reference_boxes_h), dim=1) targets = targets / reference_boxes_size return targets - def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: - assert isinstance(boxes, (list, tuple)) - assert isinstance(rel_codes, torch.Tensor) - boxes_per_image = [b.size(0) for b in boxes] - concat_boxes = torch.cat(boxes, dim=0) - box_sum = 0 - for val in boxes_per_image: - box_sum += val - if box_sum > 0: - rel_codes = rel_codes.reshape(box_sum, -1) - pred_boxes = self.decode_single(rel_codes, concat_boxes) - if box_sum > 0: - pred_boxes = pred_boxes.reshape(box_sum, -1, 4) - return pred_boxes - def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: """ From a set of original boxes and encoded relative box offsets, From db2e89bfaa48ba66484d62dd16d3f01f670a686f Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 11 Jan 2022 14:56:39 +0800 Subject: [PATCH 32/60] remove some unnecessary hints --- torchvision/models/detection/fcos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 60097c2db21..e1732136237 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -557,7 +557,6 @@ def forward( else: raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") - # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] From d03f03aa6857d89ba774467f408fadca001d5b87 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 12 Jan 2022 13:51:40 +0800 Subject: [PATCH 33/60] use the default value in detectron2. --- torchvision/models/detection/fcos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index e1732136237..8db36dad9ec 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -358,8 +358,8 @@ def __init__( head=None, center_sampling_radius=1.5, score_thresh=0.2, - nms_thresh=0.5, - detections_per_img=300, + nms_thresh=0.6, + detections_per_img=100, topk_candidates=1000, ): super().__init__() From 56896e57b9f9551bb5787a98f942c3f6582f4783 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Wed, 12 Jan 2022 13:59:44 +0800 Subject: [PATCH 34/60] update doc --- torchvision/models/detection/fcos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 8db36dad9ec..875555ce827 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -268,8 +268,8 @@ class FCOS(nn.Module): ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the class label for each ground-truth box - The model returns a Dict[Tensor] during training, containing the classification and regression - losses. + The model returns a Dict[Tensor] during training, containing the classification, regression + and centerness losses. During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as From fab274874b789892dbe96af05162334edc015792 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 13 Jan 2022 16:00:42 +0800 Subject: [PATCH 35/60] Add unittest for BoxLinearCoder --- test/test_models_detection_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 8d686023b1d..b19f4368925 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -22,6 +22,19 @@ def test_balanced_positive_negative_sampler(self): assert neg[0].sum() == 3 assert neg[0][0:6].sum() == 3 + def test_box_linear_coder(self): + box_coder = _utils.BoxLinearCoder(normalize_by_size=True) + # Generate the random boxes for testing + boxes = torch.rand(10, 4) * 50 + boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression + boxes[:, 2:] += boxes[:, :2] + + proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float() + + rel_codes = box_coder.encode_single(boxes, proposals) + pred_boxes = box_coder.decode_single(rel_codes, boxes) + torch.allclose(proposals, pred_boxes) + @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # we know how many initial layers and parameters of the network should From 80c3e64aca0106f25f483b9784b299ef37b78f26 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 13 Jan 2022 16:06:58 +0800 Subject: [PATCH 36/60] Add types in FCOS --- torchvision/models/detection/fcos.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 875555ce827..136ccedbe0d 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -346,21 +346,21 @@ class FCOS(nn.Module): def __init__( self, - backbone, - num_classes, + backbone: nn.Module, + num_classes: int, # transform parameters - min_size=800, - max_size=1333, - image_mean=None, - image_std=None, + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, # Anchor parameters - anchor_generator=None, - head=None, - center_sampling_radius=1.5, - score_thresh=0.2, - nms_thresh=0.6, - detections_per_img=100, - topk_candidates=1000, + anchor_generator: Optional[AnchorGenerator] = None, + head: Optional[nn.Module] = None, + center_sampling_radius: float = 1.5, + score_thresh: float = 0.2, + nms_thresh: float = 0.6, + detections_per_img: int = 100, + topk_candidates: int = 1000, ): super().__init__() _log_api_usage_once(self) From cb1f8e2f5f6db1c686ef12c51338daec7a2f3ebc Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 13 Jan 2022 16:23:42 +0800 Subject: [PATCH 37/60] Add docstring for BoxLinearCoder --- torchvision/models/detection/_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 2c6f9a2e937..ef4f6550eef 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -232,12 +232,15 @@ def __init__(self, normalize_by_size: bool = True) -> None: def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ - Encode a set of proposals with respect to some - reference boxes + Encode a set of proposals with respect to some reference boxes Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded + + Returns: + Tensor: the encoded relative box offsets that can be used to + decode the boxes. """ # get the center of reference_boxes reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2]) @@ -253,8 +256,9 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: if self.normalize_by_size: reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0] reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1] - reference_boxes_size = torch.stack((reference_boxes_w, reference_boxes_h, - reference_boxes_w, reference_boxes_h), dim=1) + reference_boxes_size = torch.stack( + (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1 + ) targets = targets / reference_boxes_size return targets @@ -267,6 +271,9 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: Args: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. + + Returns: + Tensor: the predicted boxes with the encoded relative box offsets. """ boxes = boxes.to(rel_codes.dtype) From 56b131f974dd853af95b5046b97f784d7192e3da Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 13 Jan 2022 16:24:10 +0800 Subject: [PATCH 38/60] Minor fix for the docstring --- test/test_models_detection_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index b19f4368925..6551a1a759f 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -24,7 +24,7 @@ def test_balanced_positive_negative_sampler(self): def test_box_linear_coder(self): box_coder = _utils.BoxLinearCoder(normalize_by_size=True) - # Generate the random boxes for testing + # Generate a random 10x4 boxes tensor, with coordinates < 50. boxes = torch.rand(10, 4) * 50 boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression boxes[:, 2:] += boxes[:, :2] From 11cc5d9e64ac1f4c66a89d1d2deee707b20a7cd5 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 13 Jan 2022 16:59:28 +0800 Subject: [PATCH 39/60] update doc --- torchvision/ops/giou_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 16d82c34c15..7871d3f4543 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -24,7 +24,7 @@ def giou_loss( Args: boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes - reduction: 'none' | 'mean' | 'sum' + reduction (str): 'none' | 'mean' | 'sum' 'none': No reduction will be applied to the output. 'mean': The output will be averaged. 'sum': The output will be summed. From 78e27da7f8ae70e920edf15058db0b8807c62c24 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Mon, 17 Jan 2022 17:55:18 +0000 Subject: [PATCH 40/60] Update fcos_resnet50_fpn_coco pretained weights url --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 136ccedbe0d..81b08112812 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -622,7 +622,7 @@ def forward( model_urls = { - "fcos_resnet50_fpn_coco": "https://github.com/o295/checkpoints/releases/download/coco/fcos_resnet50_fpn_coco-7c2e6686.pth", + "fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", } From 32ec6babe5797363d2f0491d470d738c780d592d Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 18 Jan 2022 09:58:30 +0800 Subject: [PATCH 41/60] Update torchvision/models/detection/fcos.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 81b08112812..0b2f82d8b2c 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -151,7 +151,7 @@ def __init__( self.num_anchors = num_anchors if norm_layer is None: - norm_layer = partial(nn.GroupNorm, 32) + norm_layer = partial(nn.GroupNorm, num_groups=32) conv = [] for _ in range(num_convs): From 0e9bd78757d9b146c3f2faa5c0f3fcf4e05c179e Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 18 Jan 2022 09:58:38 +0800 Subject: [PATCH 42/60] Update torchvision/models/detection/fcos.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 0b2f82d8b2c..15e30d718d9 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -51,7 +51,7 @@ def compute_loss( head_outputs: Dict[str, Tensor], anchors: List[Tensor], matched_idxs: List[Tensor], - ): + ) -> Dict[str, Tensor]: cls_logits = head_outputs["cls_logits"] # [N, HWA, C] bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4] From 4faf24ae68dd794516fb9279b1c6e1ea0e64feb1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 18 Jan 2022 09:58:54 +0800 Subject: [PATCH 43/60] Update torchvision/models/detection/fcos.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 15e30d718d9..1e8e6bd3486 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -39,7 +39,7 @@ class FCOSHead(nn.Module): "box_coder": det_utils.BoxLinearCoder, } - def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4): + def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None: super().__init__() self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) From 868fe543ea70f03674e8c68fe391a28486a4ffb4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Tue, 18 Jan 2022 09:59:04 +0800 Subject: [PATCH 44/60] Update torchvision/models/detection/fcos.py Co-authored-by: Vasilis Vryniotis --- torchvision/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 1e8e6bd3486..191641c7443 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -144,7 +144,7 @@ def __init__( num_convs: int = 4, prior_probability: float = 0.01, norm_layer: Optional[Callable[..., nn.Module]] = None, - ): + ) -> None: super().__init__() self.num_classes = num_classes From 54b12c87cd6aa675e2c613c0ae636c539ea01ca5 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 18 Jan 2022 10:49:56 +0000 Subject: [PATCH 45/60] Add FCOS model documentation --- docs/source/models.rst | 12 ++++++++++++ references/detection/README.md | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/docs/source/models.rst b/docs/source/models.rst index 62c104cf927..071ec350240 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -592,6 +592,7 @@ The models subpackage contains definitions for the following model architectures for detection: - `Faster R-CNN `_ +- `FCOS `_ - `Mask R-CNN `_ - `RetinaNet `_ - `SSD `_ @@ -637,6 +638,7 @@ Network box AP mask AP keypoint AP Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - +FCOS ResNet-50 FNP 39.1 - - RetinaNet ResNet-50 FPN 36.4 - - SSD300 VGG16 25.1 - - SSDlite320 MobileNetV3-Large 21.3 - - @@ -697,6 +699,7 @@ Network train time (s / it) test time (s / it) Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 +FCOS ResNet-50 FPN 0.1450 0.0539 3.3 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 SSD300 VGG16 0.2093 0.0744 1.5 SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5 @@ -716,6 +719,15 @@ Faster R-CNN torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn +FCOS +------------ + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + torchvision.models.detection.fcos_resnet50_fpn + RetinaNet --------- diff --git a/references/detection/README.md b/references/detection/README.md index 4d44f67b4c0..3695644138b 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` +### FCOS ResNet-50 FPN +``` +torchrun --nproc_per_node=8 train.py\ + --dataset coco --model fcos_resnet50_fpn --epochs 26\ + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp +``` + ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ From 589a3b1f0efc391b44fa19758286e7ea31207b5d Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Tue, 18 Jan 2022 11:22:04 +0000 Subject: [PATCH 46/60] Fix typo in FCOS documentation --- docs/source/models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 071ec350240..8f6e242e930 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -638,7 +638,7 @@ Network box AP mask AP keypoint AP Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - -FCOS ResNet-50 FNP 39.1 - - +FCOS ResNet-50 FPN 39.1 - - RetinaNet ResNet-50 FPN 36.4 - - SSD300 VGG16 25.1 - - SSDlite320 MobileNetV3-Large 21.3 - - From a4db0dd43e7325782ebc1dddc275b517af0bae79 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 03:08:19 +0800 Subject: [PATCH 47/60] Add fcos to the prototype builder --- .../prototype/models/detection/__init__.py | 1 + .../prototype/models/detection/fcos.py | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 torchvision/prototype/models/detection/fcos.py diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py index 13edbf75575..4146651c737 100644 --- a/torchvision/prototype/models/detection/__init__.py +++ b/torchvision/prototype/models/detection/__init__.py @@ -1,4 +1,5 @@ from .faster_rcnn import * +from .fcos import * from .keypoint_rcnn import * from .mask_rcnn import * from .retinanet import * diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py new file mode 100644 index 00000000000..7080bec3bad --- /dev/null +++ b/torchvision/prototype/models/detection/fcos.py @@ -0,0 +1,82 @@ +from typing import Any, Optional + +from torchvision.prototype.transforms import CocoEval +from torchvision.transforms.functional import InterpolationMode + +from ....models.detection.fcos import ( + _resnet_fpn_extractor, + _validate_trainable_layers, + FCOS, + LastLevelP6P7, + misc_nn_ops, + overwrite_eps, +) +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 + + +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] + + +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=CocoEval, + meta={ + "task": "image_object_detection", + "architecture": "FCOS", + "publication_year": 2019, + "num_params": 34014999, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "map": 39.1, + }, + ) + default = Coco_V1 + + +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) +def fcos_resnet50_fpn( + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FCOS: + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 + ) + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = FCOS(backbone, num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == FCOS_ResNet50_FPN_Weights.Coco_V1: + overwrite_eps(model, 0.0) + + return model From 50bf19bc112a545c95a1c78a3596a81d6b50fa47 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 03:24:48 +0800 Subject: [PATCH 48/60] Capitalize COCO_V1 --- torchvision/prototype/models/detection/fcos.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index 7080bec3bad..9a32c1ca2d1 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -25,7 +25,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): - Coco_V1 = Weights( + COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", transforms=CocoEval, meta={ @@ -39,11 +39,11 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): "map": 39.1, }, ) - default = Coco_V1 + default = COCO_V1 @handle_legacy_interface( - weights=("pretrained", FCOS_ResNet50_FPN_Weights.Coco_V1), + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), ) def fcos_resnet50_fpn( @@ -76,7 +76,7 @@ def fcos_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FCOS_ResNet50_FPN_Weights.Coco_V1: + if weights == FCOS_ResNet50_FPN_Weights.COCO_V1: overwrite_eps(model, 0.0) return model From 688f4d274474da87b9e62d98af228b77891f7997 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 10:11:06 +0800 Subject: [PATCH 49/60] Fix params of fcos --- torchvision/prototype/models/detection/fcos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index 9a32c1ca2d1..ff2d2d4ea1d 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -32,7 +32,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): "task": "image_object_detection", "architecture": "FCOS", "publication_year": 2019, - "num_params": 34014999, + "num_params": 32269600, "categories": _COCO_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", From f01dcf65ff0ecea69b969f982a15a38a3081abe2 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 20 Jan 2022 14:15:10 +0800 Subject: [PATCH 50/60] fix bug for partial --- torchvision/models/detection/fcos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 191641c7443..daaae39b4aa 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -77,6 +77,7 @@ def compute_loss( loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") # regression loss: GIoU loss + # TODO: vectorize this instead of using a for loop pred_boxes = [ self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image) for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) @@ -151,7 +152,7 @@ def __init__( self.num_anchors = num_anchors if norm_layer is None: - norm_layer = partial(nn.GroupNorm, num_groups=32) + norm_layer = partial(nn.GroupNorm, 32) conv = [] for _ in range(num_convs): From 9085461b87f1c5979f4b72fa989246ede49e8166 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 14:44:03 +0800 Subject: [PATCH 51/60] Fixing docs indentation --- docs/source/models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index e037c8b32e1..299a95fb459 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -722,7 +722,7 @@ Faster R-CNN torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn FCOS ------------- +---- .. autosummary:: :toctree: generated/ From ebbd7722baa368819ec9fa7fe317719eaf45e4bb Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 15:59:29 +0800 Subject: [PATCH 52/60] Fixing docs format in giou_loss --- torchvision/ops/giou_loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 7871d3f4543..7540740a0b6 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -24,11 +24,11 @@ def giou_loss( Args: boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes - reduction (str): 'none' | 'mean' | 'sum' - 'none': No reduction will be applied to the output. - 'mean': The output will be averaged. - 'sum': The output will be summed. - eps (float): small number to prevent division by zero + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be + applied to the output. ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'`` + eps (float, optional): small number to prevent division by zero. Default: 1e-7 """ x1, y1, x2, y2 = boxes1.unbind(dim=-1) From d8ed195aef47053dac3374fb8d7bf9e102635bf7 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 16:04:23 +0800 Subject: [PATCH 53/60] Adopt Reference for GIoU Loss --- torchvision/ops/giou_loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 7540740a0b6..c906faa399b 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -11,8 +11,6 @@ def giou_loss( Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py - Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) - https://arxiv.org/abs/1902.09630 Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap and scales with the size of their smallest enclosing box. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. @@ -29,6 +27,11 @@ def giou_loss( applied to the output. ``'mean'``: The output will be averaged. ``'sum'``: The output will be summed. Default: ``'none'`` eps (float, optional): small number to prevent division by zero. Default: 1e-7 + + Reference: + Hamid Rezatofighi et. al: Generalized Intersection over Union: + A Metric and A Loss for Bounding Box Regression: + https://arxiv.org/abs/1902.09630 """ x1, y1, x2, y2 = boxes1.unbind(dim=-1) From b0b08929eb759eaed3b1ca30b7fc69f3bbaafd4b Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 20 Jan 2022 11:43:02 +0000 Subject: [PATCH 54/60] Rename giou_loss to generalized_box_iou_loss --- torchvision/models/detection/fcos.py | 4 ++-- torchvision/ops/__init__.py | 4 ++-- torchvision/ops/{giou_loss.py => generalized_box_iou_loss.py} | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename torchvision/ops/{giou_loss.py => generalized_box_iou_loss.py} (98%) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index daaae39b4aa..f8f9ad6fec7 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -8,7 +8,7 @@ from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops import sigmoid_focal_loss, giou_loss +from ...ops import sigmoid_focal_loss, generalized_box_iou_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 @@ -83,7 +83,7 @@ def compute_loss( for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) ] # amp issue: pred_boxes need to convert float - loss_bbox_reg = giou_loss( + loss_bbox_reg = generalized_box_iou_loss( torch.stack(pred_boxes)[foregroud_mask].float(), torch.stack(all_gt_boxes_targets)[foregroud_mask], reduction="sum", diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index b27e19c9381..33a48995869 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -13,7 +13,7 @@ from .deform_conv import deform_conv2d, DeformConv2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss -from .giou_loss import giou_loss +from .generalized_box_iou_loss import generalized_box_iou_loss from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign @@ -53,5 +53,5 @@ "FrozenBatchNorm2d", "ConvNormActivation", "SqueezeExcitation", - "giou_loss", + "generalized_box_iou_loss", ] diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/generalized_box_iou_loss.py similarity index 98% rename from torchvision/ops/giou_loss.py rename to torchvision/ops/generalized_box_iou_loss.py index c906faa399b..1ac9433250d 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/generalized_box_iou_loss.py @@ -1,7 +1,7 @@ import torch -def giou_loss( +def generalized_box_iou_loss( boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = "none", From 8c24ac3824e5a3102106ead715542d580b7067b4 Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 20 Jan 2022 11:46:39 +0000 Subject: [PATCH 55/60] remove overwrite_eps --- torchvision/models/detection/fcos.py | 1 - torchvision/prototype/models/detection/fcos.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index f8f9ad6fec7..ae112802ac0 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -693,5 +693,4 @@ def fcos_resnet50_fpn( if pretrained: state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index ff2d2d4ea1d..a3de0947971 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -36,7 +36,7 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): "categories": _COCO_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", - "map": 39.1, + "map": 39.2, }, ) default = COCO_V1 @@ -76,7 +76,5 @@ def fcos_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FCOS_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) return model From bd3e26280ba6fa0006edd32e00e1ce234664fc3f Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Thu, 20 Jan 2022 11:48:25 +0000 Subject: [PATCH 56/60] Update AP test values --- docs/source/models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 299a95fb459..7f4925acc63 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -640,7 +640,7 @@ Network box AP mask AP keypoint AP Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - -FCOS ResNet-50 FPN 39.1 - - +FCOS ResNet-50 FPN 39.2 - - RetinaNet ResNet-50 FPN 36.4 - - SSD300 VGG16 25.1 - - SSDlite320 MobileNetV3-Large 21.3 - - From e452cfec2c43dee8cdc424e44fdeab34d681e1de Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 18:55:48 +0800 Subject: [PATCH 57/60] Minor fixes for the docs --- torchvision/models/detection/fcos.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index ae112802ac0..82723c808ad 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -673,9 +673,9 @@ def fcos_resnet50_fpn( progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): number of output classes of the model (including the background) pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. - Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is - passed (the default) this value is set to 3. + trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting + from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. Default: None """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 @@ -687,7 +687,7 @@ def fcos_resnet50_fpn( backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) # use P5 + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = FCOS(backbone, num_classes, **kwargs) if pretrained: From e0af32913efdd788953a403ad38088f552bc3e6c Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 20 Jan 2022 19:11:17 +0800 Subject: [PATCH 58/60] Minor fixes for the docs --- torchvision/models/detection/fcos.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 82723c808ad..2e117484aef 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -637,6 +637,7 @@ def fcos_resnet50_fpn( ): """ Constructs a FCOS model with a ResNet-50-FPN backbone. + Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" `_. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each @@ -646,19 +647,23 @@ def fcos_resnet50_fpn( During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the class label for each ground-truth box + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression losses. During inference, the model requires only the input tensors, and returns the post-processed predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as follows, where ``N`` is the number of detections: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (``Int64Tensor[N]``): the predicted labels for each detection - scores (``Tensor[N]``): the scores of each detection + For more details on the output, you may refer to :ref:`instance_seg_output`. Example: From 10be35d583132ca22dd10ca5e26ccd481a3e1ac1 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 20 Jan 2022 20:07:27 +0800 Subject: [PATCH 59/60] Update torchvision/models/detection/fcos.py Co-authored-by: Zhiqiang Wang --- torchvision/models/detection/fcos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 2e117484aef..71a6306e7e1 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -15,7 +15,6 @@ from ...utils import _log_api_usage_once from ..resnet import resnet50 from . import _utils as det_utils -from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform From 49226354abe84eecded3a1c4f300263fb9269696 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Thu, 20 Jan 2022 20:07:43 +0800 Subject: [PATCH 60/60] Update torchvision/prototype/models/detection/fcos.py Co-authored-by: Zhiqiang Wang --- torchvision/prototype/models/detection/fcos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index a3de0947971..d1f7f9ba361 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -9,7 +9,6 @@ FCOS, LastLevelP6P7, misc_nn_ops, - overwrite_eps, ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES