diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 1a8d2cea9cd6..24bb727e6401 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -65,19 +65,47 @@ def sample_multinomial(attrs, inputs, proto_obj): # Arithmetic Operations def add(attrs, inputs, proto_obj): """Adding two tensors""" - return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_add') + new_attr = {} + + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + op_value = translation_utils._fix_broadcast('broadcast_add', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + return 'broadcast_add', new_attr, inputs def subtract(attrs, inputs, proto_obj): """Subtracting two tensors""" - return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_sub') + new_attr = {} + + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + op_value = translation_utils._fix_broadcast('broadcast_sub', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + return 'broadcast_sub', new_attr, inputs def multiply(attrs, inputs, proto_obj): """Multiply two tensors""" - return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_mul') + new_attr = {} + + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + op_value = translation_utils._fix_broadcast('broadcast_mul', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + return 'broadcast_mul', new_attr, inputs def divide(attrs, inputs, proto_obj): """Divide two tensors""" - return translation_utils.broadcast_arithmetic_helper(attrs, inputs, proto_obj, 'broadcast_div') + new_attr = {} + + if 'broadcast' in attrs and attrs['broadcast'] == 1: + broadcast_axis = attrs['axis'] + op_value = translation_utils._fix_broadcast('broadcast_div', inputs, + broadcast_axis, proto_obj) + return op_value, new_attr, inputs + return 'broadcast_div', new_attr, inputs def mean(attrs, inputs, proto_obj): """Mean of all the input tensors.""" diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index 0c6730513d4b..ce55a0b7d66a 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -245,17 +245,3 @@ def get_input_shape(sym, proto_obj): result = mod.get_outputs()[0].asnumpy() return result.shape - -def broadcast_arithmetic_helper(attrs, inputs, proto_obj, current_op_name): - """Helper function for broadcast arithmetic ops.""" - new_attr = {} - op_names = ['batchnorm, convolution, deconvolution'] - if 'broadcast' in attrs and attrs['broadcast'] == 1: - broadcast_axis = attrs['axis'] - for op_name in op_names: - # if input is bias which comes after conv, deconv, batchnorm operators - # then only reshape bias term - if inputs[0].name.startswith(op_name): - op_value = _fix_broadcast(current_op_name, inputs, broadcast_axis, proto_obj) - return op_value, new_attr, inputs - return current_op_name, new_attr, inputs