diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5ac0de4335f7f..449c539dd5f83 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -674,16 +674,16 @@ def _impl(inputs, attr, params): return _impl def _gather_v2(): - "Tensorflow now support only gatherv2" + "Tensorflow now supports only GatherV2" def _impl(inputs, attr, params): axis = params[inputs.pop(2).name_hint].asnumpy()[0] new_input = [] new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0)) - return AttrCvt(op_name="take", - extras={'axis': tvm.const(axis, 'int32')}, - ignores=['Tindices', 'Tparams', 'validate_indices', \ - 'Taxis', '_class'])(new_input, attr) + return AttrCvt(op_name="take", + extras={'axis': tvm.const(axis, 'int32')}, + ignores=['Tindices', 'Tparams', 'validate_indices', \ + 'Taxis', '_class'])(new_input, attr) return _impl def _infer_out_shapes(inputs, params): @@ -815,7 +815,6 @@ def _impl(inputs, attr, params): ignores=['Tpaddings'],)(new_inputs, attr) return _impl - def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, @@ -828,6 +827,11 @@ def _impl(inputs, attr, params): return _op.transpose(inputs[0], axes=axes) return _impl +def _where(): + def _impl(inputs, attr, params): + return AttrCvt(op_name="where")(inputs, attr) + return _impl + def _rank(): def _impl(inputs, attr, params): input_shape = attr['_input_shapes'][inputs[0]] @@ -1012,6 +1016,7 @@ def _impl(inputs, attr, params): 'DepthwiseConv2dNative' : _conv('depthwise'), 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), + 'Select' : _where(), 'Fill' : _fill(), 'GatherV2' : _gather_v2(), 'StridedSlice' : _stridedSlice(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 10368ea3d9aba..573a3c82ea0fa 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -477,13 +477,13 @@ def test_forward_stridedslice(): # ------ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): - """ One iteration of a Gather """ + """ One iteration of a GatherV2""" tf.reset_default_graph() in_data = tf.placeholder(dtype, ip_shape, name="in_data") indices = tf.placeholder("int32", indice_shape, name="indices") tf.gather(in_data, indices, axis=axis) - np_data = np.random.uniform(size=ip_shape).astype(dtype) + np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype) def _fill_indices(indice_value): indices = np.array(ip_shape, dtype=dtype) @@ -497,7 +497,7 @@ def _fill_indices(indice_value): compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0') def test_forward_gather(): - '''test gather layer''' + '''test GatherV2 layer''' _test_gather((4,), (1,), 1, 0, 'int32') _test_gather((4,), (1,), 1, 0, 'float32') _test_gather((1,4), (1,), [0], 0, 'int32') @@ -509,6 +509,7 @@ def test_forward_gather(): _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32') _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') + ####################################################################### # Split # ----- @@ -825,6 +826,22 @@ def test_forward_logical(): test_logical_not() +####################################################################### +# Where, Select +# -------------------- +def test_where(): + ''' Where: return elements depending on conditions''' + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") + in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32") + compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0') + + ####################################################################### # Inception V3 # ------------ @@ -1260,3 +1277,4 @@ def test_forward_rel_ops(): # Relational ops test_forward_rel_ops() test_forward_logical() + test_where()