From f01d7e9447bee7688f182921c1b5c30f0e732592 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Tue, 9 Apr 2019 17:42:23 +0300 Subject: [PATCH] Add tests for TF matmul op --- .../frontend/tensorflow/test_forward.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 83536fba754d0..a1cfff840d34f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -567,6 +567,38 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +####################################################################### +# MatMul +# ------ + +def _test_matmul(i, j, k, dtype, outer=None): + """ One iteration of matmul """ + + A_shape_init = [i, j] + B_shape_init = [j, k] + + for transpose_a in [False, True]: + for transpose_b in [False, True]: + outer = outer or [] + A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init) + B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init) + + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name='A') + B = tf.placeholder(shape=B_shape, dtype=dtype, name='B') + result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b) + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + +def test_forward_matmul(): + """ Matmul op test""" + _test_matmul(1, 3, 6, 'int32') + _test_matmul(5, 3, 1, 'float64') + # TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support + + ####################################################################### # StridedSlice # ------------ @@ -1744,3 +1776,6 @@ def test_placeholder(): test_forward_rel_ops() test_forward_logical() test_where() + + test_forward_matmul() + # TODO missing tests: rank, range \ No newline at end of file