From 1147371a133d356bdf1648c50bd9a7c578e8092a Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 7 Aug 2019 00:25:32 +0800 Subject: [PATCH] add test for transpose oprator --- tests/python/frontend/tflite/test_forward.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 2c356d8c9156e..3e9ac1a8fd29b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -201,6 +201,31 @@ def test_forward_split(): _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') +####################################################################### +# transpose +# --------- + + +def _test_forward_transpose(ishape, axes=None): + data = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + if axes is None: + out = array_ops.transpose(in_data) + else: + out = array_ops.transpose(in_data, perm=axes) + + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_transpose(): + _test_forward_transpose((2, 2)) + _test_forward_transpose((2, 3, 4)) + _test_forward_transpose((7, 8, 8, 10)) + + ####################################################################### # Pooling # ------- @@ -823,6 +848,10 @@ def test_forward_ssd_mobilenet_v1(): if __name__ == '__main__': # Split test_forward_split() + + # Transpose + test_forward_transpose() + # Transforms test_forward_concatenation() test_forward_pad()