From 96d027454d7833a8243964f8eed8fe156c0537c7 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 13 May 2020 22:04:16 +0530 Subject: [PATCH] [MXNET]abs, round, reciprocal, sign, softsign, hard_sigmoid --- python/tvm/relay/frontend/mxnet.py | 19 +++++++++++++++++++ tests/python/frontend/mxnet/test_forward.py | 4 +++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4cb7a2a75bad..4c3144c4382a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -789,6 +789,19 @@ def _mx_l2_normalize(inputs, attrs): return _op.nn.l2_normalize(inputs[0], **new_attrs) +def _mx_softsign(inputs, attrs): + return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0])) + + +def _mx_hard_sigmoid(inputs, attrs): + x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5) + return _op.clip(x, a_min=0.0, a_max=1.0) + + +def _mx_reciprocal(inputs, attrs): + return _expr.const(1.0) /inputs[0] + + def _mx_shape_array(inputs, attrs): assert len(inputs) == 1 if attrs.get_int("lhs_begin", None) is not None: @@ -1742,12 +1755,15 @@ def impl(inputs, input_types): # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ + "abs", "log", "exp", "erf", "sqrt", "floor", "ceil", + "round", + "sign", "sigmoid", "negative", "reshape_like", @@ -1856,6 +1872,9 @@ def impl(inputs, input_types): "softmax" : _softmax_op(_op.nn.softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax), "Softmax" : _softmax_op(_op.nn.softmax), + "softsign" : _mx_softsign, + "hard_sigmoid" : _mx_hard_sigmoid, + "reciprocal" : _mx_reciprocal, # per op specialization "Reshape" : _reshape, "reshape" : _reshape, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 3fb8e30acb88..9dd85065c885 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -365,7 +365,9 @@ def test_forward_elemwise_ops(): def test_forward_unary_ops(): - for op in ["cos", "sin", "tan", + for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", + "softsign", "hard_sigmoid", + "cos", "sin", "tan", "cosh", "sinh", "tanh", "arccos", "arcsin", "arctan", "arccosh", "arcsinh", "arctanh"]: