Skip to content

Commit

Permalink
Add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jun 16, 2020
1 parent cf0df00 commit 87ddf70
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _mx_fully_connected(inputs, attrs):
if len(data_shape) > 2:
new_shape = data_shape[:-1]
new_shape.append(units)
new_shape = [int(v) for v in new_shape]
res = _op.reshape(res, new_shape)
return res

Expand Down
70 changes: 64 additions & 6 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_forward_prelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_gelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='gelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))

def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
Expand Down Expand Up @@ -1204,18 +1210,67 @@ def verify(data_shape, start=None, step=None, axis=None):
attrs['step'] = step
if axis is not None:
attrs['axis'] = axis
print(attrs)
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype("float32")
#ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)
#print(ref_res)
ref_res = mx.nd.contrib.arange_like(mx.nd.array(data_np), **attrs)

mx_sym = mx.sym.contrib.arange_like(data, **attrs)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
print(mod)
#verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=data_shape)
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify(data_shape=(3,), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1.)
verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
verify(data_shape=(3, 4, 5), start=2., step=3., axis=1)


def test_forward_interleaved_matmul_selfatt_qk():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
data = mx.sym.var('data')
data_np = np.random.uniform(size=data_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_qk(
mx.nd.array(data_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_qk(data, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(1, 10, 3, 16)
verify(3, 10, 6, 8)


def test_forward_interleaved_matmul_selfatt_valatt():
def verify(batch, seq_length, num_heads, head_dim):
data_shape = (seq_length, batch, num_heads * head_dim * 3)
weight_shape = (batch * num_heads, seq_length, seq_length)
data = mx.sym.var('data')
weight = mx.sym.var('weight')
data_np = np.random.uniform(size=data_shape).astype('float32')
weight_np = np.random.uniform(size=weight_shape).astype('float32')
ref_res = mx.nd.contrib.interleaved_matmul_selfatt_valatt(
mx.nd.array(data_np), mx.nd.array(weight_np), heads=num_heads)

mx_sym = mx.sym.contrib.interleaved_matmul_selfatt_valatt(
data, weight, heads=num_heads)
mod, _ = relay.frontend.from_mxnet(
mx_sym, {"data": data_shape, "weight": weight_shape})
for target, ctx in ctx_list():
for kind in ["graph"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data=data_np, weight=weight_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)

verify(data_shape=(3, 4, 5), start=0., step=1., axis=-1)
verify(1, 10, 4, 16)
verify(3, 10, 6, 8)


if __name__ == '__main__':
Expand All @@ -1226,6 +1281,7 @@ def verify(data_shape, start=None, step=None, axis=None):
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_gelu()
test_forward_softrelu()
test_forward_softmin()
test_forward_fc_flatten()
Expand Down Expand Up @@ -1287,3 +1343,5 @@ def verify(data_shape, start=None, step=None, axis=None):
test_forward_grid_generator()
test_forward_bilinear_sampler()
test_forward_arange_like()
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()

0 comments on commit 87ddf70

Please sign in to comment.