Skip to content

Commit

Permalink
[INTRIN] Add support for floor and ceil (apache#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent 30cdd7e commit bd8cb9b
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 25 deletions.
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ set(USE_METAL OFF)
# Whether enable Vulkan runtime
set(USE_VULKAN OFF)

# Whether enable OpenGL runtime
set(USE_OPENGL OFF)

# Whether enable RPC runtime
set(USE_RPC ON)

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);

inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,38 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)


def floor(x):
"""Take floor of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)


def ceil(x):
"""Take ceil of float input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)


def power(x, y):
"""x power y
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ struct CUDAShuffle {
}
};

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_opengl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ namespace tvm {
namespace codegen {
namespace intrin {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {

namespace llvm {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

Expand Down
3 changes: 0 additions & 3 deletions src/codegen/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
* \file build_vulkan.cc
* \brief Build SPIRV block
*/
#if TVM_VULKAN_RUNTIME

// Use libspirv for parsing and validating code.
#include <vulkan/libspirv.h>
#include <dmlc/memory_io.h>
Expand Down Expand Up @@ -92,4 +90,3 @@ TVM_REGISTER_API("codegen.build_vulkan")

} // namespace codegen
} // namespace tvm
#endif // TVM_VULKAN_RUNTIME
5 changes: 0 additions & 5 deletions src/codegen/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
* \file codegen_spirv.cc
* \brief Generate SPIRV block
*/

#if TVM_VULKAN_RUNTIME

#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "../codegen_common.h"
Expand Down Expand Up @@ -634,5 +631,3 @@ void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {

} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
13 changes: 9 additions & 4 deletions src/codegen/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
* Copyright (c) 2017 by Contributors
* \file intrin_rule_spirv.cc
*/
#if TVM_VULKAN_RUNTIME

#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
#include <vulkan/GLSL.std.450.h>
Expand Down Expand Up @@ -31,6 +29,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

Expand All @@ -43,8 +47,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);

} // namespace spirv
} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
5 changes: 0 additions & 5 deletions src/codegen/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
* \file ir_builder.cc
* \brief IRBuilder for SPIRV block
*/

#if TVM_VULKAN_RUNTIME

#include "./ir_builder.h"

namespace tvm {
Expand Down Expand Up @@ -555,5 +552,3 @@ Value IRBuilder::Select(Value cond, Value a, Value b) {
} // namespace spirv
} // namespace codegen
} // namespace tvm

#endif // TVM_VULKAN_RUNTIME
2 changes: 2 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);

/*!
* \brief Creates an operation that returns identity of a given tensor
Expand Down
34 changes: 34 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,40 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.floor(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def ceil(x):
"""Take ceil of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x):
"""Take logarithm of input x.
Expand Down
18 changes: 10 additions & 8 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ def test_ewise():

shape = (20, 3)

def test_apply(func, name, f_numpy):
def test_apply(func, name, f_numpy, low, high):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
a_np = np.random.uniform(low=1e-5, size=shape).astype(A.dtype)
a_np = np.abs(a_np)
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
b_np = f_numpy(a_np)

def check_device(device):
Expand All @@ -43,11 +42,14 @@ def check_device(device):
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm']:
check_device(device)

test_apply(topi.exp, "exp", np.exp)
test_apply(topi.tanh, "tanh", np.tanh)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)))
test_apply(topi.log, "log", np.log)
test_apply(topi.sqrt, "sqrt", np.sqrt)

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)

if __name__ == "__main__":
test_util()
Expand Down

0 comments on commit bd8cb9b

Please sign in to comment.