Skip to content

Commit

Permalink
[FRONTEND] Composed operators (apache#175)
Browse files Browse the repository at this point in the history
* fix for composed symbol

* fix

* clean up

* fix exception type
  • Loading branch information
zhreshold authored and tqchen committed May 26, 2018
1 parent c4fe8c5 commit 2cdc0d4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 59 deletions.
133 changes: 75 additions & 58 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

__all__ = ['from_mxnet']

def _get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op

def _get_mxnet_version():
try:
import mxnet as mx
Expand Down Expand Up @@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']

def _rename(new_name):
def impl(attr):
return new_name, attr
def impl(inputs, attrs):
return _get_nnvm_op(new_name)(*inputs, **attrs)
return impl

def _variable(attrs):
return "Variable", attrs

def _pooling(attrs):
def _pooling(inputs, attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non-2d kernel', 'pool_2d')
Expand All @@ -61,9 +64,9 @@ def _pooling(attrs):
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _batch_norm(attrs):
def _batch_norm(inputs, attrs):
if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm')
# if _parse_bool_str(attrs, 'fix_gamma'):
Expand All @@ -77,14 +80,14 @@ def _batch_norm(attrs):
new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True
new_attrs['scale'] = True
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _concat(attrs):
def _concat(inputs, attrs):
op_name = 'concatenate'
new_attrs = {'axis': attrs.get('dim', 1)}
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _conv2d(attrs):
def _conv2d(inputs, attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non 2d kernel', 'conv2d')
Expand All @@ -100,9 +103,9 @@ def _conv2d(attrs):
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _conv2d_transpose(attrs):
def _conv2d_transpose(inputs, attrs):
if 'target_shape' in attrs:
_raise_not_supported('target_shape', 'conv2d_transpose')
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
Expand All @@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _dense(attrs):
def _dense(inputs, attrs):
op_name, new_attrs = 'dense', {}
new_attrs['units'] = _required_attr(attrs, 'num_hidden')
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
major, minor, micro = _get_mxnet_version()
if major >= 0 and minor >= 11 and micro >= 1:
new_attrs['flatten'] = _parse_bool_str(attrs, 'flatten', 'True')
return op_name, new_attrs
use_flatten = _parse_bool_str(attrs, 'flatten', 'True')
if use_flatten:
inputs[0] = _sym.flatten(inputs[0])
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _dropout(attrs):
def _dropout(inputs, attrs):
op_name, new_attrs = 'dropout', {}
new_attrs['rate'] = attrs.get('p', 0.5)
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _leaky_relu(attrs):
def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['leaky']:
if act_type in ['leaky']:
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu':
slope = attrs.get('slope', 0.25)
sym = -slope * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs)
elif act_type == 'rrelu':
lower_bound = float(_required_attr(attrs, 'lower_bound'))
upper_bound = float(_required_attr(attrs, 'upper_bound'))
slope = (lower_bound + upper_bound) / 2.0
op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
else:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
return op_name, new_attrs
return sym

def _activations(attrs):
def _activations(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['relu', 'sigmoid', 'tanh']:
if act_type in ['relu', 'sigmoid', 'tanh']:
op_name, new_attrs = act_type, {}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'softrelu':
sym = _sym.log((1 + _sym.exp(*inputs)))
else:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = act_type, {}
return op_name, new_attrs
return sym

def _reshape(attrs):
def _reshape(inputs, attrs):
if _parse_bool_str(attrs, 'reverse'):
_raise_not_supported('reverse', 'reshape')
op_name, new_attrs = 'reshape', {}
new_attrs['shape'] = _required_attr(attrs, 'shape')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _split(attrs):
def _split(inputs, attrs):
if _parse_bool_str(attrs, 'squeeze_axis'):
_raise_not_supported('squeeze_axis', 'split')
op_name, new_attrs = 'split', {}
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
new_attrs['axis'] = attrs.get('axis', 1)
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
Expand All @@ -178,7 +198,12 @@ def _split(attrs):
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']

_convert_map = {
'null' : _variable,
'_div_scalar' : _rename('__div_scalar__'),
'_minus_scalar' : _rename('__sub_scalar__'),
'_mul_scalar' : _rename('__mul_scalar__'),
'_plus_scalar' : _rename('__add_scalar__'),
'_rdiv_scalar' : _rename('__rdiv_scalar__'),
'_rminus_scalar': _rename('__rsub_scalar__'),
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
Expand All @@ -202,7 +227,7 @@ def _split(attrs):
'sum_axis' : _rename('sum'),
}

def _convert_symbol(op_name, attrs,
def _convert_symbol(op_name, inputs, attrs,
identity_list=None,
convert_map=None):
"""Convert from mxnet op to nnvm op.
Expand All @@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
Expand All @@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
Returns
-------
(op_name, attrs)
Converted (op_name, attrs) for nnvm.
sym : nnvm.Symbol
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
pass
op = _get_nnvm_op(op_name)
sym = op(*inputs, **attrs)
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
sym = convert_map[op_name](inputs, attrs)
else:
_raise_not_supported('Operator: ' + op_name)
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op, attrs
return sym

def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol."""
Expand Down Expand Up @@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
node = graph.get(name, None)
if node:
return node
attr = symbol.list_attr()
# op_name = symbol.attr('op_name')
if symbol.get_children():
childs = symbol.get_children()
if childs:
op_name = symbol.attr('op_name')
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
attr = symbol.list_attr()
new_op, new_attr = _convert_symbol(op_name, attr)
if new_op == _sym.Variable:
node = new_op(name=name, **new_attr)
else:
childs = symbol.get_children()
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
if new_op == _sym.dense and 'flatten' in new_attr:
if new_attr['flatten']:
childs[0] = _sym.flatten(childs[0])
new_attr.pop('flatten')
node = new_op(name=name, *childs, **new_attr)
node = _convert_symbol(op_name, childs, attr)
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node


def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
Expand Down
35 changes: 34 additions & 1 deletion nnvm/tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)

def test_forward_mlp():
mlp = model_zoo.mx_mlp
Expand All @@ -62,7 +62,40 @@ def test_forward_resnet():
mx_sym = model_zoo.mx_resnet[n]
verify_mxnet_frontend_impl(mx_sym)

def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_softrelu()
test_forward_fc_flatten()

0 comments on commit 2cdc0d4

Please sign in to comment.