Skip to content

Commit

Permalink
[Relay] Add TopK in tf converter
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored and icemelon committed Jun 3, 2019
1 parent c553dbb commit 8402367
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,15 @@ def _impl(inputs, attr, params):
return _get_relay_op('log')(add_out)
return _impl

def _topk():
def _impl(inputs, attr, params):
k = params.pop(inputs.pop(1).name_hint).asnumpy()
is_ascend = False if attr['sorted'] else None # TODO: arg value for original ordering
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': int(k), 'is_ascend': is_ascend})(inputs, attr)
return _impl

def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
Expand Down Expand Up @@ -1271,6 +1280,7 @@ def _impl(inputs, attr, params):
'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(),
'Unpack' : _unpack(),

Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,26 @@ def test_forward_split():
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')


######################################################################
# TopKV2
# ------

def _test_forward_top_k_v2(in_shape, k, sorted, dtype):
np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)

tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.math.top_k(in_data, k, sorted, name='TopK')
compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')

def test_forward_top_k_v2():
#_test_forward_top_k_v2((3,), 0, True, 'int32')
_test_forward_top_k_v2((3,), 1, True, 'float32')
_test_forward_top_k_v2((3,), 3, True, 'float32')
_test_forward_top_k_v2((3, 5, 7), 3, True, 'float32')
#_test_forward_top_k_v2((3, 5, 13), 11, False, 'float32')


#######################################################################
# Unstack
# -------
Expand Down Expand Up @@ -1704,6 +1724,7 @@ def test_placeholder():
test_forward_split()
test_forward_unstack()
test_forward_tile()
test_forward_top_k_v2()

# Activations
test_forward_sigmoid()
Expand Down

0 comments on commit 8402367

Please sign in to comment.