From c299b68d6c0077ccbde1408536152b4277bd3118 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Thu, 14 Mar 2019 10:44:26 +0530 Subject: [PATCH] Support for sign (#2775) --- docs/api/python/topi.rst | 2 ++ docs/langref/relay_op.rst | 2 ++ python/tvm/relay/op/_tensor.py | 1 + python/tvm/relay/op/tensor.py | 14 ++++++++++++++ src/relay/op/tensor/unary.cc | 10 ++++++++++ tests/python/relay/test_op_level3.py | 3 ++- topi/include/topi/elemwise.h | 22 ++++++++++++++++++++++ topi/python/topi/math.py | 15 +++++++++++++++ topi/src/topi.cc | 5 +++++ topi/tests/python/test_topi_math.py | 6 ++++-- 10 files changed, 77 insertions(+), 3 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 0b9d555ca6fa..84ded7b697d1 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -11,6 +11,7 @@ List of operators topi.negative topi.floor topi.ceil + topi.sign topi.trunc topi.round topi.abs @@ -96,6 +97,7 @@ topi .. autofunction:: topi.identity .. autofunction:: topi.floor .. autofunction:: topi.ceil +.. autofunction:: topi.sign .. autofunction:: topi.trunc .. autofunction:: topi.round .. autofunction:: topi.abs diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 6c30a5d68e72..bbb27ec83b48 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -81,6 +81,7 @@ This level enables additional math and transform operators. tvm.relay.squeeze tvm.relay.floor tvm.relay.ceil + tvm.relay.sign tvm.relay.trunc tvm.relay.clip tvm.relay.round @@ -213,6 +214,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.squeeze .. autofunction:: tvm.relay.floor .. autofunction:: tvm.relay.ceil +.. autofunction:: tvm.relay.sign .. autofunction:: tvm.relay.trunc .. autofunction:: tvm.relay.clip .. autofunction:: tvm.relay.round diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 36dae03d1237..05d9acb27330 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -16,6 +16,7 @@ register_schedule("ceil", schedule_broadcast) register_schedule("trunc", schedule_broadcast) register_schedule("round", schedule_broadcast) +register_schedule("sign", schedule_broadcast) register_schedule("abs", schedule_broadcast) register_schedule("tanh", schedule_broadcast) register_schedule("logical_not", schedule_broadcast) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index ffbc7459648e..d51208d478aa 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -158,6 +158,20 @@ def abs(data): """ return _make.abs(data) +def sign(data): + """Compute element-wise absolute of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sign(data) def tanh(data): """Compute element-wise tanh of data. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 720344c3340d..4befc4b664a8 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -146,6 +146,16 @@ RELAY_REGISTER_UNARY_OP("round") .set_support_level(3) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); +RELAY_REGISTER_UNARY_OP("sign") +.describe(R"code(Returns the sign of input array, computed element-wise. + +.. numpy:: + sign(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); + RELAY_REGISTER_UNARY_OP("abs") .describe(R"code(Returns the abs of input array, computed element-wise. diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eee0bcfab008..ed6a79e82b3f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -25,7 +25,8 @@ def test_unary_identity(): (relay.round, np.round), (relay.abs, np.abs), (relay.copy, None), # np.copy - (relay.negative, np.negative)]: + (relay.negative, np.negative), + (relay.sign, np.sign)]: shape = (8, 9, 4) x = relay.var("x", relay.TensorType(shape, "float32")) y = op(x) diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 40dffa09a9bf..e5d8778041b1 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -88,6 +88,28 @@ inline Tensor logical_not(const Tensor& x, }, name, tag); } +/*! +* \brief Returns the sign of the tensor +* +* \param x The input tensor +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the sign +*/ +inline Tensor sign(const Tensor& x, + std::string name = "tensor", + std::string tag = kElementWise) { + return compute(x->shape, [&](const Array& i) { + Expr zero = make_zero(x->dtype); + Expr one = make_const(x->dtype, 1); + Expr minus_one = make_const(x->dtype, -1); + auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero); + auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1); + return s2; + }, name, tag); +} + /*! * \brief Creates an operation that clips each element of a tensor to * the interval [a_min, a_max] diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index a5d28d351719..faddc8ac8a90 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -3,6 +3,7 @@ from __future__ import absolute_import as _abs import tvm from . import tag +from . import cpp @tvm.tag_scope(tag=tag.ELEMWISE) def identity(x): @@ -107,6 +108,20 @@ def ceil(x): """ return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i))) +def sign(x): + """Returns -1, 0, 1 based on sign of x. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return cpp.sign(x) @tvm.tag_scope(tag=tag.ELEMWISE) def trunc(x): diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 366f835d808d..aed2eab9c6bc 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum") *rv = elemwise_sum(args[0]); }); +TVM_REGISTER_GLOBAL("topi.sign") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = sign(args[0]); + }); + TVM_REGISTER_GLOBAL("topi.full") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = full(args[0], args[1], args[2]); diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 8ba8c6d8d0f4..f2f2471c868f 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -18,10 +18,11 @@ def test_ewise(): shape = (20, 3) - def test_apply(func, name, f_numpy, low, high, check_round=False): + def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False): B = func(A) assert tuple(B.shape) == tuple(A.shape) - assert B.op.body[0].name == name + if not skip_name_check: + assert B.op.body[0].name == name a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 # avoid round check too close to boundary if check_round: @@ -49,6 +50,7 @@ def check_device(device): test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100) + test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) test_apply(topi.trunc, "trunc", np.trunc, -100, 100) test_apply(topi.abs, "fabs", np.abs, -100, 100) test_apply(topi.round, "round", np.round, -100, 100, check_round=True)