Skip to content

Commit

Permalink
[Frontend][MXNet] MXNet frontend support for AMP cast op (#5976)
Browse files Browse the repository at this point in the history
* amp_cast

* fix test

* more tests

* test more ctxs

* fix doc

* fix typo

* address CR comment

* fix lint

* revert doc change

* Revert "revert doc change"

This reverts commit a410dd5.

* fix doc

* Update relay_pass_infra.rst

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
eric-haibin-lin and Ubuntu authored Jul 8, 2020
1 parent a119a6d commit d9f009a
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
6 changes: 3 additions & 3 deletions docs/dev/relay_add_pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Example: Constant Folding
-------------------------

In order to better understand the process of writing a pass, we will look at
the constant folding pass (found in `src/relay/pass/fold_constant.cc`_)
the constant folding pass (found in `src/relay/transforms/fold_constant.cc`_)
as a guide, because it is a relatively simple pass that incorporates
both types of traversals.

Expand Down Expand Up @@ -329,7 +329,7 @@ Now, we construct a more convenient interface ``FoldConstant`` for our constant
folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder``
class that takes an expression and internally creates and uses a
``ConstantFolder`` instance (the full definition can be found in
`src/relay/pass/fold_constant.cc`_).
`src/relay/transforms/fold_constant.cc`_).


Registering a Pass with the Pass Manager
Expand Down Expand Up @@ -403,4 +403,4 @@ in `src/relay/pass/`_.

.. _src/relay/pass/: https://github.com/apache/incubator-tvm/tree/master/src/relay/pass

.. _src/relay/pass/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/pass/fold_constant.cc
.. _src/relay/transforms/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/transforms/fold_constant.cc
4 changes: 2 additions & 2 deletions docs/dev/virtual_machine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ them on the runtime. Graph runtime provides a fast execution experience but only
subset of Relay programs.

An alternative but not-standard approach is Relay's ahead-of-time compiler,
which compiles a Relay program into a shared library containing an ahead-
of-time implementation. The ahead-of-time compiler provides compelling performance
which compiles a Relay program into a shared library containing an ahead-of-time
implementation. The ahead-of-time compiler provides compelling performance
but is difficult to extend and instrument, which can only be done by modifying the
code generation and optimization mechanisms.

Expand Down
19 changes: 18 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,21 @@ def _mx_resize(inputs, attrs):
return _op.image.resize(inputs[0], size,
coordinate_transformation_mode="align_corners")

def _mx_amp_multicast(inputs, attrs):
cast_narrow = attrs.get_bool("cast_narrow", False)
dtypes = [_infer_type(x).checked_type.dtype for x in inputs]
supported_dtypes = ['float16', 'float32']
assert all([x in supported_dtypes for x in dtypes]), \
"amp_multicast support is limited to float16 and float32 inputs only."
has_float16 = any(x == "float16" for x in dtypes)
has_float32 = any(x == "float32" for x in dtypes)
dtype = dtypes[0]
if cast_narrow and has_float16:
dtype = 'float16'
if not cast_narrow and has_float32:
dtype = 'float32'
return [_op.cast(x, dtype) for x in inputs]

def _mx_grid_generator(inputs, attrs):
transform_type = attrs.get_str("transform_type")
if transform_type == 'affine':
Expand Down Expand Up @@ -1481,7 +1496,7 @@ def _qnn_contrib_concat(inputs, attrs):
# Get all dtypes. Find input and output scales, call concatenate.
dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs]
assert all([x == 'uint8' for x in dtypes]), \
"Current suppor is limited to uint8 inputs only."
"Current support is limited to uint8 inputs only."
new_min = min(mins)
new_max = max(maxs)
assert new_min == 0
Expand Down Expand Up @@ -2184,6 +2199,8 @@ def impl(inputs, input_types):
"Reshape" : _reshape,
"reshape" : _reshape,
"Cast" : _cast,
"amp_cast" : _cast,
"amp_multicast" : _mx_amp_multicast,
"clip" : _clip,
"transpose" : _transpose,
"UpSampling" : _upsampling,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ConstantFolder : public ExprMutator {
return Expr();
}
}
// Constant evaluate a expression.
// Constant evaluate an expression.
Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::ToANormalForm(),
transform::InferType()};
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}

// Constant evaluate a expression.
// Constant evaluate an expression.
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::InferType()};
auto mod = IRModule::FromExpr(expr);
Expand Down
47 changes: 47 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,51 @@ def verify(a_np, b_np):
verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))

def test_forward_amp_cast():
def verify(from_dtype, to_dtype):
from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype)
x_var = mx.sym.var('x', dtype=from_dtype)
mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype)
shape_dict = {'x': (1,3,18)}
dtype_dict = {'x': from_dtype}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(from_np)
assert op_res.dtype == to_dtype, op_res.dtype
tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype))

verify('float32', 'float16')
verify('float16', 'float32')

def test_forward_amp_multicast():
def verify(dtypes, cast_narrow, expected_dtype):
x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes]
x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)]
mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow,
num_outputs=len(dtypes))
shape_dict = {}
dtype_dict = {}
for i, dtype in enumerate(dtypes):
shape_dict[str(i)] = (1,3,18)
dtype_dict[str(i)] = dtype
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict)
for target, ctx in ctx_list():
for kind in ["graph", "vm", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*x_nps)
for i, res in enumerate(op_res):
assert res.dtype == expected_dtype, res.dtype
tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype))

verify(['float32', 'float16'], False, 'float32')
verify(['float32', 'float16'], True, 'float16')
verify(['float32', 'float32'], False, 'float32')
verify(['float32', 'float32'], True, 'float32')
verify(['float16', 'float16'], False, 'float16')
verify(['float16', 'float16'], True, 'float16')


def test_forward_unravel_index():
def verify(x, shape, dtype):
Expand Down Expand Up @@ -1402,3 +1447,5 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn
test_forward_interleaved_matmul_selfatt_qk()
test_forward_interleaved_matmul_selfatt_valatt()
test_forward_box_decode()
test_forward_amp_multicast()
test_forward_amp_cast()

0 comments on commit d9f009a

Please sign in to comment.