Skip to content

Commit

Permalink
support MXNet _minimum and _maximum (#2709)
Browse files Browse the repository at this point in the history
haojin2 authored and icemelon committed Mar 1, 2019

Verified

This commit was signed with the committer’s verified signature.
mikutas Takumi Sue
1 parent c8259e3 commit 8f5c27b
Showing 2 changed files with 72 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
@@ -286,6 +286,12 @@ def _lrn(inputs, attrs):
new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _minimum(inputs, attrs):
return _get_nnvm_op('broadcast_min')(*inputs, **attrs)

def _maximum(inputs, attrs):
return _get_nnvm_op('broadcast_max')(*inputs, **attrs)

def _ones(_, attrs):
op_name = 'ones'
return _get_nnvm_op(op_name)(**attrs)
@@ -330,6 +336,8 @@ def _argmin(inputs, attrs):
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_minimum' : _minimum,
'_maximum' : _maximum,
'_ones' : _ones,
'_zeros' : _zeros,
'argmax' : _argmax,
64 changes: 64 additions & 0 deletions nnvm/tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
@@ -227,6 +227,68 @@ def test_forward_slice():
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))

def test_forward_maximum():
a = mx.sym.var('a')
b = mx.sym.var('b')
dshape = (10, 20)
dtype = 'float32'
mx_sym = mx.sym._internal._maximum(a, b)
np_a = np.random.uniform(size=dshape).astype(dtype)
np_b = np.random.uniform(size=dshape).astype(dtype)
mx_a = mx.nd.array(np_a)
mx_b = mx.nd.array(np_b)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['a', 'b'])
mod.bind(data_shapes=[('a', dshape), ('b', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd._internal._maximum(mx_a, mx_b).asnumpy()
out_shape = dshape
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
shape_dict = {'a': dshape, 'b': dshape}
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("a", tvm.nd.array(np_a))
m.set_input("b", tvm.nd.array(np_b))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)

def test_forward_minimum():
a = mx.sym.var('a')
b = mx.sym.var('b')
dshape = (10, 20)
dtype = 'float32'
mx_sym = mx.sym._internal._minimum(a, b)
np_a = np.random.uniform(size=dshape).astype(dtype)
np_b = np.random.uniform(size=dshape).astype(dtype)
mx_a = mx.nd.array(np_a)
mx_b = mx.nd.array(np_b)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['a', 'b'])
mod.bind(data_shapes=[('a', dshape), ('b', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd._internal._minimum(mx_a, mx_b).asnumpy()
out_shape = dshape
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
shape_dict = {'a': dshape, 'b': dshape}
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("a", tvm.nd.array(np_a))
m.set_input("b", tvm.nd.array(np_b))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)


if __name__ == '__main__':
test_forward_mlp()
@@ -251,4 +313,6 @@ def test_forward_slice():
test_forward_argmin()
test_forward_where()
test_forward_slice()
test_forward_maximum()
test_forward_minimum()

0 comments on commit 8f5c27b

Please sign in to comment.