Skip to content

Commit

Permalink
[Relay][TensorFlow Frontend] SoftPlus Sqrt (#3187)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored and yzhliu committed May 15, 2019
1 parent 20ddd2b commit 93c8017
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,16 @@ def _impl(inputs, attr, params):
transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl

def _softplus():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
def _impl(inputs, attr, params):
exp_out = AttrCvt('exp')(inputs, attr)
inputs.append(tvm.relay.const(1, attr['T'].name))
rh = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(exp_out, rh)
return _get_relay_op('log')(add_out)
return _impl

def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
Expand Down Expand Up @@ -1163,9 +1173,11 @@ def _impl(inputs, attr, params):
'Sign' : AttrCvt('sign'),
'Slice' : _slice(),
'Softmax' : _softmax(),
'Softplus' : _softplus(),
'SpaceToBatchND' : _space_to_batch_nd(),
'Split' : _split(False),
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
Expand Down
22 changes: 19 additions & 3 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,6 @@ def test_forward_placeholder():
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)

#######################################################################
Expand Down Expand Up @@ -1440,22 +1439,37 @@ def test_forward_pow_exp():
compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')

def test_forward_log():
"""test Log """
"""test operator Log """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')

def test_forward_softplus():
"""test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.nn.softplus(in_data, name="softplus")
compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')

def test_forward_rsqrt():
"""test Rsqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.rsqrt(in_data, name="rsqrt")
print(tf.get_default_graph().as_graph_def())
compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')

def test_forward_sqrt():
"""test Sqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.sqrt(in_data, name="sqrt")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')

#######################################################################
# Mean
# ----
Expand Down Expand Up @@ -1561,6 +1575,8 @@ def test_forward_reduce_prod():
test_forward_pow_exp()
test_forward_sign()
test_forward_log()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt()
test_forward_expand_dims()

Expand Down

0 comments on commit 93c8017

Please sign in to comment.