Skip to content

Commit

Permalink
[Frontend][MXNet] Support a few contrib ops in mxnet (apache#5819)
Browse files Browse the repository at this point in the history
* support for bert in mxnet1.6 and gluonnlp0.9

* fix converter

* Add test cases

* add a todo
  • Loading branch information
icemelon authored and Trevor Morris committed Jun 30, 2020
1 parent 8541c3e commit d2b9f92
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 1 deletion.
99 changes: 99 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, import-self, len-as-condition, no-else-return, too-many-lines
"""MXNet symbol frontend."""
import json
import math
import numpy as np
import tvm
from tvm.ir import IRModule
Expand Down Expand Up @@ -655,6 +656,15 @@ def _mx_leaky_relu(inputs, attrs):
upper_bound = attrs.get_float("upper_bound")
alpha = (lower_bound + upper_bound) / 2.0
return _op.nn.leaky_relu(inputs[0], alpha=alpha)
if act_type == "gelu":
# 0.5 * x * (1 + erf(x / sqrt(2)))
sqrt2 = _expr.const(math.sqrt(2), dtype="float32")
erf = _op.erf(_op.divide(inputs[0], sqrt2))
one = _expr.const(1, dtype="float32")
erf_plus_one = _op.add(one, erf)
half = _expr.const(0.5, dtype="float32")
half_x = _op.multiply(inputs[0], half)
return _op.multiply(half_x, erf_plus_one)
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))

Expand Down Expand Up @@ -784,6 +794,42 @@ def _mx_make_loss(inputs, attrs):
return inputs[0]


def _mx_contrib_arange_like(inputs, attrs):
assert len(inputs) == 1
if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange_like.')
ty = _infer_type(inputs[0]).checked_type
assert ty
shape, dtype = get_const_tuple(ty.shape), ty.dtype
axis = attrs.get_int("axis", None)
if axis is None:
n_elems = 1
for dim in shape:
if not isinstance(dim, int):
raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
n_elems *= dim
else:
axis = axis + len(shape) if axis < 0 else axis
assert 0 <= axis < len(shape)
n_elems = shape[axis]
if not isinstance(n_elems, int):
raise tvm.error.OpError("Don't support arange_like with symbolic input shape.")
shape = (n_elems,)
start = attrs.get_float("start", 0.)
step = attrs.get_float("step", 1.)
stop = start + step * n_elems
new_attrs = {}
new_attrs["start"] = _expr.const(start, dtype=dtype)
new_attrs["stop"] = _expr.const(stop, dtype=dtype)
new_attrs["step"] = _expr.const(step, dtype=dtype)
new_attrs["dtype"] = dtype
ret = _op.arange(**new_attrs)
if len(shape) > 1:
ret = _op.reshape(ret, shape)
return ret


