Skip to content

Commit

Permalink
* all good except LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Dec 6, 2018
1 parent f11cfa9 commit 90b99db
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 28 deletions.
322 changes: 303 additions & 19 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,8 @@ def _impl(inputs, attr, params):

def _infer_out_shapes(inputs, params):
"""A method to get the output shape of an intermediate node in the NNVM graph."""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
out_type = ir_pass.infer_type(inputs)
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
return out_shapes

def _stridedSlice():
Expand Down Expand Up @@ -923,6 +922,250 @@ def _impl(inputs, attr, params):
'NotEqual' : _broadcast('not_equal'),
}

def _LSTMBlockCell():
def _impl(inputs, in_state_c, in_state_h, attr, params):
"""LSTM Block cell.
Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
Parameters
----------
inputs : nnvm.Symbol
Input data
in_state_c: list of nnvm.Symbol
Cell state input values for all the layers
in_state_h: list of nnvm.Symbol
Hidden state input values for all the layers
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
output: nnvm.Symbol
Output state value.
"""
in_data = inputs[0]
in_weight = inputs[3]
in_bias = inputs[7]
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
batch_size, input_size = input_shape[0][0], input_shape[0][1]
num_hidden_layers = weight_shape[0][1]
num_hidden = num_hidden_layers // 4

in_data = _op.reshape(in_data,
newshape=(batch_size, input_size))
ixh = _op.concatenate([in_data, in_state_h], axis=1)
in_weight = _op.transpose(in_weight, axes=None)
gates = _op.nn.dense(ixh, in_weight,
units=num_hidden_layers)
gates_bias = _op.add(gates, in_bias)
gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1)
in_gate = _op.sigmoid(gate_list[0])
in_transform = _op.tanh(gate_list[1])
forget_gate = _op.sigmoid(gate_list[2])
forget_gate = _op.add(forget_gate, tvm.relay.const(forget_bias))
out_gate = _op.sigmoid(gate_list[3])
next_c = _op.add(_op.multiply(forget_gate, in_state_c),
_op.multiply(in_gate, in_transform))
next_h = out_gate * _op.tanh(next_c)
out_state = _op.concatenate([next_c, next_h], axis=1)
out_state = _op.reshape(out_state,
newshape=(2, batch_size, num_hidden))
return next_h, out_state
return _impl

# _convert_map_rnn defines maps of rnn operator name to
# converter functor(callable) for 1 to 1 mapping.
_convert_map_rnn = {
'LSTMBlockCell' : _LSTMBlockCell(),
}

class RecurrentNetworks(object):
"""Recurrent network layer handlers.
Handle Layer operations.
ToDo: Operators like RNN/GRU layer concepts also can be handled here
Parameters
----------
nodes : list
list of graph nodes used for tensorflow parsing.
out_rnn : list
List of RecurrentNetwork outputs. This output will be appended to the
'head' nodes of the graph.
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
"""
def __init__(self, nodes, out_rnn, graph, convert_map):
self._graph = graph
self._convert_map = convert_map
self._nodes = nodes
self._out_rnn = out_rnn
self._cur_lstm_layer = 0
self._layer_name_list = []
self._recurrent_ops_layer_map = {
'LSTMBlockCell' : self._LSTMBlockCellLayer(),
}

def _LSTMBlockCellLayer(self):
"""LSTMBlockCell layer handler.
Parameters
----------
op_name : str
Operator name, eg:LSTMBlockCell
layer_name : str list
Layer name is used for creating the state input placeholder.
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
num_layers : int
Total number of LSTM layer presented in the graph
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _impl(op_name, layer_name, inputs, attrs, params, num_layers):
in_state_c_name = layer_name+'_c'
in_state_h_name = layer_name+'_h'

def _init_state(num_layers, batch_size, num_hidden):
"""Create the initial states for the first layer in the graph."""
in_state_c = [_expr.var(in_state_c_name,
shape=(num_layers, batch_size, num_hidden),
dtype='float32')]

in_state_h = [_expr.var(in_state_h_name,
shape=(num_layers, batch_size, num_hidden),
dtype='float32')]
return in_state_c, in_state_h

def _get_cur_input_state(in_state_c, in_state_h, num_layers,
layer, batch_size, num_hidden):
"""Select the appropriate states for the current layer"""
in_state_c_tup = _op.split(in_state_c[0],
indices_or_sections=num_layers, axis=0)
in_state_h_tup = _op.split(in_state_h[0],
indices_or_sections=num_layers, axis=0)
cur_in_state_c = _op.reshape(in_state_c_tup[layer],
newshape=(batch_size, num_hidden))
cur_in_state_h = _op.reshape(in_state_h_tup[layer],
newshape=(batch_size, num_hidden))
return cur_in_state_c, cur_in_state_h

def _LSTMBlockCellWrapper(inputs, attr, params,
num_layers, layer):
"""LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]

