Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

First attempt at adding a mobilenet backbone #242

Closed
wants to merge 1 commit into from

Conversation

t-vi
Copy link

@t-vi t-vi commented Dec 3, 2018

As discussed, this is a very rough attempt at putting in a mobilenet backbone.

  • It uses @tonylins Pytorch-Mobilenet (so the licensing / CLA needs to be checked with him)
  • Errors in the modifications to that are all my own of course
  • I just grab the last 5 things from the model.features (which I made into a ModuleList instead of a sequential, you could probably leave it as sequential and it would work, too, but I don't intend to call it, so this seems cleaner),
  • I attempted to do one of the todos in the FPN to not make a size-doubling assumption for the layers.
  • I'm using the pretrained weights @tonylins links to, but I'm not sure how to link them directly into the file, so I download that to the working directory,
  • Similar to the feature extraction, the layer freezing is pretty arbitrary, I freeze 14 of 19 bits in the model.features list.

I must say that I absolutely loved how the training works with the stock mscoco out of the box and is set up to be extensible!

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Dec 3, 2018
@tonylins
Copy link

tonylins commented Dec 3, 2018

Just received an email about this thread. I put an Apache License 2.0 in my repo. Feel free to use my repo and contact me if you meet any problem :).

@YellowKyu
Copy link

Hello there,

Thank you for your work guys ! Since the review is taking some time to be done, I tried your implementation, I succeeded in running your code but my 'loss_classifier' is always stuck at 0.5~0.6 and it is never decreasing ... does it behave like that for you too ? any idea on possible reason ?

Maybe if we debug we can make thing progress !

@t-vi
Copy link
Author

t-vi commented Jan 22, 2019

I think the net architecture isn't good yet (e.g. which backbone layers are used as feature planes). In lieu of anyone providing a better starting point, one could take a look at published mobilenet RCNNs, but I haven't gotten around to doing so.

@fmassa
Copy link
Contributor

fmassa commented Jan 22, 2019

@wat3rBro do you have any tips or advices for a better mobile architecture?

@kimnik6
Copy link

kimnik6 commented Jan 29, 2019

I was able to get the net to train properly by using the last layers with 24, 32, 64 and 1280 channels as feature planes.
In the e2e_faster_rcnn_mobilenet.yaml file, I changed the "FREEZE_CONV_BODY_AT" value to 3, and changed the "OUT_CHANNELS" under "MOBILENET" to (24, 32, 64, 1280).
In the mobilenet.py file I changed the forward method of the MobileNetV2-class to

    def forward(self, x):
        res = []
        for idx, m in enumerate(self.features):
            x = m(x)
            if idx in [3, 6, 10, 18]:
                res.append(x)
        return res

@fmassa
Copy link
Contributor

fmassa commented Jan 29, 2019

@kimnik6 awesome! What was the accuracy that you obtained with such configuration?

@kimnik6
Copy link

kimnik6 commented Jan 29, 2019

It's still training, right now it's at a box AP of 23.8. However, on a GTX1070 it still takes 0.095s/image for inference (compared to 0.132s/image for the Res50_FPN).

@fmassa
Copy link
Contributor

fmassa commented Jan 29, 2019

Maybe the benefits will be more noticeable when running it on the CPU? But anyway, very interesting findings!
Also, maybe removing the FPN might make things faster.

@kimnik6
Copy link

kimnik6 commented Feb 1, 2019

Finished training, box AP is at 26.5. I'll try it without the FPN to see if it gets any faster. The RPN seems to be dominating the runtime of the whole net right now.

@t-vi
Copy link
Author

t-vi commented Feb 1, 2019

I think dropping the FPN and amending the Backbone for the feature layers would be good next steps.

@kimnik6
Copy link

kimnik6 commented Feb 6, 2019

@t-vi: do you already have an idea how you would do it?