def _mx_repeat(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
Expand Down Expand Up @@ -1278,6 +1324,56 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
return _op.nn.fifo_buffer(*inputs, **new_attrs)


def _mx_contrib_interleaved_matmul_selfatt_qk(inputs, attrs):
"""
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3))
q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True)
q_proj = mx.nd.contrib.div_sqrt_dim(q_proj)
k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3))
k_proj = mx.nd.reshape(k_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True)
"""
assert len(inputs) == 1
qkv = inputs[0]
num_heads = attrs.get_int('heads')
qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
q_proj = _op.take(qkv, _expr.const(0, "int32"), axis=3)
q_proj = _op.transpose(q_proj, axes=[1, 2, 0, 3])
q_proj = _op.reverse_reshape(q_proj, newshape=(-1, 0, 0))
q_proj = _mx_contrib_div_sqrt_dim([q_proj], None)
k_proj = _op.take(qkv, _expr.const(1, "int32"), axis=3)
k_proj = _op.transpose(k_proj, axes=[1, 2, 0, 3])
k_proj = _op.reverse_reshape(k_proj, newshape=(-1, 0, 0))
ret = _op.nn.batch_matmul(q_proj, k_proj)
return ret


def _mx_contrib_interleaved_matmul_selfatt_valatt(inputs, attrs):
"""
tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1))
v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3))
v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True)
output = mx.nd.batch_dot(attention, v_proj)
output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True)
output = mx.nd.transpose(output, axes=(2, 0, 1, 3))
output = mx.nd.reshape(output, shape=(0, 0, -1))
"""
assert len(inputs) == 2
qkv, att = inputs
num_heads = attrs.get_int("heads")
qkv = _op.reshape(qkv, newshape=(0, 0, num_heads, 3, -1))
v_proj = _op.take(qkv, _expr.const(2, "int32"), axis=3)
v_proj = _op.transpose(v_proj, axes=(1, 2, 0, 3))
v_proj = _op.reverse_reshape(v_proj, newshape=(-1, 0, 0))
v_proj = _op.transpose(v_proj, axes=[0, 2, 1])
out = _op.nn.batch_matmul(att, v_proj)
out = _op.reverse_reshape(out, newshape=(-1, num_heads, 0, 0))
out = _op.transpose(out, axes=(2, 0, 1, 3))
out = _op.reshape(out, newshape=(0, 0, -1))
return out


def _mx_cond(inputs, attrs, subgraphs):
assert len(subgraphs) == 3
cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
Expand Down Expand Up @@ -2110,6 +2206,7 @@ def impl(inputs, input_types):
"smooth_l1" : _mx_smooth_l1,
"make_loss" : _mx_make_loss,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"_contrib_arange_like": _mx_contrib_arange_like,
"one_hot" : _mx_one_hot,
"depth_to_space" : _mx_depth_to_space,
"space_to_depth" : _mx_space_to_depth,
Expand All @@ -2130,6 +2227,8 @@ def impl(inputs, input_types):
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
"_contrib_interleaved_matmul_selfatt_qk" : _mx_contrib_interleaved_matmul_selfatt_qk,
"_contrib_interleaved_matmul_selfatt_valatt" : _mx_contrib_interleaved_matmul_selfatt_valatt,
# control flow
"_cond" : _mx_cond,
# Depricated:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def _impl(inputs, attrs):
def _softmax_op(new_op):
"""softmax/log_softmax"""
def _impl(inputs, attrs, _dtype='float32'):
assert len(inputs) == 1
# TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6
# assert len(inputs) == 1
axis = attrs.get_int("axis", -1)
return new_op(inputs[0], axis=axis)
return _impl
Expand Down
82 changes: 82 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_forward_prelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_gelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='gelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
Expand Down Expand Up @@ -1228,6 +1234,78 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False)


def test_forward_arange_like():
def verify(data_shape, start=None, step=None, axis=None):
attrs = {}
if start is not None:
attrs['start'] = start
if step is not None:
attrs['step'] = step
if axis is not None:
attrs['axis'] = axis
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype("float32")
ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)

mx_sym = mx.sym.contrib.arange_like(data, **attrs)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify(data_shape=(3,), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
verify(data_shape=(3, 4, 5), start=2., step=3., axis=1)


def test_forward_interleaved_matmul_selfatt_qk():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(
mx.nd.array(data_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(1, 10, 3, 16)
verify(3, 10, 6, 8)


def test_forward_interleaved_matmul_selfatt_valatt():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
weight_shape = (batch * num_heads, seq_length, seq_length)
data = mx.sym.var('data')
weight = mx.sym.var('weight')
data_np = np.random.uniform(size=data_shape).astype('float32')
weight_np = np.random.uniform(size=weight_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_valatt(
mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
data, weight, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(
mx_sym, {"data": data_shape, "weight": weight_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data=data_np, weight=weight_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(1, 10, 4, 16)
verify(3, 10, 6, 8)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand All @@ -1236,6 +1314,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_gelu()
test_forward_softrelu()
test_forward_softmin()
test_forward_fc_flatten()
Expand Down Expand Up @@ -1297,3 +1376,6 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_correlation()
test_forward_grid_generator()
test_forward_bilinear_sampler()
test_forward_arange_like()
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()

0 comments on commit d2b9f92

Please sign in to comment.