Skip to content

Commit

Permalink
address CR comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 3, 2020
1 parent 9089d04 commit 4912bfe
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
17 changes: 8 additions & 9 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,15 +909,14 @@ def _mx_amp_multicast(inputs, attrs):
supported_dtypes = ['float16', 'float32']
assert all([x in supported_dtypes for x in dtypes]), \
"amp_multicast support is limited to float16 and float32 inputs only."
dtype = 'float32' if cast_narrow else dtypes[0]
for t in dtypes:
if cast_narrow and t == 'float16':
dtype = 'float16'
break
elif not cast_narrow and t == 'float32':
dtype = 'float32'
break
return [relay.cast(x, dtype) for x in inputs]
has_float16 = any(x == "float16" for x in dtypes)
has_float32 = any(x == "float32" for x in dtypes)
dtype = dtypes[0]
if cast_narrow and has_float16:
dtype = 'float16'
if not cast_narrow and has_float32:
dtype = 'float32'
return [_op.cast(x, dtype) for x in inputs]

def _mx_grid_generator(inputs, attrs):
transform_type = attrs.get_str("transform_type")
Expand Down
19 changes: 9 additions & 10 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,42 +1133,41 @@ def verify(a_np, b_np):

def test_forward_amp_cast():
def verify(from_dtype, to_dtype):
from_nd = mx.nd.ones((2,2), dtype=from_dtype)
from_np = from_nd.asnumpy()
from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype)
x_var = mx.sym.var('x', dtype=from_dtype)
mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
shape_dict = {'x': (2,2)}
shape_dict = {'x': (1,3,18)}
dtype_dict = {'x': from_dtype}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(from_np)
assert op_res.dtype == to_dtype, op_res.dtype
tvm.testing.assert_allclose(op_res.asnumpy(), 1.)
tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype))

verify('float32', 'float16')
verify('float16', 'float32')

def test_forward_amp_multicast():
def verify(dtypes, cast_narrow, expected_dtype):
x_nps = [np.ones((2,2), dtype=dtype) for dtype in dtypes]
x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes]
x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)]
mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow,
num_outputs=len(dtypes))
shape_dict = {}
dtype_dict = {}
for i, dtype in enumerate(dtypes):
shape_dict[str(i)] = (2,2)
shape_dict[str(i)] = (1,3,18)
dtype_dict[str(i)] = dtype
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*x_nps)
for res in op_res:
for i, res in enumerate(op_res):
assert res.dtype == expected_dtype, res.dtype
tvm.testing.assert_allclose(res.asnumpy(), 1)
tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype))

verify(['float32', 'float16'], False, 'float32')
verify(['float32', 'float16'], True, 'float16')
Expand Down Expand Up @@ -1375,8 +1374,6 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn


if __name__ == '__main__':
test_forward_amp_multicast()
test_forward_amp_cast()
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
Expand Down Expand Up @@ -1450,3 +1447,5 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()
test_forward_box_decode()
test_forward_amp_multicast()
test_forward_amp_cast()

0 comments on commit 4912bfe

Please sign in to comment.