For a first attempt based off of the ResNetC4 structure, I took the last of the MobileNet layers with 96 channels for my output. I have not completely understood the differences/reasons for choosing the ROI_BOX_HEAD, so I just continued using the FPN2MLPFeatureExtractor. My config-file looks like this

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  WEIGHT: "./mobilenet_v2.pth.tar"
  BACKBONE:
    CONV_BODY: "MobileNetV2-FPN"
    OUT_CHANNELS: 96
    FREEZE_CONV_BODY_AT: 3
  RPN:
    PRE_NMS_TOP_N_TEST: 6000
    POST_NMS_TOP_N_TEST: 1000
  ROI_BOX_HEAD:
    POOLER_RESOLUTION: 14
    POOLER_SCALES: (0.0625, )
    POOLER_SAMPLING_RATIO: 0
    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
    PREDICTOR: "FPNPredictor"
DATASETS:
  TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
  TEST: ("coco_2014_minival",)
SOLVER:
  BASE_LR: 0.02
  CHECKPOINT_PERIOD: 10000
  WEIGHT_DECAY: 0.0001
  STEPS: (180000, 270000)
  MAX_ITER: 360000
  IMS_PER_BATCH: 12

It's still training, currently at a box mAP at 19.9 at 0.049s/image. Classifier and box_reg losses look good, however objectness and rpn_box_reg losses are rather high. If anyone has any ideas for a better structure, let me know.

Also, I retrained the MobileNet with FPN with a larger batch-size and slightly different schedule, got it up to a box mAP of 29.8 (still 0.095s/image).

@fmassa
Copy link
Contributor

fmassa commented Feb 6, 2019

@kimnik6 I believe that you might be using only the first layer of the MobileNetV2-FPN model to do the training, which have very small boxes. This could explain why you are seeing the objectness and rpn_box_losses as being high.

I'd recommend taking the config from R50-C4, and modify slightly the backbone to not return the 4 feature levels, only returning the last feature map.

@t-vi
Copy link
Author

t-vi commented Feb 6, 2019

I advise to strictly prefer @fmassa's advice over anything I say as he has a lot more insight into this.

With that in mind, I looked a bit at TF's mobile detection model and what it does is it amends the Mobilenet (v1 in their case) with four blocks that each half the number of channels and the resolution. So there, you start with 786 channels and the first Resblock is with 768 input channels, 384 output channels, a bottleneck of 192 channels and halfs the input resolution.
Then the last output of the mobilenet and the additional block's outputs are used as outputs fo the backbone.

@kimnik6
Copy link

kimnik6 commented Feb 8, 2019

@fmassa That was my initial guess as well, but now I think it just simply is higher than the other models, as I checked the model multiple times and the results aren't actually too bad (right now at 26.8mAP with 0.049s/image, waiting for the learning rate to be dropped one last time).

@t-vi Could you post the link to the model/script your referring to? At what point are the extra blocks inserted/added?

@t-vi
Copy link
Author

t-vi commented Feb 11, 2019

I looked at the ssd_mobilenet_v1_0.75_depth_coco linked from here: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

@kimnik6
Copy link

kimnik6 commented Feb 13, 2019

Alright, training finished at 27.8 mAP (without FPN). I will leave it at that for now.
So overall, my (or rather mostly @t-vi's ) changes were:

  • maskrcnn_benchmark/config/defaults.py: same as @t-vi
  • maskrcnn_benchmark/modeling/backbone/backbone.py: add this after line 10
from . import mobilenet


@registry.BACKBONES.register("MobileNetV2-FPN")
def build_mobilenet_backbone(cfg):
    body = mobilenet.MobileNetV2(cfg)
    in_channels_stage2 = cfg.MODEL.MOBILENET.OUT_CHANNELS
    out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
    if cfg.MODEL.RPN.USE_FPN:
        fpn = fpn_module.FPN(
            in_channels_list=in_channels_stage2,
            out_channels=out_channels,
            conv_block=conv_with_kaiming_uniform(cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU)
            top_blocks=fpn_module.LastLevelMaxPool()
        )
        model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
    else:
        model = nn.Sequential(OrderedDict([("body", body)]))
    return model
  • new file maskrcnn_benchmark/modeling/backbone/mobilenet.py (some slight changes from @t-vi's commit)
# taken from https://github.com/tonylins/pytorch-mobilenet-v2/
# Published by Ji Lin, tonylins
# licensed under the  Apache License, Version 2.0, January 2004

from torch import nn

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, cfg, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        if cfg.MODEL.RPN.USE_FPN:
            last_channel = 1280
            interverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]
        else:
            last_channel = 96
            interverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
            ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = nn.ModuleList([conv_bn(3, input_channel, 2)])
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        self.rpn_layers = [3, 6, 10, 18] if cfg.MODEL.RPN.USE_FPN else [13]
        # make it nn.Sequential
        #self.features = nn.Sequential(*self.features)

        self._initialize_weights()
        self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT)

    def _freeze_backbone(self, freeze_at):
        for layer_index in range(freeze_at):
            for p in self.features[layer_index].parameters():
                p.requires_grad = False

    def forward(self, x):
        res = []
        for idx, m in enumerate(self.features):
            x = m(x)
            if idx in self.rpn_layers:
                res.append(x)
        return res

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n) ** 0.5)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
  • new config files:
    with FPN
MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  WEIGHT: "./mobilenet_v2.pth.tar"
  BACKBONE:
    CONV_BODY: "MobileNetV2-FPN"
    OUT_CHANNELS: 256
    FREEZE_CONV_BODY_AT: 3
  MOBILENET:
    # OUT_CHANNELS: (160, 160, 160, 320, 1280)
    OUT_CHANNELS: (24, 32, 64, 1280)
  RPN:
    USE_FPN: True
    ANCHOR_STRIDE: (4, 8, 16, 32, 64)
    PRE_NMS_TOP_N_TRAIN: 2000
    PRE_NMS_TOP_N_TEST: 1000
    POST_NMS_TOP_N_TEST: 1000
    FPN_POST_NMS_TOP_N_TEST: 1000
  ROI_HEADS:
    USE_FPN: True
  ROI_BOX_HEAD:
    POOLER_RESOLUTION: 7
    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
    POOLER_SAMPLING_RATIO: 2
    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
    PREDICTOR: "FPNPredictor"
DATASETS:
  TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
  TEST: ("coco_2014_minival",)
DATALOADER:
  SIZE_DIVISIBILITY: 32
SOLVER:
  BASE_LR: 0.02
  CHECKPOINT_PERIOD: 10000
  WEIGHT_DECAY: 0.0001
  STEPS: (180000, 270000)
  MAX_ITER: 360000
  IMS_PER_BATCH: 4

without FPN:

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  WEIGHT: "./mobilenet_v2.pth.tar"
  BACKBONE:
    CONV_BODY: "MobileNetV2-FPN"
    OUT_CHANNELS: 96
    FREEZE_CONV_BODY_AT: 3
  RPN:
    PRE_NMS_TOP_N_TEST: 6000
    POST_NMS_TOP_N_TEST: 1000
  ROI_BOX_HEAD:
    POOLER_RESOLUTION: 14
    POOLER_SCALES: (0.0625,)
    POOLER_SAMPLING_RATIO: 0
    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
    PREDICTOR: "FPNPredictor"
DATALOADER:
  SIZE_DIVISIBILITY: 32
DATASETS:
  TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
  TEST: ("coco_2014_minival",)
SOLVER:
  BASE_LR: 0.02
  CHECKPOINT_PERIOD: 10000
  WEIGHT_DECAY: 0.0001
  STEPS: (180000, 270000)
  MAX_ITER: 360000
  IMS_PER_BATCH: 12

I'll try to upload the weights soon

@t-vi
Copy link
Author

t-vi commented Feb 13, 2019