batch_size = input_shape[0][0]
num_hidden = weight_shape[0][1] // 4

if layer == 0:
#Create initial states placeholder in case of first layer
in_state_c, in_state_h = _init_state(num_layers,
batch_size, num_hidden)
else:
in_state_c = self._nodes[in_state_c_name]
in_state_h = self._nodes[in_state_h_name]

cur_in_state_c, cur_in_state_h = _get_cur_input_state( \
in_state_c, in_state_h,
num_layers, layer,
batch_size, num_hidden)
output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
cur_in_state_h,
attr, params)
return output, out_state, in_state_c, in_state_h

sym, cur_out_state, in_state_c, in_state_h = \
_LSTMBlockCellWrapper(inputs, attrs, params,
num_layers, self._cur_lstm_layer)
self._nodes[in_state_c_name] = in_state_c
self._nodes[in_state_h_name] = in_state_h
cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1)
self._out_rnn.append(cur_out_state)
self._cur_lstm_layer += 1
return sym
return _impl

def process_op(self, op_name, inputs, attrs, params):
"""Process recurrent layer operators.
List '_recurrent_ops_layer_map' map each Layer based operators with its
layer handlers. Total number of layers are calculated to form the input
data shapes.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : nnvm.Symbol
Input data
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
Returns
-------
sym : nnvm.sym.Symbol
The returned nnvm symbol
"""
def _get_abs_layer_name(node):
"""Identify the layer name is already handled. Return the absolute name
"""
if not self._layer_name_list:
self._layer_name_list.append(node.name)
return node.name

for _name in self._layer_name_list:
if _name in node.name:
abs_name = _name
else:
self._layer_name_list.append(node.name)
abs_name = node.name
return abs_name

#Find number of layers of this same operator node in the graph
#and also read the inputs name for the current op.
num_layers = 0
for _, node in enumerate(self._graph.node):
if node.op == op_name:
layer_name = _get_abs_layer_name(node)
num_layers += 1

sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
params, num_layers)
return sym

class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -933,6 +1176,7 @@ def __init__(self):
self._params = {}
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -997,8 +1241,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
Expand Down Expand Up @@ -1051,7 +1295,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
input_shapes[self._nodes[i][0]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes

#inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)

# Check is op is converted to param
Expand All @@ -1075,7 +1318,18 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
out_type = ir_pass.infer_type(node_output[0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]

out = op

out = []
if outputs is None:
out = op
else:
out = [self._nodes[out_name][0] for out_name in outputs]

#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _op.concatenate(self._out_rnn, axis=0)
out.append(out_rnn)

out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(ir_pass.free_vars(out), out)

Expand All @@ -1094,7 +1348,7 @@ def _parse_import_prerequisites(self, graph):
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map]]):
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)
Expand Down Expand Up @@ -1185,6 +1439,42 @@ def _parse_attr(self, attr_proto):

return attrs

def _convert_rnn_operator(self, op_name, inputs,
attrs, params, graph, convert_map):
"""Convert RNN and its variant operators to NNVM operators.
This converter read the input states of each layers and
also maintain the output states of each layer in a list.
Parameters
----------
op_name : str
Operator name, such as LSTMBlockCell
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
params : dict
List of pretrained weights and bias
graph : Tensorflow graph object
Graph is to find the number of upcoming same operator to
calculate the number of layers.
convert_map : dict
Dict of name : callable, where name is the op's name that
require conversion to nnvm, callable are functions which
take attrs and return (new_op_name, new_attrs)
Returns
-------
sym : nnvm.Symbol
Converted nnvm Symbol
"""
if not self._num_rnn_layer:
self._out_rnn = []
self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
self._num_rnn_layer = True
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym

def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down Expand Up @@ -1213,25 +1503,19 @@ def _convert_operator(self, op_name, inputs, attrs,
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
convert_map_rnn = _convert_map_rnn
if op_name in identity_list:
sym = _get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym

def _fix_extranodes(self, op_name, attr, inputs):
if op_name == "Softmax":
# Require some times flatten of data before it goes to softmax
# Need to relook into this with latest softmax axis support.
op = AttrCvt(op_name='batch_flatten')(inputs, {})
node_output = op.name_hint
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
inputs = [op]

return inputs

def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
""" Load tensorflow graph which is a python tensorflow graph object into relay.
Expand Down
Loading

0 comments on commit 90b99db

Please sign in to comment.