diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 419befb1a494..ee78a7e523e8 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -126,7 +126,7 @@ def _impl(inputs, attr, params): def _elemwise(name): def _impl(inputs, attr, *args): - assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) + assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) op_name = _math_name_picker(name)(attr) return get_nnvm_op(op_name)(*inputs) return _impl @@ -1217,16 +1217,24 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): for node in graph.node: if node.op == 'Placeholder': + # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) - continue - self._input_shapes[node.name] = \ - tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) - for idx, dim in enumerate(self._input_shapes[node.name]): - if dim < 0: - self._input_shapes[node.name][idx] = 1 - warnings.warn("Use 1 instead of -1 in shape of operator %s." - % node.name) + else: + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) + for idx, dim in enumerate(self._input_shapes[node.name]): + if dim < 0: + self._input_shapes[node.name][idx] = 1 + warnings.warn("Use 1 instead of -1 in shape of operator %s." + % node.name) + + self._nodes[node.name] = _sym.Variable(name=node.name, + shape=self._input_shapes[node.name]) + self._output_shapes[node.name] = [self._input_shapes[node.name]] + self._outputs_are_0d[node.name] = [ \ + not tshape if isinstance(tshape, list) else False \ + for tshape in self._output_shapes[node.name]] # Ignore user's input shape for Non placeholder elif node.op == 'Const': @@ -1250,11 +1258,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': self._output_shapes[node.name] = [self._input_shapes[node.name]] - elif shape and node.name in shape: - # Give priority to user argument. - self._output_shapes[node.name] = [shape[node.name]] - elif node.op == 'Placeholder': - self._output_shapes[node.name] = [self._input_shapes[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ @@ -1269,11 +1272,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): not tshape if isinstance(tshape, list) else False \ for tshape in self._output_shapes[node.name]] - if node.op == "Placeholder": - self._nodes[node.name] = _sym.Variable(name=node.name, - shape=self._input_shapes[node.name]) - - elif node.op == "Const": + if node.op == "Const": # All Const nodes are Param nodes, lets parse self._num_param += 1 for key, value in node.attr.items(): @@ -1284,7 +1283,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr = self._parse_attr(node.attr) - else: + elif node.op != "Placeholder": # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index da89dc09408d..ad581fbbdec1 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -940,6 +940,29 @@ def test_forward_resnetv2(): tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# Placeholder +# ----------- +def test_forward_placeholder(): + '''test a simple pb with Placeholder node in the end of GraphDef''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload("Custom/placeholder.pb") + + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + out_node = 'mul' + + with tf.Session() as sess: + # Add shapes to the graph. + graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) + tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') + tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') + print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output)) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + ####################################################################### # PTB # --- @@ -1261,6 +1284,7 @@ def test_forward_rel_ops(): test_forward_inception_v1() test_forward_mobilenet() test_forward_resnetv2() + test_forward_placeholder() test_forward_ptb() # RNN diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b357a2fbff30..c3234418c33e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -239,7 +239,7 @@ def _impl(inputs, attr, params): def _elemwise(name): def _impl(inputs, attr, *args): - assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) + assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return _get_relay_op(name)(*inputs) return _impl @@ -1677,16 +1677,23 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) if node.op == 'Placeholder': + # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) - continue - self._input_shapes[node.name] = \ - tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) - for idx, dim in enumerate(self._input_shapes[node.name]): - if dim < 0: - self._input_shapes[node.name][idx] = 1 - warnings.warn("Use 1 instead of -1 in shape of operator %s." - % node.name) + else: + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) + for idx, dim in enumerate(self._input_shapes[node.name]): + if dim < 0: + self._input_shapes[node.name][idx] = 1 + warnings.warn("Use 1 instead of -1 in shape of operator %s." + % node.name) + + self._output_shapes[node.name] = [self._input_shapes[node.name]] + attr = self._parse_attr(node.attr) + self._nodes[node.name] = [_expr.var(node.name, + shape=self._input_shapes[node.name], + dtype=attr['dtype'].name)] # Ignore user's input shape for Non placeholder elif node.op == 'Const': @@ -1709,11 +1716,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': self._output_shapes[node.name] = [self._input_shapes[node.name]] - elif shape and node.name in shape: - # Give priority to user argument. - self._output_shapes[node.name] = [shape[node.name]] - elif node.op == 'Placeholder': - self._output_shapes[node.name] = [self._input_shapes[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ @@ -1728,13 +1730,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): not shape if isinstance(tshape, list) else False \ for tshape in self._output_shapes[node.name]] - if node.op == "Placeholder": - self._output_shapes[node.name] = [self._input_shapes[node.name]] - self._nodes[node.name] = [_expr.var(node.name, - shape=self._input_shapes[node.name], - dtype=attr['dtype'].name)] - - elif node.op == "Const": + if node.op == "Const": # All Const nodes are Param nodes, lets parse self._num_param += 1 for key, value in node.attr.items(): @@ -1745,7 +1741,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr = self._parse_attr(node.attr) - else: + elif node.op != "Placeholder": # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] @@ -1789,7 +1785,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): input_shapes[in_sym[0]] = input_shape # This means the node is 1d in Relay and 0d in TF. # See `_expand_dims_0d_aware`. - if self._outputs_are_0d[node_name][tensor_slot] and input_shape: + if node_name in self._outputs_are_0d \ + and self._outputs_are_0d[node_name][tensor_slot] and input_shape: input_0d_mismatch.add(in_sym[0]) attr['_input_shapes'] = input_shapes diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7e7c1510c60b..460159a14b69 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1115,6 +1115,27 @@ def test_forward_resnetv2(): tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# Placeholder +# ----------- +def test_forward_placeholder(): + '''test a simple pb with Placeholder node in the end of GraphDef''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload("Custom/placeholder.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + out_node = 'mul' + + with tf.Session() as sess: + # Add shapes to the graph. + graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) + tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0') + tvm_output = run_tvm_graph(graph_def, data, 'Placeholder') + print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output)) + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) + ####################################################################### # PTB # --- @@ -1441,6 +1462,7 @@ def test_forward_rel_ops(): test_forward_inception_v1() test_forward_mobilenet() test_forward_resnetv2() + test_forward_placeholder() test_forward_ptb() # RNN