Skip to content

Commit

Permalink
[NNVM][TENSORFLOW] Cleanup redundant code. (apache#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Aug 4, 2018
1 parent 136061d commit 7b59b8e
Showing 1 changed file with 83 additions and 157 deletions.
240 changes: 83 additions & 157 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
custom_check=_dimension_constraint())(inputs, attr)
return _impl

def _conv():
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")

# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = conv_param_weights.shape[3]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = conv_param_weights.shape[1]
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")

if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]

pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)

if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]

else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'

return AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl

def _depthwise_conv():
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
Expand All @@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = input_shapes[0][3] * depth_mult
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[3]
else:
attr['channels'] = input_shapes[0][3] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = input_shapes[0][1] * depth_mult
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
else:
attr['channels'] = input_shapes[0][1] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))


if opname == 'depthwise':
attr['groups'] = attr['channels']

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix groups
attr['groups'] = attr['channels']

# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")

Expand Down Expand Up @@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
if opname == 'conv':
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

return AttrCvt(
op_name=_dimension_picker('conv'),
Expand Down Expand Up @@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
'CheckNumerics' : _check_numerics(),
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv(),
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'ExpandDims' : _expand_dims(),
'Identity' : _identity(),
Expand All @@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Fill' : _fill(),
Expand Down Expand Up @@ -895,28 +833,6 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym


def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators


class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -925,12 +841,8 @@ class GraphProto(object):
def __init__(self):
self._nodes = {}
self._params = {}
self._renames = {}
self._replacements = {}
self._output_shapes = {}
self._num_input = 0
self._num_param = 0
self._input_node = ''
self._num_rnn_layer = False

def from_tensorflow(self, graph):
Expand Down Expand Up @@ -969,7 +881,7 @@ def from_tensorflow(self, graph):
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

missing_operators = _parse_import_prerequisites(graph)
missing_operators = self._parse_import_prerequisites(graph)

if missing_operators:
raise NotImplementedError( \
Expand All @@ -979,58 +891,42 @@ def from_tensorflow(self, graph):
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}

attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

if node.op == "Placeholder":
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])

try:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._output_shapes[node.name][0])
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
elif node.op == "Const":
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name)
else:
# Rest all nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name)
if node.name not in self._nodes:
raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name))
attr = self._parse_attr(node.attr)
#Variable converted to Const will not have only value attr
if 'value' in attr:
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
else:
# All Const nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
self._parse_param(key, value, node.name)
if node.name not in self._nodes:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
else:
"Const {} couldn't be converted to Param.".format(node.name))

attr = self._parse_attr(node.attr)
try:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]

Expand All @@ -1045,11 +941,12 @@ def from_tensorflow(self, graph):
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name

# Fill shapes for all inputs in a list
try:
inputs = [self._nodes[i] for i in node.input]
for i in node.input:
if i not in self._params:
input_shapes[self._nodes[i]] = self._output_shapes[i]
input_shapes[self._nodes[i]] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes
except KeyError:
# TODO: Need to find clean way to handle '^CheckNumerics'
Expand All @@ -1061,18 +958,40 @@ def from_tensorflow(self, graph):
# Assuming only one output.
self._nodes[node.name] = op
node_output = op

# Assume the final node is the output node
out = node_output

#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out = [out, out_rnn]

if isinstance(out, list):
out = _sym.Group(out)

return out, self._params

def _parse_import_prerequisites(self, graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators

def _parse_param(self, key, value, name):
try:
from tensorflow.python.framework import tensor_util
Expand All @@ -1082,6 +1001,13 @@ def _parse_param(self, key, value, name):

if key == 'value':
np_array = tensor_util.MakeNdarray(value.tensor)

if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self._nodes[name] = _sym.Variable(name=name)
return

array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
Expand Down

0 comments on commit 7b59b8e

Please sign in to comment.