Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][MXNET][FRONTEND] add support for MXNET numpy operators #6054

Merged
merged 19 commits into from
Aug 21, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,93 @@ def impl(inputs, input_types):
return impl


def _mx_npi_transpose(inputs, attrs):
axes = attrs.get_int_tuple("axes", None)
# translate default case
axes = None if len(axes) == 0 or axes[0] is None else axes
return _op.transpose(inputs[0], axes=axes)


def _mx_npi_pad(inputs, attrs):
pad_mode = attrs.get_str('mode', None)
if pad_mode is None:
raise tvm.error.OpAttributeRequired(
'Attribute "mode" not found in operator pad.')
if pad_mode not in ['constant', 'edge', 'reflect']:
raise tvm.error.OpAttributeInvalid(
'Value ' + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple('pad_width', None)
if pad_width is None:
raise tvm.error.OpAttributeRequired(
'Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.')
constant_values = attrs.get_float('constant_values', 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))

return _op.nn.pad(data=inputs[0],
pad_width=padding,
pad_value=constant_values,
pad_mode=pad_mode)


def _mx_npi_concatenate(inputs, attrs):
axis = attrs.get_str("axis", "0")
if axis == "None":
return _op.reshape(_op.concatenate(tuple(inputs), axis=0), (-1,))
else:
return _op.concatenate(tuple(inputs), axis=int(axis))


def _mx_npx_reshape(inputs, attrs):
shape = attrs.get_int_tuple("newshape")
reverse = attrs.get_bool("reverse", False)
shape_list = list(shape)
new_shape_list = []
for num in shape_list:
if num > 0 or num == -1:
new_shape_list.append(num)
elif num == -2:
new_shape_list.append(0)
elif num == -4:
new_shape_list.append(-2)
elif num == -5:
new_shape_list.append(-3)
elif num == -6:
new_shape_list.append(-4)
else:
raise tvm.error.OpAttributeInvalid(
'Shape dimension %d is not supported' % num)
shape = tuple(new_shape_list)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)


def _mx_split_v2(inputs, attrs):
axis = attrs.get_int("axis")
indices = list(attrs.get_int_tuple("indices", []))
# remove the prefix '0'
if len(indices) != 0 and indices[0] == 0:
indices.remove(0)
sections = attrs.get_int("sections", 0)
indices_or_sections = list(indices) if len(indices) != 0 else sections
res = _op.split(inputs[0], indices_or_sections=indices_or_sections, axis=axis)
if attrs.get_bool("squeeze_axis", False):
res = tuple([_op.squeeze(x, axis=[axis]) for x in res])
return res


def _mx_npi_where_rscalar(inputs, attrs):
scalar = attrs.get_float("scalar")
dtype = _infer_type(inputs[1]).checked_type.dtype
scalar = _expr.const(scalar, dtype=dtype)
ones = _op.ones_like(inputs[1])
scalar = _op.multiply(ones, scalar)
return _op.where(inputs[0], inputs[1], scalar)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -2226,6 +2313,7 @@ def impl(inputs, input_types):
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"_split_v2" : _mx_split_v2,
"SwapAxis" : _mx_swap_axis,
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
Expand Down Expand Up @@ -2304,6 +2392,23 @@ def impl(inputs, input_types):
"_contrib_quantized_pooling": _qnn_pooling,
"_contrib_quantized_batch_norm" : _qnn_batch_norm,
"_sg_mkldnn_fully_connected": _qnn_fully_connected,
# numpy
"_np_transpose" : _mx_npi_transpose,
"_npi_transpose" : _mx_npi_transpose,
"_npi_pad" : _mx_npi_pad,
"_npi_concatenate" : _mx_npi_concatenate,
"_npx_reshape" : _mx_npx_reshape,
"_np_copy" : _rename(_op.copy),
"_npi_power" : _rename(_op.power),
"_npi_power_scalar" : _binop_scalar(_op.power),
"_npi_multiply" : _rename(_op.multiply),
"_npi_multiply_scalar" : _binop_scalar(_op.multiply),
"_npi_add" : _rename(_op.add),
"_npi_add_scalar" : _binop_scalar(_op.add),
"_npi_where_rscalar" : _mx_npi_where_rscalar,
"_npi_less" : _rename(_op.less),
"_npi_tanh" : _rename(_op.tanh),
"_npi_true_divide_scalar" : _binop_scalar(_op.divide),
}

# set identity list
Expand Down
Loading