diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 0d76cf2afbdee..ae564c8917569 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -269,6 +269,24 @@ def test_forward_slice(): _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1]) _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1]) +####################################################################### +# Topk +# ---- +def _test_topk(in_shape, k=1): + """ One iteration of TOPK """ + data = np.random.uniform(size=in_shape).astype('float32') + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.top_k(in_data, k, name='TopK') + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]]) + +def test_forward_topk(): + """ TOPK """ + _test_topk((3,), 1) + _test_topk((3,), 3) + _test_topk((3, 5, 7), 3) + _test_topk((3, 5, 7), 3) + ####################################################################### # transpose # --------- @@ -1738,6 +1756,7 @@ def test_forward_mediapipe_hand_landmark(): test_all_resize() test_forward_squeeze() test_forward_slice() + test_forward_topk() # NN test_forward_convolution()