From d9f009a560fbec1f1f2394fdbbafbe7d43a92768 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 7 Jul 2020 22:37:10 -0700 Subject: [PATCH] [Frontend][MXNet] MXNet frontend support for AMP cast op (#5976) * amp_cast * fix test * more tests * test more ctxs * fix doc * fix typo * address CR comment * fix lint * revert doc change * Revert "revert doc change" This reverts commit a410dd5569730ac81af67ddb333c3afbe97eddd7. * fix doc * Update relay_pass_infra.rst Co-authored-by: Ubuntu --- docs/dev/relay_add_pass.rst | 6 +-- docs/dev/virtual_machine.rst | 4 +- python/tvm/relay/frontend/mxnet.py | 19 ++++++++- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/partial_eval.cc | 2 +- tests/python/frontend/mxnet/test_forward.py | 47 +++++++++++++++++++++ 6 files changed, 72 insertions(+), 8 deletions(-) diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index a82ae4ff717a..fc265592d08f 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -181,7 +181,7 @@ Example: Constant Folding ------------------------- In order to better understand the process of writing a pass, we will look at -the constant folding pass (found in `src/relay/pass/fold_constant.cc`_) +the constant folding pass (found in `src/relay/transforms/fold_constant.cc`_) as a guide, because it is a relatively simple pass that incorporates both types of traversals. @@ -329,7 +329,7 @@ Now, we construct a more convenient interface ``FoldConstant`` for our constant folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder`` class that takes an expression and internally creates and uses a ``ConstantFolder`` instance (the full definition can be found in -`src/relay/pass/fold_constant.cc`_). +`src/relay/transforms/fold_constant.cc`_). Registering a Pass with the Pass Manager @@ -403,4 +403,4 @@ in `src/relay/pass/`_. .. _src/relay/pass/: https://github.com/apache/incubator-tvm/tree/master/src/relay/pass -.. _src/relay/pass/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/pass/fold_constant.cc +.. _src/relay/transforms/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/transforms/fold_constant.cc diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst index 5bb5adee5459..58780031811d 100644 --- a/docs/dev/virtual_machine.rst +++ b/docs/dev/virtual_machine.rst @@ -38,8 +38,8 @@ them on the runtime. Graph runtime provides a fast execution experience but only subset of Relay programs. An alternative but not-standard approach is Relay's ahead-of-time compiler, -which compiles a Relay program into a shared library containing an ahead- -of-time implementation. The ahead-of-time compiler provides compelling performance +which compiles a Relay program into a shared library containing an ahead-of-time +implementation. The ahead-of-time compiler provides compelling performance but is difficult to extend and instrument, which can only be done by modifying the code generation and optimization mechanisms. diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 135756ba851e..97b9d7a44997 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -903,6 +903,21 @@ def _mx_resize(inputs, attrs): return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") +def _mx_amp_multicast(inputs, attrs): + cast_narrow = attrs.get_bool("cast_narrow", False) + dtypes = [_infer_type(x).checked_type.dtype for x in inputs] + supported_dtypes = ['float16', 'float32'] + assert all([x in supported_dtypes for x in dtypes]), \ + "amp_multicast support is limited to float16 and float32 inputs only." + has_float16 = any(x == "float16" for x in dtypes) + has_float32 = any(x == "float32" for x in dtypes) + dtype = dtypes[0] + if cast_narrow and has_float16: + dtype = 'float16' + if not cast_narrow and has_float32: + dtype = 'float32' + return [_op.cast(x, dtype) for x in inputs] + def _mx_grid_generator(inputs, attrs): transform_type = attrs.get_str("transform_type") if transform_type == 'affine': @@ -1481,7 +1496,7 @@ def _qnn_contrib_concat(inputs, attrs): # Get all dtypes. Find input and output scales, call concatenate. dtypes = [_infer_type(x).checked_type.dtype for x in input_exprs] assert all([x == 'uint8' for x in dtypes]), \ - "Current suppor is limited to uint8 inputs only." + "Current support is limited to uint8 inputs only." new_min = min(mins) new_max = max(maxs) assert new_min == 0 @@ -2184,6 +2199,8 @@ def impl(inputs, input_types): "Reshape" : _reshape, "reshape" : _reshape, "Cast" : _cast, + "amp_cast" : _cast, + "amp_multicast" : _mx_amp_multicast, "clip" : _clip, "transpose" : _transpose, "UpSampling" : _upsampling, diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 50de8711a4c1..d66d6bccdea1 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -194,7 +194,7 @@ class ConstantFolder : public ExprMutator { return Expr(); } } - // Constant evaluate a expression. + // Constant evaluate an expression. Expr ConstEvaluate(Expr expr) { std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), transform::InferType()}; diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 371142ad76a2..63bd04d526de 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -901,7 +901,7 @@ class PartialEvaluator : public ExprFunctor } } - // Constant evaluate a expression. + // Constant evaluate an expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { std::vector passes = {transform::FuseOps(0), transform::InferType()}; auto mod = IRModule::FromExpr(expr); diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4d8b1e98f950..c8bbf88c96ef 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1131,6 +1131,51 @@ def verify(a_np, b_np): verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32')) verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32')) +def test_forward_amp_cast(): + def verify(from_dtype, to_dtype): + from_np = np.random.uniform(size=(1,3,18)).astype(from_dtype) + x_var = mx.sym.var('x', dtype=from_dtype) + mx_sym = mx.sym.amp_cast(x_var, dtype=to_dtype) + shape_dict = {'x': (1,3,18)} + dtype_dict = {'x': from_dtype} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "vm", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(from_np) + assert op_res.dtype == to_dtype, op_res.dtype + tvm.testing.assert_allclose(op_res.asnumpy(), from_np.astype(to_dtype)) + + verify('float32', 'float16') + verify('float16', 'float32') + +def test_forward_amp_multicast(): + def verify(dtypes, cast_narrow, expected_dtype): + x_nps = [np.random.uniform(size=(1,3,18)).astype(dtype) for dtype in dtypes] + x_vars = [mx.sym.var(str(i), dtype=dtype) for i, dtype in enumerate(dtypes)] + mx_sym = mx.sym.amp_multicast(*x_vars, cast_narrow=cast_narrow, + num_outputs=len(dtypes)) + shape_dict = {} + dtype_dict = {} + for i, dtype in enumerate(dtypes): + shape_dict[str(i)] = (1,3,18) + dtype_dict[str(i)] = dtype + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "vm", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*x_nps) + for i, res in enumerate(op_res): + assert res.dtype == expected_dtype, res.dtype + tvm.testing.assert_allclose(res.asnumpy(), x_nps[i].astype(expected_dtype)) + + verify(['float32', 'float16'], False, 'float32') + verify(['float32', 'float16'], True, 'float16') + verify(['float32', 'float32'], False, 'float32') + verify(['float32', 'float32'], True, 'float32') + verify(['float16', 'float16'], False, 'float16') + verify(['float16', 'float16'], True, 'float16') + def test_forward_unravel_index(): def verify(x, shape, dtype): @@ -1402,3 +1447,5 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn test_forward_interleaved_matmul_selfatt_qk() test_forward_interleaved_matmul_selfatt_valatt() test_forward_box_decode() + test_forward_amp_multicast() + test_forward_amp_cast()