Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use C3 following the model specification of YOLOv5 #343

Merged
merged 4 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion
from yolort.models.transformer import darknet_tan_backbone
from yolort.models.yolo_lite import yolov5_mobilenet_v3_small_fpn
from yolort.v5 import get_yolov5_size, attempt_download


Expand Down Expand Up @@ -420,3 +421,18 @@ def test_load_from_yolov5_torchscript(arch, size_divisible, version, upstream_ve
torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0)


def test_yolov5_mobilenet_v3_small_fpn():

model = yolov5_mobilenet_v3_small_fpn()
model = model.eval()

images = torch.rand(4, 3, 320, 320)
out = model(images)
assert isinstance(out, list)
assert len(out) == 4
assert isinstance(out[0], dict)
assert isinstance(out[0]["boxes"], Tensor)
assert isinstance(out[0]["labels"], Tensor)
assert isinstance(out[0]["scores"], Tensor)
2 changes: 1 addition & 1 deletion yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
):
super().__init__()
if not isinstance(in_channels, list):
in_channels = [in_channels] * num_anchors
in_channels = [in_channels] * len(strides)
self.num_anchors = num_anchors # anchors
self.num_classes = num_classes
self.num_outputs = num_classes + 5 # number of outputs per anchor
Expand Down
78 changes: 56 additions & 22 deletions yolort/models/yolo_lite.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Dict, List, Callable, Optional

from torch import nn
from torchvision.models import mobilenet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import _validate_trainable_layers
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from yolort.utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
from .box_head import YOLOHead
from .yolo import YOLO

__all__ = ["yolov5_mobilenet_v3_small_fpn"]
Expand All @@ -31,7 +36,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""

def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super().__init__()

if extra_blocks is None:
Expand All @@ -53,12 +65,12 @@ def forward(self, x):


def mobilenet_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand All @@ -79,7 +91,7 @@ def mobilenet_backbone(
out_channels = 256

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
returned_layers = [num_stages - 3, num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

Expand All @@ -94,12 +106,12 @@ def mobilenet_backbone(


def _yolov5_mobilenet_v3_small_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
weights_name: str,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
Expand All @@ -113,18 +125,40 @@ def _yolov5_mobilenet_v3_small_fpn(
pretrained_backbone,
trainable_layers=trainable_backbone_layers,
)
strides = [8, 16, 32, 64]
anchor_grids = [
[19, 27, 44, 40, 38, 94],
[96, 68, 86, 152, 180, 137],
[140, 301, 303, 264, 238, 542],
[436, 615, 739, 380, 925, 792],
]
anchor_generator = AnchorGenerator(strides, anchor_grids)
head = YOLOHead(
backbone.out_channels,
anchor_generator.num_anchors,
anchor_generator.strides,
num_classes,
)
model = YOLO(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError(f"No checkpoint is available for model {weights_name}")
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model

model = YOLO(backbone, num_classes, **kwargs)

return model
model_urls = {
"yolov5_mobilenet_v3_small_fpn_coco": None,
}


def yolov5_mobilenet_v3_small_fpn(
pretrained=False,
progress=True,
num_classes=80,
pretrained_backbone=True,
trainable_backbone_layers=None,
pretrained: bool = False,
progress: bool = True,
num_classes: int = 80,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs,
):
"""
Expand Down