Skip to content

Commit

Permalink
[Frontend][TF] Fix Placeholder issue (apache#2834)
Browse files Browse the repository at this point in the history
* [Frontend][TF] Fix Placeholder issue

* Add test cases
  • Loading branch information
yongwww authored and wweic committed May 13, 2019
1 parent 3da8e9a commit fc30bd3
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 33 deletions.
28 changes: 18 additions & 10 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _impl(inputs, attr, params):

def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
op_name = _math_name_picker(name)(attr)
return get_nnvm_op(op_name)(*inputs)
return _impl
Expand Down Expand Up @@ -1237,16 +1237,24 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

for node in graph.node:
if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
# Give priority to user argument.
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)
else:
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

self._nodes[node.name] = _sym.Variable(name=node.name,
shape=self._input_shapes[node.name])
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._outputs_are_0d[node.name] = [ \
not tshape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
Expand Down Expand Up @@ -1304,7 +1312,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

attr = self._parse_attr(node.attr)

else:
elif node.op != "Placeholder":
# Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

Expand Down
24 changes: 24 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,29 @@ def test_forward_resnetv2():
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
# Placeholder
# -----------
def test_forward_placeholder():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("Custom/placeholder.pb")

# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)


data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'mul'

with tf.Session() as sess:
# Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
# PTB
# ---
Expand Down Expand Up @@ -1262,6 +1285,7 @@ def test_forward_rel_ops():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_placeholder()
test_forward_ptb()

# RNN
Expand Down
43 changes: 20 additions & 23 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _impl(inputs, attr, params):

def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs)
return _impl

Expand Down Expand Up @@ -1704,16 +1704,23 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
# Give priority to user argument.
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)
else:
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

self._output_shapes[node.name] = [self._input_shapes[node.name]]
attr = self._parse_attr(node.attr)
self._nodes[node.name] = [_expr.var(node.name,
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
Expand All @@ -1736,11 +1743,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif node.op == 'Placeholder':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
Expand All @@ -1755,13 +1757,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

if node.op == "Placeholder":
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]

elif node.op == "Const":
if node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items():
Expand All @@ -1772,7 +1768,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

attr = self._parse_attr(node.attr)

else:
elif node.op != "Placeholder":
# Pass the parsed shapes instead
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

Expand Down Expand Up @@ -1816,7 +1812,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
if node_name in self._outputs_are_0d \
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,27 @@ def test_forward_resnetv2():
tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
# Placeholder
# -----------
def test_forward_placeholder():
'''test a simple pb with Placeholder node in the end of GraphDef'''
with tf.Graph().as_default():
graph_def = tf_testing.get_workload("Custom/placeholder.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'mul'

with tf.Session() as sess:
# Add shapes to the graph.
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
# PTB
# ---
Expand Down Expand Up @@ -1514,6 +1535,7 @@ def test_forward_rel_ops():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_placeholder()
test_forward_ptb()

# RNN
Expand Down

0 comments on commit fc30bd3

Please sign in to comment.