diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b357a2fbff30..43e770c301d2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -532,6 +532,18 @@ def _impl(inputs, attr, params): return _op.concatenate(inputs_reshaped, axis) return _impl +def _tile(): + def _impl(inputs, attr, params): + reps = params[inputs.pop().name_hint].asnumpy() + new_input = [] + new_input.append(inputs.pop(0)) + + return AttrCvt( + op_name='tile', + extras={'reps': tuple(reps)}, + ignores=['Tmultiples'])(new_input, attr) + return _impl + def _slice(): def _impl(inputs, attr, params): begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() @@ -851,6 +863,15 @@ def _impl(inputs, attr, params): return AttrCvt(op_name="where")(inputs, attr) return _impl +def _reverse_v2(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint).asnumpy()[0] + return AttrCvt( + op_name="reverse", + ignores=['Tidx'], + extras={'axis': int(axis)})([inputs[0]], attr) + return _impl + def _rank(): def _impl(inputs, attr, params): input_shape = attr['_input_shapes'][inputs[0]] @@ -1078,6 +1099,7 @@ def _impl(inputs, attr, params): # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { + 'Add' : _elemwise('add'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), @@ -1090,60 +1112,65 @@ def _impl(inputs, attr, params): 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), 'DecodeJpeg' : _decode_image(), + 'DepthwiseConv2dNative' : _conv('depthwise'), + 'Equal' : _broadcast('equal'), 'Elu' : _elu(), + 'Exp' : AttrCvt('exp'), 'ExpandDims' : _expand_dims(), + 'Fill' : _fill(), 'Floor' : AttrCvt('floor'), + 'FusedBatchNorm' : _fused_batch_norm(), + 'FusedBatchNormV2' : _fused_batch_norm(), + 'Gather' : _gather(), + 'GatherV2' : _gather(), + 'Greater' : _broadcast('greater'), + 'GreaterEqual' : _broadcast('greater_equal'), 'Identity' : _identity(), + 'LeakyRelu' : AttrCvt('leaky_relu'), + 'Less' : _broadcast('less'), + 'LessEqual' : _broadcast('less_equal'), + 'LogicalAnd' : _logical('logical_and'), + 'LogicalOr' : _logical('logical_or'), + 'LogicalNot' : _logical('logical_not'), + 'LRN' : _lrn(), 'MatMul' : _matmul(), 'MaxPool' : _pooling('max_pool'), - 'Add' : _elemwise('add'), - 'Sub' : _elemwise('subtract'), - 'Mul' : _elemwise('multiply'), - 'RealDiv' : _elemwise('div'), 'Maximum' : _elemwise('maximum'), + 'Mean' : _mean(), 'Minimum' : _elemwise('minimum'), - 'Sum' : _sum(), - 'Square' : _square(), + 'Mul' : _elemwise('multiply'), + 'NotEqual' : _broadcast('not_equal'), 'Pack' : _pack(), - 'Slice' : _slice(), - 'LeakyRelu' : AttrCvt('leaky_relu'), + 'Pad' : _pad('Pad'), + 'PadV2' : _pad('PadV2'), + 'Pow' : _elemwise('power'), + 'Range' : _range(), + 'Rank' : _rank(), + 'RealDiv' : _elemwise('div'), 'Relu' : AttrCvt('relu'), + 'Relu6' : _relu6(), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), - 'Selu' : _selu(), - 'Softmax' : _softmax(), + 'ReverseV2' : _reverse_v2(), + 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), - 'Squeeze' : _squeeze(), - 'FusedBatchNorm' : _fused_batch_norm(), - 'FusedBatchNormV2' : _fused_batch_norm(), - 'Relu6' : _relu6(), - 'DepthwiseConv2dNative' : _conv('depthwise'), + 'Select' : _where(), + 'Selu' : _selu(), 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), - 'Select' : _where(), - 'Fill' : _fill(), - 'GatherV2' : _gather(), - 'Gather' : _gather(), - 'StridedSlice' : _stridedSlice(), - 'LRN' : _lrn(), - 'Pad' : _pad('Pad'), - 'PadV2' : _pad('PadV2'), - 'Range' : _range(), - 'Rank' : _rank(), - 'Transpose' : _transpose(), - 'Tanh' : AttrCvt('tanh'), - 'Mean' : _mean(), - 'LogicalAnd' : _logical('logical_and'), - 'LogicalOr' : _logical('logical_or'), - 'LogicalNot' : _logical('logical_not'), - 'Less' : _broadcast('less'), - 'Greater' : _broadcast('greater'), - 'LessEqual' : _broadcast('less_equal'), - 'GreaterEqual' : _broadcast('greater_equal'), - 'Equal' : _broadcast('equal'), - 'NotEqual' : _broadcast('not_equal'), + 'Sign' : AttrCvt('sign'), + 'Slice' : _slice(), + 'Softmax' : _softmax(), 'Split' : _split(False), 'SplitV' : _split(True), + 'Square' : _square(), + 'Squeeze' : _squeeze(), + 'StridedSlice' : _stridedSlice(), + 'Sub' : _elemwise('subtract'), + 'Sum' : _sum(), + 'Tanh' : AttrCvt('tanh'), + 'Tile' : _tile(), + 'Transpose' : _transpose(), 'Unpack' : _unpack(), 'SpaceToBatchND' : _space_to_batch_nd(), 'BatchToSpaceND' : _batch_to_space_nd(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7e7c1510c60b..6894c5d46210 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -762,6 +762,24 @@ def test_forward_unstack(): _test_unstack((3, 6, 4), -2, 'float32') +####################################################################### +# Tile +# ---- + +def _test_tile(in_shape, multiples, dtype): + np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.tile(in_data, multiples=multiples, name="tile") + compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0') + +def test_forward_tile(): + '''test Tile''' + _test_tile((2, ), (3, ), "int32") + _test_tile((2, 2), (2, 3), "float32") + _test_tile((2, 4, 6), (6, 7, 8), "float64") + + ####################################################################### # Multi Input to graph # -------------------- @@ -1353,6 +1371,53 @@ def test_forward_tanh(): tf.nn.tanh(in1) compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') +####################################################################### +# Tensor +# ------ + +def test_forward_round(): + """test Round""" + np_data = np.random.uniform(-10, 10, size=(5, 7)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (5, 7), name="in_data") + tf.round(in_data, name="round") + compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0') + +def _test_forward_reverse_v2(in_shape, axis, dtype): + np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.reverse(in_data, axis=[axis], name="reverse") + compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0') + +def test_forward_reverse_v2(): + """test ReverseV2""" + _test_forward_reverse_v2((2, 3), 0, "int32") + _test_forward_reverse_v2((2, 3, 5), 2, "float32") + _test_forward_reverse_v2((2, 3, 5, 7), 1, "float32") + _test_forward_reverse_v2((2, 3, 5), -1, "float64") + _test_forward_reverse_v2((2, 3, 5), -3, "float64") + +def test_forward_sign(): + """test Sign""" + np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data") + tf.sign(in_data, name="sign") + compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0') + +def test_forward_pow_exp(): + """test Pow""" + np_in1 = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32) + np_in2 = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32) + tf.reset_default_graph() + in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1") + in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2") + out1 = tf.pow(in1, in2, name="pow") + out = tf.exp(out1, name='exp') + compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0') + compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'exp:0') + ####################################################################### # Mean # ---- @@ -1394,6 +1459,7 @@ def test_forward_rel_ops(): # Main # ---- if __name__ == '__main__': + # Transforms test_forward_transpose() test_forward_reshape() @@ -1407,6 +1473,7 @@ def test_forward_rel_ops(): test_forward_stridedslice() test_forward_split() test_forward_unstack() + test_forward_tile() # Activations test_forward_sigmoid() @@ -1416,6 +1483,12 @@ def test_forward_rel_ops(): test_forward_selu() test_forward_tanh() + # Tensor + test_forward_round() + test_forward_reverse_v2() + test_forward_pow_exp() + test_forward_sign() + # Reductions test_forward_argminmax() test_forward_reduce()