From b1d8cf23ac1ac6b5d8dffc8b788761caa7c96e97 Mon Sep 17 00:00:00 2001 From: notoraptor Date: Wed, 4 Mar 2020 12:07:55 -0500 Subject: [PATCH] Fix implementation of tan in cuda. Do not support tan for float16. Simplify topi/tests/python/test_topi_math. Add testing for tan with float32 and float64. Finally implement tan as sin/cos in llvm. --- src/target/llvm/intrin_rule_llvm.cc | 28 ++++----------- src/target/source/intrin_rule_cuda.cc | 18 +++++++++- topi/tests/python/test_topi_math.py | 49 ++------------------------- 3 files changed, 25 insertions(+), 70 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index fbf67d080634..6c5a9cd079af 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -93,31 +93,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); CHECK(call != nullptr); - PrimExpr x = call->args[0]; - PrimExpr y = x; - DataType dtype = x.dtype(); - const char* opName = nullptr; - - if (!dtype.is_float()) { - LOG(FATAL) << "tan expects floating input"; - } - - if (dtype.bits() == 64) { - opName = "tan"; - } else if (dtype.bits() == 32) { - opName = "tanf"; - } else if (dtype.bits() == 16) { - opName = "tanf"; - y = cast(DataType::Float(32, dtype.lanes()), x); - } else { - LOG(FATAL) << "tan cannot handle float" << dtype.bits(); - } - - PrimExpr tan_x = tir::CallNode::make(x.dtype(), opName, {y}, tir::CallNode::Extern); + const PrimExpr& x = call->args[0]; + PrimExpr sin_x = tir::CallNode::make( + x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::CallNode::make( + x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index fc62dae7c9f9..849e098d4ad2 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath { } }; +struct CUDAFastMathTan : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.lanes() == 1 && t.is_float()) { + switch (t.bits()) { + case 64: return name; + // `__tanf` seems to produce some values too deviant from numpy tan version. + // So, let's use just `tanf` instead. + case 32: return name + 'f'; + case 16: LOG(FATAL) << "cuda tan unsupported for float16"; + default: return ""; + } + } + return ""; + } +}; + struct CUDAPopcount { std::string operator()(DataType t, std::string name) const { if (t.lanes() == 1 && t.is_uint()) { @@ -98,7 +114,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") -.set_body(DispatchExtern); +.set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") .set_body(DispatchExtern); diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 51ce3a93fa18..3e58518ed4fe 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -127,58 +127,13 @@ def check_device(device): test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi) + test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32') + test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64') test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32") test_isnan(-100, 100) -def test_ewise_tan(): - def test_apply( - func, - name, - f_numpy, - low, - high, - shape=(20, 3), - dtype='float32', - check_round=False, - skip_name_check=False, - ): - A = te.placeholder(dtype=dtype, name="A", shape=shape) - - B = func(A) - assert tuple(B.shape) == tuple(A.shape) - if not skip_name_check: - assert B.op.body[0].name == name - a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 - # avoid round check too close to boundary - if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 - b_np = f_numpy(a_np) - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.testing.get_injective_schedule(device)(B) - foo = tvm.build(s, [A, B], device, name=name) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros_like(b_np), ctx) - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - - for target in get_all_backend(): - check_device(target) - - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64') - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32') - test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float16') - - def test_cast(): def verify(from_dtype, to_dtype, low=-100, high=100): shape = (5, 4)