diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7fe82ea7eac17..87cb91523a3f3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1082,6 +1082,15 @@ def _impl(inputs, attr, params): return _get_relay_op('log')(add_out) return _impl +def _topk(): + def _impl(inputs, attr, params): + k = params.pop(inputs.pop(1).name_hint).asnumpy() + is_ascend = False if attr['sorted'] else None # TODO: arg value for original ordering + return AttrCvt(op_name='topk', + ignores=['sorted'], + extras={'k': int(k), 'is_ascend': is_ascend})(inputs, attr) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -1271,6 +1280,7 @@ def _impl(inputs, attr, params): 'Sum' : _sum(), 'Tanh' : AttrCvt('tanh'), 'Tile' : _tile(), + 'TopKV2' : _topk(), 'Transpose' : _transpose(), 'Unpack' : _unpack(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 023cdf5eb2615..59428d34f5cd2 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -754,6 +754,26 @@ def test_forward_split(): _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') +###################################################################### +# TopKV2 +# ------ + +def _test_forward_top_k_v2(in_shape, k, sorted, dtype): + np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.math.top_k(in_data, k, sorted, name='TopK') + compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0') + +def test_forward_top_k_v2(): + #_test_forward_top_k_v2((3,), 0, True, 'int32') + _test_forward_top_k_v2((3,), 1, True, 'float32') + _test_forward_top_k_v2((3,), 3, True, 'float32') + _test_forward_top_k_v2((3, 5, 7), 3, True, 'float32') + #_test_forward_top_k_v2((3, 5, 13), 11, False, 'float32') + + ####################################################################### # Unstack # ------- @@ -1704,6 +1724,7 @@ def test_placeholder(): test_forward_split() test_forward_unstack() test_forward_tile() + test_forward_top_k_v2() # Activations test_forward_sigmoid()