Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Add ops in mxnet converter #2844

Merged
merged 2 commits into from
Mar 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _mx_slice_axis(inputs, attrs):
ax_end = attrs.get_str("end")
if axis < 0:
axis += len(shape)
assert axis >= 0 and axis < len(shape)
assert 0 <= axis < len(shape)
if ax_end == "None":
ax_end = int(shape[axis])
else:
Expand All @@ -222,8 +222,8 @@ def _mx_slice_axis(inputs, attrs):
ax_beg += int(shape[axis])
if ax_end < 0:
ax_end += int(shape[axis])
assert ax_beg >= 0 and ax_beg < int(shape[axis])
assert ax_end > ax_beg and ax_end <= int(shape[axis])
assert 0 <= ax_beg < int(shape[axis])
assert ax_beg < ax_end <= int(shape[axis])
begin = []
end = []
for i, dim in enumerate(shape):
Expand Down Expand Up @@ -516,11 +516,53 @@ def _mx_shape_array(inputs, attrs):
return _op.shape_of(inputs[0], dtype='int64')


def _mx_full(inputs, attrs):
assert len(inputs) == 0
val = attrs.get_float("value")
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
return _op.full(_expr.const(val, dtype), shape, dtype)


def _mx_squeeze(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", None)
return _op.squeeze(inputs[0], axis)


def _mx_broadcast_axis(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could axis be negative here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM, axis should have only non-negative values.

size = attrs.get_int_tuple("size", [])
assert len(axis) == len(size)
if len(axis) == 0:
return inputs[0]
src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape
tgt_shape = []
for i, dim in enumerate(src_shape):
if i not in axis:
tgt_shape.append(dim)
else:
assert int(dim) == 1
idx = axis.index(i)
tgt_shape.append(size[idx])
return _op.broadcast_to(inputs[0], tgt_shape)


def _mx_embedding(inputs, _):
assert len(inputs) == 2
indices, weight = inputs
return _op.take(weight, indices.astype('int32'), axis=0)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
"log",
"exp",
"sqrt",
"floor",
"ceil",
"sigmoid",
"tanh",
"negative",
Expand Down Expand Up @@ -556,7 +598,6 @@ def _mx_shape_array(inputs, attrs):
"Flatten" : _rename(_op.nn.batch_flatten),
# scalar power
"square" : _mx_make_power(2),
"sqrt" : _mx_make_power(1/2),
"rsqrt" : _mx_make_power(-1/2),
"cbrt" : _mx_make_power(1/3),
"rcbrt" : _mx_make_power(-1/3),
Expand Down Expand Up @@ -638,11 +679,15 @@ def _mx_shape_array(inputs, attrs):
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
Expand Down
74 changes: 73 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ def test_forward_l2_normalize():
mx_sym = mx.sym.L2Normalization(data, mode="channel")
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))


def test_forward_shape_array():
def verify(shape):
x_np = np.random.uniform(size=shape).astype("float32")
Expand All @@ -395,6 +394,75 @@ def verify(shape):
verify((3, 4, 5))
verify((3, 4, 5, 6))

def test_forward_squeeze():
def verify(shape, axis):
x_np = np.random.uniform(size=shape).astype("float32")
if axis is None:
ref_res = mx.nd.squeeze(mx.nd.array(x_np))
mx_sym = mx.sym.squeeze(mx.sym.var("x"))
else:
ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis)
mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 3, 1), None)
verify((1, 3, 1), 0)
verify((1, 3, 1), 2)
verify((1, 3, 1), (0, 2))

def test_forward_broadcast_axis():
def verify(shape, axis, size):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))

def test_forward_full():
def verify(val, shape, dtype):
ctx = mx.cpu()
ref_res = mx.nd.full(shape, val, dtype=dtype)
mx_sym = mx.sym.full(shape, val, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
for target, ctx in ctx_list():
# Skip testing graph runtime because this op will be optimized out
# by constant folding.
for kind in ["debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify(2, (3, 4), "float32")
verify(2, (3, 4), "int32")
verify(3.5, (1, 3, 4), "float32")

def test_forward_embedding():
def verify(data_shape, weight_shape):
in_dim, out_dim = weight_shape
x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
w_np = np.random.uniform(size=weight_shape).astype("float32")
ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
input_dim=in_dim, output_dim=out_dim)
mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
input_dim=in_dim, output_dim=out_dim)
new_sym, _ = relay.frontend.from_mxnet(
mx_sym, {"x": data_shape, "w": weight_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))

if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -426,3 +494,7 @@ def verify(shape):
test_forward_slice_axis()
test_forward_l2_normalize()
test_forward_shape_array()
test_forward_squeeze()
test_forward_broadcast_axis()
test_forward_full()
test_forward_embedding()