diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 71536517810d..ce047d27da07 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -513,6 +513,7 @@ def _impl(inputs, attr, params): 'Relu6' : _relu6(), 'DepthwiseConv2dNative' : _depthwise_conv(), 'Shape' : _shape(), + 'Sigmoid' : AttrCvt('sigmoid'), } diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 6dc8cfab2ab4..f3dbc3cdbe40 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -13,6 +13,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops from tensorflow.core.framework import graph_pb2 import nnvm.testing.tf @@ -329,6 +330,42 @@ def _test_forward_concat_v2(): _test_concat_v2([t1, t2], 1) +####################################################################### +# Sigmoid +# ------- + +def _test_sigmoid(data): + """ One iteration of sigmoid """ + + with tf.Graph().as_default(): + in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype) + + # pylint: disable=unused-variable + sigmoid_out = math_ops.sigmoid(in_data) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['Sigmoid'], + ) + + tf_output = run_tf_graph(sess, data, + 'Const:0', 'Sigmoid:0') + tvm_output = run_tvm_graph(graph_def, + data, + "Const", tf_output.shape, data.dtype) + + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + + sess.close() + +def test_forward_sigmoid(): + """ Sigmoid """ + + _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32')) + ####################################################################### # Multi Input to graph # -------------------- @@ -437,6 +474,7 @@ def test_forward_mobilenet(): test_forward_pooling() test_forward_reshape() test_forward_squeeze() + test_forward_sigmoid() if tf.__version__ == '1.4.1': _test_forward_concat_v2() test_forward_multi_input()