Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[topi][relay] add operation tan to TVM #4938

Merged
merged 8 commits into from
Mar 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Supported Ops
- ConcatV2
- Conv2D
- Cos
- Tan
- CropAndResize
- DecodeJpeg
- DepthwiseConv2dNative
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(atan);
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"ones_like",
"where",
"gather_nd",
"tan",
"cos",
"sin"
]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,7 @@ def _impl(inputs, attr, params):
'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'),
'Log1p' : _log1p(),
'Tan' : AttrCvt('tan'),
'Cos' : AttrCvt('cos'),
'Sin' : AttrCvt('sin'),
'LogicalAnd' : _logical('logical_and'),
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOG': self.convert_log,
'SIN': self.convert_sin,
'COS': self.convert_cos,
'TAN': self.convert_tan,
'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt,
'NEG': self.convert_neg,
Expand Down Expand Up @@ -657,6 +658,13 @@ def convert_sin(self, op):
'TFlite quantized SIN operator is not supported yet.')
return self._convert_unary_elemwise(_op.sin, op)

def convert_tan(self, op):
"""Convert TFLite TAN"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized TAN operator is not supported yet.')
return self._convert_unary_elemwise(_op.tan, op)

def convert_cos(self, op):
"""Convert TFLite COS"""
if self.is_quantized(op):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


register_broadcast_schedule("log")
register_broadcast_schedule("tan")
kevinthesun marked this conversation as resolved.
Show resolved Hide resolved
register_broadcast_schedule("cos")
register_broadcast_schedule("sin")
register_broadcast_schedule("atan")
Expand Down Expand Up @@ -214,3 +215,4 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func)
7 changes: 7 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x]


@register_gradient("tan")
def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]"""
x = orig.args[0]
return [grad / (cos(x) * cos(x))]


@register_gradient("cos")
def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]"""
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def log(data):
"""
return _make.log(data)

def tan(data):
"""Compute elementwise tan of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.tan(data)

def cos(data):
"""Compute elementwise cos of data.

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from tvm.tir import comm_reducer, min, max, sum
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,22 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)

def tan(x):
"""Take tan of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "tan", x)


def cos(x):
"""Take cos of input x.

Expand Down
11 changes: 11 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));


RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise.

.. math::
Y = tan(X)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));


RELAY_REGISTER_UNARY_OP("cos")
.describe(R"code(Returns the cos of input array, computed element-wise.

Expand Down
3 changes: 3 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr sin_x = tir::CallNode::make(
x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
PrimExpr cos_x = tir::CallNode::make(
x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
PrimExpr tan_x = sin_x / cos_x;
*rv = tan_x;
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

Expand Down
3 changes: 3 additions & 0 deletions src/target/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice);

Expand Down
3 changes: 3 additions & 0 deletions src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
.set_body(DispatchExternOCML);

Expand Down
19 changes: 19 additions & 0 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath {
}
};

struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float()) {
switch (t.bits()) {
case 64: return name;
// `__tanf` seems to produce some values too deviant from numpy tan version.
// So, let's use just `tanf` instead.
case 32: return name + 'f';
case 16: LOG(FATAL) << "cuda tan unsupported for float16";
default: return "";
}
}
return "";
}
};

struct CUDAPopcount {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
Expand Down Expand Up @@ -97,6 +113,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMathTan>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) {

const char* CallNode::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
"log", "sin", "cos", "pow", tir::CallNode::shift_left, tir::CallNode::shift_right,
"log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right,
tir::CallNode::likely, tir::CallNode::popcount
};

Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,6 +2628,15 @@ def test_forward_cos():
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')


def test_forward_tan():
"""test operator tan """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.tan(in_data, name="tan")
compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0')


def test_forward_sin():
"""test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
Expand Down Expand Up @@ -3031,6 +3040,7 @@ def test_forward_add_n():
test_forward_sign()
test_forward_log()
test_forward_log1p()
test_forward_tan()
test_forward_cos()
test_forward_sin()
test_forward_negative()
Expand Down
8 changes: 8 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,13 @@ def _test_cos(data):
""" One iteration of cos """
return _test_unary_elemwise(math_ops.cos, data)
#######################################################################
# Tan
# ---

def _test_tan(data):
""" One iteration of tan """
return _test_unary_elemwise(math_ops.tan, data)
#######################################################################
# Sqrt
# ----

Expand Down Expand Up @@ -772,6 +779,7 @@ def test_all_unary_elemwise():
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_tan)

#######################################################################
# Element-wise
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def check_single_op(opfunc, ref):
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
check_single_op(opfunc, ref)

Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def check_single_op(opfunc, ref, dtype):
(relay.nn.relu, relu),
(tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin),
(tvm.relay.tan, np.tan),
(tvm.relay.atan, np.arctan)]:
for dtype in ['float16', 'float32']:
check_single_op(opfunc, ref, dtype)
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_check_numerical_grads():
lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))),
]

# Avoid values too close to 0 since singularities of our functions are there
Expand Down
1 change: 1 addition & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ def tanh(x):
return te.compute(x.shape, lambda *i: te.tanh(x(*i)))


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def tan(x):
"""Take tan of input x.

Parameters
----------
x : tvm.te.Tensor
Input argument.

Returns
-------
y : tvm.te.Tensor
The result.
"""
return te.compute(x.shape, lambda *i: te.tan(x(*i)))


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def cos(x):
"""Take cos of input x.
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
*rv = erf(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
});

TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_apply(func, name):
test_apply(topi.rsqrt, "rsqrt")
test_apply(topi.sin, "sin")
test_apply(topi.cos, "cos")
test_apply(topi.tan, "tan")
test_apply(topi.atan, "atan")


Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def check_device(device):
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)
Expand Down