From 6ee9f36389637e6e91bd162b6d1f69f98198c7ff Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sun, 6 Mar 2022 12:11:39 +0800 Subject: [PATCH] Fix format --- yolort/models/yolo_lite.py | 60 +++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/yolort/models/yolo_lite.py b/yolort/models/yolo_lite.py index 430da93a..4ad5bb06 100644 --- a/yolort/models/yolo_lite.py +++ b/yolort/models/yolo_lite.py @@ -1,9 +1,12 @@ +from typing import Dict, List, Callable, Optional + from torch import nn from torchvision.models import mobilenet from torchvision.models._utils import IntermediateLayerGetter from torchvision.models.detection.backbone_utils import _validate_trainable_layers from torchvision.ops import misc as misc_nn_ops -from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool +from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool +from yolort.utils import load_state_dict_from_url from .anchor_utils import AnchorGenerator from .box_head import YOLOHead @@ -33,7 +36,14 @@ class BackboneWithFPN(nn.Module): out_channels (int): the number of channels in the FPN """ - def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): + def __init__( + self, + backbone: nn.Module, + return_layers: Dict[str, str], + in_channels_list: List[int], + out_channels: int, + extra_blocks: Optional[ExtraFPNBlock] = None, + ) -> None: super().__init__() if extra_blocks is None: @@ -55,12 +65,12 @@ def forward(self, x): def mobilenet_backbone( - backbone_name, - pretrained, - norm_layer=misc_nn_ops.FrozenBatchNorm2d, - trainable_layers=2, - returned_layers=None, -): + backbone_name: str, + pretrained: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 2, + returned_layers: Optional[List[int]] = None, +) -> nn.Module: backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -96,12 +106,12 @@ def mobilenet_backbone( def _yolov5_mobilenet_v3_small_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=80, - pretrained_backbone=True, - trainable_backbone_layers=None, + weights_name: str, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 80, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, **kwargs, ): trainable_backbone_layers = _validate_trainable_layers( @@ -123,7 +133,6 @@ def _yolov5_mobilenet_v3_small_fpn( [436, 615, 739, 380, 925, 792], ] anchor_generator = AnchorGenerator(strides, anchor_grids) - head = YOLOHead( backbone.out_channels, anchor_generator.num_anchors, @@ -131,16 +140,25 @@ def _yolov5_mobilenet_v3_small_fpn( num_classes, ) model = YOLO(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs) - + if pretrained: + if model_urls.get(weights_name, None) is None: + raise ValueError(f"No checkpoint is available for model {weights_name}") + state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) + model.load_state_dict(state_dict) return model +model_urls = { + "yolov5_mobilenet_v3_small_fpn_coco": None, +} + + def yolov5_mobilenet_v3_small_fpn( - pretrained=False, - progress=True, - num_classes=80, - pretrained_backbone=True, - trainable_backbone_layers=None, + pretrained: bool = False, + progress: bool = True, + num_classes: int = 80, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, **kwargs, ): """