Skip to content

Commit

Permalink
[Relay] add ClipByValue and Neg in tf frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed May 21, 2019
1 parent 78a0f47 commit 924245a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ def _impl(inputs, attr, params):
return AttrCvt(op_name="where")(inputs, attr)
return _impl

def _clip_by_value():
def _impl(inputs, attr, params):
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl

def _reverse_v2():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
Expand Down Expand Up @@ -1190,6 +1197,7 @@ def _impl(inputs, attr, params):
'Cast' : _cast(),
'Ceil' : AttrCvt('ceil'),
'CheckNumerics' : _check_numerics(),
'ClipByValue' : _clip_by_value(),
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
Expand Down Expand Up @@ -1223,6 +1231,7 @@ def _impl(inputs, attr, params):
'Mean' : _mean(),
'Minimum' : _elemwise('minimum'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
'Pack' : _pack(),
'Pad' : _pad('Pad'),
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,14 @@ def test_forward_log():
tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')

def test_forward_negative():
"""test operator Neg """
np_data = np.random.uniform(-100, 100, size=(224, 224, 3)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
tf.negative(in_data, name="negative")
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')

def test_forward_softplus():
"""test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
Expand Down Expand Up @@ -1708,6 +1716,7 @@ def test_placeholder():
test_forward_pow_exp()
test_forward_sign()
test_forward_log()
test_forward_negative()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt()
Expand Down

0 comments on commit 924245a

Please sign in to comment.