Skip to content

Commit

Permalink
Add support for absolute opeartion (apache#1406)
Browse files Browse the repository at this point in the history
  • Loading branch information
PariksheetPinjari909 authored and sergei-mironov committed Aug 8, 2018
1 parent 608e687 commit 83151bd
Show file tree
Hide file tree
Showing 19 changed files with 109 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/api/python/intrin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tvm.intrin
tvm.ceil
tvm.trunc
tvm.round

tvm.abs

.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin
Expand All @@ -26,3 +26,4 @@ tvm.intrin
.. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
.. autofunction:: tvm.abs
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ List of operators
topi.ceil
topi.trunc
topi.round
topi.abs
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -84,6 +85,7 @@ topi
.. autofunction:: topi.ceil
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
2 changes: 2 additions & 0 deletions docs/nnvm_top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ This level enables typical convnet models.
nnvm.symbol.ceil
nnvm.symbol.round
nnvm.symbol.trunc
nnvm.symbol.abs
nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__
Expand Down Expand Up @@ -157,6 +158,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.abs
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
Expand Down
21 changes: 20 additions & 1 deletion include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using HalideIR::likely_if_innermost;
using HalideIR::cast;
using HalideIR::min;
using HalideIR::max;
using HalideIR::abs;
using HalideIR::select;

/*!
Expand Down Expand Up @@ -71,6 +70,26 @@ inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}

/*!
* \brief Calculate absolute value of x, elementwise
* \param x The input data
*
* \return The aboslute value of input data x
*/
inline Expr abs(Expr x) {
if (x.type().is_int()) {
return select(x >= make_zero(x.type()), x, -x);
} else if (x.type().is_float()) {
return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
} else if (x.type().is_uint()) {
return x;
} else {
LOG(WARNING) << "Warning: Data type " << x.type()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}

} // namespace tvm

#endif // TVM_IR_OPERATOR_H_
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _compute(attrs, x, _):
reg.register_pattern("round", OpPattern.ELEMWISE)
reg.register_schedule("round", _fschedule_broadcast)

# abs
reg.register_pattern("abs", OpPattern.ELEMWISE)
reg.register_schedule("abs", _fschedule_broadcast)

# trunc
reg.register_pattern("trunc", OpPattern.ELEMWISE)
reg.register_schedule("trunc", _fschedule_broadcast)
Expand Down
12 changes: 12 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(round)
return Array<Tensor>{ topi::round(inputs[0]) };
});

// abs
NNVM_REGISTER_ELEMWISE_UNARY_OP(abs)
.describe(R"code(Take absolute value of elements of the input.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::abs(inputs[0]) };
});

// sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid.
Expand Down
5 changes: 5 additions & 0 deletions nnvm/tests/python/compiler/test_top_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def test_trunc():
def test_round():
check_map(sym.round, np.round)

def test_abs():
check_map(sym.abs, np.abs)
check_map(sym.abs, np.abs, dtype = "int32")
check_map(sym.abs, np.abs, dtype = "int8")

def test_shift():
n = 3
Expand All @@ -40,4 +44,5 @@ def test_shift():
test_floor()
test_ceil()
test_round()
test_abs()
test_trunc()
16 changes: 16 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def trunc(x):
return call_pure_intrin(x.dtype, "trunc", x)


def abs(x):
"""Get absolute value of the input element-wise.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.abs(x)


def round(x):
"""Round elements of the array to the nearest integer.
Expand Down
6 changes: 6 additions & 0 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <ir/IROperator.h>
#include <tvm/api_registry.h>
#include <tvm/ir_operator.h>

namespace tvm {
namespace ir {
Expand All @@ -16,6 +17,11 @@ TVM_REGISTER_API("_Var")
*ret = Variable::make(args[1], args[0]);
});

TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::abs(args[0]);
});

TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
.set_body(DispatchExtern<Direct>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
.set_body(DispatchExtern<Direct>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs")
.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

Expand Down
1 change: 1 addition & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);

/*!
* \brief Creates an operation that returns identity of a given tensor
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def trunc(x):
return tvm.compute(x.shape, lambda *i: tvm.trunc(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def abs(x):
"""Take absolute value of the input of x, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def check_device(device):
test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
Expand Down

0 comments on commit 83151bd

Please sign in to comment.