From 717b019dff42dca6b0e0407dc02048752880f36a Mon Sep 17 00:00:00 2001 From: SasakiSaki Date: Tue, 12 Mar 2019 05:57:59 +0800 Subject: [PATCH] [Relay] Improve more operator mxnet frontend importer (#2772) --- python/tvm/relay/frontend/mxnet.py | 72 ++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1585d55ac1b9..93bd8efc6752 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -298,6 +298,51 @@ def _mx_leaky_relu(inputs, attrs): raise RuntimeError("act_type: {} is not supported".format(act_type)) +def _mx_make_power(power): + def _impl(inputs, _): # Note: no attrs + assert len(inputs) == 1 + scalar = _expr.const(power, dtype=None) + # Note: int maps to "int32", float maps to "float32" + return _op.power(inputs[0], scalar) + return _impl + + +def _mx_make_exponent(base): + # exp(b, x) = e^b * e^x + def _impl(inputs, _): # Note: no attrs + assert len(inputs) == 1 + scalar = _op.exp(_expr.const(base, dtype="float32")) + return _op.multiply(inputs[0], scalar) + return _impl + + +def _mx_make_logarithm(base): + # log(b, x) = log(x) / log(b) + def _impl(inputs, _): # Note: no attrs + assert len(inputs) == 1 + scalar = _op.log(_expr.const(base, dtype="float32")) + return _op.divide(inputs[0], scalar) + return _impl + + +def _mx_expm1(): + # exp_minus_1 x = exp(x) - 1 + def _impl(inputs, _): # Note: no attrs + assert len(inputs) == 1 + one = _expr.const(1, dtype="float32") + return _op.log(_op.subtract(inputs[0], one)) + return _impl + + +def _mx_log1p(): + # 1_plus_log x = log(x + 1) + def _impl(inputs, _): # Note: no attrs + assert len(inputs) == 1 + one = _expr.const(1, dtype="float32") + return _op.log(_op.add(inputs[0], one)) + return _impl + + def _mx_lrn(inputs, attrs): new_attrs = {} new_attrs["alpha"] = attrs.get_float("alpha", 0.0001) @@ -450,7 +495,6 @@ def _mx_l2_normalize(inputs, attrs): "exp", "sigmoid", "tanh", - "exp", "negative", "reshape_like", "zeros_like", @@ -482,6 +526,20 @@ def _mx_l2_normalize(inputs, attrs): "_minimum" : _rename(_op.minimum), "flatten" : _rename(_op.nn.batch_flatten), "Flatten" : _rename(_op.nn.batch_flatten), + # scalar power + "square" : _mx_make_power(2), + "sqrt" : _mx_make_power(1/2), + "rsqrt" : _mx_make_power(-1/2), + "cbrt" : _mx_make_power(1/3), + "rcbrt" : _mx_make_power(-1/3), + "__pow_scalar__" : _binop_scalar(_op.power), + "_power_scalar" : _binop_scalar(_op.power), + "__rsub_scalar__" : _rbinop_scalar(_op.subtract), + "_rminus_scalar" : _rbinop_scalar(_op.subtract), + "__rdiv_scalar__" : _rbinop_scalar(_op.divide), + "_rdiv_scalar" : _rbinop_scalar(_op.divide), + "__rpow_scalar__" : _rbinop_scalar(_op.power), + # scalar op "__add_scalar__" : _binop_scalar(_op.add), "_plus_scalar" : _binop_scalar(_op.add), "__sub_scalar__" : _binop_scalar(_op.subtract), @@ -490,13 +548,10 @@ def _mx_l2_normalize(inputs, attrs): "_mul_scalar" : _binop_scalar(_op.multiply), "__div_scalar__" : _binop_scalar(_op.divide), "_div_scalar" : _binop_scalar(_op.divide), - "__pow_scalar__" : _binop_scalar(_op.power), - "_power_scalar" : _binop_scalar(_op.power), - "__rsub_scalar__" : _rbinop_scalar(_op.subtract), - "_rminus_scalar" : _rbinop_scalar(_op.subtract), - "__rdiv_scalar__" : _rbinop_scalar(_op.divide), - "_rdiv_scalar" : _rbinop_scalar(_op.divide), - "__rpow_scalar__" : _rbinop_scalar(_op.power), + "log2" : _mx_make_logarithm(2), + "log10" : _mx_make_logarithm(10), + "log1p" : _mx_log1p, + "expm1" : _mx_expm1, "_equal_scalar" : _mx_compare(_op.equal, _binop_scalar), "_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar), "_greater_scalar" : _mx_compare(_op.greater, _binop_scalar), @@ -506,6 +561,7 @@ def _mx_l2_normalize(inputs, attrs): "_maximum_scalar" : _binop_scalar(_op.maximum), "_minimum_scalar" : _binop_scalar(_op.minimum), # reduction ops + "mean" : _reduce(_op.mean), "max" : _reduce(_op.max), "min" : _reduce(_op.min), "sum" : _reduce(_op.sum),