Skip to content

Commit

Permalink
add support for mxnet smooth_l1 (apache#2905)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and MarisaKirisame committed Apr 9, 2019
1 parent 823c327 commit f9e7d10
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,15 @@ def _mx_embedding(inputs, _):
return _op.take(weight, indices.astype('int32'), axis=0)


def _mx_smooth_l1(inputs, attrs):
scalar = attrs.get_float("scalar", 1.0)
scalar_sq = scalar * scalar
mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype='float32'))
return _op.where(mask,
_expr.const(scalar_sq / 2.0, dtype='float32') * inputs[0] * inputs[0],
_op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -729,6 +738,7 @@ def _mx_embedding(inputs, _):
"Embedding" : _mx_embedding,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
# vision
"_contrib_BilinearResize2D" : _mx_upsampling,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,14 @@ def verify(data_shape, weight_shape):
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))


def test_forward_smooth_l1():
data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -498,3 +506,4 @@ def verify(data_shape, weight_shape):
test_forward_broadcast_axis()
test_forward_full()
test_forward_embedding()
test_forward_smooth_l1()

0 comments on commit f9e7d10

Please sign in to comment.