Skip to content

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Mar 6, 2022
1 parent cb1e293 commit 6ee9f36
Showing 1 changed file with 39 additions and 21 deletions.
60 changes: 39 additions & 21 deletions yolort/models/yolo_lite.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Dict, List, Callable, Optional

from torch import nn
from torchvision.models import mobilenet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import _validate_trainable_layers
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from yolort.utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
from .box_head import YOLOHead
Expand Down Expand Up @@ -33,7 +36,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()

if extra_blocks is None:
Expand All @@ -55,12 +65,12 @@ def forward(self, x):


def mobilenet_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down Expand Up @@ -96,12 +106,12 @@ def mobilenet_backbone(


def _yolov5_mobilenet_v3_small_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
weights_name: str,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
Expand All @@ -123,24 +133,32 @@ def _yolov5_mobilenet_v3_small_fpn(
[436, 615, 739, 380, 925, 792],
]
anchor_generator = AnchorGenerator(strides, anchor_grids)

head = YOLOHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
model = YOLO(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)

if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
return model


model_urls = {
"yolov5_mobilenet_v3_small_fpn_coco": None,
}


def yolov5_mobilenet_v3_small_fpn(
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
"""
Expand Down

0 comments on commit 6ee9f36

Please sign in to comment.