Skip to content

Commit

Permalink
[Relay][TOPI] Add rsqrt operator (apache#2949)
Browse files Browse the repository at this point in the history
  • Loading branch information
makihiro authored and wweic committed May 13, 2019
1 parent 28664b2 commit 4ee0381
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ List of operators
topi.tanh
topi.log
topi.sqrt
topi.rsqrt
topi.sigmoid
topi.clip
topi.cast
Expand Down Expand Up @@ -122,6 +123,7 @@ topi
.. autofunction:: topi.tanh
.. autofunction:: topi.log
.. autofunction:: topi.sqrt
.. autofunction:: topi.rsqrt
.. autofunction:: topi.sigmoid
.. autofunction:: topi.clip
.. autofunction:: topi.cast
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron.

tvm.relay.log
tvm.relay.sqrt
tvm.relay.rsqrt
tvm.relay.exp
tvm.relay.sigmoid
tvm.relay.add
Expand Down Expand Up @@ -186,6 +187,7 @@ Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
.. autofunction:: tvm.relay.sqrt
.. autofunction:: tvm.relay.rsqrt
.. autofunction:: tvm.relay.exp
.. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.add
Expand Down
1 change: 1 addition & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);

Expand Down
17 changes: 17 additions & 0 deletions python/tvm/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
return numpy.zeros(shape).astype(dtype)


def rsqrt(x):
"""
Computes reciprocal of square root of x element-wise
Parameters
----------
x: Tensor
Returns
-------
res: Tensor
The result of reciprocal of square root of x
"""
return numpy.ones_like(x) / numpy.sqrt(x)


def popcount(x):
"""
Count ones in the binary representation of number x
Expand Down Expand Up @@ -103,6 +119,7 @@ def max_num_threads(allow_none=True):
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def log(x):


def sqrt(x):
"""Take log of input x.
"""Take square root of input x.
Parameters
----------
Expand All @@ -275,6 +275,22 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)


def rsqrt(x):
"""Take reciprocal of square root of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "rsqrt", x)


def floor(x):
"""Take floor of float input x.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
register_schedule("log", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def sqrt(data):
return _make.sqrt(data)


def rsqrt(data):
"""Compute elementwise rsqrt of data.
.. math::
1/sqrt(x)
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.rsqrt(data)


def sigmoid(data):
"""Compute elementwise sigmoid of data.
Expand Down
10 changes: 10 additions & 0 deletions src/codegen/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);

auto one = make_const(call->args[0].type(), 1);
*rv = one / sqrt(call->args[0]);
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
11 changes: 10 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));

RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.describe(R"code(Returns the sqrt input array, computed element-wise.
.. math::
sqrt(x)
Expand All @@ -73,6 +73,15 @@ RELAY_REGISTER_UNARY_OP("sqrt")
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));

RELAY_REGISTER_UNARY_OP("rsqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.. math::
1/sqrt(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));

RELAY_REGISTER_UNARY_OP("zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_ir_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test(x):
def test_op_level1():
x = relay.Var("x")

for op_name in ["log", "exp", "sqrt", "tanh"]:
for op_name in ["log", "exp", "sqrt", "rsqrt","tanh"]:
y = getattr(relay, op_name)(x)
assert y.op.name == op_name
assert y.op.support_level == 1
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def relu(x):
np.maximum(x_copy, 0, x_copy)
return x_copy

def rsqrt(x):
one = np.ones_like(x)
return one / np.sqrt(x)

def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
Expand Down Expand Up @@ -57,6 +61,7 @@ def check_single_op(opfunc, ref):
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
Expand Down
18 changes: 18 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,24 @@ inline Tensor sign(const Tensor& x,
}, name, tag);
}

/*!
* \brief Creates an operation that returns rsqrt of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the rsqrt operation
*/
inline Tensor rsqrt(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr one = make_const(x->dtype, 1);
return one/tvm::sqrt(x(i));
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ def sqrt(x):
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def rsqrt(x):
"""Take inverse square root of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.rsqrt(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def sigmoid(x):
"""Take sigmoid tanh of input x.
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ TVM_REGISTER_GLOBAL("topi.sqrt")
*rv = sqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_apply(func, name):
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def check_device(device):
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)

if __name__ == "__main__":
test_util()
Expand Down

0 comments on commit 4ee0381

Please sign in to comment.