Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Add MXNet converter for RNN layer ops #3125

Merged
merged 1 commit into from
May 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr as _expr
from . import ty as _ty
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen

Expand Down Expand Up @@ -427,6 +428,8 @@ def __init__(self, mod, ctx, target):
self.target = target

def _make_executor(self, func):
ret_type = ir_pass.infer_type(func).ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
Expand All @@ -440,7 +443,12 @@ def _graph_wrapper(*args, **kwargs):
# Run the module, and fetch the output.
gmodule.run()
# make a copy so multiple invocation won't hurt perf.
return gmodule.get_output(0).copyto(_nd.cpu(0))
if num_outputs == 1:
return gmodule.get_output(0).copyto(_nd.cpu(0))
outputs = []
for i in range(num_outputs):
outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
return outputs

return _graph_wrapper

Expand Down
129 changes: 115 additions & 14 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@

__all__ = ['from_mxnet']

_activation_map = {
"sigmoid": _op.sigmoid,
"tanh" : _op.tanh,
"relu" : _op.nn.relu
}

def _mx_fully_connected(inputs, attrs):
import mxnet as mx
units = attrs.get_int("num_hidden")
Expand Down Expand Up @@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name):
def _mx_activations(inputs, attrs):
act_type = attrs.get_str("act_type")
assert len(inputs) == 1
if act_type == "sigmoid":
return _op.sigmoid(inputs[0])
if act_type == "tanh":
return _op.tanh(inputs[0])
if act_type == "relu":
return _op.nn.relu(inputs[0])
if act_type == "softrelu":
def _stable_softrelu(x):
# log(1 + exp(-abs(x))) + relu(x)
Expand All @@ -80,8 +80,10 @@ def _stable_softrelu(x):
return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
_op.nn.relu(x))
return _stable_softrelu(inputs[0])
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))
if act_type not in _activation_map:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend MXNet.'.format(act_type))
return _activation_map[act_type](inputs[0])


def _mx_compare(new_op, wrapper):
Expand Down Expand Up @@ -189,7 +191,8 @@ def _pool2d(new_op, is_avg):
def _mx_adaptive_avg_pooling(inputs, attrs):
output_size = attrs.get_int_tuple("output_size", [])
if output_size != (1,):
raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
raise tvm.error.OpAttributeUnimplemented(
"AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
return _op.nn.global_avg_pool2d(inputs[0])


Expand Down Expand Up @@ -471,7 +474,7 @@ def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet")
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)

Expand Down Expand Up @@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs):
def _mx_shape_array(inputs, attrs):
assert len(inputs) == 1
if attrs.get_int("lhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_begin")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_begin")
if attrs.get_int("lhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support lhs_end")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_end")
if attrs.get_int("rhs_begin", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_begin")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin")
if attrs.get_int("rhs_end", None) is not None:
raise RuntimeError("shape_array doesn't support rhs_end")
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end")
return _op.shape_of(inputs[0], dtype='int64')


Expand Down Expand Up @@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs)


def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]


def _mx_rnn_layer(inputs, attrs):
def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activation):
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
out = _activation_map[activation](i2h + h2h)
return out, [out]

def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
dtype = ir_pass.infer_type(data).checked_type.dtype
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1)
h2h_r, h2h_z, h2h = _op.split(h2h, indices_or_sections=3, axis=1)
reset_gate = _activation_map["sigmoid"](i2h_r + h2h_r)
update_gate = _activation_map["sigmoid"](i2h_z + h2h_z)
next_h_tmp = _activation_map["tanh"](reset_gate * h2h + i2h)
next_h = (_expr.const(1, dtype) - update_gate) * next_h_tmp + update_gate * states[0]
return next_h, [next_h]

def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
gates = i2h + h2h
slice_gates = _op.split(gates, indices_or_sections=4, axis=1)
in_gate = _activation_map["sigmoid"](slice_gates[0])
forget_gate = _activation_map["sigmoid"](slice_gates[1])
in_transform = _activation_map["tanh"](slice_gates[2])
out_gate = _activation_map["sigmoid"](slice_gates[3])
next_c = forget_gate * states[1] + in_gate * in_transform
next_h = out_gate * _activation_map["tanh"](next_c)
return next_h, [next_h, next_c]

num_layers = attrs.get_int("num_layers", 1)
mode = attrs.get_str("mode")
if mode.startswith("rnn"):
mode, activation = mode.split('_')
assert mode in ["rnn", "gru", "lstm"]
bidirectional = attrs.get_bool("bidirectional", False)
if bidirectional:
raise tvm.error.OpAttributeUnimplemented(
"Bidirectional RNN op is not supported yet")
layout = attrs.get_str("layout", "TNC")
if layout != "TNC":
raise tvm.error.OpAttributeUnimplemented(
"RNN with layout other than TNC is not supported yet")
num_states = 2 if mode == 'lstm' else 1
assert len(inputs) == num_states + 2

seq_data = inputs[0]
concat_weight = inputs[1]
concat_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
assert len(concat_weight) == num_layers * 4

weights = []
bias = []
states = []
for i in range(num_layers):
w = []
b = []
s = []
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
states.append(s)

seq_output = []
for t in range(seq_len):
data = _op.take(seq_data, _expr.const(t, "int32"), axis=0)
for l in range(num_layers):
if mode == "rnn":
out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation)
elif mode == "gru":
out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l])
else: # mode == "lstm"
out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l])
states[l] = new_states
data = out
seq_output.append(out)

outputs = [_op.stack(seq_output, axis=0)]
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -807,6 +905,9 @@ def _mx_argsort(inputs, attrs):
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
Expand Down
49 changes: 49 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,54 @@ def test_forward_bilinear_resize():
mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))

def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru":
layer = gluon.rnn.GRU(hidden_size, num_layers)
else: # mode == "lstm"
layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1
layer.initialize()

dtype = "float32"
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = []
states_mx = []
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s

layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_res = [mx_out] + mx_states
mx_sym = layer._cached_graph[1]
mx_params = {}
for name, param in layer.collect_params().items():
mx_params[name] = param._reduce()

new_sym, params = relay.frontend.from_mxnet(
mx_sym, shape=shape_dict, arg_params=mx_params)
for target, ctx in ctx_list():
# only test graph runtime because debug runtime is too slow
for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params)
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)

for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)


if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -566,3 +614,4 @@ def test_forward_bilinear_resize():
test_forward_take()
test_forward_gather_nd()
test_forward_bilinear_resize()
test_forward_rnn_layer()