diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index fbf67d080634f..6c5a9cd079afd 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -93,31 +93,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); CHECK(call != nullptr); - PrimExpr x = call->args[0]; - PrimExpr y = x; - DataType dtype = x.dtype(); - const char* opName = nullptr; - - if (!dtype.is_float()) { - LOG(FATAL) << "tan expects floating input"; - } - - if (dtype.bits() == 64) { - opName = "tan"; - } else if (dtype.bits() == 32) { - opName = "tanf"; - } else if (dtype.bits() == 16) { - opName = "tanf"; - y = cast(DataType::Float(32, dtype.lanes()), x); - } else { - LOG(FATAL) << "tan cannot handle float" << dtype.bits(); - } - - PrimExpr tan_x = tir::CallNode::make(x.dtype(), opName, {y}, tir::CallNode::Extern); + const PrimExpr& x = call->args[0]; + PrimExpr sin_x = tir::CallNode::make( + x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::CallNode::make( + x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 51ce3a93fa18e..3a48558ddd529 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -133,52 +133,6 @@ def check_device(device): test_isnan(-100, 100) -def test_ewise_tan(): - def test_apply( - func, - name, - f_numpy, - low, - high, - shape=(20, 3), - dtype='float32', - check_round=False, - skip_name_check=False, - ): - A = te.placeholder(dtype=dtype, name="A", shape=shape) - - B = func(A) - assert tuple(B.shape) == tuple(A.shape) - if not skip_name_check: - assert B.op.body[0].name == name - a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 - # avoid round check too close to boundary - if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 - b_np = f_numpy(a_np) - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.testing.get_injective_schedule(device)(B) - foo = tvm.build(s, [A, B], device, name=name) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros_like(b_np), ctx) - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - - for target in get_all_backend(): - check_device(target) - - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64') - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32') - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float16') - - def test_cast(): def verify(from_dtype, to_dtype, low=-100, high=100): shape = (5, 4)