From b85c2397356c543b13be4894bf06900a41a13450 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Fri, 17 Jul 2020 06:22:28 +0800 Subject: [PATCH] Refine LSTMBlockCell to support dynamic rnn (#5963) 1. Refine conversion of `LSTMBlockCell` 1) Make its output follows definition in TensorFlow 2) Avoid introducing variables which doesn't match any placeholder nodes in TensorFlow graph 2. About change in test_forward_ptb States nodes of LSTMBlockCell in this PB file are actually Constant node. TF can feed data to those Constant nodes but relay can't do that, so current conversion of LSTMBockCell introduces extra variables to solve this issue. But this causes that relay IR doesn't match original TF graph. This PR solves this issue by convert those states node into placeholders. --- python/tvm/relay/frontend/tensorflow.py | 374 ++++-------------- python/tvm/relay/testing/tf.py | 39 +- .../frontend/tensorflow/test_forward.py | 163 ++++++-- 3 files changed, 231 insertions(+), 345 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a1a407287d20..5f52553cfd77 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1990,6 +1990,65 @@ def _impl(inputs, attr, params, mod): return _res return _impl +def _LSTMBlockCell(): + def _impl(inputs, attr, params, mod): + """LSTM Block cell. + Calculations and return values are described in: + https://github.com/tensorflow/tensorflow/blob/ + r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 + + Parameters + ---------- + inputs : relay.Expr + Input data + in_state_c: list of relay.Expr + Cell state input values for all the layers + in_state_h: list of relay.Expr + Hidden state input values for all the layers + attrs : dict + Dict of operator attributes + params : dict + List of pretrained weights and bias + + Returns + ------- + relay.Expr.TupleWapper + [i, cs, f, o, ci, co, h] + """ + in_data = inputs[0] + in_state_c = inputs[1] + in_state_h = inputs[2] + in_weight = inputs[3] + in_bias = inputs[7] + forget_bias = attr.pop('forget_bias') + input_shape = _infer_shape(inputs[0], mod) + weight_shape = _infer_shape(inputs[3], mod) + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] + + in_data = _op.reshape(in_data, + newshape=(batch_size, input_size)) + ixh = _op.concatenate([in_data, in_state_h], axis=1) + in_weight = _op.transpose(in_weight, axes=None) + gates = _op.nn.dense(ixh, in_weight, + units=num_hidden_layers) + gates_bias = _op.add(gates, in_bias) + gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) + in_gate = _op.sigmoid(gate_list[0]) + in_transform = _op.tanh(gate_list[1]) + forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name)) + forget_gate = _op.sigmoid(forget_gate) + out_gate = _op.sigmoid(gate_list[3]) + next_c = _op.add(_op.multiply(forget_gate, in_state_c), + _op.multiply(in_gate, in_transform)) + co = _op.tanh(next_c) + next_h = out_gate * co + + return tvm.relay.TupleWrapper( + tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7) + + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2078,6 +2137,7 @@ def _impl(inputs, attr, params, mod): 'LogicalOr' : _logical('logical_or'), 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), + 'LSTMBlockCell' : _LSTMBlockCell(), 'MatMul' : _matmul(), 'Max' : _reduce('max'), 'Maximum' : _elemwise('maximum'), @@ -2154,253 +2214,8 @@ def _impl(inputs, attr, params, mod): 'UnravelIndex' : _unravel_index(), 'Where' : _where(), 'ZerosLike' : AttrCvt('zeros_like'), - -} - -def _LSTMBlockCell(): - def _impl(inputs, in_state_c, in_state_h, attr, params, mod): - """LSTM Block cell. - Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ - r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 - - Parameters - ---------- - inputs : relay.Expr - Input data - in_state_c: list of relay.Expr - Cell state input values for all the layers - in_state_h: list of relay.Expr - Hidden state input values for all the layers - attrs : dict - Dict of operator attributes - params : dict - List of pretrained weights and bias - - Returns - ------- - sym : relay.Expr - Converted relay.Expr - output: relay.Expr - Output state value. - """ - in_data = inputs[0] - in_weight = inputs[3] - in_bias = inputs[7] - forget_bias = attr.pop('forget_bias') - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) - batch_size, input_size = input_shape[0], input_shape[1] - num_hidden_layers = weight_shape[1] - num_hidden = num_hidden_layers // 4 - - in_data = _op.reshape(in_data, - newshape=(batch_size, input_size)) - ixh = _op.concatenate([in_data, in_state_h], axis=1) - in_weight = _op.transpose(in_weight, axes=None) - gates = _op.nn.dense(ixh, in_weight, - units=num_hidden_layers) - gates_bias = _op.add(gates, in_bias) - gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) - in_gate = _op.sigmoid(gate_list[0]) - in_transform = _op.tanh(gate_list[1]) - forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name)) - forget_gate = _op.sigmoid(forget_gate) - out_gate = _op.sigmoid(gate_list[3]) - next_c = _op.add(_op.multiply(forget_gate, in_state_c), - _op.multiply(in_gate, in_transform)) - next_h = out_gate * _op.tanh(next_c) - out_state = _op.concatenate([next_c, next_h], axis=1) - out_state = _op.reshape(out_state, - newshape=(2, batch_size, num_hidden)) - return next_h, out_state - return _impl - -# _convert_map_rnn defines maps of rnn operator name to -# converter functor(callable) for 1 to 1 mapping. -_convert_map_rnn = { - 'LSTMBlockCell' : _LSTMBlockCell(), } -class RecurrentNetworks(object): - """Recurrent network layer handlers. - - Handle Layer operations. - ToDo: Operators like RNN/GRU layer concepts also can be handled here - - Parameters - ---------- - nodes : list - list of graph nodes used for tensorflow parsing. - - out_rnn : list - List of RecurrentNetwork outputs. This output will be appended to the - 'head' nodes of the graph. - - graph : tensorflow graph definition object - The loaded tensorflow GraphDef - - convert_map : dict - Dict of name : callable, where name is the op's name that - require conversion to relay, callable are functions which - take attrs and return (new_op_name, new_attrs) - """ - def __init__(self, nodes, out_rnn, graph, convert_map): - self._graph = graph - self._convert_map = convert_map - self._nodes = nodes - self._out_rnn = out_rnn - self._cur_lstm_layer = 0 - self._layer_name_list = [] - self._recurrent_ops_layer_map = { - 'LSTMBlockCell' : self._LSTMBlockCellLayer(), - } - - def _LSTMBlockCellLayer(self): - """LSTMBlockCell layer handler. - - Parameters - ---------- - op_name : str - Operator name, eg:LSTMBlockCell - - layer_name : str list - Layer name is used for creating the state input placeholder. - - inputs : relay.Expr - Input data - - attrs : dict - Dict of operator attributes - - params : dict - List of pretrained weights and bias - - num_layers : int - Total number of LSTM layer presented in the graph - - Returns - ------- - sym : relay.Expr - The returned relay Expr - """ - def _impl(op_name, layer_name, inputs, attrs, params, num_layers, mod): - in_state_c_name = layer_name+'_c' - in_state_h_name = layer_name+'_h' - - def _init_state(num_layers, batch_size, num_hidden): - """Create the initial states for the first layer in the graph.""" - in_state_c = [_expr.var(in_state_c_name, - shape=(num_layers, batch_size, num_hidden), - dtype='float32')] - - in_state_h = [_expr.var(in_state_h_name, - shape=(num_layers, batch_size, num_hidden), - dtype='float32')] - return in_state_c, in_state_h - - def _get_cur_input_state(in_state_c, in_state_h, num_layers, - layer, batch_size, num_hidden): - """Select the appropriate states for the current layer""" - in_state_c_tup = _op.split(in_state_c[0], - indices_or_sections=num_layers, axis=0) - in_state_h_tup = _op.split(in_state_h[0], - indices_or_sections=num_layers, axis=0) - cur_in_state_c = _op.reshape(in_state_c_tup[layer], - newshape=(batch_size, num_hidden)) - cur_in_state_h = _op.reshape(in_state_h_tup[layer], - newshape=(batch_size, num_hidden)) - return cur_in_state_c, cur_in_state_h - - def _LSTMBlockCellWrapper(inputs, attr, params, - num_layers, layer): - """LSTM cell warapper to prepare the inputs""" - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) - - batch_size = input_shape[0] - num_hidden = weight_shape[1] // 4 - - if layer == 0: - #Create initial states placeholder in case of first layer - in_state_c, in_state_h = _init_state(num_layers, - batch_size, num_hidden) - else: - in_state_c = self._nodes[in_state_c_name] - in_state_h = self._nodes[in_state_h_name] - - cur_in_state_c, cur_in_state_h = _get_cur_input_state( - in_state_c, in_state_h, - num_layers, layer, - batch_size, num_hidden) - output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, - cur_in_state_h, - attr, params, mod) - return output, out_state, in_state_c, in_state_h - - sym, cur_out_state, in_state_c, in_state_h = \ - _LSTMBlockCellWrapper(inputs, attrs, params, - num_layers, self._cur_lstm_layer) - self._nodes[in_state_c_name] = in_state_c - self._nodes[in_state_h_name] = in_state_h - cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1) - self._out_rnn.append(cur_out_state) - self._cur_lstm_layer += 1 - return sym - return _impl - - def process_op(self, op_name, inputs, attrs, params, mod): - """Process recurrent layer operators. - - List '_recurrent_ops_layer_map' map each Layer based operators with its - layer handlers. Total number of layers are calculated to form the input - data shapes. - - Parameters - ---------- - op_name : str - Operator name, such as LSTMBlockCell - - inputs : relay.Expr - Input data - - attrs : dict - Dict of operator attributes - - params : dict - List of pretrained weights and bias - - Returns - ------- - sym : relay.Expr - Returns relay.Expr - """ - def _get_abs_layer_name(node): - """Identify the layer name is already handled. Return the absolute name - """ - if not self._layer_name_list: - self._layer_name_list.append(node.name) - return node.name - - for _name in self._layer_name_list: - if _name in node.name: - abs_name = _name - else: - self._layer_name_list.append(node.name) - abs_name = node.name - return abs_name - - #Find number of layers of this same operator node in the graph - #and also read the inputs name for the current op. - num_layers = 0 - for _, node in enumerate(self._graph.node): - if node.op == op_name: - layer_name = _get_abs_layer_name(node) - num_layers += 1 - - sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, - params, num_layers, mod) - return sym - # An internal list to contain all the control flow primitives used in Tensorflow # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] @@ -2889,7 +2704,9 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): elif cnode.name != wnode.name: if is_tensor_array_constuctor(cnode): inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] - self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op) + tn = wnode.input[inode_idx].split(":") + output_index = int(tn[1]) if len(tn) > 1 else 0 + self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op, output_index) break # First, parse all control flow nodes. @@ -2942,15 +2759,10 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): else: out.append(self._nodes[out_name][0]) - #Add the RNN outputs also with 'head' nodes of the relay graph - if self._num_rnn_layer: - if len(self._out_rnn) == 1: - out.append(self._out_rnn[0]) - else: - out_rnn = _op.concatenate(self._out_rnn, axis=0) - out.append(out_rnn) - - out = out[0] if len(out) == 1 else _expr.Tuple(out) + if isinstance(out, _expr.TupleWrapper): + out = out.tuple_value + else: + out = out[0] if len(out) == 1 else _expr.Tuple(out) fvars = analysis.free_vars(out) func = _function.Function(fvars, out) final_params = {} @@ -2988,7 +2800,6 @@ def _parse_import_prerequisites(self, graph): pass else: if any([node.op in t for t in [_identity_list, _convert_map, - _convert_map_rnn, _control_flow_nodes]]): pass elif op_def is not None and op_def.is_stateful: @@ -3082,42 +2893,6 @@ def _parse_attr(self, attr_proto): return attrs - def _convert_rnn_operator(self, op_name, inputs, - attrs, params, graph, convert_map): - """Convert RNN and its variant operators to Relay operators. - This converter read the input states of each layers and - also maintain the output states of each layer in a list. - - Parameters - ---------- - op_name : str - Operator name, such as LSTMBlockCell - inputs : list of relay.Expr - List of input symbols. - attrs : dict - Dict of operator attributes - params : dict - List of pretrained weights and bias - graph : Tensorflow graph object - Graph is to find the number of upcoming same operator to - calculate the number of layers. - convert_map : dict - Dict of name : callable, where name is the op's name that - require conversion to relay, callable are functions which - take attrs and return (new_op_name, new_attrs) - - Returns - ------- - sym : relay.Expr - Converted relay.Expr - """ - if not self._num_rnn_layer: - self._out_rnn = [] - self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) - self._num_rnn_layer = True - sym = self.rnn.process_op(op_name, inputs, attrs, params, self._mod) - return sym - def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): """ Convert the Relay control flow primitive into corresponding component @@ -3355,7 +3130,6 @@ def _convert_operator(self, op_name, inputs, attrs, """ identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map - convert_map_rnn = _convert_map_rnn if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: @@ -3363,12 +3137,6 @@ def _convert_operator(self, op_name, inputs, attrs, sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) else: sym = convert_map[op_name](inputs, attrs, self._params, self._mod) - - elif op_name in convert_map_rnn: - sym = self._convert_rnn_operator(op_name, inputs, attrs, - self._params, graph, - convert_map_rnn) - elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]: sym = self._partition_call_operator(inputs, attrs) else: @@ -3482,8 +3250,12 @@ def _backtrack_construct(self, node_name): if elem_shape: attr["shape"] = elem_shape if attr['identical_element_shapes'] or elem_shape: - shape_node, wnode_op = self._tensor_array_shape_nodes[node.name] - converted = self._backtrack_construct(shape_node.name) + shape_node, wnode_op, output_index = \ + self._tensor_array_shape_nodes[node.name] + name = shape_node.name + if output_index > 0: + name += ":" + str(output_index) + converted = self._backtrack_construct(name) shape = _infer_shape(converted, self._mod) if wnode_op.startswith("TensorArraySplit"): shape = (Any(),) + shape[1:] diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 988faa85dd08..9715bd7a44e1 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -197,7 +197,7 @@ def get_workload_official(model_url, model_sub_path): raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) -def get_workload(model_path, model_sub_path=None): +def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None): """ Import workload from frozen protobuf Parameters @@ -226,8 +226,14 @@ def get_workload(model_path, model_sub_path=None): with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f: graph_def = tf_compat_v1.GraphDef() graph_def.ParseFromString(f.read()) - graph = tf_compat_v1.import_graph_def(graph_def, name='') - return graph_def + graph = tf_compat_v1.import_graph_def(graph_def, name='', input_map=inputs_dict) + + if inputs_dict is not None: + # graph is changed so generate graph_def again + with tf_compat_v1.Session(graph=graph) as sess: + graph_def = AddShapesToGraphDef(sess, output) + + return graph_def ####################################################################### # PTB LSTMBlockCell Model @@ -266,7 +272,7 @@ def do_tf_sample(session, data, in_states, num_samples): 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0', 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0', 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0'] - state = session.run(state_input_name) + state = in_states #Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal #nodes for intermediate operations (gates) in the cell during run. @@ -364,4 +370,27 @@ def get_workload_ptb(): t.extractall(dir_path) word_to_id, id_to_word = _create_ptb_vocabulary(dir_path) - return word_to_id, id_to_word, get_workload(ptb_model_file) + dtype = 'float32' + shape = (1, 200) + + # Convert states of LSTMBlockCell to placeholder, so TVM can feed data + state_name = [ + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0', + ] + + inputs_dict = { + state_name[0]: + tf_compat_v1.placeholder(dtype, shape, state_name[0].split(':')[0]), + state_name[1]: + tf_compat_v1.placeholder(dtype, shape, state_name[1].split(':')[0]), + state_name[2]: + tf_compat_v1.placeholder(dtype, shape, state_name[2].split(':')[0]), + state_name[3]: + tf_compat_v1.placeholder(dtype, shape, state_name[3].split(':')[0]), + } + return word_to_id, id_to_word, get_workload(ptb_model_file, + inputs_dict=inputs_dict, + output='Model/Softmax') diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 182c2d72447a..3b9d4d4fb440 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2051,46 +2051,41 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): input_size = num_hidden input_data = np.full((batch_size, input_size), 1., dtype=dtype) in_state_c = np.full( - (num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + (batch_size, num_hidden), 0.1, dtype=dtype) in_state_h = np.full( - (num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + (batch_size, num_hidden), 0.1, dtype=dtype) def _get_tensorflow_output(): with tf.Session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - m0 = array_ops.zeros([batch_size, num_hidden]) - m1 = array_ops.zeros([batch_size, num_hidden]) - x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype) + m0 = tf.placeholder(dtype, [batch_size, num_hidden], name="m0") + m1 = tf.placeholder(dtype, [batch_size, num_hidden], name="m1") + x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype, name="input") g, ((out_m0, out_m1)) = \ tensorflow.contrib.rnn.LSTMBlockCell(num_hidden, forget_bias=forget_bias)(x, (m0, m1)) sess.run([variables.global_variables_initializer()]) res = sess.run([g, out_m0, out_m1], { x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([batch_size, num_hidden]), - m1.name: 0.1 * np.ones([batch_size, num_hidden]), + m0.name: in_state_c, + m1.name: in_state_h, }) graph_def = sess.graph.as_graph_def(add_shapes=True) final_graph_def = graph_util.convert_variables_to_constants( sess, graph_def, ['root/lstm_cell/LSTMBlockCell']) + return final_graph_def, res graph_def, tf_out = _get_tensorflow_output() tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h], - ['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c', - 'root/lstm_cell/LSTMBlockCell_h'], num_output=2) + ['root/input', "root/m0", "root/m1"], num_output=7) assert isinstance(tvm_output, list) - out = tvm_output[0] - out_state = tvm_output[1] - out_state_tup = np.split(out_state, indices_or_sections=2, axis=1) - out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden)) - out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden)) - tvm_out = [out, out_state_c, out_state_h] - tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(tf_out[0], tvm_output[6], rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(tf_out[1], tvm_output[1], rtol=1e-3, atol=1e-3) def test_forward_lstm(): @@ -2477,7 +2472,7 @@ def test_forward_ptb(): batch_size = config.batch_size vocab_size = config.vocab_size out_sample_shape = (batch_size, vocab_size) - out_state_shape = (num_layers, 2, batch_size, num_hidden) + out_state_shape = (batch_size, num_hidden) # Sample input inpt = "we have no useful information on" cnt_sample = 20 @@ -2490,18 +2485,17 @@ def _pretty_print(items, is_char_model, id2word): def _get_tvm_graph_module(graph_def): # Cell inputs 'c and 'h' consist of all layers values - shape_dict = {'Model/Placeholder': (batch_size, num_steps), - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': - (num_layers, batch_size, num_hidden), - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': - (num_layers, batch_size, num_hidden)} + shape_dict = {'Model/Placeholder': (batch_size, num_steps)} mod, params = relay.frontend.from_tensorflow( - graph_def, shape=shape_dict) + graph_def, shape=shape_dict, + outputs=['Model/Softmax:0', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6', + ]) - dtype_dict = {'Model/Placeholder': 'int32', - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'float32', - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'float32'} target = 'llvm' with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(mod, @@ -2519,24 +2513,26 @@ def _do_tvm_sample(model, data, in_states, params, num_samples): def _get_sample(data, state): input_data = np.full((batch_size, num_steps), data, dtype="int32") - in_state_tup = np.split(state, indices_or_sections=2, axis=1) - in_state_c = np.reshape( - in_state_tup[0], (num_layers, batch_size, num_hidden)) - in_state_h = np.reshape( - in_state_tup[1], (num_layers, batch_size, num_hidden)) model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32"))) - model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c', - tvm.nd.array(in_state_c.astype("float32"))) - model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h', - tvm.nd.array(in_state_h.astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros', + tvm.nd.array(state[0].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1', + tvm.nd.array(state[1].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros', + tvm.nd.array(state[2].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1', + tvm.nd.array(state[3].astype("float32"))) model.set_input(**params) model.run() tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape, "float32")).asnumpy() - state_output = model.get_output(1, tvm.nd.empty(out_state_shape, - "float32")).asnumpy() + + state_output = [] + for i in range(4): + state_output.append(model.get_output(i+1, tvm.nd.empty(out_state_shape, + "float32")).asnumpy()) sample = tf_testing.pick_from_weight(tvm_output[0]) return sample, state_output @@ -2570,8 +2566,7 @@ def _get_sample(data, state): cnt_stm = 0 while cnt_stm < 10: cnt_stm += 1 - in_state = np.full( - (num_layers, 2, batch_size, num_hidden), 0, dtype="float32") + in_state = [np.full((batch_size, num_hidden), 0, dtype="float32")] * 2 * num_layers seed_for_sample = inpt.split() tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] for word in seed_for_sample], @@ -3748,6 +3743,94 @@ def test_forward_dynamic_input_shape(): tvm.testing.assert_allclose(tvm_output[0], tf_output[0], rtol=1e-5, atol=1e-5) +def test_forward_dynmaic_rnn_lstmblockcell(): + if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + return + + total_series_length = 50000 + truncated_backprop_length = 15 + state_size = 4 + echo_step = 3 + batch_size = 5 + num_layers = 5 + + def generateData(): + x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5])) + y = np.roll(x, echo_step) + y[0:echo_step] = 0 + + x = x.reshape((batch_size, -1)) # The first index changing slowest, subseries as rows + y = y.reshape((batch_size, -1)) + + return (x, y) + + batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length]) + + init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) + + state_per_layer_list = tf.unstack(init_state, axis=0) + rnn_tuple_state = tuple( + [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) + for idx in range(num_layers)] + ) + + # Forward passes + def lstm_cell(): + return tensorflow.contrib.rnn.LSTMBlockCell(state_size) + cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)], state_is_tuple=True) + states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1), + initial_state=rnn_tuple_state) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + x, y = generateData() + _current_state = np.zeros((num_layers, 2, batch_size, state_size)) + + start_idx = 0 + end_idx = start_idx + truncated_backprop_length + + batchX = x[:, start_idx:end_idx] + + # Save current state for TVM + current_state_tvm = _current_state + + _current_state, _states_series = sess.run( + [current_state, states_series], + feed_dict={ + batchX_placeholder: batchX, + init_state: _current_state + }) + + # Organize results and corresponding names + tf_output = [_states_series] + + for c in _current_state: + tf_output.append(c.c) + tf_output.append(c.h) + + name = [states_series.name.split(':')[0]] + + for t in current_state: + name.append(t.c.name.split(':')[0]) + name.append(t.h.name.split(':')[0]) + + graph_def = sess.graph.as_graph_def(add_shapes=True) + + final_graph_def = graph_util.convert_variables_to_constants( + sess, + graph_def, + name) + + tvm_output = run_tvm_graph(final_graph_def, + [batchX.astype('float32'), current_state_tvm.astype('float32')], + ["Placeholder", "Placeholder_1"], out_names=name, + num_output=len(name), mode='vm', disabled_pass=["FoldScaleAxis"]) + + # Compare result + for i in range(len(tf_output)): + tvm.testing.assert_allclose( + tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + ####################################################################### # Main @@ -3887,3 +3970,5 @@ def test_forward_dynamic_input_shape(): # Test dynamic input shape test_forward_dynamic_input_shape() + + test_forward_dynmaic_rnn_lstmblockcell()