From ed4a494ff2d8120490e8fe7fe67a8d73e7b3f825 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 12 Nov 2019 00:51:59 +0000 Subject: [PATCH] add --- python/tvm/relay/frontend/mxnet.py | 68 ++++++++++++++++++++- tests/python/frontend/mxnet/test_forward.py | 26 ++++++++ 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 25c062682ebd..abef45d498a1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -20,10 +20,12 @@ import json import tvm +from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr from .. import op as _op from .. import module as _module +from .. import scope_builder as _scope_builder from ... import nd as _nd from .common import StrAttrsDict @@ -1037,6 +1039,47 @@ def _mx_contrib_fifo_buffer(inputs, attrs): new_attrs['axis'] = attrs.get_int('axis') return _op.nn.fifo_buffer(*inputs, **new_attrs) +def _mx_cond(inputs, attrs, subgraphs): + assert len(subgraphs) == 3 + cond_input_locs = json.loads(attrs.get_str("cond_input_locs")) + then_input_locs = json.loads(attrs.get_str("then_input_locs")) + else_input_locs = json.loads(attrs.get_str("else_input_locs")) + num_outputs = attrs.get_int("num_outputs") + + input_args = [] + for i, arg in enumerate(inputs): + var = _expr.var("arg%s" % i, _infer_type(arg).checked_type) + input_args.append(var) + cond_args = [input_args[i] for i in cond_input_locs] + then_args = [input_args[i] for i in then_input_locs] + else_args = [input_args[i] for i in else_input_locs] + + cond_arg_shapes = [arg.type_annotation.shape for arg in cond_args] + cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args] + cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info) + cond = _expr.Call(cond_func, cond_args).astype("bool") + cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape) + if len(cond_shape) > 0: + assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar" + cond = _op.take(cond, _expr.const(1, "int")) + + sb = _scope_builder.ScopeBuilder() + with sb.if_scope(cond): + then_arg_shapes = [arg.type_annotation.shape for arg in then_args] + then_arg_dtype_info = [arg.type_annotation.dtype for arg in then_args] + then_func = _from_mxnet_impl(subgraphs[1], then_arg_shapes, then_arg_dtype_info) + sb.ret(_expr.Call(then_func, then_args)) + with sb.else_scope(): + else_arg_shapes = [arg.type_annotation.shape for arg in else_args] + else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args] + else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info) + sb.ret(_expr.Call(else_func, else_args)) + func = _expr.Function(input_args, sb.get()) + ret = _expr.Call(func, inputs) + if num_outputs > 1: + ret = _expr.TupleWrapper(ret, num_outputs) + return ret + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free @@ -1204,6 +1247,8 @@ def _mx_contrib_fifo_buffer(inputs, attrs): # NLP "RNN" : _mx_rnn_layer, "_rnn_param_concat" : _mx_rnn_param_concat, + # control flow + "_cond" : _mx_cond, # Depricated: "Crop" : _mx_crop_like, # List of missing operators that are present in NNVMv1 @@ -1245,9 +1290,13 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): Converted relay Function """ assert symbol is not None - jgraph = json.loads(symbol.tojson()) + if isinstance(symbol, dict): + jgraph = symbol + else: + jgraph = json.loads(symbol.tojson()) jnodes = jgraph["nodes"] node_map = {} + shape_idx = 0 for nid, node in enumerate(jnodes): children = [node_map[e[0]][e[1]] for e in node["inputs"]] @@ -1255,14 +1304,27 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): node_name = node["name"] op_name = node["op"] if op_name == "null": - shape = shape_dict[node_name] if node_name in shape_dict else None + if isinstance(shape_dict, dict): + shape = shape_dict[node_name] if node_name in shape_dict else None + elif isinstance(shape_dict, (list, tuple)): + shape = shape_dict[shape_idx] + else: + raise ValueError("Unknown type of shape_dict: %s" + type(shape_dict)) if isinstance(dtype_info, dict): dtype = dtype_info[node_name] if node_name in dtype_info else "float32" + elif isinstance(dtype_info, (list, tuple)): + dtype = dtype_info[shape_idx] else: dtype = dtype_info + if isinstance(shape_dict, (list, tuple)): + shape_idx += 1 node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] elif op_name in _convert_map: - res = _convert_map[op_name](children, attrs) + if op_name in ['_cond', '_foreach', '_while_loop']: + subgraphs = node['subgraphs'] + res = _convert_map[op_name](children, attrs, subgraphs) + else: + res = _convert_map[op_name](children, attrs) if res is None: # defer conversion, used in RNN state initialization res = [node] diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index f45f152ed68d..be4436dda07e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -909,6 +909,31 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) +def test_forward_cond(): + def verify(a_np, b_np): + a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np) + pred = a_nd * b_nd < 5 + then_func = lambda: (a_nd + 5) * (b_nd + 5) + else_func = lambda: (a_nd - 5) * (b_nd - 5) + ref_res = mx.nd.contrib.cond(pred, then_func, else_func) + + a_sym, b_sym = mx.sym.var("a"), mx.sym.var("b") + pred = a_sym * b_sym < 5 + then_func = lambda: (a_sym + 5) * (b_sym + 5) + else_func = lambda: (a_sym - 5) * (b_sym - 5) + mx_sym = mx.sym.contrib.cond(pred, then_func, else_func) + + shape_dict = {"a": a_np.shape, "b": b_np.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["debug", "vm"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np, b_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + + verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32')) + verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32')) + if __name__ == '__main__': test_forward_mlp() @@ -963,3 +988,4 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): test_forward_one_hot() test_forward_convolution() test_forward_deconvolution() + test_forward_cond()