Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MXNet frontend support for AMP cast op #5976

Merged
merged 14 commits into from
Jul 8, 2020
6 changes: 3 additions & 3 deletions docs/dev/relay_add_pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Example: Constant Folding
-------------------------

In order to better understand the process of writing a pass, we will look at
the constant folding pass (found in `src/relay/pass/fold_constant.cc`_)
the constant folding pass (found in `src/relay/transforms/fold_constant.cc`_)
as a guide, because it is a relatively simple pass that incorporates
both types of traversals.

Expand Down Expand Up @@ -329,7 +329,7 @@ Now, we construct a more convenient interface ``FoldConstant`` for our constant
folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder``
class that takes an expression and internally creates and uses a
``ConstantFolder`` instance (the full definition can be found in
`src/relay/pass/fold_constant.cc`_).
`src/relay/transforms/fold_constant.cc`_).


Registering a Pass with the Pass Manager
Expand Down Expand Up @@ -403,4 +403,4 @@ in `src/relay/pass/`_.

.. _src/relay/pass/: https://github.com/apache/incubator-tvm/tree/master/src/relay/pass

.. _src/relay/pass/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/pass/fold_constant.cc
.. _src/relay/transforms/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/transforms/fold_constant.cc
4 changes: 2 additions & 2 deletions docs/dev/virtual_machine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ them on the runtime. Graph runtime provides a fast execution experience but only
subset of Relay programs.

An alternative but not-standard approach is Relay's ahead-of-time compiler,
which compiles a Relay program into a shared library containing an ahead-
of-time implementation. The ahead-of-time compiler provides compelling performance
which compiles a Relay program into a shared library containing an ahead-of-time
implementation. The ahead-of-time compiler provides compelling performance
but is difficult to extend and instrument, which can only be done by modifying the
code generation and optimization mechanisms.

Expand Down
19 changes: 18 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,21 @@ def _mx_resize(inputs, attrs):
return _op.image.resize(inputs[0], size,
coordinate_transformation_mode="align_corners")

def _mx_amp_multicast(inputs, attrs):
cast_narrow = attrs.get_bool("cast_narrow", False)
dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
supported_dtypes = ['float16', 'float32']
assert all([x in supported_dtypes for x in dtypes]), \
"amp_multicast support is limited to float16 and float32 inputs only."
has_float16 = any(x == "float16" for x in dtypes)
has_float32 = any(x == "float32" for x in dtypes)
dtype = dtypes[0]
if cast_narrow and has_float16:
dtype = 'float16'
if not cast_narrow and has_float32:
dtype = 'float32'
return [_op.cast(x, dtype) for x in inputs]

def _mx_grid_generator(inputs, attrs):
transform_type = attrs.get_str("transform_type")
if transform_type == 'affine':
Expand Down Expand Up @@ -1481,7 +1496,7 @@ def _qnn_contrib_concat(inputs, attrs):
# Get all dtypes. Find input and output scales, call concatenate.
dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs]
assert all([x == 'uint8' for x in dtypes]), \
"Current suppor is limited to uint8 inputs only."
"Current support is limited to uint8 inputs only."
new_min = min(mins)
new_max = max(maxs)
assert new_min == 0
Expand Down Expand Up @@ -2184,6 +2199,8 @@ def impl(inputs, input_types):
"Reshape" : _reshape,
"reshape" : _reshape,
"Cast" : _cast,
"amp_cast" : _cast,
"amp_multicast" : _mx_amp_multicast,
"clip" : _clip,
"transpose" : _transpose,
"UpSampling" : _upsampling,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ConstantFolder : public ExprMutator {
return Expr();
}
}
// Constant evaluate a expression.
// Constant evaluate an expression.
Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::ToANormalForm(),
transform::InferType()};
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}

// Constant evaluate a expression.
// Constant evaluate an expression.
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::InferType()};
auto mod = IRModule::FromExpr(expr);
Expand Down
47 changes: 47 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,51 @@ def verify(a_np, b_np):
verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))

def test_forward_amp_cast():
def verify(from_dtype, to_dtype):
from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype)
x_var = mx.sym.var('x', dtype=from_dtype)
mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
shape_dict = {'x': (1,3,18)}
dtype_dict = {'x': from_dtype}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(from_np)
assert op_res.dtype == to_dtype, op_res.dtype
tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype))

verify('float32', 'float16')
verify('float16', 'float32')

def test_forward_amp_multicast():
def verify(dtypes, cast_narrow, expected_dtype):
x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes]
x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)]
mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow,
num_outputs=len(dtypes))
shape_dict = {}
dtype_dict = {}
for i, dtype in enumerate(dtypes):
shape_dict[str(i)] = (1,3,18)
dtype_dict[str(i)] = dtype
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*x_nps)
for i, res in enumerate(op_res):
assert res.dtype == expected_dtype, res.dtype
tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype))

verify(['float32', 'float16'], False, 'float32')
verify(['float32', 'float16'], True, 'float16')
verify(['float32', 'float32'], False, 'float32')
verify(['float32', 'float32'], True, 'float32')
verify(['float16', 'float16'], False, 'float16')
verify(['float16', 'float16'], True, 'float16')


def test_forward_unravel_index():
def verify(x, shape, dtype):
Expand Down Expand Up @@ -1402,3 +1447,5 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()
test_forward_box_decode()
test_forward_amp_multicast()
test_forward_amp_cast()