Skip to content

Commit

Permalink
Add tests for TF matmul op
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed May 17, 2019
1 parent ba748ba commit f01d7e9
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,38 @@ def test_forward_variable():
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))


#######################################################################
# MatMul
# ------

def _test_matmul(i, j, k, dtype, outer=None):
""" One iteration of matmul """

A_shape_init = [i, j]
B_shape_init = [j, k]

for transpose_a in [False, True]:
for transpose_b in [False, True]:
outer = outer or []
A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)

with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)

A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)

def test_forward_matmul():
""" Matmul op test"""
_test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64')
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support


#######################################################################
# StridedSlice
# ------------
Expand Down Expand Up @@ -1744,3 +1776,6 @@ def test_placeholder():
test_forward_rel_ops()
test_forward_logical()
test_where()

test_forward_matmul()
# TODO missing tests: rank, range

0 comments on commit f01d7e9

Please sign in to comment.