Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MATH][TOPI][NNVM] introduce trunc, round #1310

Merged
merged 1 commit into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/api/python/intrin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ tvm.intrin
tvm.register_intrin_rule
tvm.exp
tvm.log
tvm.floor
tvm.ceil
tvm.trunc
tvm.round


.. autofunction:: tvm.call_packed
Expand All @@ -18,3 +22,7 @@ tvm.intrin
.. autofunction:: tvm.register_intrin_rule
.. autofunction:: tvm.exp
.. autofunction:: tvm.log
.. autofunction:: tvm.floor
.. autofunction:: tvm.ceil
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
8 changes: 8 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ List of operators

topi.identity
topi.negative
topi.floor
topi.ceil
topi.trunc
topi.round
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -68,6 +72,10 @@ topi
~~~~
.. autofunction:: topi.negative
.. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
8 changes: 8 additions & 0 deletions docs/nnvm_top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ This level enables typical convnet models.
nnvm.symbol.reshape
nnvm.symbol.copy
nnvm.symbol.negative
nnvm.symbol.floor
nnvm.symbol.ceil
nnvm.symbol.round
nnvm.symbol.trunc
nnvm.symbol.leaky_relu
nnvm.symbol.__add_scalar__
nnvm.symbol.__sub_scalar__
Expand Down Expand Up @@ -147,6 +151,10 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.floor
.. autofunction:: nnvm.symbol.ceil
.. autofunction:: nnvm.symbol.round
.. autofunction:: nnvm.symbol.trunc
.. autofunction:: nnvm.symbol.leaky_relu
.. autofunction:: nnvm.symbol.__add_scalar__
.. autofunction:: nnvm.symbol.__sub_scalar__
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);

inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
Expand Down
16 changes: 16 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def compute_cast(attrs, inputs, _):
reg.register_pattern("cast", OpPattern.ELEMWISE)
reg.register_schedule("cast", _fschedule_broadcast)

# floor
reg.register_pattern("floor", OpPattern.ELEMWISE)
reg.register_schedule("floor", _fschedule_broadcast)

# ceil
reg.register_pattern("ceil", OpPattern.ELEMWISE)
reg.register_schedule("ceil", _fschedule_broadcast)

# round
reg.register_pattern("round", OpPattern.ELEMWISE)
reg.register_schedule("round", _fschedule_broadcast)

# trunc
reg.register_pattern("trunc", OpPattern.ELEMWISE)
reg.register_schedule("trunc", _fschedule_broadcast)

# exp
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)
Expand Down
48 changes: 48 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,54 @@ Used to produce invalide node during optimization.
.set_num_outputs(1)
.set_num_inputs(0);

// floor
NNVM_REGISTER_ELEMWISE_UNARY_OP(floor)
.describe(R"code(Take floor input array, computed element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::floor(inputs[0]) };
});

// ceil
NNVM_REGISTER_ELEMWISE_UNARY_OP(ceil)
.describe(R"code(Take ceil input array, computed element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::ceil(inputs[0]) };
});

// trunc
NNVM_REGISTER_ELEMWISE_UNARY_OP(trunc)
.describe(R"code(Take truncated value of the input, element-wise.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::trunc(inputs[0]) };
});

// round
NNVM_REGISTER_ELEMWISE_UNARY_OP(round)
.describe(R"code(Round elements of the input to nearest integer.
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::round(inputs[0]) };
});

// sigmoid
NNVM_REGISTER_ELEMWISE_UNARY_OP(sigmoid)
.describe(R"code(Computes sigmoid.
Expand Down
36 changes: 36 additions & 0 deletions nnvm/tests/python/compiler/test_top_level3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
from test_top_level1 import helper

def check_map(symfunc, np_func, np_backward=None):
x = sym.Variable("x")
y = symfunc(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, lambda x: np_func(x), np_backward)


def test_floor():
check_map(sym.floor, np.floor)

def test_ceil():
check_map(sym.ceil, np.ceil)

def test_trunc():
check_map(sym.trunc, np.trunc)

def test_round():
check_map(sym.round, np.round)


if __name__ == "__main__":
test_floor()
test_ceil()
test_round()
test_trunc()
36 changes: 36 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Expression Intrinsics and math functions in TVM."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs

from ._ffi.function import register_func as _register_func
Expand Down Expand Up @@ -265,6 +266,41 @@ def ceil(x):
return call_pure_intrin(x.dtype, "ceil", x)


def trunc(x):
"""Get truncated value of the input.

The truncated value of the scalar x is the
nearest integer i which is closer to zero than x is.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "trunc", x)


def round(x):
"""Round elements of the array to the nearest integer.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "round", x)


def power(x, y):
"""x power y

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round")
.set_body(DispatchExtern<Direct>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp")
.set_body(DispatchExtern<Direct>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

Expand Down
2 changes: 2 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);

/*!
* \brief Creates an operation that returns identity of a given tensor
Expand Down
35 changes: 35 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Elementwise operators"""
# pylint: disable=redefined-builtin
from __future__ import absolute_import as _abs
import tvm
from . import tag
Expand Down Expand Up @@ -107,6 +108,40 @@ def ceil(x):
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
"""Take truncated value of the input of x, element-wise.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.trunc(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.round(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def log(x):
"""Take logarithm of input x.
Expand Down
4 changes: 3 additions & 1 deletion topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def check_device(device):
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo = tvm.build(s, [A, B], device, name=name)
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

Expand All @@ -45,6 +45,8 @@ def check_device(device):

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
Expand Down