diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a1a407287d205..5f52553cfd77d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1990,6 +1990,65 @@ def _impl(inputs, attr, params, mod): return _res return _impl +def _LSTMBlockCell(): + def _impl(inputs, attr, params, mod): + """LSTM Block cell. + Calculations and return values are described in: + https://github.com/tensorflow/tensorflow/blob/ + r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 + + Parameters + ---------- + inputs : relay.Expr + Input data + in_state_c: list of relay.Expr + Cell state input values for all the layers + in_state_h: list of relay.Expr + Hidden state input values for all the layers + attrs : dict + Dict of operator attributes + params : dict + List of pretrained weights and bias + + Returns + ------- + relay.Expr.TupleWapper + [i, cs, f, o, ci, co, h] + """ + in_data = inputs[0] + in_state_c = inputs[1] + in_state_h = inputs[2] + in_weight = inputs[3] + in_bias = inputs[7] + forget_bias = attr.pop('forget_bias') + input_shape = _infer_shape(inputs[0], mod) + weight_shape = _infer_shape(inputs[3], mod) + batch_size, input_size = input_shape[0], input_shape[1] + num_hidden_layers = weight_shape[1] + + in_data = _op.reshape(in_data, + newshape=(batch_size, input_size)) + ixh = _op.concatenate([in_data, in_state_h], axis=1) + in_weight = _op.transpose(in_weight, axes=None) + gates = _op.nn.dense(ixh, in_weight, + units=num_hidden_layers) + gates_bias = _op.add(gates, in_bias) + gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) + in_gate = _op.sigmoid(gate_list[0]) + in_transform = _op.tanh(gate_list[1]) + forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name)) + forget_gate = _op.sigmoid(forget_gate) + out_gate = _op.sigmoid(gate_list[3]) + next_c = _op.add(_op.multiply(forget_gate, in_state_c), + _op.multiply(in_gate, in_transform)) + co = _op.tanh(next_c) + next_h = out_gate * co + + return tvm.relay.TupleWrapper( + tvm.relay.Tuple([in_gate, next_c, forget_gate, out_gate, in_transform, co, next_h]), 7) + + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2078,6 +2137,7 @@ def _impl(inputs, attr, params, mod): 'LogicalOr' : _logical('logical_or'), 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), + 'LSTMBlockCell' : _LSTMBlockCell(), 'MatMul' : _matmul(), 'Max' : _reduce('max'), 'Maximum' : _elemwise('maximum'), @@ -2154,253 +2214,8 @@ def _impl(inputs, attr, params, mod): 'UnravelIndex' : _unravel_index(), 'Where' : _where(), 'ZerosLike' : AttrCvt('zeros_like'), - -} - -def _LSTMBlockCell(): - def _impl(inputs, in_state_c, in_state_h, attr, params, mod): - """LSTM Block cell. - Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ - r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 - - Parameters - ---------- - inputs : relay.Expr - Input data - in_state_c: list of relay.Expr - Cell state input values for all the layers - in_state_h: list of relay.Expr - Hidden state input values for all the layers - attrs : dict - Dict of operator attributes - params : dict - List of pretrained weights and bias - - Returns - ------- - sym : relay.Expr - Converted relay.Expr - output: relay.Expr - Output state value. - """ - in_data = inputs[0] - in_weight = inputs[3] - in_bias = inputs[7] - forget_bias = attr.pop('forget_bias') - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) - batch_size, input_size = input_shape[0], input_shape[1] - num_hidden_layers = weight_shape[1] - num_hidden = num_hidden_layers // 4 - - in_data = _op.reshape(in_data, - newshape=(batch_size, input_size)) - ixh = _op.concatenate([in_data, in_state_h], axis=1) - in_weight = _op.transpose(in_weight, axes=None) - gates = _op.nn.dense(ixh, in_weight, - units=num_hidden_layers) - gates_bias = _op.add(gates, in_bias) - gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) - in_gate = _op.sigmoid(gate_list[0]) - in_transform = _op.tanh(gate_list[1]) - forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name)) - forget_gate = _op.sigmoid(forget_gate) - out_gate = _op.sigmoid(gate_list[3]) - next_c = _op.add(_op.multiply(forget_gate, in_state_c), - _op.multiply(in_gate, in_transform)) - next_h = out_gate * _op.tanh(next_c) - out_state = _op.concatenate([next_c, next_h], axis=1) - out_state = _op.reshape(out_state, - newshape=(2, batch_size, num_hidden)) - return next_h, out_state - return _impl - -# _convert_map_rnn defines maps of rnn operator name to -# converter functor(callable) for 1 to 1 mapping. -_convert_map_rnn = { - 'LSTMBlockCell' : _LSTMBlockCell(), } -class RecurrentNetworks(object): - """Recurrent network layer handlers. - - Handle Layer operations. - ToDo: Operators like RNN/GRU layer concepts also can be handled here - - Parameters - ---------- - nodes : list - list of graph nodes used for tensorflow parsing. - - out_rnn : list - List of RecurrentNetwork outputs. This output will be appended to the - 'head' nodes of the graph. - - graph : tensorflow graph definition object - The loaded tensorflow GraphDef - - convert_map : dict - Dict of name : callable, where name is the op's name that - require conversion to relay, callable are functions which - take attrs and return (new_op_name, new_attrs) - """ - def __init__(self, nodes, out_rnn, graph, convert_map): - self._graph = graph - self._convert_map = convert_map - self._nodes = nodes - self._out_rnn = out_rnn - self._cur_lstm_layer = 0 - self._layer_name_list = [] - self._recurrent_ops_layer_map = { - 'LSTMBlockCell' : self._LSTMBlockCellLayer(), - } - - def _LSTMBlockCellLayer(self): - """LSTMBlockCell layer handler. - - Parameters - ---------- - op_name : str - Operator name, eg:LSTMBlockCell - - layer_name : str list - Layer name is used for creating the state input placeholder. - - inputs : relay.Expr - Input data - - attrs : dict - Dict of operator attributes - - params : dict - List of pretrained weights and bias - - num_layers : int - Total number of LSTM layer presented in the graph - - Returns - ------- - sym : relay.Expr - The returned relay Expr - """ - def _impl(op_name, layer_name, inputs, attrs, params, num_layers, mod): - in_state_c_name = layer_name+'_c' - in_state_h_name = layer_name+'_h' - - def _init_state(num_layers, batch_size, num_hidden): - """Create the initial states for the first layer in the graph.""" - in_state_c = [_expr.var(in_state_c_name, - shape=(num_layers, batch_size, num_hidden), - dtype='float32')] - - in_state_h = [_expr.var(in_state_h_name, - shape=(num_layers, batch_size, num_hidden), - dtype='float32')] - return in_state_c, in_state_h - - def _get_cur_input_state(in_state_c, in_state_h, num_layers, - layer, batch_size, num_hidden): - """Select the appropriate states for the current layer""" - in_state_c_tup = _op.split(in_state_c[0], - indices_or_sections=num_layers, axis=0) - in_state_h_tup = _op.split(in_state_h[0], - indices_or_sections=num_layers, axis=0) - cur_in_state_c = _op.reshape(in_state_c_tup[layer], - newshape=(batch_size, num_hidden)) - cur_in_state_h = _op.reshape(in_state_h_tup[layer], - newshape=(batch_size, num_hidden)) - return cur_in_state_c, cur_in_state_h - - def _LSTMBlockCellWrapper(inputs, attr, params, - num_layers, layer): - """LSTM cell warapper to prepare the inputs""" - input_shape = _infer_shape(inputs[0], mod) - weight_shape = _infer_shape(inputs[3], mod) - - batch_size = input_shape[0] - num_hidden = weight_shape[1] // 4 - - if layer == 0: - #Create initial states placeholder in case of first layer - in_state_c, in_state_h = _init_state(num_layers, - batch_size, num_hidden) - else: - in_state_c = self._nodes[in_state_c_name] - in_state_h = self._nodes[in_state_h_name] - - cur_in_state_c, cur_in_state_h = _get_cur_input_state( - in_state_c, in_state_h, - num_layers, layer, - batch_size, num_hidden) - output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, - cur_in_state_h, - attr, params, mod) - return output, out_state, in_state_c, in_state_h - - sym, cur_out_state, in_state_c, in_state_h = \ - _LSTMBlockCellWrapper(inputs, attrs, params, - num_layers, self._cur_lstm_layer) - self._nodes[in_state_c_name] = in_state_c - self._nodes[in_state_h_name] = in_state_h - cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1) - self._out_rnn.append(cur_out_state) - self._cur_lstm_layer += 1 - return sym - return _impl - - def process_op(self, op_name, inputs, attrs, params, mod): - """Process recurrent layer operators. - - List '_recurrent_ops_layer_map' map each Layer based operators with its - layer handlers. Total number of layers are calculated to form the input - data shapes. - - Parameters - ---------- - op_name : str - Operator name, such as LSTMBlockCell - - inputs : relay.Expr - Input data - - attrs : dict - Dict of operator attributes - - params : dict - List of pretrained weights and bias - - Returns - ------- - sym : relay.Expr - Returns relay.Expr - """ - def _get_abs_layer_name(node): - """Identify the layer name is already handled. Return the absolute name - """ - if not self._layer_name_list: - self._layer_name_list.append(node.name) - return node.name - - for _name in self._layer_name_list: - if _name in node.name: - abs_name = _name - else: - self._layer_name_list.append(node.name) - abs_name = node.name - return abs_name - - #Find number of layers of this same operator node in the graph - #and also read the inputs name for the current op. - num_layers = 0 - for _, node in enumerate(self._graph.node): - if node.op == op_name: - layer_name = _get_abs_layer_name(node) - num_layers += 1 - - sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, - params, num_layers, mod) - return sym - # An internal list to contain all the control flow primitives used in Tensorflow # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] @@ -2889,7 +2704,9 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): elif cnode.name != wnode.name: if is_tensor_array_constuctor(cnode): inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] - self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op) + tn = wnode.input[inode_idx].split(":") + output_index = int(tn[1]) if len(tn) > 1 else 0 + self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op, output_index) break # First, parse all control flow nodes. @@ -2942,15 +2759,10 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): else: out.append(self._nodes[out_name][0]) - #Add the RNN outputs also with 'head' nodes of the relay graph - if self._num_rnn_layer: - if len(self._out_rnn) == 1: - out.append(self._out_rnn[0]) - else: - out_rnn = _op.concatenate(self._out_rnn, axis=0) - out.append(out_rnn) - - out = out[0] if len(out) == 1 else _expr.Tuple(out) + if isinstance(out, _expr.TupleWrapper): + out = out.tuple_value + else: + out = out[0] if len(out) == 1 else _expr.Tuple(out) fvars = analysis.free_vars(out) func = _function.Function(fvars, out) final_params = {} @@ -2988,7 +2800,6 @@ def _parse_import_prerequisites(self, graph): pass else: if any([node.op in t for t in [_identity_list, _convert_map, - _convert_map_rnn, _control_flow_nodes]]): pass elif op_def is not None and op_def.is_stateful: @@ -3082,42 +2893,6 @@ def _parse_attr(self, attr_proto): return attrs - def _convert_rnn_operator(self, op_name, inputs, - attrs, params, graph, convert_map): - """Convert RNN and its variant operators to Relay operators. - This converter read the input states of each layers and - also maintain the output states of each layer in a list. - - Parameters - ---------- - op_name : str - Operator name, such as LSTMBlockCell - inputs : list of relay.Expr - List of input symbols. - attrs : dict - Dict of operator attributes - params : dict - List of pretrained weights and bias - graph : Tensorflow graph object - Graph is to find the number of upcoming same operator to - calculate the number of layers. - convert_map : dict - Dict of name : callable, where name is the op's name that - require conversion to relay, callable are functions which - take attrs and return (new_op_name, new_attrs) - - Returns - ------- - sym : relay.Expr - Converted relay.Expr - """ - if not self._num_rnn_layer: - self._out_rnn = [] - self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) - self._num_rnn_layer = True - sym = self.rnn.process_op(op_name, inputs, attrs, params, self._mod) - return sym - def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): """ Convert the Relay control flow primitive into corresponding component @@ -3355,7 +3130,6 @@ def _convert_operator(self, op_name, inputs, attrs, """ identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map - convert_map_rnn = _convert_map_rnn if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: @@ -3363,12 +3137,6 @@ def _convert_operator(self, op_name, inputs, attrs, sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) else: sym = convert_map[op_name](inputs, attrs, self._params, self._mod) - - elif op_name in convert_map_rnn: - sym = self._convert_rnn_operator(op_name, inputs, attrs, - self._params, graph, - convert_map_rnn) - elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]: sym = self._partition_call_operator(inputs, attrs) else: @@ -3482,8 +3250,12 @@ def _backtrack_construct(self, node_name): if elem_shape: attr["shape"] = elem_shape if attr['identical_element_shapes'] or elem_shape: - shape_node, wnode_op = self._tensor_array_shape_nodes[node.name] - converted = self._backtrack_construct(shape_node.name) + shape_node, wnode_op, output_index = \ + self._tensor_array_shape_nodes[node.name] + name = shape_node.name + if output_index > 0: + name += ":" + str(output_index) + converted = self._backtrack_construct(name) shape = _infer_shape(converted, self._mod) if wnode_op.startswith("TensorArraySplit"): shape = (Any(),) + shape[1:] diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 988faa85dd081..9715bd7a44e13 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -197,7 +197,7 @@ def get_workload_official(model_url, model_sub_path): raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) -def get_workload(model_path, model_sub_path=None): +def get_workload(model_path, model_sub_path=None, inputs_dict=None, output=None): """ Import workload from frozen protobuf Parameters @@ -226,8 +226,14 @@ def get_workload(model_path, model_sub_path=None): with tf_compat_v1.gfile.FastGFile(path_model, 'rb') as f: graph_def = tf_compat_v1.GraphDef() graph_def.ParseFromString(f.read()) - graph = tf_compat_v1.import_graph_def(graph_def, name='') - return graph_def + graph = tf_compat_v1.import_graph_def(graph_def, name='', input_map=inputs_dict) + + if inputs_dict is not None: + # graph is changed so generate graph_def again + with tf_compat_v1.Session(graph=graph) as sess: + graph_def = AddShapesToGraphDef(sess, output) + + return graph_def ####################################################################### # PTB LSTMBlockCell Model @@ -266,7 +272,7 @@ def do_tf_sample(session, data, in_states, num_samples): 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0', 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0', 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0'] - state = session.run(state_input_name) + state = in_states #Graph nodes to be fetched as run output. Tensorflow LSTMBlockCell create internal #nodes for intermediate operations (gates) in the cell during run. @@ -364,4 +370,27 @@ def get_workload_ptb(): t.extractall(dir_path) word_to_id, id_to_word = _create_ptb_vocabulary(dir_path) - return word_to_id, id_to_word, get_workload(ptb_model_file) + dtype = 'float32' + shape = (1, 200) + + # Convert states of LSTMBlockCell to placeholder, so TVM can feed data + state_name = [ + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros:0', + 'Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1:0', + ] + + inputs_dict = { + state_name[0]: + tf_compat_v1.placeholder(dtype, shape, state_name[0].split(':')[0]), + state_name[1]: + tf_compat_v1.placeholder(dtype, shape, state_name[1].split(':')[0]), + state_name[2]: + tf_compat_v1.placeholder(dtype, shape, state_name[2].split(':')[0]), + state_name[3]: + tf_compat_v1.placeholder(dtype, shape, state_name[3].split(':')[0]), + } + return word_to_id, id_to_word, get_workload(ptb_model_file, + inputs_dict=inputs_dict, + output='Model/Softmax') diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index a7371a885daa9..c44ae72c4b35a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2066,46 +2066,41 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias, dtype): input_size = num_hidden input_data = np.full((batch_size, input_size), 1., dtype=dtype) in_state_c = np.full( - (num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + (batch_size, num_hidden), 0.1, dtype=dtype) in_state_h = np.full( - (num_layers, batch_size, num_hidden), 0.1, dtype=dtype) + (batch_size, num_hidden), 0.1, dtype=dtype) def _get_tensorflow_output(): with tf.Session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): - m0 = array_ops.zeros([batch_size, num_hidden]) - m1 = array_ops.zeros([batch_size, num_hidden]) - x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype) + m0 = tf.placeholder(dtype, [batch_size, num_hidden], name="m0") + m1 = tf.placeholder(dtype, [batch_size, num_hidden], name="m1") + x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype, name="input") g, ((out_m0, out_m1)) = \ tensorflow.contrib.rnn.LSTMBlockCell(num_hidden, forget_bias=forget_bias)(x, (m0, m1)) sess.run([variables.global_variables_initializer()]) res = sess.run([g, out_m0, out_m1], { x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([batch_size, num_hidden]), - m1.name: 0.1 * np.ones([batch_size, num_hidden]), + m0.name: in_state_c, + m1.name: in_state_h, }) graph_def = sess.graph.as_graph_def(add_shapes=True) final_graph_def = graph_util.convert_variables_to_constants( sess, graph_def, ['root/lstm_cell/LSTMBlockCell']) + return final_graph_def, res graph_def, tf_out = _get_tensorflow_output() tvm_output = run_tvm_graph(graph_def, [input_data, in_state_c, in_state_h], - ['root/Placeholder', 'root/lstm_cell/LSTMBlockCell_c', - 'root/lstm_cell/LSTMBlockCell_h'], num_output=2) + ['root/input', "root/m0", "root/m1"], num_output=7) assert isinstance(tvm_output, list) - out = tvm_output[0] - out_state = tvm_output[1] - out_state_tup = np.split(out_state, indices_or_sections=2, axis=1) - out_state_c = np.reshape(out_state_tup[0], (batch_size, num_hidden)) - out_state_h = np.reshape(out_state_tup[1], (batch_size, num_hidden)) - tvm_out = [out, out_state_c, out_state_h] - tvm.testing.assert_allclose(tf_out[0], tvm_out[0], rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(tf_out[0], tvm_output[6], rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(tf_out[1], tvm_output[1], rtol=1e-3, atol=1e-3) def test_forward_lstm(): @@ -2493,7 +2488,7 @@ def test_forward_ptb(): batch_size = config.batch_size vocab_size = config.vocab_size out_sample_shape = (batch_size, vocab_size) - out_state_shape = (num_layers, 2, batch_size, num_hidden) + out_state_shape = (batch_size, num_hidden) # Sample input inpt = "we have no useful information on" cnt_sample = 20 @@ -2506,18 +2501,17 @@ def _pretty_print(items, is_char_model, id2word): def _get_tvm_graph_module(graph_def): # Cell inputs 'c and 'h' consist of all layers values - shape_dict = {'Model/Placeholder': (batch_size, num_steps), - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': - (num_layers, batch_size, num_hidden), - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': - (num_layers, batch_size, num_hidden)} + shape_dict = {'Model/Placeholder': (batch_size, num_steps)} mod, params = relay.frontend.from_tensorflow( - graph_def, shape=shape_dict) + graph_def, shape=shape_dict, + outputs=['Model/Softmax:0', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1', + 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6', + ]) - dtype_dict = {'Model/Placeholder': 'int32', - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'float32', - 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'float32'} target = 'llvm' with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(mod, @@ -2535,24 +2529,26 @@ def _do_tvm_sample(model, data, in_states, params, num_samples): def _get_sample(data, state): input_data = np.full((batch_size, num_steps), data, dtype="int32") - in_state_tup = np.split(state, indices_or_sections=2, axis=1) - in_state_c = np.reshape( - in_state_tup[0], (num_layers, batch_size, num_hidden)) - in_state_h = np.reshape( - in_state_tup[1], (num_layers, batch_size, num_hidden)) model.set_input('Model/Placeholder', tvm.nd.array(input_data.astype("int32"))) - model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c', - tvm.nd.array(in_state_c.astype("float32"))) - model.set_input('Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h', - tvm.nd.array(in_state_h.astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros', + tvm.nd.array(state[0].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState/zeros_1', + tvm.nd.array(state[1].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros', + tvm.nd.array(state[2].astype("float32"))) + model.set_input('Model/MultiRNNCellZeroState/LSTMBlockCellZeroState_1/zeros_1', + tvm.nd.array(state[3].astype("float32"))) model.set_input(**params) model.run() tvm_output = model.get_output(0, tvm.nd.empty(out_sample_shape, "float32")).asnumpy() - state_output = model.get_output(1, tvm.nd.empty(out_state_shape, - "float32")).asnumpy() + + state_output = [] + for i in range(4): + state_output.append(model.get_output(i+1, tvm.nd.empty(out_state_shape, + "float32")).asnumpy()) sample = tf_testing.pick_from_weight(tvm_output[0]) return sample, state_output @@ -2586,8 +2582,7 @@ def _get_sample(data, state): cnt_stm = 0 while cnt_stm < 10: cnt_stm += 1 - in_state = np.full( - (num_layers, 2, batch_size, num_hidden), 0, dtype="float32") + in_state = [np.full((batch_size, num_hidden), 0, dtype="float32")] * 2 * num_layers seed_for_sample = inpt.split() tvm_samples, tvm_state = _do_tvm_sample(m, [word_to_id[word] for word in seed_for_sample], @@ -3764,6 +3759,94 @@ def test_forward_dynamic_input_shape(): tvm.testing.assert_allclose(tvm_output[0], tf_output[0], rtol=1e-5, atol=1e-5) +def test_forward_dynmaic_rnn_lstmblockcell(): + if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + return + + total_series_length = 50000 + truncated_backprop_length = 15 + state_size = 4 + echo_step = 3 + batch_size = 5 + num_layers = 5 + + def generateData(): + x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5])) + y = np.roll(x, echo_step) + y[0:echo_step] = 0 + + x = x.reshape((batch_size, -1)) # The first index changing slowest, subseries as rows + y = y.reshape((batch_size, -1)) + + return (x, y) + + batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length]) + + init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) + + state_per_layer_list = tf.unstack(init_state, axis=0) + rnn_tuple_state = tuple( + [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) + for idx in range(num_layers)] + ) + + # Forward passes + def lstm_cell(): + return tensorflow.contrib.rnn.LSTMBlockCell(state_size) + cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)], state_is_tuple=True) + states_series, current_state = tf.nn.dynamic_rnn(cell, tf.expand_dims(batchX_placeholder, -1), + initial_state=rnn_tuple_state) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + x, y = generateData() + _current_state = np.zeros((num_layers, 2, batch_size, state_size)) + + start_idx = 0 + end_idx = start_idx + truncated_backprop_length + + batchX = x[:, start_idx:end_idx] + + # Save current state for TVM + current_state_tvm = _current_state + + _current_state, _states_series = sess.run( + [current_state, states_series], + feed_dict={ + batchX_placeholder: batchX, + init_state: _current_state + }) + + # Organize results and corresponding names + tf_output = [_states_series] + + for c in _current_state: + tf_output.append(c.c) + tf_output.append(c.h) + + name = [states_series.name.split(':')[0]] + + for t in current_state: + name.append(t.c.name.split(':')[0]) + name.append(t.h.name.split(':')[0]) + + graph_def = sess.graph.as_graph_def(add_shapes=True) + + final_graph_def = graph_util.convert_variables_to_constants( + sess, + graph_def, + name) + + tvm_output = run_tvm_graph(final_graph_def, + [batchX.astype('float32'), current_state_tvm.astype('float32')], + ["Placeholder", "Placeholder_1"], out_names=name, + num_output=len(name), mode='vm', disabled_pass=["FoldScaleAxis"]) + + # Compare result + for i in range(len(tf_output)): + tvm.testing.assert_allclose( + tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + ####################################################################### # Main @@ -3903,3 +3986,5 @@ def test_forward_dynamic_input_shape(): # Test dynamic input shape test_forward_dynamic_input_shape() + + test_forward_dynmaic_rnn_lstmblockcell()