Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][TOPI] Add rsqrt operator #2949

Merged
merged 9 commits into from
Apr 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ List of operators
topi.tanh
topi.log
topi.sqrt
topi.rsqrt
topi.sigmoid
topi.clip
topi.cast
Expand Down Expand Up @@ -105,6 +106,7 @@ topi
.. autofunction:: topi.tanh
.. autofunction:: topi.log
.. autofunction:: topi.sqrt
.. autofunction:: topi.rsqrt
.. autofunction:: topi.sigmoid
.. autofunction:: topi.clip
.. autofunction:: topi.cast
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ This level enables fully connected multi-layer perceptron.

tvm.relay.log
tvm.relay.sqrt
tvm.relay.rsqrt
tvm.relay.exp
tvm.relay.sigmoid
tvm.relay.add
Expand Down Expand Up @@ -168,6 +169,7 @@ Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
.. autofunction:: tvm.relay.sqrt
.. autofunction:: tvm.relay.rsqrt
.. autofunction:: tvm.relay.exp
.. autofunction:: tvm.relay.sigmoid
.. autofunction:: tvm.relay.add
Expand Down
1 change: 1 addition & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);

Expand Down
17 changes: 17 additions & 0 deletions python/tvm/hybrid/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
return numpy.zeros(shape).astype(dtype)


def rsqrt(x):
"""
Computes reciprocal of square root of x element-wise

Parameters
----------
x: Tensor

Returns
-------
res: Tensor
The result of reciprocal of square root of x
"""
return numpy.ones_like(x) / numpy.sqrt(x)


def popcount(x):
"""
Count ones in the binary representation of number x
Expand Down Expand Up @@ -87,6 +103,7 @@ def max_num_threads(allow_none=True):
'allocate' : allocate,
'output_tensor' : allocate,
'sqrt' : numpy.sqrt,
'rsqrt' : rsqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def log(x):


def sqrt(x):
"""Take log of input x.
"""Take square root of input x.

Parameters
----------
Expand All @@ -259,6 +259,22 @@ def sqrt(x):
return call_pure_intrin(x.dtype, "sqrt", x)


def rsqrt(x):
"""Take reciprocal of square root of input x.

Parameters
----------
x : Expr
Input argument.

Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "rsqrt", x)


def floor(x):
"""Take floor of float input x.

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
register_schedule("log", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
register_schedule("sigmoid", schedule_broadcast)
register_schedule("floor", schedule_broadcast)
register_schedule("ceil", schedule_broadcast)
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ def sqrt(data):
return _make.sqrt(data)


def rsqrt(data):
"""Compute elementwise rsqrt of data.
makihiro marked this conversation as resolved.
Show resolved Hide resolved

.. math::

1/sqrt(x)
makihiro marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.rsqrt(data)


def sigmoid(data):
"""Compute elementwise sigmoid of data.

Expand Down
10 changes: 10 additions & 0 deletions src/codegen/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);

auto one = make_const(call->args[0].type(), 1);
*rv = one / sqrt(call->args[0]);
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
11 changes: 10 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));

RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.
.describe(R"code(Returns the sqrt input array, computed element-wise.

.. math::
sqrt(x)
Expand All @@ -54,6 +54,15 @@ RELAY_REGISTER_UNARY_OP("sqrt")
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));

RELAY_REGISTER_UNARY_OP("rsqrt")
.describe(R"code(Returns the rsqrt input array, computed element-wise.

.. math::
1/sqrt(x)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt));

RELAY_REGISTER_UNARY_OP("zeros_like")
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_ir_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test(x):
def test_op_level1():
x = relay.Var("x")

for op_name in ["log", "exp", "sqrt", "tanh"]:
for op_name in ["log", "exp", "sqrt", "rsqrt","tanh"]:
y = getattr(relay, op_name)(x)
assert y.op.name == op_name
assert y.op.support_level == 1
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def relu(x):
np.maximum(x_copy, 0, x_copy)
return x_copy

def rsqrt(x):
one = np.ones_like(x)
return one / np.sqrt(x)

def test_unary_op():
def check_single_op(opfunc, ref):
shape = (10, 4)
Expand Down Expand Up @@ -41,6 +45,7 @@ def check_single_op(opfunc, ref):
for opfunc, ref in [(tvm.relay.log, np.log),
(tvm.relay.exp, np.exp),
(tvm.relay.sqrt, np.sqrt),
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
Expand Down
18 changes: 18 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ inline Tensor sign(const Tensor& x,
}, name, tag);
}

/*!
* \brief Creates an operation that returns rsqrt of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the rsqrt operation
*/
inline Tensor rsqrt(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr one = make_const(x->dtype, 1);
return one/tvm::sqrt(x(i));
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ def sqrt(x):
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def rsqrt(x):
"""Take inverse square root of input x.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.rsqrt(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def sigmoid(x):
"""Take sigmoid tanh of input x.
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ TVM_REGISTER_GLOBAL("topi.sqrt")
*rv = sqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});

TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_apply(func, name):
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def check_device(device):
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)

if __name__ == "__main__":
test_util()
Expand Down