Skip to content

Commit

Permalink
[Relay][TensorFlow] Support tf.math.reduce_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
Li Xiaoquan committed May 10, 2019
1 parent 95a323a commit 617e50d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,15 @@ def _impl(inputs, attr, params):

return _impl


def _prod():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
keepdims = attr['keep_dims']
return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1136,6 +1145,7 @@ def _impl(inputs, attr, params):
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
'Prod' : _prod(),
'Range' : _range(),
'Rank' : _rank(),
'RealDiv' : _elemwise('div'),
Expand Down
20 changes: 19 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def is_gpu_available():
else:
return False


#######################################################################
# Pooling
# -------
Expand Down Expand Up @@ -1509,6 +1508,24 @@ def test_forward_expand_dims():
_test_forward_expand_dims(np.array([[1], [2]]), 1)
_test_forward_expand_dims(np.array([[1], [2]]), -1)

#######################################################################
# Prod
# ----
def _test_forward_reduce_prod(shape, axis, keepdims):
inp_array1 = np.random.uniform(-5, 5, size=shape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
out = tf.math.reduce_prod(in1, axis, keepdims)
compare_tf_with_tvm(inp_array1, in1.name, out.name)

def test_forward_reduce_prod():
_test_forward_reduce_prod((5,), 0, False)
_test_forward_reduce_prod((5, 5), 0, False)
_test_forward_reduce_prod((5, 5), 1, False)
_test_forward_reduce_prod((5,), 0, True)
_test_forward_reduce_prod((5, 5), 0, True)
_test_forward_reduce_prod((5, 5), 1, True)

#######################################################################
# Main
# ----
Expand Down Expand Up @@ -1550,6 +1567,7 @@ def test_forward_expand_dims():
test_forward_argminmax()
test_forward_reduce()
test_forward_mean()
test_forward_reduce_prod()

# General
test_forward_multi_input()
Expand Down

0 comments on commit 617e50d

Please sign in to comment.