From 389a00f604b436cb51674c3e183d0f283aca5a36 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 22 Sep 2017 22:50:07 -0700 Subject: [PATCH] init mxnet converter (#27) graph backup update finish mxnet converter fix fix various add tests fix add multi networks uses model_zoo fix tests minor fix fix graph fix --- nnvm/python/nnvm/__init__.py | 1 + nnvm/python/nnvm/frontend/__init__.py | 3 + nnvm/python/nnvm/frontend/mxnet.py | 301 ++++++++++++++++ .../frontend/mxnet/model_zoo/__init__.py | 22 ++ .../python/frontend/mxnet/model_zoo/mlp.py | 44 +++ .../python/frontend/mxnet/model_zoo/resnet.py | 322 ++++++++++++++++++ .../python/frontend/mxnet/model_zoo/vgg.py | 128 +++++++ .../python/frontend/mxnet/test_forward.py | 88 +++++ .../tests/python/frontend/mxnet/test_graph.py | 38 +++ 9 files changed, 947 insertions(+) create mode 100644 nnvm/python/nnvm/frontend/__init__.py create mode 100644 nnvm/python/nnvm/frontend/mxnet.py create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/mlp.py create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/resnet.py create mode 100644 nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py create mode 100644 nnvm/tests/python/frontend/mxnet/test_forward.py create mode 100644 nnvm/tests/python/frontend/mxnet/test_graph.py diff --git a/nnvm/python/nnvm/__init__.py b/nnvm/python/nnvm/__init__.py index d30b1c152e2b..31b88587764d 100644 --- a/nnvm/python/nnvm/__init__.py +++ b/nnvm/python/nnvm/__init__.py @@ -7,5 +7,6 @@ from . import symbol as sym from . import symbol from ._base import NNVMError +from . import frontend __version__ = _base.__version__ diff --git a/nnvm/python/nnvm/frontend/__init__.py b/nnvm/python/nnvm/frontend/__init__.py new file mode 100644 index 000000000000..de6c9ee3f7a1 --- /dev/null +++ b/nnvm/python/nnvm/frontend/__init__.py @@ -0,0 +1,3 @@ +"""Frontend package.""" +from __future__ import absolute_import +from .mxnet import from_mxnet diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py new file mode 100644 index 000000000000..5e1ea0457d60 --- /dev/null +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -0,0 +1,301 @@ +"""MXNet symbol frontend.""" +from __future__ import absolute_import as _abs +import json +from .. import symbol as _sym + +__all__ = ['from_mxnet'] + +def _required_attr(attr, key): + assert isinstance(attr, dict) + if key not in attr: + raise AttributeError("Required attribute {} not found.".format(key)) + return attr[key] + +def _raise_not_supported(attr, op='nnvm'): + err = "{} is not supported in {}.".format(attr, op) + raise NotImplementedError(err) + +def _warn_not_used(attr, op='nnvm'): + import warnings + err = "{} is ignored in {}.".format(attr, op) + warnings.warn(err) + +def _parse_tshape(tshape): + """Parse tshape in string.""" + return [int(x.strip()) for x in tshape.strip('()').split(',')] + +def _parse_bool_str(attr, key, default='False'): + """Parse bool string to boolean.""" + return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] + +def _rename(new_name): + def impl(attr): + return new_name, attr + return impl + +def _variable(attrs): + return "Variable", attrs + +def _pooling(attrs): + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + if len(kernel) != 2: + _raise_not_supported('non-2d kernel', 'pool_2d') + global_pool = 'global' if _parse_bool_str(attrs, 'global_pool') else '' + pool_type = _required_attr(attrs, 'pool_type') + if pool_type not in ['avg', 'max']: + _raise_not_supported('non-avg/max', 'pool2d') + op_name, new_attrs = '_'.join([global_pool, pool_type, 'pool2d']).strip('_'), {} + # new_attrs['layout'] = 'NCHW' + if not global_pool: + new_attrs['pool_size'] = kernel + new_attrs['strides'] = attrs.get('stride', (1, 1)) + new_attrs['padding'] = attrs.get('pad', (0, 0)) + new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full') + return op_name, new_attrs + +def _batch_norm(attrs): + if _parse_bool_str(attrs, 'output_mean_var'): + _raise_not_supported('output_mean_var', 'batch_norm') + if _parse_bool_str(attrs, 'fix_gamma'): + _warn_not_used('fix_gamma', 'batch_norm') + if _parse_bool_str(attrs, 'use_global_stats'): + _warn_not_used('use_global_stats', 'batch_norm') + if _parse_bool_str(attrs, 'momentum'): + _warn_not_used('momentum', 'batch_norm') + op_name, new_attrs = 'batch_norm', {} + new_attrs['axis'] = attrs.get('axis', 1) + new_attrs['epsilon'] = attrs.get('eps', 0.001) + new_attrs['center'] = True + new_attrs['scale'] = True + return op_name, new_attrs + +def _concat(attrs): + op_name = 'concatenate' + new_attrs = {'axis': attrs.get('dim', 1)} + return op_name, new_attrs + +def _conv2d(attrs): + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + if len(kernel) != 2: + _raise_not_supported('non 2d kernel', 'conv2d') + layout = attrs.get('layout', 'NCHW') + if layout not in ['NCHW', 'NHWC']: + _raise_not_supported('layout: ' + layout, 'conv2d') + op_name, new_attrs = 'conv2d', {} + new_attrs['channels'] = _required_attr(attrs, 'num_filter') + new_attrs['kernel_size'] = kernel + new_attrs['strides'] = attrs.get('stride', (1, 1)) + new_attrs['padding'] = attrs.get('pad', (0, 0)) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + new_attrs['layout'] = layout + new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' + return op_name, new_attrs + +def _conv2d_transpose(attrs): + if 'target_shape' in attrs: + _raise_not_supported('target_shape', 'conv2d_transpose') + kernel = _parse_tshape(_required_attr(attrs, 'kernel')) + if len(kernel) != 2: + _raise_not_supported('non-2d kernel', 'conv2d_transpose') + layout = attrs.get('layout', 'NCHW') + if layout not in ['NCHW', 'NHWC']: + _raise_not_supported('layout: ' + layout, 'conv2d_transpose') + op_name, new_attrs = 'conv2d_transpose', {} + new_attrs['channels'] = _required_attr(attrs, 'num_filter') + new_attrs['kernel_size'] = kernel + new_attrs['strides'] = attrs.get('stride', (1, 1)) + new_attrs['output_padding'] = attrs.get('adj', (0, 0)) + new_attrs['padding'] = attrs.get('pad', (0, 0)) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + new_attrs['layout'] = layout + new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') + return op_name, new_attrs + +def _dense(attrs): + op_name, new_attrs = 'dense', {} + new_attrs['units'] = _required_attr(attrs, 'num_hidden') + new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') + return op_name, new_attrs + +def _dropout(attrs): + op_name, new_attrs = 'dropout', {} + new_attrs['rate'] = attrs.get('p', 0.5) + return op_name, new_attrs + +def _leaky_relu(attrs): + act_type = _required_attr(attrs, 'act_type') + if act_type not in ['leaky']: + _raise_not_supported('act_type: ' + act_type) + op_name, new_attrs = 'leaky_relu', {} + new_attrs['alpha'] = attrs.get('slope', 0.25) + return op_name, new_attrs + +def _activations(attrs): + act_type = _required_attr(attrs, 'act_type') + if act_type not in ['relu', 'sigmoid', 'tanh']: + _raise_not_supported('act_type: ' + act_type) + op_name, new_attrs = act_type, {} + return op_name, new_attrs + +def _reshape(attrs): + if _parse_bool_str(attrs, 'reverse'): + _raise_not_supported('reverse', 'reshape') + op_name, new_attrs = 'reshape', {} + new_attrs['shape'] = _required_attr(attrs, 'shape') + return op_name, new_attrs + +def _split(attrs): + if _parse_bool_str(attrs, 'squeeze_axis'): + _raise_not_supported('squeeze_axis', 'split') + op_name, new_attrs = 'split', {} + new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') + new_attrs['axis'] = attrs.get('axis', 1) + return op_name, new_attrs + +_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', + '__div_symbol__', '__mul_scalar__', '__mul_symbol__', + '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', + '__rsub_scalar__', '__sub_scalar__', '__sub_symbol__', + 'broadcast_add', 'broadcast_div', 'broadcast_mul', + 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', + 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', + 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', + 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] + +_convert_map = { + 'null' : _variable, + 'Activation' : _activations, + 'BatchNorm' : _batch_norm, + 'BatchNorm_v1' : _batch_norm, + 'Cast' : _rename('cast'), + 'Concat' : _concat, + 'Convolution' : _conv2d, + 'Convolution_v1': _conv2d, + 'Deconvolution' : _conv2d_transpose, + 'Dropout' : _dropout, + 'Flatten' : _rename('flatten'), + 'FullyConnected': _dense, + 'LeakyReLU' : _leaky_relu, + 'Pooling' : _pooling, + 'Pooling_v1' : _pooling, + 'Reshape' : _reshape, + 'Softmax' : _rename('softmax'), + 'concat' : _concat, + 'max_axis' : _rename('max'), + 'min_axis' : _rename('min'), + 'reshape' : _reshape, + 'sum_axis' : _rename('sum'), +} + +def _convert_symbol(op_name, attrs, + identity_list=_identity_list, + convert_map=_convert_map): + """Convert from mxnet op to nnvm op. + The converter must specify some conversions explicitly to + support gluon format ops such as conv2d... + + Parameters + ---------- + op_name : str + Operator name, such as Convolution, FullyConnected + attrs : dict + Dict of operator attributes + identity_list : list + List of operators that don't require conversion + convert_map : dict + Dict of name : callable, where name is the op's name that + require conversion to nnvm, callable are functions which + take attrs and return (new_op_name, new_attrs) + + Returns + ------- + (op_name, attrs) + Converted (op_name, attrs) for nnvm. + """ + if op_name in identity_list: + pass + elif op_name in convert_map: + op_name, attrs = convert_map[op_name](attrs) + else: + _raise_not_supported('Operator: ' + op_name) + op = getattr(_sym, op_name, None) + if not op: + raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) + return op, attrs + +def _is_mxnet_group_symbol(symbol): + """Internal check for mxnet group symbol.""" + return len(symbol.list_outputs()) > 1 + +def _as_list(arr): + """Force being a list, ignore if already is.""" + if isinstance(arr, list): + return arr + return [arr] + +def _from_mxnet_impl(symbol, graph): + """Convert mxnet symbol to nnvm implementation. + Reconstruct a nnvm symbol by traversing the mxnet symbol. + + Parameters + ---------- + symbol : mxnet.sym.Symbol + Incompatible symbol from mxnet, sharing similar graph structure. + The op_name and attrs inside are not always compatible. + graph : dict + Reusable nodes are stored in graph. + + Returns: + ------- + nnvm.sym.Symbol + Converted symbol + """ + try: + from mxnet import sym as mx_sym + except ImportError as e: + raise ImportError('{}. MXNet is required to parse symbols.'.format(e)) + + if not isinstance(symbol, mx_sym.Symbol): + raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol)) + + if _is_mxnet_group_symbol(symbol): + return [_from_mxnet_impl(s, graph) for s in symbol] + + name = symbol.attr('name') + node = graph.get(name, None) + if node: + return node + # op_name = symbol.attr('op_name') + if symbol.get_children(): + op_name = symbol.attr('op_name') + else: + op_name = json.loads(symbol.tojson())['nodes'][0]['op'] + attr = symbol.list_attr() + new_op, new_attr = _convert_symbol(op_name, attr) + if new_op == _sym.Variable: + node = new_op(name=name, **new_attr) + else: + childs = symbol.get_children() + childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)] + childs = [x for y in childs for x in _as_list(y)] # expand group symbol + node = new_op(name=name, *childs, **new_attr) + graph[name] = node + return node + + +def from_mxnet(symbol): + """Convert from mxnet.Symbol to compatible nnvm.Symbol + + Parameters + ---------- + symbol : mxnet.Symbol + MXNet symbol + + Returns + ------- + nnvm.Symbol + Compatible nnvm symbol + """ + return _from_mxnet_impl(symbol, {}) diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py b/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py new file mode 100644 index 000000000000..07be9bba02c0 --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import +from . import mlp, resnet, vgg + +_num_class = 1000 + +# mlp fc +mx_mlp = mlp.get_symbol(_num_class) +nnvm_mlp = mlp.get_symbol_nnvm(_num_class) + +# resnet fc +mx_resnet = {} +nnvm_resnet = {} +for num_layer in [18, 34, 50, 101, 152, 200, 269]: + mx_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3,224,224') + nnvm_resnet[num_layer] = resnet.get_symbol(_num_class, num_layer, '3, 224, 224', lib='nnvm') + +# vgg fc +mx_vgg = {} +nnvm_vgg = {} +for num_layer in [11, 13, 16, 19]: + mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer) + nnvm_vgg[num_layer] = vgg.get_symbol_nnvm(_num_class, num_layer) diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/mlp.py b/nnvm/tests/python/frontend/mxnet/model_zoo/mlp.py new file mode 100644 index 000000000000..f69d4a2d1d52 --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/model_zoo/mlp.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +a simple multilayer perceptron +""" +import mxnet as mx +import nnvm + +def get_symbol(num_classes=10, **kwargs): + data = mx.symbol.Variable('data') + data = mx.sym.Flatten(data=data) + fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) + act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") + fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) + act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") + fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) + mlp = mx.symbol.softmax(data = fc3, name = 'softmax') + return mlp + +def get_symbol_nnvm(num_classes=10, **kwargs): + data = nnvm.symbol.Variable('data') + data = nnvm.sym.flatten(data=data) + fc1 = nnvm.symbol.dense(data = data, name='fc1', units=128) + act1 = nnvm.symbol.relu(data = fc1, name='relu1') + fc2 = nnvm.symbol.dense(data = act1, name = 'fc2', units = 64) + act2 = nnvm.symbol.relu(data = fc2, name='relu2') + fc3 = nnvm.symbol.dense(data = act2, name='fc3', units=num_classes) + mlp = nnvm.symbol.softmax(data = fc3, name = 'softmax') + return mlp diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/resnet.py b/nnvm/tests/python/frontend/mxnet/model_zoo/resnet.py new file mode 100644 index 000000000000..2d4a5c9179d1 --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/model_zoo/resnet.py @@ -0,0 +1,322 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +''' +Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py +Original author Wei Wu + +Implemented the following paper: + +Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" +''' +import mxnet as mx +import numpy as np +import nnvm + +def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): + """Return ResNet Unit symbol for building ResNet + Parameters + ---------- + data : str + Input data + num_filter : int + Number of output channels + bnf : int + Bottle neck channels factor with regard to num_filter + stride : tuple + Stride used in convolution + dim_match : Boolean + True means channel number between input and output is the same, otherwise means differ + name : str + Base name of the operators + workspace : int + Workspace used in convolution operator + """ + if bottle_neck: + # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True, + workspace=workspace, name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1') + act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv1') + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + if dim_match: + shortcut = data + else: + shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv2 + shortcut + +def residual_unit_nnvm(data, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): + """Return ResNet Unit symbol for building ResNet + Parameters + ---------- + data : str + Input data + num_filter : int + Number of output channels + bnf : int + Bottle neck channels factor with regard to num_filter + stride : tuple + Stride used in convolution + dim_match : Boolean + True means channel number between input and output is the same, otherwise means differ + name : str + Base name of the operators + workspace : int + Workspace used in convolution operator + """ + if bottle_neck: + # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper + bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1') + act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1') + conv1 = nnvm.sym.conv2d(data=act1, channels=int(num_filter*0.25), kernel_size=(1,1), strides=(1,1), padding=(0,0), + use_bias=False, name=name + '_conv1') + bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2') + act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2') + conv2 = nnvm.sym.conv2d(data=act2, channels=int(num_filter*0.25), kernel_size=(3,3), strides=stride, padding=(1,1), + use_bias=False, name=name + '_conv2') + bn3 = nnvm.sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3') + act3 = nnvm.sym.relu(data=bn3, name=name + '_relu3') + conv3 = nnvm.sym.conv2d(data=act3, channels=num_filter, kernel_size=(1,1), strides=(1,1), padding=(0,0), use_bias=False, + name=name + '_conv3') + if dim_match: + shortcut = data + else: + shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False, + name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return nnvm.sym.elemwise_add(conv3, shortcut) + else: + bn1 = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1') + act1 = nnvm.sym.relu(data=bn1, name=name + '_relu1') + conv1 = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(3,3), strides=stride, padding=(1,1), + use_bias=False, name=name + '_conv1') + bn2 = nnvm.sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2') + act2 = nnvm.sym.relu(data=bn2, name=name + '_relu2') + conv2 = nnvm.sym.conv2d(data=act2, channels=num_filter, kernel_size=(3,3), strides=(1,1), padding=(1,1), + use_bias=False, name=name + '_conv2') + if dim_match: + shortcut = data + else: + shortcut = nnvm.sym.conv2d(data=act1, channels=num_filter, kernel_size=(1,1), strides=stride, use_bias=False, + name=name+'_sc') + if memonger: + shortcut._set_attr(mirror_stage='True') + return nnvm.sym.elemwise_add(conv2, shortcut) + +def resnet(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): + """Return ResNet symbol of + Parameters + ---------- + units : list + Number of units in each stage + num_stages : int + Number of stage + filter_list : list + Channel size of each stage + num_classes : int + Ouput size of symbol + dataset : str + Dataset type, only cifar10 and imagenet supports + workspace : int + Workspace used in convolution operator + dtype : str + Precision (float32 or float16) + """ + num_unit = len(units) + assert(num_unit == num_stages) + data = mx.sym.Variable(name='data') + if dtype == 'float32': + # data = mx.sym.identity(data=data, name='id') + data = data + else: + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data') + (nchannel, height, width) = image_shape + if height <= 32: # such as cifar10 + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(3, 3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) + else: # often expected to be 224 such as imagenet + body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = mx.sym.Activation(data=body, act_type='relu', name='relu0') + body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max') + + for i in range(num_stages): + body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, + memonger=memonger) + for j in range(units[i]-1): + body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), + bottle_neck=bottle_neck, workspace=workspace, memonger=memonger) + bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1') + relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1') + # Although kernel is not used here when global_pool=True, we should put one + pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') + flat = mx.sym.Flatten(data=pool1) + fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') + if dtype == 'float16': + fc1 = mx.sym.Cast(data=fc1, dtype=np.float32) + return mx.sym.softmax(data=fc1, name='softmax') + +def resnet_nnvm(units, num_stages, filter_list, num_classes, image_shape, bottle_neck=True, bn_mom=0.9, workspace=256, dtype='float32', memonger=False): + """Return ResNet symbol of + Parameters + ---------- + units : list + Number of units in each stage + num_stages : int + Number of stage + filter_list : list + Channel size of each stage + num_classes : int + Ouput size of symbol + dataset : str + Dataset type, only cifar10 and imagenet supports + workspace : int + Workspace used in convolution operator + dtype : str + Precision (float32 or float16) + """ + num_unit = len(units) + assert(num_unit == num_stages) + data = nnvm.sym.Variable(name='data') + if dtype == 'float32': + # data = nnvm.sym.identity(data=data, name='id') + data = data + else: + if dtype == 'float16': + data = nnvm.sym.cast(data=data, dtype=np.float16) + data = nnvm.sym.batch_norm(data=data, epsilon=2e-5, name='bn_data') + (nchannel, height, width) = image_shape + if height <= 32: # such as cifar10 + body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(3, 3), strides=(1,1), padding=(1, 1), + use_bias=False, name="conv0") + else: # often expected to be 224 such as imagenet + body = nnvm.sym.conv2d(data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2,2), padding=(3, 3), + use_bias=False, name="conv0") + body = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn0') + body = nnvm.sym.relu(data=body, name='relu0') + body = nnvm.sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2,2), padding=(1,1)) + + for i in range(num_stages): + body = residual_unit_nnvm(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False, + name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, + memonger=memonger) + for j in range(units[i]-1): + body = residual_unit_nnvm(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i + 1, j + 2), + bottle_neck=bottle_neck, memonger=memonger) + bn1 = nnvm.sym.batch_norm(data=body, epsilon=2e-5, name='bn1') + relu1 = nnvm.sym.relu(data=bn1, name='relu1') + # Although kernel is not used here when global_pool=True, we should put one + pool1 = nnvm.sym.global_avg_pool2d(data=relu1, name='pool1') + flat = nnvm.sym.flatten(data=pool1) + fc1 = nnvm.sym.dense(data=flat, units=num_classes, name='fc1') + if dtype == 'float16': + fc1 = nnvm.sym.cast(data=fc1, dtype=np.float32) + return nnvm.sym.softmax(data=fc1, name='softmax') + +def get_symbol(num_classes, num_layers, image_shape, conv_workspace=256, dtype='float32', lib='mxnet', **kwargs): + """ + Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py + Original author Wei Wu + """ + image_shape = [int(l) for l in image_shape.split(',')] + (nchannel, height, width) = image_shape + if height <= 28: + num_stages = 3 + if (num_layers-2) % 9 == 0 and num_layers >= 164: + per_unit = [(num_layers-2)//9] + filter_list = [16, 64, 128, 256] + bottle_neck = True + elif (num_layers-2) % 6 == 0 and num_layers < 164: + per_unit = [(num_layers-2)//6] + filter_list = [16, 16, 32, 64] + bottle_neck = False + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + units = per_unit * num_stages + else: + if num_layers >= 50: + filter_list = [64, 256, 512, 1024, 2048] + bottle_neck = True + else: + filter_list = [64, 64, 128, 256, 512] + bottle_neck = False + num_stages = 4 + if num_layers == 18: + units = [2, 2, 2, 2] + elif num_layers == 34: + units = [3, 4, 6, 3] + elif num_layers == 50: + units = [3, 4, 6, 3] + elif num_layers == 101: + units = [3, 4, 23, 3] + elif num_layers == 152: + units = [3, 8, 36, 3] + elif num_layers == 200: + units = [3, 24, 36, 3] + elif num_layers == 269: + units = [3, 30, 48, 8] + else: + raise ValueError("no experiments done on num_layers {}, you can do it yourself".format(num_layers)) + + return resnet(units = units, + num_stages = num_stages, + filter_list = filter_list, + num_classes = num_classes, + image_shape = image_shape, + bottle_neck = bottle_neck, + workspace = conv_workspace, + dtype = dtype) if lib == 'mxnet' else \ + resnet_nnvm(units = units, + num_stages = num_stages, + filter_list = filter_list, + num_classes = num_classes, + image_shape = image_shape, + bottle_neck = bottle_neck, + workspace = conv_workspace, + dtype = dtype) diff --git a/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py new file mode 100644 index 000000000000..1992aa0ccc1c --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/model_zoo/vgg.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""References: + +Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for +large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014). +""" + +import mxnet as mx +import nnvm +import numpy as np + +def get_feature(internel_layer, layers, filters, batch_norm = False, **kwargs): + for i, num in enumerate(layers): + for j in range(num): + internel_layer = mx.sym.Convolution(data = internel_layer, kernel=(3, 3), pad=(1, 1), num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1)) + if batch_norm: + internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu", name="relu%s_%s" %(i + 1, j + 1)) + internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1)) + return internel_layer + +def get_feature_nnvm(internel_layer, layers, filters, batch_norm = False, **kwargs): + for i, num in enumerate(layers): + for j in range(num): + internel_layer = nnvm.sym.conv2d(data = internel_layer, kernel_size=(3, 3), padding=(1, 1), channels=filters[i], name="conv%s_%s" %(i + 1, j + 1)) + if batch_norm: + internel_layer = nnvm.symbol.batch_norm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1)) + internel_layer = nnvm.sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1)) + internel_layer = nnvm.sym.max_pool2d(data=internel_layer, pool_size=(2, 2), strides=(2,2), name="pool%s" %(i + 1)) + return internel_layer + +def get_classifier(input_data, num_classes, **kwargs): + flatten = mx.sym.Flatten(data=input_data, name="flatten") + fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6") + fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7") + fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") + return fc8 + +def get_classifier_nnvm(input_data, num_classes, **kwargs): + flatten = nnvm.sym.flatten(data=input_data, name="flatten") + fc6 = nnvm.sym.dense(data=flatten, units=4096, name="fc6") + relu6 = nnvm.sym.relu(data=fc6, name="relu6") + drop6 = nnvm.sym.dropout(data=relu6, rate=0.5, name="drop6") + fc7 = nnvm.sym.dense(data=drop6, units=4096, name="fc7") + relu7 = nnvm.sym.relu(data=fc7, name="relu7") + drop7 = nnvm.sym.dropout(data=relu7, rate=0.5, name="drop7") + fc8 = nnvm.sym.dense(data=drop7, units=num_classes, name="fc8") + return fc8 + +def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): + """ + Parameters + ---------- + num_classes : int, default 1000 + Number of classification classes. + num_layers : int + Number of layers for the variant of densenet. Options are 11, 13, 16, 19. + batch_norm : bool, default False + Use batch normalization. + dtype: str, float32 or float16 + Data precision. + """ + vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), + 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), + 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), + 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} + if not vgg_spec.has_key(num_layers): + raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)) + layers, filters = vgg_spec[num_layers] + data = mx.sym.Variable(name="data") + if dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + feature = get_feature(data, layers, filters, batch_norm) + classifier = get_classifier(feature, num_classes) + if dtype == 'float16': + classifier = mx.sym.Cast(data=classifier, dtype=np.float32) + symbol = mx.sym.softmax(data=classifier, name='softmax') + return symbol + +def get_symbol_nnvm(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs): + """ + Parameters + ---------- + num_classes : int, default 1000 + Number of classification classes. + num_layers : int + Number of layers for the variant of densenet. Options are 11, 13, 16, 19. + batch_norm : bool, default False + Use batch normalization. + dtype: str, float32 or float16 + Data precision. + """ + vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]), + 13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]), + 16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]), + 19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])} + if not vgg_spec.has_key(num_layers): + raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)) + layers, filters = vgg_spec[num_layers] + data = nnvm.sym.Variable(name="data") + if dtype == 'float16': + data = nnvm.sym.cast(data=data, dtype=np.float16) + feature = get_feature_nnvm(data, layers, filters, batch_norm) + classifier = get_classifier_nnvm(feature, num_classes) + if dtype == 'float16': + classifier = nnvm.sym.cast(data=classifier, dtype=np.float32) + symbol = nnvm.sym.softmax(data=classifier, name='softmax') + return symbol diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py new file mode 100644 index 000000000000..8fcc106e4db2 --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -0,0 +1,88 @@ +import numpy as np + +import topi +import tvm +import nnvm.symbol as sym +import nnvm.compiler +import nnvm.runtime +from nnvm import frontend +import mxnet as mx +import model_zoo + +USE_GPU=True + +def default_target(): + if USE_GPU: + return 'cuda' + else: + return 'llvm' + +def default_ctx(): + if USE_GPU: + return tvm.gpu(0) + else: + return tvm.cpu(0) + +def test_mxnet_frontend_impl(mx_symbol, data_shape=(2, 3, 224, 224), out_shape=(2, 1000)): + def get_mxnet_output(symbol, x, dtype='float32'): + from collections import namedtuple + Batch = namedtuple('Batch', ['data']) + mod = mx.mod.Module(symbol, label_names=None) + mod.bind(data_shapes=[('data', x.shape)], for_training=False) + mod.init_params() + mod.forward(Batch([mx.nd.array(x.astype(dtype))])) + out = mod.get_outputs()[0].asnumpy() + args, auxs = mod.get_params() + return out, args, auxs + + def get_tvm_output(symbol, x, args, auxs, dtype='float32'): + dshape = x.shape + shape_dict = {'data': dshape} + for k, v in args.items(): + shape_dict[k] = v.shape + for k, v in auxs.items(): + shape_dict[k] = v.shape + graph, lib, _ = nnvm.compiler.build(symbol, default_target(), shape_dict) + m = nnvm.runtime.create(graph, lib, default_ctx()) + # get member functions + set_input, run, get_output = m['set_input'], m['run'], m['get_output'] + # set inputs + set_input('data', tvm.nd.array(x.astype(dtype))) + for k, v in args.items(): + set_input(k, tvm.nd.array(v.asnumpy().astype(dtype))) + for k, v in auxs.items(): + set_input(k, tvm.nd.array(v.asnumpy().astype(dtype))) + # execute + run() + # get outputs + out = tvm.nd.empty(out_shape, dtype) + get_output(0, out) + return out.asnumpy() + + # random input + dtype = 'float32' + x = np.random.uniform(size=data_shape) + mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype) + new_sym = frontend.from_mxnet(mx_symbol) + tvm_out = get_tvm_output(new_sym, x, args, auxs, dtype) + np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5) + +def test_forward_mlp(): + mlp = model_zoo.mx_mlp + test_mxnet_frontend_impl(mlp) + +def test_forward_vgg(): + for n in [11]: + mx_sym = model_zoo.mx_vgg[n] + test_mxnet_frontend_impl(mx_sym) + +def test_forward_resnet(): + for n in [18]: + mx_sym = model_zoo.mx_resnet[n] + test_mxnet_frontend_impl(mx_sym) + +if __name__ == '__main__': + test_forward_mlp() + # waiting for max_pool2d + # test_forward_vgg() + # test_forward_resnet() diff --git a/nnvm/tests/python/frontend/mxnet/test_graph.py b/nnvm/tests/python/frontend/mxnet/test_graph.py new file mode 100644 index 000000000000..7701aa86ae58 --- /dev/null +++ b/nnvm/tests/python/frontend/mxnet/test_graph.py @@ -0,0 +1,38 @@ +import mxnet as mx +import nnvm +from nnvm.compiler import graph_util, graph_attr +import model_zoo + +def compare_graph(sym1, sym2, ishape=(2, 3, 224, 224)): + g1 = nnvm.graph.create(sym1) + g2 = nnvm.graph.create(sym2) + graph_attr.set_shape_inputs(g1, {'data':ishape}) + graph_attr.set_shape_inputs(g2, {'data':ishape}) + g1 = g1.apply("InferShape").apply("SimplifyInference") + g2 = g2.apply("InferShape").apply("SimplifyInference") + graph_util.check_graph_equal(g1, g2) + +def test_mlp(): + mx_sym = model_zoo.mx_mlp + from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) + nnvm_sym = model_zoo.nnvm_mlp + compare_graph(from_mx_sym, nnvm_sym) + +def test_vgg(): + for n in [11, 13, 16, 19]: + mx_sym = model_zoo.mx_vgg[n] + from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) + nnvm_sym = model_zoo.nnvm_vgg[n] + compare_graph(from_mx_sym, nnvm_sym) + +def test_resnet(): + for n in [18, 34, 50, 101]: + mx_sym = model_zoo.mx_resnet[n] + from_mx_sym = nnvm.frontend.from_mxnet(mx_sym) + nnvm_sym = model_zoo.nnvm_resnet[n] + compare_graph(from_mx_sym, nnvm_sym) + +if __name__ == '__main__': + test_mlp() + test_vgg() + test_resnet()