From 0b58be5ddcc6b1da3ba234d7abf93ba0369e3fde Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Tue, 9 Apr 2019 17:42:23 +0300 Subject: [PATCH] Add tests for TF matmul op --- .../frontend/tensorflow/test_forward.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 92edac732d935..466fe49c6e750 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -532,6 +532,39 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +####################################################################### +# MatMul +# ------ + +# TODO add tests for more than 2 dimensions +def _test_matmul(i, j, k, dtype, outer=None): + """ One iteration of matmul """ + + A_shape_init = [i, j] + B_shape_init = [j, k] + + for transpose_a in [False, True]: + for transpose_b in [False, True]: + outer = outer or [] + A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init) + B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init) + + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name='A') + B = tf.placeholder(shape=B_shape, dtype=dtype, name='B') + result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b) + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + +def test_forward_matmul(): + """ Matmul op test""" + _test_matmul(1, 3, 6, 'int32') + _test_matmul(5, 3, 1, 'float64') + # TODO non-empty outer requires BatchMatMul (and BatchMatMulV2?) support + + ####################################################################### # StridedSlice # ------------ @@ -1428,3 +1461,6 @@ def test_forward_rel_ops(): test_forward_rel_ops() test_forward_logical() test_where() + + test_forward_matmul() + # TODO missing tests: rank, range \ No newline at end of file