diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 0a714d8fc8c8..a30c3c989322 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -508,16 +508,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); } \ TVM_DECLARE_INTRIN_UNARY(exp); +TVM_DECLARE_INTRIN_UNARY(exp2); +TVM_DECLARE_INTRIN_UNARY(exp10); TVM_DECLARE_INTRIN_UNARY(erf); TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(log); +TVM_DECLARE_INTRIN_UNARY(log2); +TVM_DECLARE_INTRIN_UNARY(log10); TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(tan); TVM_DECLARE_INTRIN_UNARY(cos); +TVM_DECLARE_INTRIN_UNARY(cosh); TVM_DECLARE_INTRIN_UNARY(sin); +TVM_DECLARE_INTRIN_UNARY(sinh); TVM_DECLARE_INTRIN_UNARY(atan); namespace tir { diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index aa5871aa636a..fa244ac72103 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,9 @@ from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace -from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil +from .op import exp, exp2, exp10, log, log2, log10 +from .op import cos, sin, cosh, sinh, tan, tanh, atan +from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index c5b1a0a1cbb5..d82724f43efc 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -330,6 +330,38 @@ def exp(x): return call_pure_intrin(x.dtype, "exp", x) +def exp2(x): + """Calculate 2**x + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "exp2", x) + + +def exp10(x): + """Calculate 10**x + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "exp10", x) + + def erf(x): """Take gauss error function of the input x. @@ -393,6 +425,38 @@ def log(x): """ return call_pure_intrin(x.dtype, "log", x) + +def log2(x): + """Take log2 of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "log2", x) + + +def log10(x): + """Take log10 of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "log10", x) + def tan(x): """Take tan of input x. @@ -424,6 +488,23 @@ def cos(x): """ return call_pure_intrin(x.dtype, "cos", x) + +def cosh(x): + """Take cosh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "cosh", x) + + def sin(x): """Take sin of input x. @@ -439,6 +520,23 @@ def sin(x): """ return call_pure_intrin(x.dtype, "sin", x) + +def sinh(x): + """Take sin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "sinh", x) + + def atan(x): """Take atan of input x. diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 6c5a9cd079af..880a0fe58000 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -35,12 +35,35 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = tir::CallNode::make( + x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + *rv = ret; +}); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); @@ -108,9 +131,45 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = tir::CallNode::make( + x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make( + x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx + exp_negx) / two; + *rv = ret; +}); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = tir::CallNode::make( + x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::CallNode::make( + x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx - exp_negx) / two; + *rv = ret; +}); + } // namespace llvm } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 6d41d132724b..0dc1272d7d49 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2") +.set_body(DispatchExternLibDevice); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf") .set_body(DispatchExternLibDevice); @@ -72,6 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2") +.set_body(DispatchExternLibDevice); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt") .set_body(DispatchExternLibDevice); @@ -87,9 +99,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan") .set_body(DispatchExternLibDevice); diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 4e6a661f298d..3699c9f691b3 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -62,6 +62,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") .set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") .set_body(DispatchExternOCML); @@ -71,6 +77,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") .set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2") +.set_body(DispatchExternOCML); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") .set_body(DispatchExternOCML); @@ -86,9 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") .set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") .set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") .set_body(DispatchExternOCML); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 849e098d4ad2..d009110e0a6c 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -107,21 +107,39 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan") .set_body(DispatchExtern); diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 83514faeee58..8bc87d2b280f 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") .set_body(DispatchExtern); @@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh") +.set_body(DispatchExtern); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index fcad11515eb0..1a4f52e4dfd1 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") .set_body(DispatchExtern); @@ -63,6 +75,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") +.set_body(DispatchExtern); + // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension struct IntelShuffle { diff --git a/src/target/source/intrin_rule_opengl.cc b/src/target/source/intrin_rule_opengl.cc index 78416473e517..1710d45d8bd6 100644 --- a/src/target/source/intrin_rule_opengl.cc +++ b/src/target/source/intrin_rule_opengl.cc @@ -36,9 +36,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh") .set_body(DispatchExtern); @@ -51,6 +63,18 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh") +.set_body(DispatchExtern); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index 28e686102751..41e76f260ff4 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -45,9 +45,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh") .set_body(DispatchExtern); @@ -60,6 +72,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh") +.set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/tests/python/unittest/test_tvm_intrin.py b/tests/python/unittest/test_tvm_intrin.py index 5bb1c6538750..0054273e6210 100644 --- a/tests/python/unittest/test_tvm_intrin.py +++ b/tests/python/unittest/test_tvm_intrin.py @@ -45,5 +45,33 @@ def test_nearbyint(): a_rounded.asnumpy(), np.rint(a.asnumpy())) +def test_unary_intrin(): + test_funcs = [ + (tvm.tir.exp10, lambda x : np.power(10, x)), + (tvm.tir.log2, lambda x : np.log2(x)), + (tvm.tir.log10, lambda x : np.log10(x)), + (tvm.tir.sinh, lambda x : np.sinh(x)), + (tvm.tir.cosh, lambda x : np.cosh(x)), + ] + def run_test(tvm_intrin, np_func): + m = te.var("m",) + A = te.placeholder((m,), name='A') + B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name='B') + s = te.create_schedule(B.op) + f = tvm.build(s, [A, B], "llvm") + ctx = tvm.cpu(0) + n = 10 + a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + b = tvm.nd.array( \ + np.random.uniform(size=n).astype(A.dtype), ctx) + f(a, b) + tvm.testing.assert_allclose( + b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5) + + for func in test_funcs: + run_test(*func); + + if __name__ == "__main__": test_nearbyint() + test_unary_intrin()