Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][MxNet] support mxnet cond op #4311

Merged
merged 1 commit into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import json
import tvm
from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
from .. import scope_builder as _scope_builder
from ... import nd as _nd

from .common import StrAttrsDict
Expand Down Expand Up @@ -1037,6 +1039,47 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)

def _mx_cond(inputs, attrs, subgraphs):
assert len(subgraphs) == 3
cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
then_input_locs = json.loads(attrs.get_str("then_input_locs"))
else_input_locs = json.loads(attrs.get_str("else_input_locs"))
num_outputs = attrs.get_int("num_outputs")

input_args = []
for i, arg in enumerate(inputs):
var = _expr.var("arg%s" % i, _infer_type(arg).checked_type)
input_args.append(var)
cond_args = [input_args[i] for i in cond_input_locs]
then_args = [input_args[i] for i in then_input_locs]
else_args = [input_args[i] for i in else_input_locs]

cond_arg_shapes = [arg.type_annotation.shape for arg in cond_args]
cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args]
cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info)
cond = _expr.Call(cond_func, cond_args).astype("bool")
cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
if len(cond_shape) > 0:
assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar"
cond = _op.take(cond, _expr.const(1, "int"))

sb = _scope_builder.ScopeBuilder()
with sb.if_scope(cond):
then_arg_shapes = [arg.type_annotation.shape for arg in then_args]
then_arg_dtype_info = [arg.type_annotation.dtype for arg in then_args]
then_func = _from_mxnet_impl(subgraphs[1], then_arg_shapes, then_arg_dtype_info)
sb.ret(_expr.Call(then_func, then_args))
with sb.else_scope():
else_arg_shapes = [arg.type_annotation.shape for arg in else_args]
else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
sb.ret(_expr.Call(else_func, else_args))
func = _expr.Function(input_args, sb.get())
ret = _expr.Call(func, inputs)
if num_outputs > 1:
ret = _expr.TupleWrapper(ret, num_outputs)
return ret


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
Expand Down Expand Up @@ -1204,6 +1247,8 @@ def _mx_contrib_fifo_buffer(inputs, attrs):
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# control flow
"_cond" : _mx_cond,
# Depricated:
"Crop" : _mx_crop_like,
# List of missing operators that are present in NNVMv1
Expand Down Expand Up @@ -1245,24 +1290,41 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
Converted relay Function
"""
assert symbol is not None
jgraph = json.loads(symbol.tojson())
if isinstance(symbol, dict):
jgraph = symbol
else:
jgraph = json.loads(symbol.tojson())
jnodes = jgraph["nodes"]
node_map = {}
shape_idx = 0

for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]]
attrs = StrAttrsDict(node.get("attrs", {}))
node_name = node["name"]
op_name = node["op"]
if op_name == "null":
shape = shape_dict[node_name] if node_name in shape_dict else None
if isinstance(shape_dict, dict):
shape = shape_dict[node_name] if node_name in shape_dict else None
elif isinstance(shape_dict, (list, tuple)):
shape = shape_dict[shape_idx]
else:
raise ValueError("Unknown type of shape_dict: %s" + type(shape_dict))
if isinstance(dtype_info, dict):
dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
elif isinstance(dtype_info, (list, tuple)):
dtype = dtype_info[shape_idx]
else:
dtype = dtype_info
if isinstance(shape_dict, (list, tuple)):
shape_idx += 1
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs)
if op_name in ['_cond', '_foreach', '_while_loop']:
subgraphs = node['subgraphs']
res = _convert_map[op_name](children, attrs, subgraphs)
else:
res = _convert_map[op_name](children, attrs)
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,31 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)

def test_forward_cond():
def verify(a_np, b_np):
a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np)
pred = a_nd * b_nd < 5
then_func = lambda: (a_nd + 5) * (b_nd + 5)
else_func = lambda: (a_nd - 5) * (b_nd - 5)
ref_res = mx.nd.contrib.cond(pred, then_func, else_func)

a_sym, b_sym = mx.sym.var("a"), mx.sym.var("b")
pred = a_sym * b_sym < 5
then_func = lambda: (a_sym + 5) * (b_sym + 5)
else_func = lambda: (a_sym - 5) * (b_sym - 5)
mx_sym = mx.sym.contrib.cond(pred, then_func, else_func)

shape_dict = {"a": a_np.shape, "b": b_np.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["debug", "vm"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)

verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))


if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -963,3 +988,4 @@ def verify(data_shape, kernel_size, stride, pad, num_filter):
test_forward_one_hot()
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()