Skip to content

Commit

Permalink
[Relay][Frontend] TF Tile Round Sign Pow Exp Reverse (#2960)
Browse files Browse the repository at this point in the history
* [Relay][Frontend] TF Round Sign Pow Exp Reverse

* fix ci

* fix comments
  • Loading branch information
yongwww authored and srkreddy1238 committed Apr 19, 2019
1 parent 50484b3 commit 85a3ea0
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 37 deletions.
101 changes: 64 additions & 37 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,18 @@ def _impl(inputs, attr, params):
return _op.concatenate(inputs_reshaped, axis)
return _impl

def _tile():
def _impl(inputs, attr, params):
reps = params[inputs.pop().name_hint].asnumpy()
new_input = []
new_input.append(inputs.pop(0))

return AttrCvt(
op_name='tile',
extras={'reps': tuple(reps)},
ignores=['Tmultiples'])(new_input, attr)
return _impl

def _slice():
def _impl(inputs, attr, params):
begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
Expand Down Expand Up @@ -851,6 +863,15 @@ def _impl(inputs, attr, params):
return AttrCvt(op_name="where")(inputs, attr)
return _impl

def _reverse_v2():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
return AttrCvt(
op_name="reverse",
ignores=['Tidx'],
extras={'axis': int(axis)})([inputs[0]], attr)
return _impl

def _rank():
def _impl(inputs, attr, params):
input_shape = attr['_input_shapes'][inputs[0]]
Expand Down Expand Up @@ -1078,6 +1099,7 @@ def _impl(inputs, attr, params):
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'Add' : _elemwise('add'),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
Expand All @@ -1090,60 +1112,65 @@ def _impl(inputs, attr, params):
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Exp' : AttrCvt('exp'),
'ExpandDims' : _expand_dims(),
'Fill' : _fill(),
'Floor' : AttrCvt('floor'),
'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(),
'Gather' : _gather(),
'GatherV2' : _gather(),
'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(),
'LeakyRelu' : AttrCvt('leaky_relu'),
'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'LRN' : _lrn(),
'MatMul' : _matmul(),
'MaxPool' : _pooling('max_pool'),
'Add' : _elemwise('add'),
'Sub' : _elemwise('subtract'),
'Mul' : _elemwise('multiply'),
'RealDiv' : _elemwise('div'),
'Maximum' : _elemwise('maximum'),
'Mean' : _mean(),
'Minimum' : _elemwise('minimum'),
'Sum' : _sum(),
'Square' : _square(),
'Mul' : _elemwise('multiply'),
'NotEqual' : _broadcast('not_equal'),
'Pack' : _pack(),
'Slice' : _slice(),
'LeakyRelu' : AttrCvt('leaky_relu'),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Pow' : _elemwise('power'),
'Range' : _range(),
'Rank' : _rank(),
'RealDiv' : _elemwise('div'),
'Relu' : AttrCvt('relu'),
'Relu6' : _relu6(),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
'Selu' : _selu(),
'Softmax' : _softmax(),
'ReverseV2' : _reverse_v2(),
'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(),
'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'Select' : _where(),
'Selu' : _selu(),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
'Select' : _where(),
'Fill' : _fill(),
'GatherV2' : _gather(),
'Gather' : _gather(),
'StridedSlice' : _stridedSlice(),
'LRN' : _lrn(),
'Pad' : _pad('Pad'),
'PadV2' : _pad('PadV2'),
'Range' : _range(),
'Rank' : _rank(),
'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'),
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
'Sign' : AttrCvt('sign'),
'Slice' : _slice(),
'Softmax' : _softmax(),
'Split' : _split(False),
'SplitV' : _split(True),
'Square' : _square(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(),
'Transpose' : _transpose(),
'Unpack' : _unpack(),
'SpaceToBatchND' : _space_to_batch_nd(),
'BatchToSpaceND' : _batch_to_space_nd(),
Expand Down
73 changes: 73 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,24 @@ def test_forward_unstack():
_test_unstack((3, 6, 4), -2, 'float32')


#######################################################################
# Tile
# ----

def _test_tile(in_shape, multiples, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.tile(in_data, multiples=multiples, name="tile")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')

def test_forward_tile():
'''test Tile'''
_test_tile((2, ), (3, ), "int32")
_test_tile((2, 2), (2, 3), "float32")
_test_tile((2, 4, 6), (6, 7, 8), "float64")


#######################################################################
# Multi Input to graph
# --------------------
Expand Down Expand Up @@ -1353,6 +1371,53 @@ def test_forward_tanh():
tf.nn.tanh(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0')

#######################################################################
# Tensor
# ------

def test_forward_round():
"""test Round"""
np_data = np.random.uniform(-10, 10, size=(5, 7)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7), name="in_data")
tf.round(in_data, name="round")
compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')

def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.reverse(in_data, axis=[axis], name="reverse")
compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')

def test_forward_reverse_v2():
"""test ReverseV2"""
_test_forward_reverse_v2((2, 3), 0, "int32")
_test_forward_reverse_v2((2, 3, 5), 2, "float32")
_test_forward_reverse_v2((2, 3, 5, 7), 1, "float32")
_test_forward_reverse_v2((2, 3, 5), -1, "float64")
_test_forward_reverse_v2((2, 3, 5), -3, "float64")

def test_forward_sign():
"""test Sign"""
np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
tf.sign(in_data, name="sign")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')

def test_forward_pow_exp():
"""test Pow"""
np_in1 = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
np_in2 = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1")
in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2")
out1 = tf.pow(in1, in2, name="pow")
out = tf.exp(out1, name='exp')
compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'exp:0')

#######################################################################
# Mean
# ----
Expand Down Expand Up @@ -1394,6 +1459,7 @@ def test_forward_rel_ops():
# Main
# ----
if __name__ == '__main__':

# Transforms
test_forward_transpose()
test_forward_reshape()
Expand All @@ -1407,6 +1473,7 @@ def test_forward_rel_ops():
test_forward_stridedslice()
test_forward_split()
test_forward_unstack()
test_forward_tile()

# Activations
test_forward_sigmoid()
Expand All @@ -1416,6 +1483,12 @@ def test_forward_rel_ops():
test_forward_selu()
test_forward_tanh()

# Tensor
test_forward_round()
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()

# Reductions
test_forward_argminmax()
test_forward_reduce()
Expand Down

0 comments on commit 85a3ea0

Please sign in to comment.