Skip to content

Commit

Permalink
[MXNET]Softmin, trunc op support added (#5715)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored Jun 3, 2020
1 parent 4347b41 commit c1f3b2f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,11 @@ def _mx_softsign(inputs, attrs):
return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0]))


def _mx_softmin(inputs, attrs):
axis = attrs.get_int("axis", -1)
return _op.nn.softmax(_op.negative(inputs[0]), axis)


def _mx_hard_sigmoid(inputs, attrs):
x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5)
return _op.clip(x, a_min=0.0, a_max=1.0)
Expand Down Expand Up @@ -1829,6 +1834,7 @@ def impl(inputs, input_types):
"floor",
"ceil",
"round",
"trunc",
"sign",
"sigmoid",
"negative",
Expand Down Expand Up @@ -1938,6 +1944,7 @@ def impl(inputs, input_types):
"log_softmax" : _softmax_op(_op.nn.log_softmax),
"Softmax" : _softmax_op(_op.nn.softmax),
"softsign" : _mx_softsign,
"softmin" : _mx_softmin,
"hard_sigmoid" : _mx_hard_sigmoid,
"reciprocal" : _mx_reciprocal,
# per op specialization
Expand Down
12 changes: 11 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,17 @@ def test_forward_elemwise_ops():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_softmin():
data = mx.sym.var('data')
mx_sym = mx.sym.softmin(data)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))

mx_sym = mx.sym.softmin(data, axis=2)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))


def test_forward_unary_ops():
for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal",
for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc",
"softsign", "hard_sigmoid",
"cos", "sin", "tan",
"cosh", "sinh", "tanh",
Expand Down Expand Up @@ -1191,6 +1200,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_softmin()
test_forward_fc_flatten()
test_forward_clip()
test_forward_split()
Expand Down

0 comments on commit c1f3b2f

Please sign in to comment.