Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM] Support argmax/argmin in tensorflow frontend #1514

Merged
merged 2 commits into from
Aug 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def _impl(inputs, attr, *args):
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
return _impl

def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params):
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_input_name = inputs[1].list_output_names()[0]
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
return _impl

def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
Expand Down Expand Up @@ -650,6 +664,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'ArgMax' : _argx(_sym.argmax, 'argmax'),
'ArgMin' : _argx(_sym.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
Expand Down Expand Up @@ -864,6 +880,28 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym


def _parse_import_prerequisites(graph):
""" Calculate the named preconditions from TensorFlow `graph`.
Return prerequisites for parsing:
a. Set of operator names which don't have their mapping in TVM, i.e.
which are not supported
"""
missing_operators = set()
for node in graph.node:
if node.op == "Placeholder":
pass
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
pass
else:
missing_operators.add(node.op)

return missing_operators


class GraphProto(object):
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -886,7 +924,7 @@ def from_tensorflow(self, graph):
Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.

-> First Const or Placeholder node will be considered as graph input.
-> First Placeholder or Const node will be considered as graph input.
-> Rest all Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
Expand All @@ -895,6 +933,7 @@ def from_tensorflow(self, graph):
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.

TODO: Change algorithm to stop treating first 'Const' in a special way.

Parameters
----------
Expand All @@ -908,23 +947,25 @@ def from_tensorflow(self, graph):
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
# Parse throught all nodes and start extracting
# params aka Const nodes
# input nodes : First const node
# normal nodes : other normal nodes

try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

missing_operators = _parse_import_prerequisites(graph)

if missing_operators:
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
input_shapes = {}
if node.op == "Placeholder":
# Assuming only one input graph with type 'Placeholder'
self._input_node = node.name
self._num_input += 1

Expand All @@ -939,7 +980,6 @@ def from_tensorflow(self, graph):
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
elif node.op == "Const":
# Assuming first Const node as Graph Input node
if self._input_node == '':
self._input_node = node.name
self._num_input += 1
Expand Down Expand Up @@ -982,7 +1022,7 @@ def from_tensorflow(self, graph):
# Pass the node name too in attr
attr["_node_name"] = node.name

#ToDo: Some of the tensorflow operators maintain internaly maintain
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
Expand Down
31 changes: 31 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,37 @@ def test_forward_sigmoid():

_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))

#######################################################################
# Argmin/Argmax
# -------------

def _test_argx(func, data, **kwargs):

with tf.Graph().as_default():
inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")

# pylint: disable=unused-variable
out = func(inp, name="argx0", **kwargs)
# pylint: enable=unused-variable

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph.as_graph_def(add_shapes=True),
output_node_names=["argx0"])

tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')

np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

sess.close()

def test_argmin_argmax():
for axis in [None,0,1,2]:
data = np.random.uniform(size=(8,4,9)).astype('float32')
_test_argx(tf.argmax, data=data, axis=axis)
_test_argx(tf.argmin, data=data, axis=axis)

#######################################################################
# Variable
Expand Down