Awesome work @kimnik6! Do you have a github branch I can update this PR with or should I copy the files?
I tend to use the github release mechanism for releasing weights (e.g. https://github.com/t-vi/pytorch-tvmisc/releases).

@kimnik6
Copy link

kimnik6 commented Feb 13, 2019

You can just copy the code. Thanks for the suggestion @t-vi, I uploaded them under https://github.com/kimnik6/maskrcnn-benchmark-mobilenet/releases.

@zimenglan-sysu-512
Copy link
Contributor

zimenglan-sysu-512 commented Feb 14, 2019

hi @kimnik6
how many gpus do u use to train the model without FPN?
btw, can u share the pretrained mobilenet-v2 model?

@kimnik6
Copy link

kimnik6 commented Feb 14, 2019

Without the FPN, multi-GPU didn't work, so I only trained it on one GPU. The initial weights I used are from https://github.com/tonylins/pytorch-mobilenet-v2, you can find my trained models at https://github.com/kimnik6/maskrcnn-benchmark-mobilenet/releases.

@zimenglan-sysu-512
Copy link
Contributor

zimenglan-sysu-512 commented Feb 25, 2019

hi @kimnik6
i train the mobile-net with fpn and get 0.306 with these configurations:

  • 2x lr
  • fix bn
  • 4 gpus with bacthsize=16

and train the mobile-net with retinanet and get 0.330 with these configurations:

  • 2x lr
  • fix bn
  • 4 gpus with bacthsize=16

also i train it from scratch using GN and fpn, get 0.35 with these configurations:

  • 3x lr
  • use gn
  • 4 gpus with bacthsize=8

but the speed is not fast.

@kimnik6
Copy link

kimnik6 commented Feb 25, 2019

Concerning the speed, the rest of the network really slows down the process. The fastest I got the MobileNet to work was ~20FPS without the FPN (without changing the image dimensions)

@fmassa
Copy link
Contributor

fmassa commented Feb 26, 2019

Note that FBNet PR has just been merged #463
and it seems to be faster than MobileNet. @newstzpz will be uploading pre-trained models and numbers soon.

@zimenglan-sysu-512
Copy link
Contributor

zimenglan-sysu-512 commented Feb 28, 2019

hi @fmassa
i run configs/e2e_faster_rcnn_fbnet.yaml and get 24.6, and the inference speed is 0.04926319899559021s/img using single gpu GeForce GTX 1080 (11G). it's great to improve the speed.

@fmassa
Copy link
Contributor

fmassa commented Feb 28, 2019

@zimenglan-sysu-512 did you change anything in the training config?
I believe results are expected to be better, and to run faster as well.
cc @newstzpz who will be uploading trained models and statistics

@zimenglan-sysu-512
Copy link
Contributor

hi @fmassa
i don't change any settings in the training config.

@fmassa
Copy link
Contributor

fmassa commented Feb 28, 2019

Ok. I'll let @newstzpz comment on the expected numbers for those models

@newstzpz
Copy link
Contributor

newstzpz commented Mar 1, 2019

Hi @zimenglan-sysu-512 , thanks for experimenting with our model. The number you got is expected as it was target for mobile and not GPU friendly. We have models, although still not very GPU friendly, runs at 0.188ms with mAP 33.5 measuring in Caffe2 using V100. Please see here for more details.

Our model supports various efficient building blocks and architecture, including the one in your PR. In addition, we also support using efficient building blocks in RPN and roi heads (bbox, mask etc.). This is important for fast speed on mobile. Here is an example of how to define the architecture:

    "cham_v1a": {
        "block_op_type": [
            # stage 0
            ["ir_k3"],
            # stage 1
            ["ir_k7"] * 2,
            # stage 2
            ["ir_k3"] * 5,
            # stage 3
            ["ir_k5"] * 7 + ["ir_k3"] * 5,
            # stage 4, bbox head
            ["ir_k3"] * 5,
            # stage 5, rpn
            ["ir_k3"] * 3,
        ],
        "block_cfg": {
            "first": [32, 2],
            "stages": [
                # [t, c, n, s]
                # stage 0
                [[1, 24, 1, 1]],
                # stage 1
                [[4, 48, 2, 2]],
                # stage 2
                [[7, 64, 5, 2]],
                # stage 3
                [[12, 56, 7, 2], [8, 88, 5, 1]],
                # stage 4, bbox head
                [[7, 152, 4, 2], [10, 104, 1, 1]],
                # stage 5, rpn head
                [[8, 88, 3, 1]],
            ],
            # [c, channel_scale]
            "last": [0, 0.0],
            "backbone": [0, 1, 2, 3],
            "rpn": [5],
            "bbox": [4],
        },
    },

We provide a few baseline architectures, please feel free to check them out here and all the supported building blocks here.

@t-vi
Copy link
Author

t-vi commented Apr 30, 2019

This is obsolete now.

@t-vi t-vi closed this Apr 30, 2019
@adizhol
Copy link

adizhol commented Feb 18, 2020

Note that FBNet PR has just been merged #463
and it seems to be faster than MobileNet. @newstzpz will be uploading pre-trained models and numbers soon.

FBNet is a searched architecture. In my experience these have bad performance when changing domain.
I would still merge this MobileNet PR.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants