Skip to content

Commit

Permalink
MobileNetV2 (apache#9614)
Browse files Browse the repository at this point in the history
* MobileNetV2

* add mobilenetv2 to gluon vision model zoo.

* code reformat use autopep8 and isort.
fix pylint style check error.
add docstring, disable some check.
delete duplicate code in example.

* change ver to version.

* use exist helper.

* model code refactor.

* fix line too long.

* remove invalid name option.

* merge output operations.

* change variables name

* fix line too long.

* s -> stride

* add output name_scope

* remove relu in first conv2d.

* change block name from BottleNeck to LinearBottleneck.

* resolve conflict

* fix parameter name.

* correct strides

* add mobilenetv2 to unittest.

* use autopep8 to reformat code.

* add mobilenetv2 symbols to gluon.vision

* add relu in 1st conv2d.

* code refactor by using helpers.

* split mobilenet v1 and v2 apis.
  • Loading branch information
dwSun authored and zheng-da committed Jun 28, 2018
1 parent 6a63cbe commit 69fc6d9
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 14 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ List of Contributors
* [Meghna Baijal](https://github.com/mbaijal)
* [Tao Hu](https://github.com/dongzhuoyao)
* [Sorokin Evgeniy](https://github.com/TheTweak)
* [dwSun](https://github.com/dwSun/)
76 changes: 76 additions & 0 deletions example/image-classification/symbols/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# -*- coding:utf-8 -*-
'''
MobileNetV2, implemented in Gluon.
Reference:
Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation
https://arxiv.org/abs/1801.04381
'''
__author__ = 'dwSun'
__date__ = '18/1/31'

import mxnet as mx

from mxnet.gluon.model_zoo.vision.mobilenet import MobileNetV2


__all__ = ['MobileNetV2', 'get_symbol']


def get_symbol(num_classes=1000, multiplier=1.0, ctx=mx.cpu(), **kwargs):
r"""MobileNetV2 model from the
`"Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation"
<https://arxiv.org/abs/1801.04381>`_ paper.
Parameters
----------
num_classes : int, default 1000
Number of classes for the output layer.
multiplier : float, default 1.0
The width multiplier for controling the model size. The actual number of channels
is equal to the original channel size multiplied by this multiplier.
ctx : Context, default CPU
The context in which to initialize the model weights.
"""
net = MobileNetV2(multiplier=multiplier, classes=num_classes, **kwargs)
net.initialize(ctx=ctx, init=mx.init.Xavier())
net.hybridize()

data = mx.sym.var('data')
out = net(data)
sym = mx.sym.SoftmaxOutput(out, name='softmax')
return sym


def plot_net():
"""
Visualize the network.
"""
sym = get_symbol(1000, prefix='mob_')

# plot network graph
mx.viz.plot_network(sym, shape={'data': (8, 3, 224, 224)},
node_attrs={'shape': 'oval', 'fixedsize': 'fasl==false'}).view()


if __name__ == '__main__':
plot_net()
11 changes: 9 additions & 2 deletions python/mxnet/gluon/model_zoo/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- `SqueezeNet`_
- `VGG`_
- `MobileNet`_
- `MobileNetV2`_
You can construct a model with random weights by calling its constructor:
Expand Down Expand Up @@ -69,6 +70,7 @@
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _VGG: https://arxiv.org/abs/1409.1556
.. _MobileNet: https://arxiv.org/abs/1704.04861
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
"""

from .alexnet import *
Expand All @@ -85,6 +87,7 @@

from .mobilenet import *


def get_model(name, **kwargs):
"""Returns a pre-defined model by name
Expand Down Expand Up @@ -135,11 +138,15 @@ def get_model(name, **kwargs):
'mobilenet1.0': mobilenet1_0,
'mobilenet0.75': mobilenet0_75,
'mobilenet0.5': mobilenet0_5,
'mobilenet0.25': mobilenet0_25
'mobilenet0.25': mobilenet0_25,
'mobilenetv2_1.0': mobilenet_v2_1_0,
'mobilenetv2_0.75': mobilenet_v2_0_75,
'mobilenetv2_0.5': mobilenet_v2_0_5,
'mobilenetv2_0.25': mobilenet_v2_0_25
}
name = name.lower()
if name not in models:
raise ValueError(
'Model %s is not supported. Available options are\n\t%s'%(
'Model %s is not supported. Available options are\n\t%s' % (
name, '\n\t'.join(sorted(models.keys()))))
return models[name](**kwargs)
Loading

0 comments on commit 69fc6d9

Please sign in to comment.