Skip to content

Commit

Permalink
add test for transpose oprator
Browse files Browse the repository at this point in the history
  • Loading branch information
cchung100m committed Aug 6, 2019
1 parent 01b1864 commit 1147371
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ def test_forward_split():
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')

#######################################################################
# transpose
# ---------


def _test_forward_transpose(ishape, axes=None):
data = np.random.uniform(size=ishape).astype(np.float32)

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

if axes is None:
out = array_ops.transpose(in_data)
else:
out = array_ops.transpose(in_data, perm=axes)

compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])


def test_forward_transpose():
_test_forward_transpose((2, 2))
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))


#######################################################################
# Pooling
# -------
Expand Down Expand Up @@ -823,6 +848,10 @@ def test_forward_ssd_mobilenet_v1():
if __name__ == '__main__':
# Split
test_forward_split()

# Transpose
test_forward_transpose()

# Transforms
test_forward_concatenation()
test_forward_pad()
Expand Down

0 comments on commit 1147371

Please sign in to comment.