From bd8cb9b49595d2ed315e96a0a1c30d22543a1361 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 12 Jun 2018 13:31:16 -0700 Subject: [PATCH] [INTRIN] Add support for floor and ceil (#1267) --- cmake/config.cmake | 3 +++ include/tvm/ir_operator.h | 2 ++ python/tvm/intrin.py | 32 ++++++++++++++++++++++++ src/codegen/intrin_rule_cuda.cc | 6 +++++ src/codegen/intrin_rule_metal.cc | 6 +++++ src/codegen/intrin_rule_opencl.cc | 6 +++++ src/codegen/intrin_rule_opengl.cc | 6 +++++ src/codegen/llvm/intrin_rule_llvm.cc | 6 +++++ src/codegen/llvm/intrin_rule_rocm.cc | 6 +++++ src/codegen/spirv/build_vulkan.cc | 3 --- src/codegen/spirv/codegen_spirv.cc | 5 ---- src/codegen/spirv/intrin_rule_spirv.cc | 13 +++++++--- src/codegen/spirv/ir_builder.cc | 5 ---- topi/include/topi/elemwise.h | 2 ++ topi/python/topi/math.py | 34 ++++++++++++++++++++++++++ topi/tests/python/test_topi_math.py | 18 ++++++++------ 16 files changed, 128 insertions(+), 25 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index b4896da79ef13..db7d800e918a4 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -38,6 +38,9 @@ set(USE_METAL OFF) # Whether enable Vulkan runtime set(USE_VULKAN OFF) +# Whether enable OpenGL runtime +set(USE_OPENGL OFF) + # Whether enable RPC runtime set(USE_RPC ON) diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index ee349c092ee68..9d72b655c484f 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(log); +TVM_DECLARE_INTRIN_UNARY(floor); +TVM_DECLARE_INTRIN_UNARY(ceil); inline Expr pow(Expr x, Expr y) { return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index f8f65e25aa68a..9be7502b26eea 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -233,6 +233,38 @@ def sqrt(x): return call_pure_intrin(x.dtype, "sqrt", x) +def floor(x): + """Take floor of float input x. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return call_pure_intrin(x.dtype, "floor", x) + + +def ceil(x): + """Take ceil of float input x. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return call_pure_intrin(x.dtype, "ceil", x) + + def power(x, y): """x power y diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index 1d199fe5af281..dccbcf8ca2325 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -55,6 +55,12 @@ struct CUDAShuffle { } }; +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") .set_body(DispatchExtern); diff --git a/src/codegen/intrin_rule_metal.cc b/src/codegen/intrin_rule_metal.cc index b0e41770ebff1..3de5cafc69c69 100644 --- a/src/codegen/intrin_rule_metal.cc +++ b/src/codegen/intrin_rule_metal.cc @@ -9,6 +9,12 @@ namespace tvm { namespace codegen { namespace intrin { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") .set_body(DispatchExtern); diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index b8b2412215d10..cdb432e5462f6 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -9,6 +9,12 @@ namespace tvm { namespace codegen { namespace intrin { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") .set_body(DispatchExtern); diff --git a/src/codegen/intrin_rule_opengl.cc b/src/codegen/intrin_rule_opengl.cc index 6ae2ee5d2b4e8..e9728a25b40cf 100644 --- a/src/codegen/intrin_rule_opengl.cc +++ b/src/codegen/intrin_rule_opengl.cc @@ -9,6 +9,12 @@ namespace tvm { namespace codegen { namespace intrin { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor") +.set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") .set_body(DispatchExtern); diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index 2e6bf061eb5e5..e5eff1f67ff51 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { Expr e = targs[0]; diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc index 38211db0b9b1f..4546ee4f039b4 100644 --- a/src/codegen/llvm/intrin_rule_rocm.cc +++ b/src/codegen/llvm/intrin_rule_rocm.cc @@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { namespace llvm { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") .set_body(DispatchExternOCML); diff --git a/src/codegen/spirv/build_vulkan.cc b/src/codegen/spirv/build_vulkan.cc index 719b1f30e2679..2d4b35daa006a 100644 --- a/src/codegen/spirv/build_vulkan.cc +++ b/src/codegen/spirv/build_vulkan.cc @@ -3,8 +3,6 @@ * \file build_vulkan.cc * \brief Build SPIRV block */ -#if TVM_VULKAN_RUNTIME - // Use libspirv for parsing and validating code. #include #include @@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan") } // namespace codegen } // namespace tvm -#endif // TVM_VULKAN_RUNTIME diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index d9a06ab13d346..b6582ec6c0d8e 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -3,9 +3,6 @@ * \file codegen_spirv.cc * \brief Generate SPIRV block */ - -#if TVM_VULKAN_RUNTIME - #include #include #include "../codegen_common.h" @@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) { } // namespace codegen } // namespace tvm - -#endif // TVM_VULKAN_RUNTIME diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index cef98266ca9e0..1f9c56c561f84 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -2,8 +2,6 @@ * Copyright (c) 2017 by Contributors * \file intrin_rule_spirv.cc */ -#if TVM_VULKAN_RUNTIME - #include #include #include @@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic); } +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") +.set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil") +.set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") .set_body(DispatchGLSLPureIntrin); @@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt") TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") .set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh") +.set_body(DispatchGLSLPureIntrin); + } // namespace spirv } // namespace codegen } // namespace tvm - -#endif // TVM_VULKAN_RUNTIME diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 2bd1d0d5da7c7..f3e044596e5c4 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -3,9 +3,6 @@ * \file ir_builder.cc * \brief IRBuilder for SPIRV block */ - -#if TVM_VULKAN_RUNTIME - #include "./ir_builder.h" namespace tvm { @@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) { } // namespace spirv } // namespace codegen } // namespace tvm - -#endif // TVM_VULKAN_RUNTIME diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index c3797197710bb..7500ce0c66b0e 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(log); +TOPI_DECLARE_UNARY_OP(floor); +TOPI_DECLARE_UNARY_OP(ceil); /*! * \brief Creates an operation that returns identity of a given tensor diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 2a1de9972bfe9..588c306136d9f 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -73,6 +73,40 @@ def tanh(x): return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) +@tvm.tag_scope(tag=tag.ELEMWISE) +def floor(x): + """Take floor of input x. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.floor(x(*i))) + + +@tvm.tag_scope(tag=tag.ELEMWISE) +def ceil(x): + """Take ceil of input x. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i))) + + @tvm.tag_scope(tag=tag.ELEMWISE) def log(x): """Take logarithm of input x. diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index dc3c015d4f25e..b9305e6e6fdf4 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -18,12 +18,11 @@ def test_ewise(): shape = (20, 3) - def test_apply(func, name, f_numpy): + def test_apply(func, name, f_numpy, low, high): B = func(A) assert tuple(B.shape) == tuple(A.shape) assert B.op.body[0].name == name - a_np = np.random.uniform(low=1e-5, size=shape).astype(A.dtype) - a_np = np.abs(a_np) + a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 b_np = f_numpy(a_np) def check_device(device): @@ -43,11 +42,14 @@ def check_device(device): for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']: check_device(device) - test_apply(topi.exp, "exp", np.exp) - test_apply(topi.tanh, "tanh", np.tanh) - test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x))) - test_apply(topi.log, "log", np.log) - test_apply(topi.sqrt, "sqrt", np.sqrt) + + test_apply(topi.floor, "floor", np.floor, -100, 100) + test_apply(topi.ceil, "ceil", np.ceil, -100, 100) + test_apply(topi.exp, "exp", np.exp, -1, 1) + test_apply(topi.tanh, "tanh", np.tanh, -10, 10) + test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1) + test_apply(topi.log, "log", np.log, 0, 100) + test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) if __name__ == "__main__": test_util()