Skip to content

Commit

Permalink
[Relay][Frontend] Support tf.where
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Apr 1, 2019
1 parent eb1ed11 commit 1dc9e96
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
17 changes: 11 additions & 6 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,16 +674,16 @@ def _impl(inputs, attr, params):
return _impl

def _gather_v2():
"Tensorflow now support only gatherv2"
"Tensorflow now supports only GatherV2"
def _impl(inputs, attr, params):
axis = params[inputs.pop(2).name_hint].asnumpy()[0]
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return _impl

def _infer_out_shapes(inputs, params):
Expand Down Expand Up @@ -815,7 +815,6 @@ def _impl(inputs, attr, params):
ignores=['Tpaddings'],)(new_inputs, attr)
return _impl


def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
Expand All @@ -828,6 +827,11 @@ def _impl(inputs, attr, params):
return _op.transpose(inputs[0], axes=axes)
return _impl

def _where():
def _impl(inputs, attr, params):
return AttrCvt(op_name="where")(inputs, attr)
return _impl

def _rank():
def _impl(inputs, attr, params):
input_shape = attr['_input_shapes'][inputs[0]]
Expand Down Expand Up @@ -1012,6 +1016,7 @@ def _impl(inputs, attr, params):
'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Select' : _where(),
'Fill' : _fill(),
'GatherV2' : _gather_v2(),
'StridedSlice' : _stridedSlice(),
Expand Down
24 changes: 21 additions & 3 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,13 @@ def test_forward_stridedslice():
# ------

def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
""" One iteration of a Gather """
""" One iteration of a GatherV2"""

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis)
np_data = np.random.uniform(size=ip_shape).astype(dtype)
np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)

def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
Expand All @@ -497,7 +497,7 @@ def _fill_indices(indice_value):
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')

def test_forward_gather():
'''test gather layer'''
'''test GatherV2 layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
_test_gather((1,4), (1,), [0], 0, 'int32')
Expand All @@ -509,6 +509,7 @@ def test_forward_gather():
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')


#######################################################################
# Split
# -----
Expand Down Expand Up @@ -825,6 +826,22 @@ def test_forward_logical():
test_logical_not()


#######################################################################
# Where, Select
# --------------------
def test_where():
''' Where: return elements depending on conditions'''
with tf.Graph().as_default():
with tf.Session() as sess:
input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1')
input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2')
mask = input1 > input2
tf.where(mask, input1 + 1, input2 * 2)
in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0')


#######################################################################
# Inception V3
# ------------
Expand Down Expand Up @@ -1260,3 +1277,4 @@ def test_forward_rel_ops():
# Relational ops
test_forward_rel_ops()
test_forward_logical()
test_where()

0 comments on commit 1dc9e96

Please sign in to comment.