Skip to content

Commit

Permalink
Use C3 following the model specification of YOLOv5 (#343)
Browse files Browse the repository at this point in the history
* Use C3 following the model specification of YOLOv5

* Add unittest for yolov5_mobilenet_v3_small_fpn

* Add type annotations

* Minor fix for weights_name
  • Loading branch information
zhiqwang authored Mar 6, 2022
1 parent b96c225 commit f229a08
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 23 deletions.
16 changes: 16 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.yolo_lite import yolov5_mobilenet_v3_small_fpn
from yolort.v5 import get_yolov5_size, attempt_download


Expand Down Expand Up @@ -420,3 +421,18 @@ def test_load_from_yolov5_torchscript(arch, size_divisible, version, upstream_ve
torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0)


def test_yolov5_mobilenet_v3_small_fpn():

model = yolov5_mobilenet_v3_small_fpn()
model = model.eval()

images = torch.rand(4, 3, 320, 320)
out = model(images)
assert isinstance(out, list)
assert len(out) == 4
assert isinstance(out[0], dict)
assert isinstance(out[0]["boxes"], Tensor)
assert isinstance(out[0]["labels"], Tensor)
assert isinstance(out[0]["scores"], Tensor)
2 changes: 1 addition & 1 deletion yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
):
super().__init__()
if not isinstance(in_channels, list):
in_channels = [in_channels] * num_anchors
in_channels = [in_channels] * len(strides)
self.num_anchors = num_anchors # anchors
self.num_classes = num_classes
self.num_outputs = num_classes + 5 # number of outputs per anchor
Expand Down
78 changes: 56 additions & 22 deletions yolort/models/yolo_lite.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Dict, List, Callable, Optional

from torch import nn
from torchvision.models import mobilenet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import _validate_trainable_layers
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from yolort.utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
from .box_head import YOLOHead
from .yolo import YOLO

__all__ = ["yolov5_mobilenet_v3_small_fpn"]
Expand All @@ -31,7 +36,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()

if extra_blocks is None:
Expand All @@ -53,12 +65,12 @@ def forward(self, x):


def mobilenet_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand All @@ -79,7 +91,7 @@ def mobilenet_backbone(
out_channels = 256

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
returned_layers = [num_stages - 3, num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

Expand All @@ -94,12 +106,12 @@ def mobilenet_backbone(


def _yolov5_mobilenet_v3_small_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
weights_name: str,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
Expand All @@ -113,18 +125,40 @@ def _yolov5_mobilenet_v3_small_fpn(
pretrained_backbone,
trainable_layers=trainable_backbone_layers,
)
strides = [8, 16, 32, 64]
anchor_grids = [
[19, 27, 44, 40, 38, 94],
[96, 68, 86, 152, 180, 137],
[140, 301, 303, 264, 238, 542],
[436, 615, 739, 380, 925, 792],
]
anchor_generator = AnchorGenerator(strides, anchor_grids)
head = YOLOHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
model = YOLO(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model

model = YOLO(backbone, num_classes, **kwargs)

return model
model_urls = {
"yolov5_mobilenet_v3_small_fpn_coco": None,
}


def yolov5_mobilenet_v3_small_fpn(
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
"""
Expand Down

0 comments on commit f229a08

Please sign in to comment.