diff --git a/test/test_models.py b/test/test_models.py index 1b55ecb2..0ef183d5 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -12,6 +12,7 @@ from yolort.models.backbone_utils import darknet_pan_backbone from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion from yolort.models.transformer import darknet_tan_backbone +from yolort.models.yolo_lite import yolov5_mobilenet_v3_small_fpn from yolort.v5 import get_yolov5_size, attempt_download @@ -420,3 +421,18 @@ def test_load_from_yolov5_torchscript(arch, size_divisible, version, upstream_ve torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0) torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0) torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0) + + +def test_yolov5_mobilenet_v3_small_fpn(): + + model = yolov5_mobilenet_v3_small_fpn() + model = model.eval() + + images = torch.rand(4, 3, 320, 320) + out = model(images) + assert isinstance(out, list) + assert len(out) == 4 + assert isinstance(out[0], dict) + assert isinstance(out[0]["boxes"], Tensor) + assert isinstance(out[0]["labels"], Tensor) + assert isinstance(out[0]["scores"], Tensor) diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 4eb6f73b..4210f8ff 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -20,7 +20,7 @@ def __init__( ): super().__init__() if not isinstance(in_channels, list): - in_channels = [in_channels] * num_anchors + in_channels = [in_channels] * len(strides) self.num_anchors = num_anchors # anchors self.num_classes = num_classes self.num_outputs = num_classes + 5 # number of outputs per anchor diff --git a/yolort/models/yolo_lite.py b/yolort/models/yolo_lite.py index 0d3b15bf..627f0863 100644 --- a/yolort/models/yolo_lite.py +++ b/yolort/models/yolo_lite.py @@ -1,10 +1,15 @@ +from typing import Dict, List, Callable, Optional + from torch import nn from torchvision.models import mobilenet from torchvision.models._utils import IntermediateLayerGetter from torchvision.models.detection.backbone_utils import _validate_trainable_layers from torchvision.ops import misc as misc_nn_ops -from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool +from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool +from yolort.utils import load_state_dict_from_url +from .anchor_utils import AnchorGenerator +from .box_head import YOLOHead from .yolo import YOLO __all__ = ["yolov5_mobilenet_v3_small_fpn"] @@ -31,7 +36,14 @@ class BackboneWithFPN(nn.Module): out_channels (int): the number of channels in the FPN """ - def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): + def __init__( + self, + backbone: nn.Module, + return_layers: Dict[str, str], + in_channels_list: List[int], + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + ) -> None: super().__init__() if extra_blocks is None: @@ -53,12 +65,12 @@ def forward(self, x): def mobilenet_backbone( - backbone_name, - pretrained, - norm_layer=misc_nn_ops.FrozenBatchNorm2d, - trainable_layers=2, - returned_layers=None, -): + backbone_name: str, + pretrained: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 2, + returned_layers: Optional[List[int]] = None, +) -> nn.Module: backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -79,7 +91,7 @@ def mobilenet_backbone( out_channels = 256 if returned_layers is None: - returned_layers = [num_stages - 2, num_stages - 1] + returned_layers = [num_stages - 3, num_stages - 2, num_stages - 1] assert min(returned_layers) >= 0 and max(returned_layers) < num_stages return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} @@ -94,12 +106,12 @@ def mobilenet_backbone( def _yolov5_mobilenet_v3_small_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=80, - pretrained_backbone=True, - trainable_backbone_layers=None, + weights_name: str, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 80, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, **kwargs, ): trainable_backbone_layers = _validate_trainable_layers( @@ -113,18 +125,40 @@ def _yolov5_mobilenet_v3_small_fpn( pretrained_backbone, trainable_layers=trainable_backbone_layers, ) + strides = [8, 16, 32, 64] + anchor_grids = [ + [19, 27, 44, 40, 38, 94], + [96, 68, 86, 152, 180, 137], + [140, 301, 303, 264, 238, 542], + [436, 615, 739, 380, 925, 792], + ] + anchor_generator = AnchorGenerator(strides, anchor_grids) + head = YOLOHead( + backbone.out_channels, + anchor_generator.num_anchors, + anchor_generator.strides, + num_classes, + ) + model = YOLO(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs) + if pretrained: + if model_urls.get(weights_name, None) is None: + raise ValueError(f"No checkpoint is available for model {weights_name}") + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) + model.load_state_dict(state_dict) + return model - model = YOLO(backbone, num_classes, **kwargs) - return model +model_urls = { + "yolov5_mobilenet_v3_small_fpn_coco": None, +} def yolov5_mobilenet_v3_small_fpn( - pretrained=False, - progress=True, - num_classes=80, - pretrained_backbone=True, - trainable_backbone_layers=None, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 80, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, **kwargs, ): """