diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index b66cf60e3c04..143d995803d8 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -168,81 +168,7 @@ def _impl(inputs, attr, params): custom_check=_dimension_constraint())(inputs, attr) return _impl -def _conv(): - def _impl(inputs, attr, params): - attr['data_format'] = attr['data_format'].decode("utf-8") - - # Extract kernel shape from params - conv_param_weights = params[inputs[1].list_output_names()[0]] - - if attr['data_format'] == 'NHWC': - attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) - attr['channels'] = conv_param_weights.shape[3] - if 'dilations' in attr: - attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) - elif attr['data_format'] == 'NCHW': - attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) - attr['channels'] = conv_param_weights.shape[1] - if 'dilations' in attr: - attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) - else: - raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) - - # Fix strides - attr['strides'] = (attr['strides'][1], attr['strides'][2]) - - # Fix padding - input_shapes = attr['_input_shapes'][inputs[0]] - attr['padding'] = attr['padding'].decode("utf-8") - - if attr['padding'] == 'VALID': - attr['padding'] = [0, 0] - elif attr['padding'] == 'SAME': - stride_h, stride_w = attr['strides'] - kernel_h, kernel_w = attr['kernel_shape'] - if attr['data_format'] == 'NHWC': - in_h = input_shapes[0][1] - in_w = input_shapes[0][2] - else: - in_h = input_shapes[0][2] - in_w = input_shapes[0][3] - - pad_v = _get_pad_pair(in_h, kernel_h, stride_h) - pad_h = _get_pad_pair(in_w, kernel_w, stride_w) - - if attr['data_format'] == 'NHWC': - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) - else: - inputs[0] = _sym.pad(data=inputs[0], - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) - - attr['padding'] = [0, 0] - - else: - raise TypeError("Unsupported padding type : {}".format(attr['padding'])) - - if 'kernel_layout' not in attr: - attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' - - return AttrCvt( - op_name=_dimension_picker('conv'), - transforms={ - 'kernel_shape': 'kernel_size', - 'data_format': 'layout', - 'dilations': ('dilation', (0, 0)), - 'group': ('groups', 1)}, - extras={'use_bias': len(inputs) == 3}, - custom_check=_dimension_constraint())(inputs, attr) - return _impl - -def _depthwise_conv(): +def _conv(opname): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") input_shapes = attr['_input_shapes'][inputs[0]] @@ -253,24 +179,33 @@ def _impl(inputs, attr, params): if attr['data_format'] == 'NHWC': kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1]) - attr['channels'] = input_shapes[0][3] * depth_mult + if opname == 'conv': + attr['channels'] = conv_param_weights.shape[3] + else: + attr['channels'] = input_shapes[0][3] * depth_mult + if 'dilations' in attr: attr['dilations'] = (attr['dilations'][0], attr['dilations'][1]) elif attr['data_format'] == 'NCHW': depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3]) - attr['channels'] = input_shapes[0][1] * depth_mult + if opname == 'conv': + attr['channels'] = conv_param_weights.shape[1] + else: + attr['channels'] = input_shapes[0][1] * depth_mult + if 'dilations' in attr: attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) else: raise TypeError("Unsupported data format type : {}".format(attr['data_format'])) + + if opname == 'depthwise': + attr['groups'] = attr['channels'] + # Fix strides attr['strides'] = (attr['strides'][1], attr['strides'][2]) - # Fix groups - attr['groups'] = attr['channels'] - # Fix padding attr['padding'] = attr['padding'].decode("utf-8") @@ -308,7 +243,10 @@ def _impl(inputs, attr, params): raise TypeError("Unsupported padding type : {}".format(attr['padding'])) if 'kernel_layout' not in attr: - attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' + if opname == 'conv': + attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' + else: + attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' return AttrCvt( op_name=_dimension_picker('conv'), @@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): 'CheckNumerics' : _check_numerics(), 'Concat' : _concat(), 'ConcatV2' : _concatV2(), - 'Conv2D' : _conv(), + 'Conv2D' : _conv('conv'), 'DecodeJpeg' : _decode_image(), 'ExpandDims' : _expand_dims(), 'Identity' : _identity(), @@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): 'Squeeze' : _squeeze(), 'FusedBatchNorm' : _fused_batch_norm(), 'Relu6' : _relu6(), - 'DepthwiseConv2dNative' : _depthwise_conv(), + 'DepthwiseConv2dNative' : _conv('depthwise'), 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), 'Fill' : _fill(), @@ -895,28 +833,6 @@ def _get_abs_layer_name(node): params, num_layers) return sym - -def _parse_import_prerequisites(graph): - """ Calculate the named preconditions from TensorFlow `graph`. - Return prerequisites for parsing: - a. Set of operator names which don't have their mapping in TVM, i.e. - which are not supported - """ - missing_operators = set() - for node in graph.node: - if node.op == "Placeholder": - pass - elif node.op == "Const": - pass - else: - if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): - pass - else: - missing_operators.add(node.op) - - return missing_operators - - class GraphProto(object): """ A helper class for handling nnvm graph copying from Tensorflow GraphDef. Definition: @@ -925,12 +841,8 @@ class GraphProto(object): def __init__(self): self._nodes = {} self._params = {} - self._renames = {} - self._replacements = {} self._output_shapes = {} - self._num_input = 0 self._num_param = 0 - self._input_node = '' self._num_rnn_layer = False def from_tensorflow(self, graph): @@ -969,7 +881,7 @@ def from_tensorflow(self, graph): raise ImportError( "Unable to import tensorflow which is required {}".format(e)) - missing_operators = _parse_import_prerequisites(graph) + missing_operators = self._parse_import_prerequisites(graph) if missing_operators: raise NotImplementedError( \ @@ -979,58 +891,42 @@ def from_tensorflow(self, graph): for node in graph.node: # Tensorflow doesn't have seperate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. + input_shapes = {} + + attr = self._parse_attr(node.attr) + + #Variable converted to Const will not have only value attr + if 'value' in attr and node.op == 'Const': + tensor_value = attr['value'] + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList( \ + tensor_value.tensor_shape)] + elif '_output_shapes' in attr: + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList(shape) \ + for shape in attr['_output_shapes']] + else: + raise NotImplementedError( \ + "Please freeze the graph with add_shapes=True") + if node.op == "Placeholder": - self._input_node = node.name - self._num_input += 1 + self._nodes[node.name] = _sym.Variable(name=node.name, + shape=self._output_shapes[node.name][0]) - try: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in self._parse_attr(node.attr)['_output_shapes']] - self._nodes[node.name] = _sym.Variable(name=node.name, - shape=self._output_shapes[node.name][0]) - input_shapes[self._nodes[node.name]] = self._output_shapes[node.name] - except KeyError: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") + #input_shapes[self._nodes[node.name]] = self._output_shapes[node.name] elif node.op == "Const": - if self._input_node == '': - self._input_node = node.name - self._num_input += 1 - self._nodes[node.name] = _sym.Variable(name=node.name) - else: - # Rest all nodes are Param nodes, lets parse - self._num_param += 1 - for key, value in node.attr.items(): - self._parse_param(key, value, node.name) - if node.name not in self._nodes: - raise NotImplementedError( \ - "Const {} couldn't be converted to Param.".format(node.name)) - attr = self._parse_attr(node.attr) - #Variable converted to Const will not have only value attr - if 'value' in attr: - tensor_value = attr['value'] - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList( \ - tensor_value.tensor_shape)] - elif '_output_shapes' in attr: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in self._parse_attr(node.attr)['_output_shapes']] - else: + # All Const nodes are Param nodes, lets parse + self._num_param += 1 + for key, value in node.attr.items(): + self._parse_param(key, value, node.name) + if node.name not in self._nodes: raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") - else: + "Const {} couldn't be converted to Param.".format(node.name)) + attr = self._parse_attr(node.attr) - try: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(shape) \ - for shape in attr['_output_shapes']] - except KeyError: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") + else: # Pass the parsed shapes instead attr["_output_shapes"] = self._output_shapes[node.name] @@ -1045,11 +941,12 @@ def from_tensorflow(self, graph): if ":" in node.input[0]: in_name, _ = node.input[0].split(':') node.input[0] = in_name + + # Fill shapes for all inputs in a list try: inputs = [self._nodes[i] for i in node.input] for i in node.input: - if i not in self._params: - input_shapes[self._nodes[i]] = self._output_shapes[i] + input_shapes[self._nodes[i]] = self._output_shapes[i] attr['_input_shapes'] = input_shapes except KeyError: # TODO: Need to find clean way to handle '^CheckNumerics' @@ -1061,6 +958,7 @@ def from_tensorflow(self, graph): # Assuming only one output. self._nodes[node.name] = op node_output = op + # Assume the final node is the output node out = node_output @@ -1068,11 +966,32 @@ def from_tensorflow(self, graph): if self._num_rnn_layer: out_rnn = _sym.concatenate(*self._out_rnn, axis=0) out = [out, out_rnn] + if isinstance(out, list): out = _sym.Group(out) return out, self._params + def _parse_import_prerequisites(self, graph): + """ Calculate the named preconditions from TensorFlow `graph`. + Return prerequisites for parsing: + a. Set of operator names which don't have their mapping in TVM, i.e. + which are not supported + """ + missing_operators = set() + for node in graph.node: + if node.op == "Placeholder": + pass + elif node.op == "Const": + pass + else: + if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + pass + else: + missing_operators.add(node.op) + + return missing_operators + def _parse_param(self, key, value, name): try: from tensorflow.python.framework import tensor_util @@ -1082,6 +1001,13 @@ def _parse_param(self, key, value, name): if key == 'value': np_array = tensor_util.MakeNdarray(value.tensor) + + if np_array.dtype == np.dtype(object): + # Object types are generally tensorflow DT_STRING (DecodeJpeg op). + # Just leave it as placeholder. + self._nodes[name] = _sym.Variable(name=name) + return + array_ndim = len(np_array.shape) if array_ndim == 0: new_array = np.empty([1], dtype=np_array.dtype)