Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][MXNet] Support a few contrib ops in mxnet #5819

Merged
merged 4 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return, too-many-lines
"""MXNet symbol frontend."""
import json
import math
import numpy as np
import tvm
from tvm.ir import IRModule
Expand Down Expand Up @@ -655,6 +656,15 @@ def _mx_leaky_relu(inputs, attrs):
upper_bound = attrs.get_float("upper_bound")
alpha = (lower_bound + upper_bound) / 2.0
return _op.nn.leaky_relu(inputs[0], alpha=alpha)
if act_type == "gelu":
# 0.5 * x * (1 + erf(x / sqrt(2)))
sqrt2 = _expr.const(math.sqrt(2), dtype="float32")
erf = _op.erf(_op.divide(inputs[0], sqrt2))
one = _expr.const(1, dtype="float32")
erf_plus_one = _op.add(one, erf)
half = _expr.const(0.5, dtype="float32")
half_x = _op.multiply(inputs[0], half)
return _op.multiply(half_x, erf_plus_one)
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))

Expand Down Expand Up @@ -784,6 +794,42 @@ def _mx_make_loss(inputs, attrs):
return inputs[0]


def _mx_contrib_arange_like(inputs, attrs):
assert len(inputs) == 1
if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange_like.')
ty = _infer_type(inputs[0]).checked_type
assert ty
shape, dtype = get_const_tuple(ty.shape), ty.dtype
axis = attrs.get_int("axis", None)
if axis is None:
n_elems = 1
for dim in shape:
if not isinstance(dim, int):
raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
n_elems *= dim
else:
axis = axis + len(shape) if axis < 0 else axis
assert 0 <= axis < len(shape)
n_elems = shape[axis]
if not isinstance(n_elems, int):
raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
shape = (n_elems,)
start = attrs.get_float("start", 0.)
step = attrs.get_float("step", 1.)
stop = start + step * n_elems
new_attrs = {}
new_attrs["start"] = _expr.const(start, dtype=dtype)
new_attrs["stop"] = _expr.const(stop, dtype=dtype)
new_attrs["step"] = _expr.const(step, dtype=dtype)
new_attrs["dtype"] = dtype
ret = _op.arange(**new_attrs)
if len(shape) > 1:
ret = _op.reshape(ret, shape)
return ret


def _mx_repeat(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
Expand Down Expand Up @@ -1263,6 +1309,56 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
return _op.nn.fifo_buffer(*inputs, **new_attrs)


def _mx_contrib_interleaved_matmul_selfatt_qk(inputs, attrs):
"""
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshape(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
"""
assert len(inputs) == 1
qkv = inputs[0]
num_heads = attrs.get_int('heads')
qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
q_proj = _op.take(qkv, _expr.const(0, "int32"), axis=3)
q_proj = _op.transpose(q_proj, axes=[1, 2, 0, 3])
q_proj = _op.reverse_reshape(q_proj, newshape=(-1, 0, 0))
q_proj = _mx_contrib_div_sqrt_dim([q_proj], None)
k_proj = _op.take(qkv, _expr.const(1, "int32"), axis=3)
k_proj = _op.transpose(k_proj, axes=[1, 2, 0, 3])
k_proj = _op.reverse_reshape(k_proj, newshape=(-1, 0, 0))
ret = _op.nn.batch_matmul(q_proj, k_proj)
return ret


def _mx_contrib_interleaved_matmul_selfatt_valatt(inputs, attrs):
"""
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(2, 0, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
"""
assert len(inputs) == 2
qkv, att = inputs
num_heads = attrs.get_int("heads")
qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
v_proj = _op.take(qkv, _expr.const(2, "int32"), axis=3)
v_proj = _op.transpose(v_proj, axes=(1, 2, 0, 3))
v_proj = _op.reverse_reshape(v_proj, newshape=(-1, 0, 0))
v_proj = _op.transpose(v_proj, axes=[0, 2, 1])
out = _op.nn.batch_matmul(att, v_proj)
out = _op.reverse_reshape(out, newshape=(-1, num_heads, 0, 0))
out = _op.transpose(out, axes=(2, 0, 1, 3))
out = _op.reshape(out, newshape=(0, 0, -1))
return out


def _mx_cond(inputs, attrs, subgraphs):
assert len(subgraphs) == 3
cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
Expand Down Expand Up @@ -2094,6 +2190,7 @@ def impl(inputs, input_types):
"smooth_l1" : _mx_smooth_l1,
"make_loss" : _mx_make_loss,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"_contrib_arange_like": _mx_contrib_arange_like,
"one_hot" : _mx_one_hot,
"depth_to_space" : _mx_depth_to_space,
"space_to_depth" : _mx_space_to_depth,
Expand All @@ -2114,6 +2211,8 @@ def impl(inputs, input_types):
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
"_contrib_interleaved_matmul_selfatt_qk" : _mx_contrib_interleaved_matmul_selfatt_qk,
"_contrib_interleaved_matmul_selfatt_valatt" : _mx_contrib_interleaved_matmul_selfatt_valatt,
# control flow
"_cond" : _mx_cond,
# Depricated:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def _impl(inputs, attrs):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
# TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
# assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
return _impl
Expand Down
82 changes: 82 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_forward_prelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_gelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='gelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
Expand Down Expand Up @@ -1195,6 +1201,78 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False)


def test_forward_arange_like():
def verify(data_shape, start=None, step=None, axis=None):
attrs = {}
if start is not None:
attrs['start'] = start
if step is not None:
attrs['step'] = step
if axis is not None:
attrs['axis'] = axis
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype("float32")
ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)

mx_sym = mx.sym.contrib.arange_like(data, **attrs)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify(data_shape=(3,), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
verify(data_shape=(3, 4, 5), start=2., step=3., axis=1)


def test_forward_interleaved_matmul_selfatt_qk():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(
mx.nd.array(data_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(1, 10, 3, 16)
verify(3, 10, 6, 8)


def test_forward_interleaved_matmul_selfatt_valatt():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
weight_shape = (batch * num_heads, seq_length, seq_length)
data = mx.sym.var('data')
weight = mx.sym.var('weight')
data_np = np.random.uniform(size=data_shape).astype('float32')
weight_np = np.random.uniform(size=weight_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_valatt(
mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
data, weight, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(
mx_sym, {"data": data_shape, "weight": weight_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data=data_np, weight=weight_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(1, 10, 4, 16)
verify(3, 10, 6, 8)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand All @@ -1203,6 +1281,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_gelu()
test_forward_softrelu()
test_forward_softmin()
test_forward_fc_flatten()
Expand Down Expand Up @@ -1263,3 +1342,6 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_correlation()
test_forward_grid_generator()
test_forward_bilinear_sampler()
test_forward_arange_like()
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()