Skip to content

Commit

Permalink
Add MobileNetV3 architecture for Detection (#3253)
Browse files Browse the repository at this point in the history
* Minor refactoring of a private method to make it reusuable.

* Adding a FasterRCNN + MobileNetV3 with & w/o FPN models.

* Reducing Resolution to 320-640 and anchor sizes to 16-256.

* Increase anchor sizes.

* Adding rpn score threshold param on the train script.

* Adding trainable_backbone_layers param on the train script.

* Adding rpn_score_thresh param directly in fasterrcnn_mobilenet_v3_large_fpn.

* Remove fasterrcnn_mobilenet_v3_large prototype and update expected file.

* Update documentation and adding weights.

* Use buildin Identity.

* Fix spelling.
  • Loading branch information
datumbox authored Jan 18, 2021
1 parent 0985533 commit bf211da
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 70 deletions.
33 changes: 18 additions & 15 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,14 @@ models return the predictions of the following classes:
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.

================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== ===========
================================== ======= ======== ===========
Network box AP mask AP keypoint AP
================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 23.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================== ======= ======== ===========

For person keypoint detection, the accuracies for the pre-trained
models are as follows
Expand Down Expand Up @@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the
precision-recall.

============================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== ===========
================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
================================== =================== ================== ===========


Faster R-CNN
------------

.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn


RetinaNet
Expand Down
9 changes: 8 additions & 1 deletion references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ You must modify the following flags:

Except otherwise noted, all models have been trained on 8x V100 GPUs.

### Faster R-CNN
### Faster R-CNN ResNet-50 FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### Faster R-CNN MobileNetV3-Large FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
Expand Down
10 changes: 8 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ def main(args):
collate_fn=utils.collate_fn)

print("Creating model")
kwargs = {}
kwargs = {
"trainable_backbone_layers": args.trainable_backbone_layers
}
if "rcnn" in args.model:
kwargs["rpn_score_thresh"] = 0.0
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model.to(device)
Expand Down Expand Up @@ -177,6 +180,9 @@ def main(args):
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down
Binary file not shown.
3 changes: 3 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_available_video_models():
'googlenet': lambda x: x.logits,
'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
Expand Down Expand Up @@ -105,6 +106,8 @@ def _test_detection_model(self, name, dev):
if "retinanet" in name:
# Reduce the default threshold to ensure the returned boxes are not empty.
kwargs["score_thresh"] = 0.01
elif "fasterrcnn_mobilenet" in name:
kwargs["box_score_thresh"] = 0.02076
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
model.eval().to(device=dev)
input_shape = (3, 300, 300)
Expand Down
15 changes: 8 additions & 7 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,15 @@ def test_assign_targets_to_proposals(self):
self.assertEqual(labels[0].dtype, torch.int64)

def test_forward_negative_sample_frcnn(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))

def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
Expand All @@ -130,7 +131,7 @@ def test_forward_negative_sample_krcnn(self):

def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
Expand Down
12 changes: 6 additions & 6 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self):

def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=True, trainable_backbone_layers=None)
ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
self.assertEqual(ret, 3)
# can't go beyond 5
with self.assertRaises(AssertionError):
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=True, trainable_backbone_layers=6)
ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
# if not pretrained, should use all trainable layers and warn
with self.assertWarns(UserWarning):
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=False, trainable_backbone_layers=0)
ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
self.assertEqual(ret, 5)

def test_transform_copy_targets(self):
Expand Down
62 changes: 55 additions & 7 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from collections import OrderedDict
from torch import nn
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import mobilenet
from .. import resnet


Expand Down Expand Up @@ -108,17 +108,65 @@ def resnet_fpn_backbone(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)


def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used
if not pretrained:
if trainable_backbone_layers is not None:
warnings.warn(
"Changing trainable_backbone_layers has not effect if "
"neither pretrained nor pretrained_backbone have been set to True, "
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
trainable_backbone_layers = 5
# by default, freeze first 2 blocks following Faster R-CNN
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
trainable_backbone_layers = max_value

# by default freeze first blocks
if trainable_backbone_layers is None:
trainable_backbone_layers = 3
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
trainable_backbone_layers = default_value
assert 0 <= trainable_backbone_layers <= max_value
return trainable_backbone_layers


def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)

out_channels = 256
if fpn:
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
else:
m = nn.Sequential(
backbone,
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
return m
59 changes: 51 additions & 8 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F

from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign

from ._utils import overwrite_eps
Expand All @@ -15,11 +12,11 @@
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone


__all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn",
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
]


Expand Down Expand Up @@ -291,6 +288,8 @@ def forward(self, x):
model_urls = {
'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
}


Expand Down Expand Up @@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand All @@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model


def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
**kwargs):
"""
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
Example::
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)

anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
model.load_state_dict(state_dict)
return model
7 changes: 3 additions & 4 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers


__all__ = [
Expand Down Expand Up @@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand Down
7 changes: 3 additions & 4 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers

__all__ = [
"MaskRCNN", "maskrcnn_resnet50_fpn",
Expand Down Expand Up @@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand Down
Loading

0 comments on commit bf211da

Please sign in to comment.