Skip to content

Commit

Permalink
[Relay][Frontend] Add a few mxnet ops in relay frontend (apache#2704)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and wweic committed Mar 12, 2019
1 parent 215aedb commit 76e83df
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 26 deletions.
79 changes: 53 additions & 26 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def _stable_softrelu(x):
raise RuntimeError("Do not support act_type: {}".format(act_type))


def _mx_compare(new_op, wrapper):
def impl(inputs, attrs):
dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
return wrapper(new_op)(inputs, attrs).astype(dtype)
return impl


def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
Expand Down Expand Up @@ -333,32 +340,52 @@ def _mx_roi_align(inputs, attrs):
]

_convert_map = {
"_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide),
"elemwise_add" : _rename(_op.add),
"elemwise_sub" : _rename(_op.subtract),
"elemwise_mul" : _rename(_op.multiply),
"elemwise_div" : _rename(_op.divide),
"flatten" : _rename(_op.nn.batch_flatten),
"Flatten" : _rename(_op.nn.batch_flatten),
"_plus_scalar" : _binop_scalar(_op.add),
"__add_scalar__": _binop_scalar(_op.add),
"__sub_scalar__": _binop_scalar(_op.subtract),
"_minus_scalar" : _binop_scalar(_op.subtract),
"__mul_scalar__": _binop_scalar(_op.multiply),
"_mul_scalar" : _binop_scalar(_op.multiply),
"__div_scalar__": _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__": _binop_scalar(_op.power),
"_rminus_scalar": _rbinop_scalar(_op.subtract),
"__rsub_scalar__": _rbinop_scalar(_op.subtract),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"__rpow_scalar__": _rbinop_scalar(_op.power),
"_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide),
"broadcast_mod" : _rename(_op.mod),
"broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum),
"broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
"broadcast_greater" : _mx_compare(_op.greater, _rename),
"broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
"broadcast_lesser" : _mx_compare(_op.less, _rename),
"broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
"elemwise_add" : _rename(_op.add),
"elemwise_sub" : _rename(_op.subtract),
"elemwise_mul" : _rename(_op.multiply),
"elemwise_div" : _rename(_op.divide),
"_maximum" : _rename(_op.maximum),
"_minimum" : _rename(_op.minimum),
"flatten" : _rename(_op.nn.batch_flatten),
"Flatten" : _rename(_op.nn.batch_flatten),
"__add_scalar__" : _binop_scalar(_op.add),
"_plus_scalar" : _binop_scalar(_op.add),
"__sub_scalar__" : _binop_scalar(_op.subtract),
"_minus_scalar" : _binop_scalar(_op.subtract),
"__mul_scalar__" : _binop_scalar(_op.multiply),
"_mul_scalar" : _binop_scalar(_op.multiply),
"__div_scalar__" : _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__" : _binop_scalar(_op.power),
"_power_scalar" : _binop_scalar(_op.power),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
"_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
"_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
"_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
"_greater_equal_scalar" : _mx_compare(_op.greater_equal, _binop_scalar),
"_lesser_scalar" : _mx_compare(_op.less, _binop_scalar),
"_lesser_equal_scalar" : _mx_compare(_op.less_equal, _binop_scalar),
"_maximum_scalar" : _binop_scalar(_op.maximum),
"_minimum_scalar" : _binop_scalar(_op.minimum),
# reduction ops
"max" : _reduce(_op.max),
"min" : _reduce(_op.min),
Expand Down
83 changes: 83 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import operator

import tvm
from tvm.contrib import graph_runtime
Expand Down Expand Up @@ -256,6 +257,85 @@ def verify(start, stop, step):
verify(20, 1, -1)
verify(20, 1, -1.5)

def _mx_symbol(F, op_name, inputs):
op = getattr(F, op_name)
return op(*inputs)

def test_forward_broadcast_ops():
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
"broadcast_div", "broadcast_mod", "broadcast_maximum",
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
"broadcast_greater", "broadcast_greater_equal",
"broadcast_lesser", "broadcast_lesser_equal"]:
a_shape = (3, 4, 5)
b_shape = (4, 5)
if op == "broadcast_mod":
dtype = 'int32'
a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
else:
dtype = 'float32'
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
shapes = {'a': a_shape, 'b': b_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

def test_forward_elemwise_ops():
for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
"elemwise_div", "maximum", "minimum"]:
shape = (3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = np.random.uniform(size=shape).astype(dtype)
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
shapes = {'a': shape, 'b': shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

def test_forward_scalar_ops():
for op in [operator.add, operator.sub, operator.mul, operator.truediv,
operator.pow, operator.lt, operator.le, operator.eq,
operator.ne, operator.gt, operator.ge]:
dtype='float32'
a_shape = (3, 4, 5)
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_scalar = 2.3
mx_sym = op(mx.sym.var('a'), b_scalar)
ref_res = op(mx.nd.array(a_np), b_scalar)
shapes = {'a': a_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
for op in ["maximum", "minimum"]:
dtype='float32'
a_shape = (3, 4, 5)
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_scalar = 2.3
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
shapes = {'a': a_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


if __name__ == '__main__':
test_forward_mlp()
Expand All @@ -280,3 +360,6 @@ def verify(start, stop, step):
test_forward_argmin()
test_forward_where()
test_forward_arange()
test_forward_broadcast_ops()
test_forward_elemwise_ops()
test_forward_scalar_ops()

0 comments on commit 76e83df

Please sign in to comment.